Add everything

This commit is contained in:
dx-tan 2023-11-30 18:22:16 +07:00
parent 4e83776907
commit 7e9a8e2d4b
277 changed files with 36106 additions and 1 deletions

2
.gitignore vendored
View File

@ -12,3 +12,5 @@ backup/
*.log
__pycache__
migrations/
test/
._git/

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "cope2n-ai-fi/modules/sdsvkvu"]
path = cope2n-ai-fi/modules/sdsvkvu
url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu

View File

@ -0,0 +1,6 @@
[submodule "modules/sdsvkvu"]
path = modules/sdsvkvu
url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu.git
[submodule "modules/ocr_engine"]
path = modules/ocr_engine
url = https://code.sdsdev.co.kr/tuanlv/IDP-BasicOCR.git

7
cope2n-ai-fi/.dockerignore Executable file
View File

@ -0,0 +1,7 @@
.github
.git
.vscode
__pycache__
DataBase/image_temp/
DataBase/json_temp/
DataBase/template.db

21
cope2n-ai-fi/.gitignore vendored Executable file
View File

@ -0,0 +1,21 @@
.vscode
__pycache__
DataBase/image_temp/
DataBase/json_temp/
DataBase/template.db
sdsvtd/
sdsvtr/
sdsvkie/
detectron2/
output/
data/
models/
server/
image_logs/
experiments/
weights/
packages/
tmp_results/
.env
.zip
.json

43
cope2n-ai-fi/Dockerfile Executable file
View File

@ -0,0 +1,43 @@
# FROM thucpd2408/env-cope2n:v1
FROM thucpd2408/env-deskew
COPY ./packages/cudnn-linux*.tar.xz /tmp/cudnn-linux*.tar.xz
RUN tar -xvf /tmp/cudnn-linux*.tar.xz -C /tmp/ \
&& cp /tmp/cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include \
&& cp -P /tmp/cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 \
&& chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn* \
&& rm -rf /tmp/cudnn-*-archive
RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6
# RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
WORKDIR /workspace
COPY ./modules/ocr_engine/externals/ /workspace/cope2n-ai-fi/modules/ocr_engine/externals/
COPY ./modules/ocr_engine/requirements.txt /workspace/cope2n-ai-fi/modules/ocr_engine/requirements.txt
COPY ./modules/sdsvkie/ /workspace/cope2n-ai-fi/modules/sdsvkie/
COPY ./modules/sdsvkvu/ /workspace/cope2n-ai-fi/modules/sdsvkvu/
COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtd && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtr && pip3 install -v -e .
# RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/ && pip3 install -r requirements.txt
RUN cd /workspace/cope2n-ai-fi/modules/sdsvkie && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi/modules/sdsvkvu && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt
RUN rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \
rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \
rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11 && \
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublasLt.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvblas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11
ENV PYTHONPATH="."
CMD [ "sh", "run.sh"]
# CMD ["tail -f > /dev/null"]

21
cope2n-ai-fi/Dockerfile-dev Executable file
View File

@ -0,0 +1,21 @@
FROM thucpd2408/env-cope2n:v1
RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6
WORKDIR /workspace
COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt
COPY ./sdsvkie/ /workspace/cope2n-ai-fi/sdsvkie/
COPY ./sdsvtd /workspace/cope2n-ai-fi/sdsvtd/
COPY ./sdsvtr/ /workspace/cope2n-ai-fi/sdsvtr/
COPY ./models/ /models
RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt
RUN cd /workspace/cope2n-ai-fi/sdsvkie && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi/sdsvtd && pip3 install -v -e .
RUN cd /workspace/cope2n-ai-fi/sdsvtr && pip3 install -v -e .
ENV PYTHONPATH="."
CMD [ "sh", "run.sh"]
# CMD ["tail -f > /dev/null"]

View File

@ -0,0 +1,8 @@
FROM hisiter/fwd_env:1.0.0
RUN apt-get update && apt-get install -y \
libssl-dev \
libxml2-dev \
libxslt-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
RUN pip install paddleocr>=2.0.1

674
cope2n-ai-fi/LICENSE Executable file
View File

@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

7
cope2n-ai-fi/NOTE.md Executable file
View File

@ -0,0 +1,7 @@
#### Environment
```
docker run -itd --privileged --name=TannedCungnoCope2n-ai-fi \
-v /mnt/hdd2T/dxtan/TannedCung/OCR/cope2n-ai-fi:/workspace \
tannedcung/mmocr:latest \
tail -f > /dev/null
```

36
cope2n-ai-fi/README.md Executable file
View File

@ -0,0 +1,36 @@
# AI-core
## Add your files
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
```bash
cd existing_repo
git remote add origin http://code.sdsrv.vn/c-ope2n/ai-core.git
git branch -M main
git push -uf origin main
```
## Develop
Assume you are at root folder with struct:
```bash
.
├── cope2n-ai-fi
├── cope2n-api
├── cope2n-fe
├── .env
└── docker-compose-dev.yml
```
Run: `docker-compose -f docker-compose-dev.yml up --build -d` to bring the project alive
## Environment Variables
Variable | Default | Usage | Note
------------- | ------- | ----- | ----
CELERY_BROKER | $250 | | |
SAP_KIE_MODEL | $80 | | |
FI_KIE_MODEL | $420 | | |

26
cope2n-ai-fi/TODO.md Normal file
View File

@ -0,0 +1,26 @@
## Bring abs path to relative
- [x] save_dir `Kie_Invoice_AP/prediction_sap.py:18`
- [x] detector `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor)
- [x] rotator_version `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor)
- [x] cfg `Kie_Invoice_AP/prediction_fi.py`
- [x] weight `Kie_Invoice_AP/prediction_fi.py`
- [x] save_dir `Kie_Invoice_AP/prediction_fi.py:18`
## Bring abs path to .env
- [x] CELERY_BROKER:
- [ ] SAP_KIE_MODEL: `Kie_Invoice_AP/prediction_sap.py:20` [_NEED_REFACTOR_](#refactor)
- [ ] FI_KIE_MODEL: `Kie_Invoice_AP/prediction_fi.py:20` [_NEED_REFACTOR_](#refactor)
## Possible logic confict
### Refactor
- [ ] Each model should be loaded in a docker container and serve as a service
- [ ] Some files (weights, ...) should be mounted in container in a format for endurability
- [ ] `Kie_Invoice_AP/prediction_fi.py` and `Kie_Invoice_AP/prediction_fi.py` should be merged into a single file as it shared resources with different logic
- [ ] `Kie_Invoice_AP/prediction.py` seems to be the base function, this should act as a proxy which import all other `predict_{anything else}` functions
- [ ] There should be a unique folder to keep all models with different versions then mount as /models in container. Currently, `fi` is loading from `/models/Kie_invoice_fi` while `sap` is loading from `Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20231003-171748`. Another model weight is at `sdsvtd/hub` for unknown reason
- [ ] Env variables should have its description in README
- [ ]

View File

@ -0,0 +1,149 @@
from PIL import Image
import cv2
import numpy as np
from transformers import (
LayoutXLMTokenizer,
LayoutLMv2FeatureExtractor,
LayoutXLMProcessor,
LayoutLMv2ForTokenClassification,
)
from common.utils.word_formation import *
from common.utils.global_variables import *
from common.utils.process_label import *
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["CURL_CA_BUNDLE"] = ""
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# config
IGNORE_KIE_LABEL = "others"
KIE_LABELS = [
"Number",
"Name",
"Birthday",
"Home Town",
"Address",
"Sex",
"Nationality",
"Expiry Date",
"Nation",
"Religion",
"Date Range",
"Issued By",
IGNORE_KIE_LABEL,
"Rank"
]
DEVICE = "cuda:0"
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
# tokenizer = LayoutXLMTokenizer.from_pretrained(
# "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
# )
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
model = LayoutLMv2ForTokenClassification.from_pretrained(
"Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True
).to(
DEVICE
) # TODO FIX this hard code
def load_ocr_labels(list_lines):
words, boxes, labels = [], [], []
for line in list_lines:
for word_group in line.list_word_groups:
for word in word_group.list_words:
xmin, ymin, xmax, ymax = (
word.boundingbox[0],
word.boundingbox[1],
word.boundingbox[2],
word.boundingbox[3],
)
text = word.text
label = "seller_name_value" # TODO ??? fix this
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
if text != " ":
words.append(text)
boxes.append([x1, y1, x2, y2])
labels.append(label)
return words, boxes, labels
def _normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def infer_driving_license(image_crop, list_lines, max_n_words, processor):
# Load inputs
# image = Image.open(image_path)
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
batch_preds, batch_true_boxes = [], []
list_words = []
for i in range(0, len(batch_words), max_n_words):
words = batch_words[i : i + max_n_words]
boxes = batch_boxes[i : i + max_n_words]
boxes_norm = [
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
]
# Preprocess
dummy_word_labels = [0] * len(words)
encoding = processor(
image,
text=words,
boxes=boxes_norm,
word_labels=dummy_word_labels,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
)
# Run model
for k, v in encoding.items():
encoding[k] = v.to(DEVICE)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# Postprocess
is_subword = (
(encoding["labels"] == -100).detach().cpu().numpy()[0]
) # remove padding
true_predictions = [
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
]
true_boxes = (
boxes # TODO check assumption that layourlm do not change box order
)
for i, word in enumerate(words):
bndbox = [int(j) for j in true_boxes[i]]
list_words.append(
Word(
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
)
)
batch_preds.extend(true_predictions)
batch_true_boxes.extend(true_boxes)
batch_preds = np.array(batch_preds)
batch_true_boxes = np.array(batch_true_boxes)
return batch_words, batch_preds, batch_true_boxes, list_words

View File

@ -0,0 +1,145 @@
from PIL import Image
import numpy as np
import cv2
from transformers import LayoutLMv2ForTokenClassification
from common.utils.word_formation import *
from common.utils.global_variables import *
from common.utils.process_label import *
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["CURL_CA_BUNDLE"] = ""
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# config
IGNORE_KIE_LABEL = "others"
KIE_LABELS = [
"Number",
"Name",
"Birthday",
"Home Town",
"Address",
"Sex",
"Nationality",
"Expiry Date",
"Nation",
"Religion",
"Date Range",
"Issued By",
IGNORE_KIE_LABEL
]
DEVICE = "cuda:0"
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
# tokenizer = LayoutXLMTokenizer.from_pretrained(
# "Kie_AHung_ID/model/pretrained/layoutxlm-base/tokenizer",
# model_max_length=MAX_SEQ_LENGTH,
# )
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
model = LayoutLMv2ForTokenClassification.from_pretrained(
"Kie_AHung_ID/model/ID_CARD_145_train_300_val_0.02_char_0.06_word",
num_labels=len(KIE_LABELS),
local_files_only=True,
).to(
DEVICE
) # TODO FIX this hard code
def load_ocr_labels(list_lines):
words, boxes, labels = [], [], []
for line in list_lines:
for word_group in line.list_word_groups:
for word in word_group.list_words:
xmin, ymin, xmax, ymax = (
word.boundingbox[0],
word.boundingbox[1],
word.boundingbox[2],
word.boundingbox[3],
)
text = word.text
label = "seller_name_value" # TODO ??? fix this
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
if text != " ":
words.append(text)
boxes.append([x1, y1, x2, y2])
labels.append(label)
return words, boxes, labels
def _normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def infer_id_card(image_crop, list_lines, max_n_words, processor):
# Load inputs
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
batch_preds, batch_true_boxes = [], []
list_words = []
for i in range(0, len(batch_words), max_n_words):
words = batch_words[i : i + max_n_words]
boxes = batch_boxes[i : i + max_n_words]
boxes_norm = [
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
]
# Preprocess
dummy_word_labels = [0] * len(words)
encoding = processor(
image,
text=words,
boxes=boxes_norm,
word_labels=dummy_word_labels,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
)
# Run model
for k, v in encoding.items():
encoding[k] = v.to(DEVICE)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# Postprocess
is_subword = (
(encoding["labels"] == -100).detach().cpu().numpy()[0]
) # remove padding
true_predictions = [
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
]
true_boxes = (
boxes # TODO check assumption that layourlm do not change box order
)
for i, word in enumerate(words):
bndbox = [int(j) for j in true_boxes[i]]
list_words.append(
Word(
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
)
)
batch_preds.extend(true_predictions)
batch_true_boxes.extend(true_boxes)
batch_preds = np.array(batch_preds)
batch_true_boxes = np.array(batch_true_boxes)
return batch_words, batch_preds, batch_true_boxes, list_words

View File

@ -0,0 +1,57 @@
from sdsvkie import Predictor
import cv2
import numpy as np
import urllib
from common import serve_model
from common import ocr
model = Predictor(
cfg = "./models/kie_invoice/config.yaml",
device = "cuda:0",
weights = "./models/models/kie_invoice/last",
proccessor = serve_model.processor,
ocr_engine = ocr.engine
)
def predict(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
image = cv2.imdecode(arr, -1)
out = model(image)
output = out["end2end_results"]
output_dict = {
"document_type": "invoice",
"fields": []
}
for key in output.keys():
field = {
"label": key,
"value": output[key]['value'] if output[key]['value'] else "",
"box": output[key]['box'],
"confidence": output[key]['conf'],
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict

View File

@ -0,0 +1,83 @@
from common.utils_invoice.run_ocr import ocr_predict
import os
from Kie_Hoanglv.prediction2 import KIEInvoiceInfer
from configs.config_invoice.layoutxlm_base_invoice import *
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import cv2
model = KIEInvoiceInfer(
weight_dir=TRAINED_DIR,
tokenizer_dir=TOKENIZER_DIR,
max_seq_len=MAX_SEQ_LENGTH,
classes=KIE_LABELS,
device=DEVICE,
outdir_visualize=VISUALIZE_DIR,
)
def format_result(result):
"""
return:
[
{
key: 'name',
value: 'Nguyen Hoang Hiep',
true_box: [
373,
113,
700,
420
]
},
{
key: 'name',
value: 'Nguyen Hoang Hiep 1',
true_box: [
10,
10,
20,
20,
]
},
]
"""
new_result = []
for i, item in enumerate(result[0]):
new_result.append(
{
"key": item,
"value": result[0][item],
"true_box": result[1][i],
}
)
return new_result
def predict(image_url):
if not os.path.exists(PRED_DIR):
os.makedirs(PRED_DIR, exist_ok=True)
if not os.path.exists(VISUALIZE_DIR):
os.makedirs(VISUALIZE_DIR, exist_ok=True)
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
bboxes, texts = ocr_predict(cv_image)
texts_replaced = []
for text in texts:
if "" in text:
text_replaced = text.replace("", " ")
texts_replaced.append(text_replaced)
else:
texts_replaced.append(text)
inputs = model.prepare_kie_inputs(image, ocr_info=[bboxes, texts_replaced])
result = model(inputs)
result = format_result(result)
return result

View File

@ -0,0 +1,137 @@
import os
import glob
import cv2
import argparse
from tqdm import tqdm
import urllib
import numpy as np
import imagesize
# from omegaconf import OmegaConf
import sys
cur_dir = os.path.dirname(__file__)
sys.path.append(cur_dir)
# sys.path.append('/cope2n-ai-fi/Kie_Invoice_AP/AnyKey_Value/')
from predictor import KVUPredictor
from preprocess import KVUProcess, DocumentKVUProcess
from utils.utils import create_dir, visualize, get_colormap, export_kvu_outputs, export_kvu_for_manulife
def get_args():
args = argparse.ArgumentParser(description='Main file')
args.add_argument('--img_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
help='Input image directory')
args.add_argument('--save_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
help='Save directory')
# args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
# help='Checkpoint and config of model')
args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
help='Checkpoint and config of model')
args.add_argument('--export_img', default=0, type=int,
help='Save visualize on image')
args.add_argument('--mode', default=3, type=int,
help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
args.add_argument('--dir_level', default=0, type=int,
help='Number of subfolders contains image')
return args.parse_args()
def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
configs = {
'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
}
dummy_idx = 512
predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
# processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
# predictor.net.feature_extractor, predictor.backbone_type, class_names,
# predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode)
processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor,
predictor.backbone_type, class_names,
predictor.max_window_count, predictor.slice_interval, predictor.window_size,
run_ocr=1, mode=mode)
return predictor, processor
def revert_box(box, width, height):
return [
int((box[0] / 1000) * width),
int((box[1] / 1000) * height),
int((box[2] / 1000) * width),
int((box[3] / 1000) * height)
]
def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
fname = os.path.basename(img_path)
img_ext = img_path.split('.')[-1]
inputs = processor(img_path, ocr_path='')
width, height = imagesize.get(img_path)
bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
# slide_window = False if len(bbox) == 1 else True
if len(bbox) == 0:
bbox, lwords, pr_class_words, pr_relations = [bbox], [lwords], [pr_class_words], [pr_relations]
for i in range(len(bbox)):
bbox[i] = [revert_box(bb, width, height) for bb in bbox[i]]
# vat_outputs_invoice = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
# vat_outputs_receipt = export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
print(vat_outputs_invoice)
return vat_outputs_invoice
def load_groundtruth(img_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor: KVUProcess, export_img: int) -> None:
fname = os.path.basename(img_path)
img_ext = img_path.split('.')[-1]
inputs = processor.load_ground_truth(os.path.join(json_dir, fname.replace(f".{img_ext}", ".json")))
bbox, lwords, pr_class_words, pr_relations = predictor.get_ground_truth_label(inputs)
export_kvu_outputs(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords, pr_class_words, pr_relations, predictor.class_names)
if export_img == 1:
save_path = os.path.join(save_dir, 'kvu_results')
create_dir(save_path)
color_map = get_colormap()
image = cv2.imread(img_path)
image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness=1)
cv2.imwrite(os.path.join(save_path, fname), image)
def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor, export_img: int) -> None:
list_images = []
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
print('No. images:', len(list_images))
for img_path in tqdm(list_images):
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
def Predictor_KVU(image_url: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
image_path = "./Kie_Invoice_AP/tmp_image/{image_url}.jpg"
cv2.imwrite(image_path, img)
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
return vat_outputs
if __name__ == "__main__":
args = get_args()
class_names = ['others', 'title', 'key', 'value', 'header']
predict_mode = {
'normal': 0,
'full_tokens': 1,
'sliding': 2,
'document': 3
}
predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
create_dir(args.save_dir)
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
predict_image(image_path, save_dir, predictor, processor)
print('[INFO] Done')

View File

@ -0,0 +1,133 @@
import time
import torch
import torch.utils.data
from overrides import overrides
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import LambdaLR
from lightning_modules.schedulers import (
cosine_scheduler,
linear_scheduler,
multistep_scheduler,
)
from model import get_model
from utils import cfg_to_hparams, get_specific_pl_logger
class ClassifierModule(LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.net = get_model(self.cfg)
self.ignore_index = -100
self.time_tracker = None
self.optimizer_types = {
"sgd": SGD,
"adam": Adam,
"adamw": AdamW,
}
@overrides
def setup(self, stage):
self.time_tracker = time.time()
@overrides
def configure_optimizers(self):
optimizer = self._get_optimizer()
scheduler = self._get_lr_scheduler(optimizer)
scheduler = {
"scheduler": scheduler,
"name": "learning_rate",
"interval": "step",
}
return [optimizer], [scheduler]
def _get_lr_scheduler(self, optimizer):
cfg_train = self.cfg.train
lr_schedule_method = cfg_train.optimizer.lr_schedule.method
lr_schedule_params = cfg_train.optimizer.lr_schedule.params
if lr_schedule_method is None:
scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
elif lr_schedule_method == "step":
scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
elif lr_schedule_method == "cosine":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = cosine_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
elif lr_schedule_method == "linear":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = linear_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
else:
raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")
return scheduler
def _get_optimizer(self):
opt_cfg = self.cfg.train.optimizer
method = opt_cfg.method.lower()
if method not in self.optimizer_types:
raise ValueError(f"Unknown optimizer method={method}")
kwargs = dict(opt_cfg.params)
kwargs["params"] = self.net.parameters()
optimizer = self.optimizer_types[method](**kwargs)
return optimizer
@rank_zero_only
@overrides
def on_fit_end(self):
hparam_dict = cfg_to_hparams(self.cfg, {})
metric_dict = {"metric/dummy": 0}
tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger)
if tb_logger:
tb_logger.log_hyperparams(hparam_dict, metric_dict)
@overrides
def training_epoch_end(self, training_step_outputs):
avg_loss = torch.tensor(0.0).to(self.device)
for step_out in training_step_outputs:
avg_loss += step_out["loss"]
log_dict = {"train_loss": avg_loss}
self._log_shell(log_dict, prefix="train ")
def _log_shell(self, log_info, prefix=""):
log_info_shell = {}
for k, v in log_info.items():
new_v = v
if type(new_v) is torch.Tensor:
new_v = new_v.item()
log_info_shell[k] = new_v
out_str = prefix.upper()
if prefix.upper().strip() in ["TRAIN", "VAL"]:
out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]"
if self.training:
lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"]
log_info_shell["lr"] = lr
for key, value in log_info_shell.items():
out_str += f" || {key}: {round(value, 5)}"
out_str += f" || time: {round(time.time() - self.time_tracker, 1)}"
out_str += " secs."
self.print(out_str)
self.time_tracker = time.time()

View File

@ -0,0 +1,390 @@
import numpy as np
import torch
import torch.utils.data
from overrides import overrides
from lightning_modules.classifier import ClassifierModule
from utils import get_class_names
class KVUClassifierModule(ClassifierModule):
def __init__(self, cfg):
super().__init__(cfg)
class_names = get_class_names(self.cfg.dataset_root_path)
self.window_size = cfg.train.max_num_words
self.slice_interval = cfg.train.slice_interval
self.eval_kwargs = {
"class_names": class_names,
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
}
self.stage = cfg.stage
@overrides
def training_step(self, batch, batch_idx, *args):
if self.stage == 1:
_, loss = self.net(batch['windows'])
elif self.stage == 2:
_, loss = self.net(batch)
else:
raise ValueError(
f"Not supported stage: {self.stage}"
)
log_dict_input = {"train_loss": loss}
self.log_dict(log_dict_input, sync_dist=True)
return loss
@torch.no_grad()
@overrides
def validation_step(self, batch, batch_idx, *args):
if self.stage == 1:
step_out_total = {
"loss": 0,
"ee":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el_from_key":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
}}
for window in batch['windows']:
head_outputs, loss = self.net(window)
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
for key in step_out_total:
if key == 'loss':
step_out_total[key] += step_out[key]
else:
for subkey in step_out_total[key]:
step_out_total[key][subkey] += step_out[key][subkey]
return step_out_total
elif self.stage == 2:
head_outputs, loss = self.net(batch)
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
return step_out
@torch.no_grad()
@overrides
def validation_epoch_end(self, validation_step_outputs):
scores = do_eval_epoch_end(validation_step_outputs)
self.print(
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
)
self.print(
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
)
self.print(
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
)
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
'val_f1_el_from_key': scores['el_from_key']['f1'],}
return {'log': tensorboard_logs}
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
class_names = eval_kwargs["class_names"]
dummy_idx = eval_kwargs["dummy_idx"]
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_labels = torch.argmax(itc_outputs, -1)
pr_stc_labels = torch.argmax(stc_outputs, -1)
pr_el_labels = torch.argmax(el_outputs, -1)
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
(
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
) = eval_ee_spade_batch(
pr_itc_labels,
batch["itc_labels"],
batch["are_box_first_tokens"],
pr_stc_labels,
batch["stc_labels"],
batch["attention_mask_layoutxlm"],
class_names,
dummy_idx,
)
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
pr_el_labels,
batch["el_labels"],
batch["are_box_first_tokens"],
dummy_idx,
)
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
pr_el_labels_from_key,
batch["el_labels_from_key"],
batch["are_box_first_tokens"],
dummy_idx,
)
step_out = {
"loss": loss,
"ee":{
"n_batch_gt": n_batch_gt_classes,
"n_batch_pr": n_batch_pr_classes,
"n_batch_correct": n_batch_correct_classes,
},
"el":{
"n_batch_gt": n_batch_gt_rel,
"n_batch_pr": n_batch_pr_rel,
"n_batch_correct": n_batch_correct_rel,
},
"el_from_key":{
"n_batch_gt": n_batch_gt_rel_from_key,
"n_batch_pr": n_batch_pr_rel_from_key,
"n_batch_correct": n_batch_correct_rel_from_key,
}
}
return step_out
def eval_ee_spade_batch(
pr_itc_labels,
gt_itc_labels,
are_box_first_tokens,
pr_stc_labels,
gt_stc_labels,
attention_mask,
class_names,
dummy_idx,
):
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
bsz = pr_itc_labels.shape[0]
for example_idx in range(bsz):
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
pr_itc_labels[example_idx],
gt_itc_labels[example_idx],
are_box_first_tokens[example_idx],
pr_stc_labels[example_idx],
gt_stc_labels[example_idx],
attention_mask[example_idx],
class_names,
dummy_idx,
)
n_batch_gt_classes += n_gt_classes
n_batch_pr_classes += n_pr_classes
n_batch_correct_classes += n_correct_classes
return (
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
)
def eval_ee_spade_example(
pr_itc_label,
gt_itc_label,
box_first_token_mask,
pr_stc_label,
gt_stc_label,
attention_mask,
class_names,
dummy_idx,
):
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, dummy_idx
)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
for class_idx in range(len(class_names)):
# Evaluate by ID
gt_parse = set(gt_class_words[class_idx])
pr_parse = set(pr_class_words[class_idx])
n_gt_classes += len(gt_parse)
n_pr_classes += len(pr_parse)
n_correct_classes += len(gt_parse & pr_parse)
return n_gt_classes, n_pr_classes, n_correct_classes
def parse_initial_words(itc_label, box_first_token_mask, class_names):
itc_label_np = itc_label.cpu().numpy()
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
outputs = [[] for _ in range(len(class_names))]
for token_idx, label in enumerate(itc_label_np):
if box_first_token_mask_np[token_idx] and label != 0:
outputs[label].append(token_idx)
return outputs
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
max_connections = 50
valid_stc_label = stc_label * attention_mask.bool()
valid_stc_label = valid_stc_label.cpu().numpy()
stc_label_np = stc_label.cpu().numpy()
valid_token_indices = np.where(
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
)
next_token_idx_dict = {}
for token_idx in valid_token_indices[0]:
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
outputs = []
for init_token_indices in init_words:
sub_outputs = []
for init_token_idx in init_token_indices:
cur_token_indices = [init_token_idx]
for _ in range(max_connections):
if cur_token_indices[-1] in next_token_idx_dict:
if (
next_token_idx_dict[cur_token_indices[-1]]
not in init_token_indices
):
cur_token_indices.append(
next_token_idx_dict[cur_token_indices[-1]]
)
else:
break
else:
break
sub_outputs.append(tuple(cur_token_indices))
outputs.append(sub_outputs)
return outputs
def eval_el_spade_batch(
pr_el_labels,
gt_el_labels,
are_box_first_tokens,
dummy_idx,
):
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
bsz = pr_el_labels.shape[0]
for example_idx in range(bsz):
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
pr_el_labels[example_idx],
gt_el_labels[example_idx],
are_box_first_tokens[example_idx],
dummy_idx,
)
n_batch_gt_rel += n_gt_rel
n_batch_pr_rel += n_pr_rel
n_batch_correct_rel += n_correct_rel
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
gt_relations = set(gt_relations)
pr_relations = set(pr_relations)
n_gt_rel = len(gt_relations)
n_pr_rel = len(pr_relations)
n_correct_rel = len(gt_relations & pr_relations)
return n_gt_rel, n_pr_rel, n_correct_rel
def parse_relations(el_label, box_first_token_mask, dummy_idx):
valid_el_labels = el_label * box_first_token_mask
valid_el_labels = valid_el_labels.cpu().numpy()
el_label_np = el_label.cpu().numpy()
max_token = box_first_token_mask.shape[0] - 1
valid_token_indices = np.where(
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
)
link_map_tuples = []
for token_idx in valid_token_indices[0]:
link_map_tuples.append((el_label_np[token_idx], token_idx))
return set(link_map_tuples)
def do_eval_epoch_end(step_outputs):
scores = {}
for task in ['ee', 'el', 'el_from_key']:
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
for step_out in step_outputs:
n_total_gt_classes += step_out[task]["n_batch_gt"]
n_total_pr_classes += step_out[task]["n_batch_pr"]
n_total_correct_classes += step_out[task]["n_batch_correct"]
precision = (
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
)
recall = (
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
)
f1 = (
0.0
if recall * precision == 0
else 2.0 * recall * precision / (recall + precision)
)
scores[task] = {
"precision": precision,
"recall": recall,
"f1": f1,
}
return scores
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
class_names = get_class_names(dataset_root_path)
dummy_idx = max_seq_length
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
return eval_kwargs
def get_eval_kwargs_spade_rel(max_seq_length):
dummy_idx = max_seq_length
eval_kwargs = {"dummy_idx": dummy_idx}
return eval_kwargs

View File

@ -0,0 +1,135 @@
import time
import torch
import pytorch_lightning as pl
from overrides import overrides
from torch.utils.data.dataloader import DataLoader
from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset
from lightning_modules.utils import _get_number_samples
class KVUDataModule(pl.LightningDataModule):
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor):
super().__init__()
self.cfg = cfg
self.train_loader = None
self.val_loader = None
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.collate_fn = None
self.backbone_type = self.cfg.model.backbone
self.num_samples_per_epoch = _get_number_samples(cfg.dataset_root_path)
@overrides
def setup(self, stage=None):
self.train_loader = self._get_train_loader()
self.val_loader = self._get_val_loaders()
@overrides
def train_dataloader(self):
return self.train_loader
@overrides
def val_dataloader(self):
return self.val_loader
def _get_train_loader(self):
start_time = time.time()
if self.cfg.stage == 1:
dataset = KVUDataset(
self.cfg,
self.tokenizer_layoutxlm,
self.feature_extractor,
mode="train",
)
elif self.cfg.stage == 2:
# dataset = KVUEmbeddingDataset(
# self.cfg,
# mode="train",
# )
dataset = KVUDataset(
self.cfg,
self.tokenizer_layoutxlm,
self.feature_extractor,
mode="train",
)
else:
raise ValueError(
f"Not supported stage: {self.cfg.stage}"
)
print('No. training samples:', len(dataset))
data_loader = DataLoader(
dataset,
batch_size=self.cfg.train.batch_size,
shuffle=True,
num_workers=self.cfg.train.num_workers,
pin_memory=True,
)
elapsed_time = time.time() - start_time
print(f"Elapsed time for loading training data: {elapsed_time}")
return data_loader
def _get_val_loaders(self):
if self.cfg.stage == 1:
dataset = KVUDataset(
self.cfg,
self.tokenizer_layoutxlm,
self.feature_extractor,
mode="val",
)
elif self.cfg.stage == 2:
# dataset = KVUEmbeddingDataset(
# self.cfg,
# mode="val",
# )
dataset = KVUDataset(
self.cfg,
self.tokenizer_layoutxlm,
self.feature_extractor,
mode="val",
)
else:
raise ValueError(
f"Not supported stage: {self.cfg.stage}"
)
print('No. validation samples:', len(dataset))
data_loader = DataLoader(
dataset,
batch_size=self.cfg.val.batch_size,
shuffle=False,
num_workers=self.cfg.val.num_workers,
pin_memory=True,
drop_last=False,
)
return data_loader
@overrides
def transfer_batch_to_device(self, batch, device, dataloader_idx):
if isinstance(batch, list):
for sub_batch in batch:
for k in sub_batch.keys():
if isinstance(sub_batch[k], torch.Tensor):
sub_batch[k] = sub_batch[k].to(device)
else:
for k in batch.keys():
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device)
return batch

View File

@ -0,0 +1,728 @@
import os
import json
import omegaconf
import itertools
import copy
import numpy as np
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset
from utils import get_class_names
from lightning_modules.utils import sliding_windows_by_words
class KVUDataset(Dataset):
def __init__(
self,
cfg,
tokenizer_layoutxlm,
feature_extractor,
mode=None,
):
super(KVUDataset, self).__init__()
self.dataset_root_path = cfg.dataset_root_path
if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig):
self.dataset_root_path = [self.dataset_root_path]
self.backbone_type = cfg.model.backbone
self.max_seq_length = cfg.train.max_seq_length
self.window_size = cfg.train.max_num_words
self.slice_interval = cfg.train.slice_interval
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.stage = cfg.stage
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.examples = self._load_examples()
self.class_names = get_class_names(self.dataset_root_path)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
def _load_examples(self):
examples = []
for dataset_dir in self.dataset_root_path:
with open(
os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"),
"r",
encoding="utf-8",
) as fp:
for line in fp.readlines():
preprocessed_file = os.path.join(dataset_dir, line.strip())
examples.append(
json.load(open(preprocessed_file, "r", encoding="utf-8"))
)
return examples
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
json_obj = self.examples[index]
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
img_path = json_obj["meta"]["image_path"]
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words(
json_obj["words"],
json_obj['parse']['class'],
json_obj['parse']['relations'],
self.window_size, self.slice_interval)
outputs = {}
if self.stage == 1:
if self.mode == 'train':
i = np.random.randint(0, len(word_windows), 1)[0]
outputs['windows'] = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
else:
outputs['windows'] = []
for i in range(len(word_windows)):
single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
outputs['windows'].append(single_window)
elif self.stage == 2:
outputs['documents'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'],
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
windows = []
for i in range(len(word_windows)):
_words = word_windows[i]
_parse_class = parse_class_windows[i]
_parse_relation = parse_relation_windows[i]
windows.append(
self.preprocess(_words, _parse_class, _parse_relation,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
self.max_seq_length)
)
outputs['windows'] = windows
return outputs
def preprocess(self, words, parse_class, parse_relation, feature_maps, max_seq_length):
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
itc_labels = np.zeros(max_seq_length, dtype=int)
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
el_labels = np.ones((max_seq_length,), dtype=int) * max_seq_length
el_labels_from_key = np.ones((max_seq_length,), dtype=int) * max_seq_length
list_layoutxlm_tokens = []
list_bbs = []
box2token_span_map = []
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
len_overlap_tokens = 0
len_non_overlap_tokens = 0
len_valid_tokens = 0
for word_idx, word in enumerate(words):
this_box_token_indices = []
layoutxlm_tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
len_valid_tokens += len(layoutxlm_tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(layoutxlm_tokens)
# print(word_idx, layoutxlm_tokens, non_overlap_tokens)
if len(layoutxlm_tokens) == 0:
layoutxlm_tokens.append(self.unk_token_id_layoutxlm)
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
break
box2token_span_map.append(
[len(list_layoutxlm_tokens) + 1, len(list_layoutxlm_tokens) + len(layoutxlm_tokens) + 1]
) # including st_idx
list_layoutxlm_tokens += layoutxlm_tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(layoutxlm_tokens))]
for _ in layoutxlm_tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_layoutxlm_tokens = (
[self.cls_token_id_layoutxlm]
+ list_layoutxlm_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
# Label for entity extraction
classes_dic = parse_class
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Label for entity linking
relations = parse_relation
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= max_seq_length
or box2token_span_map[relation[1]][0] >= max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
#### 1st relation => ['key, 'value']
#### 2st relation => ['header', 'key'or'value']
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
# overlap_tokens = max_seq_length - non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
# ntokens = max_seq_length
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm[:ntokens])
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm[:ntokens])
bbox = torch.from_numpy(bbox[:ntokens])
are_box_first_tokens = torch.from_numpy(are_box_first_tokens[:ntokens])
itc_labels = itc_labels[:ntokens]
stc_labels = stc_labels[:ntokens]
el_labels = el_labels[:ntokens]
el_labels_from_key = el_labels_from_key[:ntokens]
itc_labels = np.where(itc_labels != max_seq_length, itc_labels, ntokens)
stc_labels = np.where(stc_labels != max_seq_length, stc_labels, ntokens)
el_labels = np.where(el_labels != max_seq_length, el_labels, ntokens)
el_labels_from_key = np.where(el_labels_from_key != max_seq_length, el_labels_from_key, ntokens)
itc_labels = torch.from_numpy(itc_labels)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
return_dict = {
"img_path": feature_maps['img_path'],
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids_layoutxlm,
"attention_mask_layoutxlm": attention_mask_layoutxlm,
"bbox": bbox,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key,
}
return return_dict
class KVUPredefinedDataset(KVUDataset):
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None):
super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode)
self.max_windows = cfg.train.max_windows
def __getitem__(self, index):
json_obj = self.examples[index]
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
img_path = json_obj["meta"]["image_path"]
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words(
json_obj["words"],
json_obj['parse']['class'],
json_obj['parse']['relations'],
self.window_size, self.slice_interval)
word_windows = word_windows[: self.max_windows] if len(word_windows) >= self.max_windows else word_windows + [[]] * (self.max_windows - len(word_windows))
parse_class_windows = parse_class_windows[: self.max_windows] if len(parse_class_windows) >= self.max_windows else parse_class_windows + [[]] * (self.max_windows - len(parse_class_windows))
parse_relation_windows = parse_relation_windows[: self.max_windows] if len(parse_relation_windows) >= self.max_windows else parse_relation_windows + [[]] * (self.max_windows - len(parse_relation_windows))
outputs = {}
# outputs['labels'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'],
# {'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
# max_seq_length=self.max_seq_length*self.max_windows)
outputs['windows'] = []
for i in range(len(self.max_windows)):
single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
outputs['windows'].append(single_window)
outputs['windows'] = torch.cat(outputs["windows"], dim=0)
return outputs
class KVUEmbeddingDataset(Dataset):
def __init__(
self,
cfg,
mode=None,
):
super(KVUEmbeddingDataset, self).__init__()
self.dataset_root_path = cfg.dataset_root_path
if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig):
self.dataset_root_path = [self.dataset_root_path]
self.stage = cfg.stage
self.mode = mode
self.examples = self._load_examples()
def _load_examples(self):
examples = []
for dataset_dir in self.dataset_root_path:
with open(
os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"),
"r",
encoding="utf-8",
) as fp:
for line in fp.readlines():
preprocessed_file = os.path.join(dataset_dir, line.strip())
examples.append(
preprocessed_file.replace('preprocessed', 'embedding_matrix').replace('.json', '.npz')
)
return examples
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
input_embeddings = np.load(self.examples[index])
return_dict = {
"embeddings": torch.from_numpy(input_embeddings["embeddings"]).type(torch.HalfTensor),
"attention_mask_layoutxlm": torch.from_numpy(input_embeddings["attention_mask_layoutxlm"]),
"are_box_first_tokens": torch.from_numpy(input_embeddings["are_box_first_tokens"]),
"bbox": torch.from_numpy(input_embeddings["bbox"]),
"itc_labels": torch.from_numpy(input_embeddings["itc_labels"]),
"stc_labels": torch.from_numpy(input_embeddings["stc_labels"]),
"el_labels": torch.from_numpy(input_embeddings["el_labels"]),
"el_labels_from_key": torch.from_numpy(input_embeddings["el_labels_from_key"]),
}
return return_dict
class DocumentKVUDataset(KVUDataset):
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None):
super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode)
self.self.max_window_count = cfg.train.max_window_count
def __getitem__(self, idx):
json_obj = self.examples[idx]
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
feature_maps = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
n_words = len(json_obj['words'])
output_dicts = {'windows': [], 'documents': []}
box_to_token_indices_document = []
box2token_span_map_document = []
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * 1
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
itc_labels = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
# stc_labels stores the index of the previous token.
# A stored index of max_seq_length (512) indicates that
# this token is the initial token of a word box.
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
el_labels = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length
el_labels_from_key = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts.append(output_dicts[-1])
box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices)
for i_box in range(len(box_to_token_indices_to_mod)):
for j in range(len(box_to_token_indices_to_mod[i_box])):
box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length
for element in box_to_token_indices_to_mod:
box_to_token_indices_document.append(element)
box2token_span_map_to_mod = copy.deepcopy(box2token_span_map)
for i_box in range(len(box2token_span_map_to_mod)):
for j in range(len(box2token_span_map_to_mod[i_box])):
box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length
for element in box2token_span_map_to_mod:
box2token_span_map_document.append(element)
continue
list_tokens = []
list_bbs = []
box2token_span_map = []
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
# Parse words
for word_idx, word in enumerate(json_obj["words"][start_word_idx:stop_word_idx]):
this_box_token_indices = []
tokens = word["tokens"]
bb = word["boundingBox"]
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break ### be able to apply sliding window here
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ('layoutlm', 'layoutxlm', 'xlm-roberta'):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
# Parse word groups
classes_dic = json_obj["parse"]["class"]
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
word_list = [w for w in word_list if w >= start_word_idx and w < stop_word_idx]
if len(word_list) == 0:
continue # no more word left
word_list = [w - start_word_idx for w in word_list]
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= self.max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Parse relation
relations = json_obj["parse"]["relations"]
for relation in relations:
relation = [r for r in relation if r >= start_word_idx and r < stop_word_idx]
if len(relation) != 2:
continue # relation popped due to window inconsistent
relation[0] -= start_word_idx
relation[1] -= start_word_idx
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
#### 1st relation => ['key, 'value']
#### 2st relation => ['header', 'key'or'value']
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
itc_labels = torch.from_numpy(itc_labels)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices)
for i_box in range(len(box_to_token_indices_to_mod)):
for j in range(len(box_to_token_indices_to_mod[i_box])):
box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length
for element in box_to_token_indices_to_mod:
box_to_token_indices_document.append(element)
box2token_span_map_to_mod = copy.deepcopy(box2token_span_map)
for i_box in range(len(box2token_span_map_to_mod)):
for j in range(len(box2token_span_map_to_mod[i_box])):
box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length
for element in box2token_span_map_to_mod:
box2token_span_map_document.append(element)
return_dict = {
"image": feature_maps,
"input_ids": input_ids,
"bbox": bbox,
"attention_mask": attention_mask,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key,
}
output_dicts["windows"].append(return_dict)
# Parse whole document labels
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts])
self.max_seq_length_document = self.max_seq_length * self.max_window_count
itc_labels = np.zeros(self.max_seq_length_document, dtype=int)
stc_labels = np.ones(self.max_seq_length_document, dtype=np.int64) * self.max_seq_length_document
el_labels = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document
el_labels_from_key = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document
# Parse word groups
classes_dic = json_obj["parse"]["class"]
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
word_lists = classes_dic[class_name]
for word_list in word_lists:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices_document):
break
box2token_list = box_to_token_indices_document[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= self.max_seq_length_document:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Parse relation
relations = json_obj["parse"]["relations"]
for relation in relations:
if relation[0] >= len(box2token_span_map_document) or relation[1] >= len(
box2token_span_map_document
):
continue
if (
box2token_span_map_document[relation[0]][0] >= self.max_seq_length_document
or box2token_span_map_document[relation[1]][0] >= self.max_seq_length_document
):
continue
word_from = box2token_span_map_document[relation[0]][0]
word_to = box2token_span_map_document[relation[1]][0]
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
itc_labels = torch.from_numpy(itc_labels)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
return_dict = {
"img_path": json_obj["meta"]["image_path"],
"attention_mask": attention_mask,
"bbox": bbox,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key,
"n_empty_windows": n_empty_windows
}
output_dicts['documents'] = return_dict
return output_dicts

View File

@ -0,0 +1,53 @@
"""
BROS
Copyright 2022-present NAVER Corp.
Apache License v2.0
"""
import math
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1):
"""linear_scheduler with warmup from huggingface"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0,
float(training_steps - current_step)
/ float(max(1, training_steps - warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def cosine_scheduler(
optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1
):
"""Cosine LR scheduler with warmup from huggingface"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return current_step / max(1, warmup_steps)
progress = current_step - warmup_steps
progress /= max(1, training_steps - warmup_steps)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1):
def lr_lambda(current_step):
if current_step < warmup_steps:
# calculate a warmup ratio
return current_step / max(1, warmup_steps)
else:
# calculate a multistep lr scaling ratio
idx = np.searchsorted(milestones, current_step)
return gamma ** idx
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@ -0,0 +1,218 @@
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import math
def _get_number_samples(dataset_root_path):
n_samples = 0
for dataset_dir in dataset_root_path:
with open(
os.path.join(dataset_dir, f"preprocessed_files_train.txt"),
"r",
encoding="utf-8",
) as fp:
for line in fp.readlines():
n_samples += 1
return n_samples
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
element_windows = []
if len(elements) > window_size:
max_step = math.ceil((len(elements) - window_size)/slice_interval)
for i in range(0, max_step + 1):
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
if (i*slice_interval+window_size) >= len(elements):
_window = copy.deepcopy(elements[i*slice_interval:])
else:
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
element_windows.append(_window)
return element_windows
else:
return [elements]
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
word_windows = []
parse_class_windows = []
parse_relation_windows = []
if len(lwords) > window_size:
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
for i in range(0, max_step+1):
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
if (i*slice_interval+window_size) >= len(lwords):
_word_window = copy.deepcopy(lwords[i*slice_interval:])
else:
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
if len(_word_window) < 2:
continue
first_word_id = _word_window[0]['word_id']
last_word_id = _word_window[-1]['word_id']
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
# word list
for _word in _word_window:
_word['word_id'] -= first_word_id
# Entity extraction
_class_window = {k: [] for k in list(parse_class.keys())}
for class_name, _parse_class in parse_class.items():
for group in _parse_class:
tmp = []
for idw in group:
idw -= first_word_id
if 0 <= idw <= (last_word_id - first_word_id):
tmp.append(idw)
_class_window[class_name].append(tmp)
# Entity Linking
_relation_window = []
for pair in parse_relation:
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
_relation_window.append([idw - first_word_id for idw in pair])
word_windows.append(_word_window)
parse_class_windows.append(_class_window)
parse_relation_windows.append(_relation_window)
return word_windows, parse_class_windows, parse_relation_windows
else:
return [lwords], [parse_class], [parse_relation]
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = copy.deepcopy(loverlaps[i-1])
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
if overlap_gap != 0:
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
cls_token = lpatches[0][:, :1, ...]
sep_token = lpatches[0][:, -1:, ...]
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = loverlaps[i-1]
window = lpatches[i][:, start_pos:end_pos, ...]
if overlap_gap != 0:
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
curr_overlap = window[:, :overlap_gap, ...]
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
# def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
# # start_pos = 1
# # end_pos = start_pos + lvalids[0]
# # embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
# embedding_tokens = np.zeros((1, 2046, 768))
# cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
# sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
# token_apperance_count = np.zeros((2046,))
# start_idx = 1
# for loverlap, lvalid in zip(loverlaps, lvalids):
# token_apperance_count[start_idx:start_idx+lvalid] += 1
# start_idx = start_idx + lvalid - loverlap
# embedding_matrix_spos = 0
# for i in range(0, len(lpatches)):
# embedding_matrix_epos = embedding_matrix_spos + int(lvalids[i])
# # assert embedding_matrix_epos - embedding_matrix_spos == end_pos - 1, (embedding_matrix_spos, embedding_matrix_epos, lvalid[i], end_pos)
# overlap_gap = copy.deepcopy(loverlaps[i]).cpu().numpy()
# window = copy.deepcopy(lpatches[i][:, 1:int(lvalids[i])+1, ...]).cpu().numpy()
# # embedding_tokens[:,embedding_matrix_spos:embedding_matrix_epos:] = window * token_apperance_count[embedding_matrix_spos:embedding_matrix_epos]
# for i in range(len(window)):
# window[:,i,:] *= token_apperance_count[embedding_matrix_spos + i]
# embedding_tokens[: ,int(embedding_matrix_spos): int(embedding_matrix_epos), :] += window
# embedding_matrix_spos -= overlap_gap
# embedding_tokens = torch.tensor(embedding_tokens).type(torch.HalfTensor).cuda()
# # if overlap_gap != 0:
# # prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
# # curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
# # assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
# # if average:
# # avg_overlap = (
# # prev_overlap + curr_overlap
# # ) / 2.
# # embedding_tokens = torch.cat(
# # [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
# # )
# # else:
# # embedding_tokens = torch.cat(
# # [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
# # )
# # else:
# # embedding_tokens = torch.cat(
# # [embedding_tokens, window], dim=1
# # )
# return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)

View File

@ -0,0 +1,15 @@
from model.combined_model import CombinedKVUModel
from model.kvu_model import KVUModel
from model.document_kvu_model import DocumentKVUModel
def get_model(cfg):
if cfg.stage == 1:
model = CombinedKVUModel(cfg=cfg)
elif cfg.stage == 2:
model = KVUModel(cfg=cfg)
elif cfg.stage == 3:
model = DocumentKVUModel(cfg=cfg)
else:
AssertionError('[ERROR] Trainging stage is wrong')
return model

View File

@ -0,0 +1,224 @@
import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from transformers import AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from utils import load_checkpoint
class CombinedKVUModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if os.path.exists(self.model_cfg.ckpt_model_file):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
if self.finetune_only == 'EE':
for name, param in self.named_parameters():
if 'itc_layer' not in name and 'stc_layer' not in name:
param.requires_grad = False
if self.finetune_only == 'EL':
for name, param in self.named_parameters():
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
param.requires_grad = False
if self.finetune_only == 'ELK':
for name, param in self.named_parameters():
if 'relation_layer_from_key' not in name:
param.requires_grad = False
def _create_head(self):
self.backbone_hidden_size = 768
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
def _get_backbones(self, config_type):
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
@staticmethod
def _init_weight(module):
init_std = 0.02
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight, 1.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def forward(self, batch):
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
loss = 0.0
if any(['labels' in key for key in batch.keys()]):
loss = self._get_loss(head_outputs, batch)
return head_outputs, loss
def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
el_loss = self._get_el_loss(el_outputs, batch)
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
return loss
def _get_itc_loss(self, itc_outputs, batch):
itc_mask = batch["are_box_first_tokens"].view(-1)
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
itc_logits = itc_logits[itc_mask]
itc_labels = batch["itc_labels"].view(-1)
itc_labels = itc_labels[itc_mask]
itc_loss = self.loss_func(itc_logits, itc_labels)
return itc_loss
def _get_stc_loss(self, stc_outputs, batch):
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device
invalid_token_mask = torch.cat(
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_logits = stc_logits[stc_mask]
stc_labels = batch["stc_labels"].view(-1)
stc_labels = stc_labels[stc_mask]
stc_loss = self.loss_func(stc_logits, stc_labels)
return stc_loss
def _get_el_loss(self, el_outputs, batch, from_key=False):
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
device = batch["attention_mask_layoutxlm"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
if from_key:
el_labels = batch["el_labels_from_key"]
else:
el_labels = batch["el_labels"]
labels = el_labels.view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss

View File

@ -0,0 +1,285 @@
import torch
from torch import nn
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2Config, LayoutLMv2Model
from transformers import LayoutXLMTokenizer
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from utils import load_checkpoint
class DocumentKVUModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.train_cfg = cfg.train
self._get_backbones(self.model_cfg.backbone)
# if 'pth' in self.model_cfg.ckpt_model_file:
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone, 'backbone')
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
def _create_head(self):
self.backbone_hidden_size = self.backbone_config.hidden_size
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
self.relations = self.model_cfg.n_relations
# self.repr_hiddent_size = self.backbone_hidden_size + self.n_classes + (self.train_cfg.max_seq_length + 1) * 3
self.repr_hiddent_size = self.backbone_hidden_size
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=self.relations, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=self.relations, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# Classfication Layer for whole document
# (1) Initial token classification
self.itc_layer_document = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer_document = RelationExtractor(
n_relations=self.relations,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key_document = RelationExtractor(
n_relations=self.relations,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
self.relation_layer_from_key.apply(self._init_weight)
self.itc_layer_document.apply(self._init_weight)
self.stc_layer_document.apply(self._init_weight)
self.relation_layer_document.apply(self._init_weight)
self.relation_layer_from_key_document.apply(self._init_weight)
def _get_backbones(self, config_type):
configs = {
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
}
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
if config_type != 'xlm-roberta':
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
else:
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
@staticmethod
def _init_weight(module):
init_std = 0.02
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight, 1.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def forward(self, batches):
head_outputs_list = []
loss = 0.0
for batch in batches["windows"]:
image = batch["image"]
input_ids = batch["input_ids"]
bbox = batch["bbox"]
attention_mask = batch["attention_mask"]
if self.freeze:
for param in self.backbone.parameters():
param.requires_grad = False
if self.model_cfg.backbone == 'layoutxlm':
backbone_outputs = self.backbone(
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
)
else:
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
window_repr = last_hidden_states.transpose(0, 1).contiguous()
head_outputs = {"window_repr": window_repr,
"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
head_outputs_list.append(head_outputs)
batch = batches["documents"]
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
document_repr = document_repr.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
return head_outputs, loss
def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
el_loss = self._get_el_loss(el_outputs, batch)
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
return loss
def _get_itc_loss(self, itc_outputs, batch):
itc_mask = batch["are_box_first_tokens"].view(-1)
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
itc_logits = itc_logits[itc_mask]
itc_labels = batch["itc_labels"].view(-1)
itc_labels = itc_labels[itc_mask]
itc_loss = self.loss_func(itc_logits, itc_labels)
return itc_loss
def _get_stc_loss(self, stc_outputs, batch):
inv_attention_mask = 1 - batch["attention_mask"]
bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device
invalid_token_mask = torch.cat(
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = batch["attention_mask"].view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_logits = stc_logits[stc_mask]
stc_labels = batch["stc_labels"].view(-1)
stc_labels = stc_labels[stc_mask]
stc_loss = self.loss_func(stc_logits, stc_labels)
return stc_loss
def _get_el_loss(self, el_outputs, batch, from_key=False):
bsz, max_seq_length = batch["attention_mask"].shape
device = batch["attention_mask"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
if from_key:
el_labels = batch["el_labels_from_key"]
else:
el_labels = batch["el_labels"]
labels = el_labels.view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss

View File

@ -0,0 +1,248 @@
import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2
from model.relation_extractor import RelationExtractor
from utils import load_checkpoint
class KVUModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.device = 'cuda'
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
# if cfg.stage == 2:
# self.freeze = True
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
def _create_head(self):
self.backbone_hidden_size = 768
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
def _get_backbones(self, config_type):
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
@staticmethod
def _init_weight(module):
init_std = 0.02
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight, 1.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
# def forward(self, inputs):
# token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda()
# itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
# stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
# head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
# "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
# loss = self._get_loss(head_outputs, inputs)
# return head_outputs, loss
# def forward_single_doccument(self, lbatches):
def forward(self, lbatches):
windows = lbatches['windows']
token_embeddings_windows = []
lvalids = []
loverlaps = []
for i, batch in enumerate(windows):
batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')}
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
lvalids.append(batch['len_valid_tokens'])
loverlaps.append(batch['len_overlap_tokens'])
token_embeddings_windows.append(last_hidden_states_layoutxlm)
token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False)
# token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True)
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key,
'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()}
loss = 0.0
if any(['labels' in key for key in lbatches.keys()]):
labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')}
loss = self._get_loss(head_outputs, labels)
return head_outputs, loss
def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
el_loss = self._get_el_loss(el_outputs, batch)
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
return loss
def _get_itc_loss(self, itc_outputs, batch):
itc_mask = batch["are_box_first_tokens"].view(-1)
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
itc_logits = itc_logits[itc_mask]
itc_labels = batch["itc_labels"].view(-1)
itc_labels = itc_labels[itc_mask]
itc_loss = self.loss_func(itc_logits, itc_labels)
return itc_loss
def _get_stc_loss(self, stc_outputs, batch):
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device
invalid_token_mask = torch.cat(
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_logits = stc_logits[stc_mask]
stc_labels = batch["stc_labels"].view(-1)
stc_labels = stc_labels[stc_mask]
stc_loss = self.loss_func(stc_logits, stc_labels)
return stc_loss
def _get_el_loss(self, el_outputs, batch, from_key=False):
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
device = batch["attention_mask_layoutxlm"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
if from_key:
el_labels = batch["el_labels_from_key"]
else:
el_labels = batch["el_labels"]
labels = el_labels.view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss

View File

@ -0,0 +1,48 @@
import torch
from torch import nn
class RelationExtractor(nn.Module):
def __init__(
self,
n_relations,
backbone_hidden_size,
head_hidden_size,
head_p_dropout=0.1,
):
super().__init__()
self.n_relations = n_relations
self.backbone_hidden_size = backbone_hidden_size
self.head_hidden_size = head_hidden_size
self.head_p_dropout = head_p_dropout
self.drop = nn.Dropout(head_p_dropout)
self.q_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.k_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
nn.init.normal_(self.dummy_node)
def forward(self, h_q, h_k):
h_q = self.q_net(self.drop(h_q))
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
h_k = torch.cat([h_k, dummy_vec], axis=0)
h_k = self.k_net(self.drop(h_k))
head_q = h_q.view(
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
)
head_k = h_k.view(
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
)
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
return relation_score

View File

@ -0,0 +1,133 @@
externals/sdsv_dewarp
test/
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*/__pycache__/
*.py[cod]
*$py.class
results/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

View File

@ -0,0 +1,6 @@
[submodule "sdsvtr"]
path = externals/sdsvtr
url = https://github.com/mrlasdt/sdsvtr.git
[submodule "sdsvtd"]
path = externals/sdsvtd
url = https://github.com/mrlasdt/sdsvtd.git

View File

@ -0,0 +1,47 @@
# OCR Engine
OCR Engine is a Python package that combines text detection and recognition models from [mmdet](https://github.com/open-mmlab/mmdetection) and [mmocr](https://github.com/open-mmlab/mmocr) to perform Optical Character Recognition (OCR) on various inputs. The package currently supports three types of input: a single image, a recursive directory, or a csv file.
## Installation
To install OCR Engine, clone the repository and install the required packages:
```bash
git clone git@github.com:mrlasdt/ocr-engine.git
cd ocr-engine
pip install -r requirements.txt
```
## Usage
To use OCR Engine, simply run the `ocr_engine.py` script with the desired input type and input path. For example, to perform OCR on a single image:
```css
python ocr_engine.py --input_type image --input_path /path/to/image.jpg
```
To perform OCR on a recursive directory:
```css
python ocr_engine.py --input_type directory --input_path /path/to/directory/
```
To perform OCR on a csv file:
```
python ocr_engine.py --input_type csv --input_path /path/to/file.csv
```
OCR Engine will automatically detect and recognize text in the input and output the results in a CSV file named `ocr_results.csv`.
## Contributing
If you would like to contribute to OCR Engine, please fork the repository and submit a pull request. We welcome contributions of all types, including bug fixes, new features, and documentation improvements.
## License
OCR Engine is released under the [MIT License](https://opensource.org/licenses/MIT). See the LICENSE file for more information.

View File

@ -0,0 +1,11 @@
# # Define package-level variables
# __version__ = '0.0'
# Import modules
from .src.ocr import OcrEngine
# from .src.word_formation import words_to_lines
from .src.word_formation import words_to_lines_tesseract as words_to_lines
from .src.utils import ImageReader, read_ocr_result_from_txt
from .src.dto import Word, Line, Page, Document, Box
# Expose package contents
__all__ = ["OcrEngine", "Box", "Word", "Line", "Page", "Document", "words_to_lines", "ImageReader", "read_ocr_result_from_txt"]

View File

@ -0,0 +1,82 @@
addict==2.4.0
asttokens==2.2.1
autopep8==1.6.0
backcall==0.2.0
backports.functools-lru-cache==1.6.4
brotlipy==0.7.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==2.0.4
click==8.1.3
colorama==0.4.6
cryptography==39.0.1
debugpy==1.5.1
decorator==5.1.1
docopt==0.6.2
entrypoints==0.4
executing==1.2.0
flit_core==3.6.0
idna==3.4
importlib-metadata==6.0.0
ipykernel==6.15.0
ipython==8.11.0
jedi==0.18.2
jupyter-client==7.0.6
jupyter_core==4.12.0
Markdown==3.4.1
markdown-it-py==2.2.0
matplotlib-inline==0.1.6
mdurl==0.1.2
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
mmcv-full==1.7.1
model-index==0.1.11
nest-asyncio==1.5.6
numpy==1.23.5
opencv-python==4.7.0.72
openmim==0.3.6
ordered-set==4.1.0
packaging==23.0
pandas==1.5.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.4.0
pip==22.3.1
pipdeptree==2.5.2
prompt-toolkit==3.0.38
psutil==5.9.0
ptyprocess==0.7.0
pure-eval==0.2.2
pycodestyle==2.10.0
pycparser==2.21
Pygments==2.14.0
pyOpenSSL==23.0.0
PySocks==1.7.1
python-dateutil==2.8.2
pytz==2022.7.1
PyYAML==6.0
pyzmq==19.0.2
requests==2.28.1
rich==13.3.1
sdsvtd==0.1.1
sdsvtr==0.0.5
setuptools==65.6.3
Shapely==1.8.4
six==1.16.0
stack-data==0.6.2
tabulate==0.9.0
toml==0.10.2
torch==1.13.1
torchvision==0.14.1
tornado==6.1
tqdm==4.65.0
traitlets==5.9.0
typing_extensions==4.4.0
urllib3==1.26.14
wcwidth==0.2.6
wheel==0.38.4
yapf==0.32.0
yarg==0.1.9
zipp==3.15.0

View File

@ -0,0 +1,143 @@
"""
see scripts/run_ocr.sh to run
"""
# from pathlib import Path # add parent path to run debugger
# import sys
# FILE = Path(__file__).absolute()
# sys.path.append(FILE.parents[2].as_posix())
from src.utils import construct_file_path, ImageReader
from src.dto import Line
from src.ocr import OcrEngine
import argparse
import tqdm
import pandas as pd
from pathlib import Path
import json
import os
import numpy as np
from typing import Union, Tuple, List
current_dir = os.getcwd()
def get_args():
parser = argparse.ArgumentParser()
# parser image
parser.add_argument("--image", type=str, required=True,
help="path to input image/directory/csv file")
parser.add_argument("--save_dir", type=str, required=True,
help="path to save directory")
parser.add_argument(
"--base_dir", type=str, required=False, default=current_dir,
help="used when --image and --save_dir are relative paths to a base directory, default to current directory")
parser.add_argument(
"--export_csv", type=str, required=False, default="",
help="used when --image is a directory. If set, a csv file contains image_path, ocr_path and label will be exported to save_dir.")
parser.add_argument(
"--export_img", type=bool, required=False, default=False, help="whether to save the visualize img")
parser.add_argument("--ocr_kwargs", type=str, required=False, default="")
opt = parser.parse_args()
return opt
def load_engine(opt) -> OcrEngine:
print("[INFO] Loading engine...")
kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {}
engine = OcrEngine(**kw)
print("[INFO] Engine loaded")
return engine
def convert_relative_path_to_positive_path(tgt_dir: Path, base_dir: Path) -> Path:
return tgt_dir if tgt_dir.is_absolute() else base_dir.joinpath(tgt_dir)
def get_paths_from_opt(opt) -> Tuple[Path, Path]:
# BC\ kiem\ tra\ y\ te -> BC kiem tra y te
img_path = opt.image.replace("\\ ", " ").strip()
save_dir = opt.save_dir.replace("\\ ", " ").strip()
base_dir = opt.base_dir.replace("\\ ", " ").strip()
input_image = convert_relative_path_to_positive_path(
Path(img_path), Path(base_dir))
save_dir = convert_relative_path_to_positive_path(
Path(save_dir), Path(base_dir))
if not save_dir.exists():
save_dir.mkdir()
print("[INFO]: Creating folder ", save_dir)
return input_image, save_dir
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
save_dir_or_path = Path(save_dir_or_path)
if isinstance(img, np.ndarray):
if save_dir_or_path.is_dir():
raise ValueError(
"numpy array input require a save path, not a save dir")
page = engine(img)
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
page.write_to_file('word', save_path)
if export_img:
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
def process_dir(
dir_path: str, save_dir: str, engine: OcrEngine, export_img: bool, lskip_dir: List[str] = [],
ddata: dict = {"img_path": list(),
"ocr_path": list(),
"label": list()}) -> None:
dir_path = Path(dir_path)
# save_dir_sub = Path(construct_file_path(save_dir, dir_path, ext=""))
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
for img_path in (pbar := tqdm.tqdm(dir_path.iterdir())):
pbar.set_description(f"Processing {dir_path}")
if img_path.is_dir() and img_path not in lskip_dir:
save_dir_sub = save_dir.joinpath(img_path.stem)
process_dir(img_path, str(save_dir_sub), engine, ddata)
elif img_path.suffix.lower() in ImageReader.supported_ext:
simg_path = str(img_path)
try:
img = ImageReader.read(
simg_path) if img_path.suffix != ".pdf" else ImageReader.read(simg_path)[0]
save_path = str(Path(save_dir).joinpath(
img_path.stem + ".txt"))
process_img(img, save_path, engine, export_img)
except Exception as e:
print('[ERROR]: ', e, ' at ', simg_path)
continue
ddata["img_path"].append(simg_path)
ddata["ocr_path"].append(save_path)
ddata["label"].append(dir_path.stem)
# ddata.update({"img_path": img_path, "save_path": save_path, "label": dir_path.stem})
return ddata
def process_csv(csv_path: str, engine: OcrEngine) -> None:
df = pd.read_csv(csv_path)
if not 'image_path' in df.columns or not 'ocr_path' in df.columns:
raise AssertionError('Cannot fing image_path in df headers')
for row in df.iterrows():
process_img(row.image_path, row.ocr_path, engine)
if __name__ == "__main__":
opt = get_args()
engine = load_engine(opt)
print("[INFO]: OCR engine settings:", engine.settings)
img, save_dir = get_paths_from_opt(opt)
lskip_dir = []
if img.is_dir():
ddata = process_dir(img, save_dir, engine, opt.export_img)
if opt.export_csv:
pd.DataFrame.from_dict(ddata).to_csv(
Path(save_dir).joinpath(opt.export_csv))
elif img.suffix in ImageReader.supported_ext:
process_img(str(img), save_dir, engine, opt.export_img)
elif img.suffix == '.csv':
print("[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used")
process_csv(img, engine)
else:
raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img))

View File

@ -0,0 +1,23 @@
#bash scripts/run_ocr.sh -i /mnt/ssd1T/hungbnt/DocumentClassification/data/OCR040_043 -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr/OCR040_043 -e out.csv -k "{\"device\":\"cuda:1\"}" -x True
export PYTHONWARNINGS="ignore"
while getopts i:o:b:e:x:k: flag
do
case "${flag}" in
i) img=${OPTARG};;
o) out_dir=${OPTARG};;
b) base_dir=${OPTARG};;
e) export_csv=${OPTARG};;
x) export_img=${OPTARG};;
k) ocr_kwargs=${OPTARG};;
esac
done
echo "run.py --image=\"$img\" --save_dir \"$out_dir\" --base_dir \"$base_dir\" --export_csv \"$export_csv\" --export_img \"$export_img\" --ocr_kwargs \"$ocr_kwargs\""
python run.py \
--image="$img" \
--save_dir $out_dir \
--export_csv $export_csv\
--export_img $export_img\
--ocr_kwargs $ocr_kwargs\

View File

@ -0,0 +1,17 @@
detector: "/models/Kie_invoice_ap/wild_receipt_finetune_weights_c_lite.pth"
rotator_version: "/models/Kie_invoice_ap/best_bbox_mAP_epoch_30_lite.pth"
recog_max_seq_len: 25
recognizer: "satrn-lite-general-pretrain-20230106"
device: "cuda:0"
do_extend_bbox: True
margin_bbox: [0, 0.03, 0.02, 0.05]
batch_mode: False
batch_size: 16
auto_rotate: True
img_size: [] #[1920,1920] #text det default size: 1280x1280 #[] = originla size, TODO: fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size
deskew: False #True
words_to_lines: {
"gradient": 0.6,
"max_x_dist": 20,
"y_overlap_threshold": 0.5,
}

View File

@ -0,0 +1,453 @@
import numpy as np
from typing import Optional, List
import cv2
from PIL import Image
from .utils import visualize_bbox_and_label
class Box:
def __init__(self, x1, y1, x2, y2, conf=-1., label=""):
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.conf = conf
self.label = label
def __repr__(self) -> str:
return str(self.bbox)
def __str__(self) -> str:
return str(self.bbox)
def get(self, return_confidence=False) -> list:
return self.bbox if not return_confidence else self.xyxyc
def __getitem__(self, key):
return self.bbox[key]
@property
def width(self):
return self.x2 - self.x1
@property
def height(self):
return self.y2 - self.y1
@property
def bbox(self) -> list:
return [self.x1, self.y1, self.x2, self.y2]
@bbox.setter
def bbox(self, bbox_: list):
self.x1, self.y1, self.x2, self.y2 = bbox_
@property
def xyxyc(self) -> list:
return [self.x1, self.y1, self.x2, self.y2, self.conf]
@staticmethod
def normalize_bbox(bbox: list):
return [int(b) for b in bbox]
def to_int(self):
self.x1, self.y1, self.x2, self.y2 = self.normalize_bbox([self.x1, self.y1, self.x2, self.y2])
return self
@staticmethod
def clamp_bbox_by_img_wh(bbox: list, width: int, height: int):
x1, y1, x2, y2 = bbox
x1 = min(max(0, x1), width)
x2 = min(max(0, x2), width)
y1 = min(max(0, y1), height)
y2 = min(max(0, y2), height)
return (x1, y1, x2, y2)
def clamp_by_img_wh(self, width: int, height: int):
self.x1, self.y1, self.x2, self.y2 = self.clamp_bbox_by_img_wh(
[self.x1, self.y1, self.x2, self.y2], width, height)
return self
@staticmethod
def extend_bbox(bbox: list, margin: list): # -> Self (python3.11)
margin_l, margin_t, margin_r, margin_b = margin
l, t, r, b = bbox # left, top, right, bottom
t = t - (b - t) * margin_t
b = b + (b - t) * margin_b
l = l - (r - l) * margin_l
r = r + (r - l) * margin_r
return [l, t, r, b]
def get_extend_bbox(self, margin: list):
extended_bbox = self.extend_bbox(self.bbox, margin)
return Box(*extended_bbox, label=self.label)
@staticmethod
def bbox_is_valid(bbox: list) -> bool:
l, t, r, b = bbox # left, top, right, bottom
return True if (b - t) * (r - l) > 0 else False
def is_valid(self) -> bool:
return self.bbox_is_valid(self.bbox)
@staticmethod
def crop_img_by_bbox(img: np.ndarray, bbox: list) -> np.ndarray:
l, t, r, b = bbox
return img[t:b, l:r]
def crop_img(self, img: np.ndarray) -> np.ndarray:
return self.crop_img_by_bbox(img, self.bbox)
class Word:
def __init__(
self,
image=None,
text="",
conf_cls=-1.,
bndbox: Optional[Box] = None,
conf_detect=-1.,
kie_label="",
):
self.type = "word"
self.text = text
self.image = image
self.conf_detect = conf_detect
self.conf_cls = conf_cls
# [left, top,right,bot] coordinate of top-left and bottom-right point
self.boundingbox = bndbox
self.word_id = 0 # id of word
self.word_group_id = 0 # id of word_group which instance belongs to
self.line_id = 0 # id of line which instance belongs to
self.paragraph_id = 0 # id of line which instance belongs to
self.kie_label = kie_label
@property
def bbox(self):
return self.boundingbox.bbox
@property
def height(self):
return self.boundingbox.height
@property
def width(self):
return self.boundingbox.width
def __repr__(self) -> str:
return self.text
def __str__(self) -> str:
return self.text
def invalid_size(self):
return (self.boundingbox[2] - self.boundingbox[0]) * (
self.boundingbox[3] - self.boundingbox[1]
) > 0
def is_special_word(self):
left, top, right, bottom = self.boundingbox
width, height = right - left, bottom - top
text = self.text
if text is None:
return True
# if len(text) > 7:
# return True
if len(text) >= 7:
no_digits = sum(c.isdigit() for c in text)
return no_digits / len(text) >= 0.3
return False
class Word_group:
def __init__(self, list_words_: List[Word] = list(),
text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1):
self.type = "word_group"
self.list_words = list_words_ # dict of word instances
self.word_group_id = 0 # word group id
self.line_id = 0 # id of line which instance belongs to
self.paragraph_id = 0 # id of paragraph which instance belongs to
self.text = text
self.boundingbox = boundingbox
self.kie_label = ""
self.conf_cls = conf_cls
@property
def bbox(self):
return self.boundingbox
def __repr__(self) -> str:
return self.text
def __str__(self) -> str:
return self.text
def add_word(self, word: Word): # add a word instance to the word_group
if word.text != "":
for w in self.list_words:
if word.word_id == w.word_id:
print("Word id collision")
return False
word.word_group_id = self.word_group_id #
word.line_id = self.line_id
word.paragraph_id = self.paragraph_id
self.list_words.append(word)
self.text += " " + word.text
if self.boundingbox == [-1, -1, -1, -1]:
self.boundingbox = word.boundingbox
else:
self.boundingbox = [
min(self.boundingbox[0], word.boundingbox[0]),
min(self.boundingbox[1], word.boundingbox[1]),
max(self.boundingbox[2], word.boundingbox[2]),
max(self.boundingbox[3], word.boundingbox[3]),
]
return True
else:
return False
def update_word_group_id(self, new_word_group_id):
self.word_group_id = new_word_group_id
for i in range(len(self.list_words)):
self.list_words[i].word_group_id = new_word_group_id
def update_kie_label(self):
list_kie_label = [word.kie_label for word in self.list_words]
dict_kie = dict()
for label in list_kie_label:
if label not in dict_kie:
dict_kie[label] = 1
else:
dict_kie[label] += 1
total = len(list(dict_kie.values()))
max_value = max(list(dict_kie.values()))
list_keys = list(dict_kie.keys())
list_values = list(dict_kie.values())
self.kie_label = list_keys[list_values.index(max_value)]
def update_text(self): # update text after changing positions of words in list word
text = ""
for word in self.list_words:
text += " " + word.text
self.text = text
class Line:
def __init__(self, list_word_groups: List[Word_group] = [],
text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1):
self.type = "line"
self.list_word_groups = list_word_groups # list of Word_group instances in the line
self.line_id = 0 # id of line in the paragraph
self.paragraph_id = 0 # id of paragraph which instance belongs to
self.text = text
self.boundingbox = boundingbox
self.conf_cls = conf_cls
@property
def bbox(self):
return self.boundingbox
def __repr__(self) -> str:
return self.text
def __str__(self) -> str:
return self.text
def add_group(self, word_group: Word_group): # add a word_group instance
if word_group.list_words is not None:
for wg in self.list_word_groups:
if word_group.word_group_id == wg.word_group_id:
print("Word_group id collision")
return False
self.list_word_groups.append(word_group)
self.text += word_group.text
word_group.paragraph_id = self.paragraph_id
word_group.line_id = self.line_id
for i in range(len(word_group.list_words)):
word_group.list_words[
i
].paragraph_id = self.paragraph_id # set paragraph_id for word
word_group.list_words[i].line_id = self.line_id # set line_id for word
return True
return False
def update_line_id(self, new_line_id):
self.line_id = new_line_id
for i in range(len(self.list_word_groups)):
self.list_word_groups[i].line_id = new_line_id
for j in range(len(self.list_word_groups[i].list_words)):
self.list_word_groups[i].list_words[j].line_id = new_line_id
def merge_word(self, word): # word can be a Word instance or a Word_group instance
if word.text != "":
if self.boundingbox == [-1, -1, -1, -1]:
self.boundingbox = word.boundingbox
else:
self.boundingbox = [
min(self.boundingbox[0], word.boundingbox[0]),
min(self.boundingbox[1], word.boundingbox[1]),
max(self.boundingbox[2], word.boundingbox[2]),
max(self.boundingbox[3], word.boundingbox[3]),
]
self.list_word_groups.append(word)
self.text += " " + word.text
return True
return False
def __cal_ratio(self, top1, bottom1, top2, bottom2):
sorted_vals = sorted([top1, bottom1, top2, bottom2])
intersection = sorted_vals[2] - sorted_vals[1]
min_height = min(bottom1 - top1, bottom2 - top2)
if min_height == 0:
return -1
ratio = intersection / min_height
return ratio
def __cal_ratio_height(self, top1, bottom1, top2, bottom2):
height1, height2 = top1 - bottom1, top2 - bottom2
ratio_height = float(max(height1, height2)) / float(min(height1, height2))
return ratio_height
def in_same_line(self, input_line, thresh=0.7):
# calculate iou in vertical direction
_, top1, _, bottom1 = self.boundingbox
_, top2, _, bottom2 = input_line.boundingbox
ratio = self.__cal_ratio(top1, bottom1, top2, bottom2)
ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2)
if (
(top2 <= top1 <= bottom2) or (top1 <= top2 <= bottom1)
and ratio >= thresh
and (ratio_height < 2)
):
return True
return False
class Paragraph:
def __init__(self, id=0, lines=None):
self.list_lines = lines if lines is not None else [] # list of all lines in the paragraph
self.paragraph_id = id # index of paragraph in the ist of paragraph
self.text = ""
self.boundingbox = [-1, -1, -1, -1]
@property
def bbox(self):
return self.boundingbox
def __repr__(self) -> str:
return self.text
def __str__(self) -> str:
return self.text
def add_line(self, line: Line): # add a line instance
if line.list_word_groups is not None:
for l in self.list_lines:
if line.line_id == l.line_id:
print("Line id collision")
return False
for i in range(len(line.list_word_groups)):
line.list_word_groups[
i
].paragraph_id = (
self.paragraph_id
) # set paragraph id for every word group in line
for j in range(len(line.list_word_groups[i].list_words)):
line.list_word_groups[i].list_words[
j
].paragraph_id = (
self.paragraph_id
) # set paragraph id for every word in word groups
line.paragraph_id = self.paragraph_id # set paragraph id for line
self.list_lines.append(line) # add line to paragraph
self.text += " " + line.text
return True
else:
return False
def update_paragraph_id(
self, new_paragraph_id
): # update new paragraph_id for all lines, word_groups, words inside paragraph
self.paragraph_id = new_paragraph_id
for i in range(len(self.list_lines)):
self.list_lines[
i
].paragraph_id = new_paragraph_id # set new paragraph_id for line
for j in range(len(self.list_lines[i].list_word_groups)):
self.list_lines[i].list_word_groups[
j
].paragraph_id = new_paragraph_id # set new paragraph_id for word_group
for k in range(len(self.list_lines[i].list_word_groups[j].list_words)):
self.list_lines[i].list_word_groups[j].list_words[
k
].paragraph_id = new_paragraph_id # set new paragraph id for word
return True
class Page:
def __init__(self, llines: List[Line], image: np.ndarray) -> None:
self.__llines = llines
self.__image = image
self.__drawed_image = None
@property
def llines(self):
return self.__llines
@property
def image(self):
return self.__image
@property
def PIL_image(self):
return Image.fromarray(self.__image)
@property
def drawed_image(self):
if self.__drawed_image:
self.__drawed_image = self
def visualize_bbox_and_label(self, **kwargs: dict):
if self.__drawed_image is not None:
return self.__drawed_image
bboxes = list()
texts = list()
for line in self.__llines:
for word_group in line.list_word_groups:
for word in word_group.list_words:
bboxes.append([int(float(b)) for b in word.bbox[:]])
texts.append(word.text)
img = visualize_bbox_and_label(self.__image, bboxes, texts, **kwargs)
self.__drawed_image = img
return self.__drawed_image
def save_img(self, save_path: str, **kwargs: dict) -> None:
img = self.visualize_bbox_and_label(**kwargs)
cv2.imwrite(save_path, img)
def write_to_file(self, mode: str, save_path: str) -> None:
f = open(save_path, "w+", encoding="utf-8")
for line in self.__llines:
if mode == 'line':
xmin, ymin, xmax, ymax = line.bbox[:]
f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, line.text))
elif mode == "word":
for word_group in line.list_word_groups:
for word in word_group.list_words:
# xmin, ymin, xmax, ymax = word.bbox[:]
xmin, ymin, xmax, ymax = [int(float(b)) for b in word.bbox[:]]
f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, word.text))
f.close()
class Document:
def __init__(self, lpages: List[Page]) -> None:
self.lpages = lpages

View File

@ -0,0 +1,207 @@
from typing import Union, overload, List, Optional, Tuple
from PIL import Image
import torch
import numpy as np
import yaml
from pathlib import Path
import mmcv
from sdsvtd import StandaloneYOLOXRunner
from sdsvtr import StandaloneSATRNRunner
from .utils import ImageReader, chunks, rotate_bbox, Timer
# from .utils import jdeskew as deskew
# from externals.deskew.sdsv_dewarp import pdeskew as deskew
from .utils import deskew, post_process_recog
from .dto import Word, Line, Page, Document, Box
# from .word_formation import words_to_lines as words_to_lines
# from .word_formation import wo rds_to_lines_mmocr as words_to_lines
from .word_formation import words_to_lines_tesseract as words_to_lines
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
class OcrEngine:
def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict):
""" Warper of text detection and text recognition
:param settings_file: path to default setting file
:param kwargs: keyword arguments to overwrite the default settings file
"""
with open(settings_file) as f:
# use safe_load instead load
self.__settings = yaml.safe_load(f)
for k, v in kwargs.items(): # overwrite default settings by keyword arguments
if k not in self.__settings:
raise ValueError("Invalid setting found in OcrEngine: ", k)
self.__settings[k] = v
if "cuda" in self.__settings["device"]:
if not torch.cuda.is_available():
print("[WARNING]: CUDA is not available, running with cpu instead")
self.__settings["device"] = "cpu"
self._detector = StandaloneYOLOXRunner(
version=self.__settings["detector"],
device=self.__settings["device"],
auto_rotate=self.__settings["auto_rotate"],
rotator_version=self.__settings["rotator_version"])
self._recognizer = StandaloneSATRNRunner(
version=self.__settings["recognizer"],
return_confident=True, device=self.__settings["device"],
max_seq_len_overwrite=self.__settings["recog_max_seq_len"]
)
# extend the bbox to avoid losing accent mark in vietnames, if using ocr for only english, disable it
self._do_extend_bbox = self.__settings["do_extend_bbox"]
# left, top, right, bottom"]
self._margin_bbox = self.__settings["margin_bbox"]
self._batch_mode = self.__settings["batch_mode"]
self._batch_size = self.__settings["batch_size"]
self._deskew = self.__settings["deskew"]
self._img_size = self.__settings["img_size"]
self.__version__ = {
"detector": self.__settings["detector"],
"recognizer": self.__settings["recognizer"],
}
@property
def version(self):
return self.__version__
@property
def settings(self):
return self.__settings
# @staticmethod
# def xyxyc_to_xyxy_c(xyxyc: np.ndarray) -> Tuple[List[list], list]:
# '''
# convert sdsvtd yoloX detection output to list of bboxes and list of confidences
# @param xyxyc: array of shape (n, 5)
# '''
# xyxy = xyxyc[:, :4].tolist()
# confs = xyxyc[:, 4].tolist()
# return xyxy, confs
# -> Tuple[np.ndarray, List[Box]]:
def preprocess(self, img: np.ndarray) -> np.ndarray:
img_ = img.copy()
if self.__settings["img_size"]:
img_ = mmcv.imrescale(
img, tuple(self.__settings["img_size"]),
return_scale=False, interpolation='bilinear', backend='cv2')
if self._deskew:
with Timer("deskew"):
img_, angle = deskew(img_)
# for i, bbox in enumerate(bboxes):
# rotated_bbox = rotate_bbox(bbox[:], angle, img.shape[:2])
# bboxes[i].bbox = rotated_bbox
return img_ # , bboxes
def run_detect(self, img: np.ndarray, return_raw: bool = False) -> Tuple[np.ndarray, Union[List[Box], List[list]]]:
'''
run text detection and return list of xyxyc if return_confidence is True, otherwise return a list of xyxy
'''
pred_det = self._detector(img)
if self.__settings["auto_rotate"]:
img, pred_det = pred_det
pred_det = pred_det[0] # only image at a time
return (img, pred_det.tolist()) if return_raw else (img, [Box(*xyxyc) for xyxyc in pred_det.tolist()])
def run_recog(self, imgs: List[np.ndarray]) -> Union[List[str], List[Tuple[str, float]]]:
if len(imgs) == 0:
return list()
pred_rec = self._recognizer(imgs)
return [(post_process_recog(word), conf) for word, conf in zip(pred_rec[0], pred_rec[1])]
def read_img(self, img: str) -> np.ndarray:
return ImageReader.read(img)
def get_cropped_imgs(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> Tuple[List[np.ndarray], List[bool]]:
"""
img: np image
bboxes: list of xyxy
"""
lcropped_imgs = list()
mask = list()
for bbox in bboxes:
bbox = Box(*bbox) if isinstance(bbox, list) else bbox
bbox = bbox.get_extend_bbox(
self._margin_bbox) if self._do_extend_bbox else bbox
bbox.clamp_by_img_wh(img.shape[1], img.shape[0])
bbox.to_int()
if not bbox.is_valid():
mask.append(False)
continue
cropped_img = bbox.crop_img(img)
lcropped_imgs.append(cropped_img)
mask.append(True)
return lcropped_imgs, mask
def read_page(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> List[Line]:
if len(bboxes) == 0: # no bbox found
return list()
with Timer("cropped imgs"):
lcropped_imgs, mask = self.get_cropped_imgs(img, bboxes)
with Timer("recog"):
# batch_mode for efficiency
pred_recs = self.run_recog(lcropped_imgs)
with Timer("construct words"):
lwords = list()
for i in range(len(pred_recs)):
if not mask[i]:
continue
text, conf_rec = pred_recs[i][0], pred_recs[i][1]
bbox = Box(*bboxes[i]) if isinstance(bboxes[i],
list) else bboxes[i]
lwords.append(Word(
image=img, text=text, conf_cls=conf_rec, bndbox=bbox, conf_detect=bbox.conf))
with Timer("words to lines"):
return words_to_lines(
lwords, **self.__settings["words_to_lines"])[0]
# https://stackoverflow.com/questions/48127642/incompatible-types-in-assignment-on-union
@overload
def __call__(self, img: Union[str, np.ndarray, Image.Image]) -> Page: ...
@overload
def __call__(
self, img: List[Union[str, np.ndarray, Image.Image]]) -> Document: ...
def __call__(self, img):
"""
Accept an image or list of them, return ocr result as a page or document
"""
with Timer("read image"):
img = ImageReader.read(img)
if not self._batch_mode:
if isinstance(img, list):
if len(img) == 1:
img = img[0] # in case input type is a 1 page pdf
else:
raise AssertionError(
"list input can only be used with batch_mode enabled")
img = self.preprocess(img)
with Timer("detect"):
img, bboxes = self.run_detect(img)
with Timer("read_page"):
llines = self.read_page(img, bboxes)
return Page(llines, img)
else:
lpages = []
# chunks to reduce memory footprint
for imgs in chunks(img, self._batch_size):
# pred_dets = self._detector(imgs)
# TEMP: use list comprehension because sdsvtd do not support batch mode of text detection
img = self.preprocess(img)
img, bboxes = self.run_detect(img)
for img_, bboxes_ in zip(imgs, bboxes):
llines = self.read_page(img, bboxes_)
page = Page(llines, img)
lpages.append(page)
return Document(lpages)
if __name__ == "__main__":
img_path = "/mnt/ssd1T/hungbnt/Cello/data/PH/Sea7/Sea_7_1.jpg"
engine = OcrEngine(device="cuda:0", return_confidence=True)
# https://stackoverflow.com/questions/66435480/overload-following-optional-argument
page = engine(img_path) # type: ignore
print(page.__llines)

View File

@ -0,0 +1,346 @@
from PIL import ImageFont, ImageDraw, Image, ImageOps
# import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import time
from typing import Generator, Union, List, overload, Tuple, Callable
import glob
import math
from pathlib import Path
from pdf2image import convert_from_path
from deskew import determine_skew
from jdeskew.estimator import get_angle
from jdeskew.utility import rotate as jrotate
def post_process_recog(text: str) -> str:
text = text.replace("", " ")
return text
class Timer:
def __init__(self, name: str) -> None:
self.name = name
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, func: Callable, *args):
self.end_time = time.perf_counter()
self.elapsed_time = self.end_time - self.start_time
print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds")
def rotate(
image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]
) -> np.ndarray:
old_width, old_height = image.shape[:2]
angle_radian = math.radians(angle)
width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width)
height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height)
image_center = tuple(np.array(image.shape[1::-1]) / 2)
rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
rot_mat[1, 2] += (width - old_width) / 2
rot_mat[0, 2] += (height - old_height) / 2
return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background)
# def rotate_bbox(bbox: list, angle: float) -> list:
# # Compute the center point of the bounding box
# cx = bbox[0] + bbox[2] / 2
# cy = bbox[1] + bbox[3] / 2
# # Define the scale factor for the rotated bounding box
# scale = 1.0 # following the deskew and jdeskew function
# angle_radian = math.radians(angle)
# # Obtain the rotation matrix using cv2.getRotationMatrix2D()
# M = cv2.getRotationMatrix2D((cx, cy), angle_radian, scale)
# # Apply the rotation matrix to the four corners of the bounding box
# corners = np.array([[bbox[0], bbox[1]],
# [bbox[0] + bbox[2], bbox[1]],
# [bbox[0] + bbox[2], bbox[1] + bbox[3]],
# [bbox[0], bbox[1] + bbox[3]]], dtype=np.float32)
# rotated_corners = cv2.transform(np.array([corners]), M)[0]
# # Compute the bounding box of the rotated corners
# x = int(np.min(rotated_corners[:, 0]))
# y = int(np.min(rotated_corners[:, 1]))
# w = int(np.max(rotated_corners[:, 0]) - np.min(rotated_corners[:, 0]))
# h = int(np.max(rotated_corners[:, 1]) - np.min(rotated_corners[:, 1]))
# rotated_bbox = [x, y, w, h]
# return rotated_bbox
def rotate_bbox(bbox: List[int], angle: float, old_shape: Tuple[int, int]) -> List[int]:
# https://medium.com/@pokomaru/image-and-bounding-box-rotation-using-opencv-python-2def6c39453
bbox_ = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]]
h, w = old_shape
cx, cy = (int(w / 2), int(h / 2))
bbox_tuple = [
(bbox_[0], bbox_[1]),
(bbox_[2], bbox_[3]),
(bbox_[4], bbox_[5]),
(bbox_[6], bbox_[7]),
] # put x and y coordinates in tuples, we will iterate through the tuples and perform rotation
rotated_bbox = []
for i, coord in enumerate(bbox_tuple):
M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
cos, sin = abs(M[0, 0]), abs(M[0, 1])
newW = int((h * sin) + (w * cos))
newH = int((h * cos) + (w * sin))
M[0, 2] += (newW / 2) - cx
M[1, 2] += (newH / 2) - cy
v = [coord[0], coord[1], 1]
adjusted_coord = np.dot(M, v)
rotated_bbox.insert(i, (adjusted_coord[0], adjusted_coord[1]))
result = [int(x) for t in rotated_bbox for x in t]
return [result[i] for i in [0, 1, 2, -1]] # reformat to xyxy
def deskew(image: np.ndarray) -> Tuple[np.ndarray, float]:
grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
angle = 0.
try:
angle = determine_skew(grayscale)
except Exception:
pass
rotated = rotate(image, angle, (0, 0, 0)) if angle else image
return rotated, angle
def jdeskew(image: np.ndarray) -> Tuple[np.ndarray, float]:
angle = 0.
try:
angle = get_angle(image)
except Exception:
pass
# TODO: change resize = True and scale the bounding box
rotated = jrotate(image, angle, resize=False) if angle else image
return rotated, angle
class ImageReader:
"""
accept anything, return numpy array image
"""
supported_ext = [".png", ".jpg", ".jpeg", ".pdf", ".gif"]
@staticmethod
def validate_img_path(img_path: str) -> None:
if not os.path.exists(img_path):
raise FileNotFoundError(img_path)
if os.path.isdir(img_path):
raise IsADirectoryError(img_path)
if not Path(img_path).suffix.lower() in ImageReader.supported_ext:
raise NotImplementedError("Not supported extension at {}".format(img_path))
@overload
@staticmethod
def read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: ...
@overload
@staticmethod
def read(img: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: ...
@overload
@staticmethod
def read(img: str) -> List[np.ndarray]: ... # for pdf or directory
@staticmethod
def read(img):
if isinstance(img, list):
return ImageReader.from_list(img)
elif isinstance(img, str) and os.path.isdir(img):
return ImageReader.from_dir(img)
elif isinstance(img, str) and img.endswith(".pdf"):
return ImageReader.from_pdf(img)
else:
return ImageReader._read(img)
@staticmethod
def from_dir(dir_path: str) -> List[np.ndarray]:
if os.path.isdir(dir_path):
image_files = glob.glob(os.path.join(dir_path, "*"))
return ImageReader.from_list(image_files)
else:
raise NotADirectoryError(dir_path)
@staticmethod
def from_str(img_path: str) -> np.ndarray:
ImageReader.validate_img_path(img_path)
return ImageReader.from_PIL(Image.open(img_path))
@staticmethod
def from_np(img_array: np.ndarray) -> np.ndarray:
return img_array
@staticmethod
def from_PIL(img_pil: Image.Image, transpose=True) -> np.ndarray:
# if img_pil.is_animated:
# raise NotImplementedError("Only static images are supported, animated image found")
if transpose:
img_pil = ImageOps.exif_transpose(img_pil)
if img_pil.mode != "RGB":
img_pil = img_pil.convert("RGB")
return np.array(img_pil)
@staticmethod
def from_list(img_list: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]:
limgs = list()
for img_path in img_list:
try:
if isinstance(img_path, str):
ImageReader.validate_img_path(img_path)
limgs.append(ImageReader._read(img_path))
except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e:
print("[ERROR]: ", e)
print("[INFO]: Skipping image {}".format(img_path))
return limgs
@staticmethod
def from_pdf(pdf_path: str, start_page: int = 0, end_page: int = 0) -> List[np.ndarray]:
pdf_file = convert_from_path(pdf_path)
if end_page is not None:
end_page = min(len(pdf_file), end_page + 1)
limgs = [np.array(pdf_page) for pdf_page in pdf_file[start_page:end_page]]
return limgs
@staticmethod
def _read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray:
if isinstance(img, str):
return ImageReader.from_str(img)
elif isinstance(img, Image.Image):
return ImageReader.from_PIL(img)
elif isinstance(img, np.ndarray):
return ImageReader.from_np(img)
else:
raise ValueError("Invalid img argument type: ", type(img))
def get_name(file_path, ext: bool = True):
file_path_ = os.path.basename(file_path)
return file_path_ if ext else os.path.splitext(file_path_)[0]
def construct_file_path(dir, file_path, ext=''):
'''
args:
dir: /path/to/dir
file_path /example_path/to/file.txt
ext = '.json'
return
/path/to/dir/file.json
'''
return os.path.join(
dir, get_name(file_path,
True)) if ext == '' else os.path.join(
dir, get_name(file_path,
False)) + ext
def chunks(lst: list, n: int) -> Generator:
"""
Yield successive n-sized chunks from lst.
https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equally-sized-chunks
"""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def read_ocr_result_from_txt(file_path: str) -> Tuple[list, list]:
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
x1, y1, x2, y2, text = line.split("\t")
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_xyxywh_base_on_format(bbox, format):
if format == "xywh":
x1, y1, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
x2, y2 = x1 + w, y1 + h
elif format == "xyxy":
x1, y1, x2, y2 = bbox
w, h = x2 - x1, y2 - y1
else:
raise NotImplementedError("Invalid format {}".format(format))
return (x1, y1, x2, y2, w, h)
def get_dynamic_params_for_bbox_of_label(text, x1, y1, w, h, img_h, img_w, font):
font_scale_factor = img_h / (img_w + img_h)
font_scale = w / (w + h) * font_scale_factor # adjust font scale by width height
thickness = int(font_scale_factor) + 1
(text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0]
text_offset_x = x1
text_offset_y = y1 - thickness
box_coords = ((text_offset_x, text_offset_y + 1), (text_offset_x + text_width - 2, text_offset_y - text_height - 2))
return (font_scale, thickness, text_height, box_coords)
def visualize_bbox_and_label(
img, bboxes, texts, bbox_color=(200, 180, 60),
text_color=(0, 0, 0),
format="xyxy", is_vnese=False, draw_text=True):
ori_img_type = type(img)
if is_vnese:
img = Image.fromarray(img) if ori_img_type is np.ndarray else img
draw = ImageDraw.Draw(img)
img_w, img_h = img.size
font_pil_str = "fonts/arial.ttf"
font_cv2 = cv2.FONT_HERSHEY_SIMPLEX
else:
img_h, img_w = img.shape[0], img.shape[1]
font_cv2 = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(bboxes)):
text = texts[i] # text = "{}: {:.0f}%".format(LABELS[classIDs[i]], confidences[i]*100)
x1, y1, x2, y2, w, h = get_xyxywh_base_on_format(bboxes[i], format)
font_scale, thickness, text_height, box_coords = get_dynamic_params_for_bbox_of_label(
text, x1, y1, w, h, img_h, img_w, font=font_cv2)
if is_vnese:
font_pil = ImageFont.truetype(font_pil_str, size=text_height) # type: ignore
fdraw_text = draw.text # type: ignore
fdraw_bbox = draw.rectangle # type: ignore
# Pil use different coordinate => y = y+thickness = y-thickness + 2*thickness
arg_text = ((box_coords[0][0], box_coords[1][1]), text)
kwarg_text = {"font": font_pil, "fill": text_color, "width": thickness}
arg_rec = ((x1, y1, x2, y2),)
kwarg_rec = {"outline": bbox_color, "width": thickness}
arg_rec_text = ((box_coords[0], box_coords[1]),)
kwarg_rec_text = {"fill": bbox_color, "width": thickness}
else:
# cv2.rectangle(img, box_coords[0], box_coords[1], color, cv2.FILLED)
# cv2.putText(img, text, (text_offset_x, text_offset_y), font, fontScale=font_scale, color=(50, 0,0), thickness=thickness)
# cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
fdraw_text = cv2.putText
fdraw_bbox = cv2.rectangle
arg_text = (img, text, box_coords[0])
kwarg_text = {"fontFace": font_cv2, "fontScale": font_scale, "color": text_color, "thickness": thickness}
arg_rec = (img, (x1, y1), (x2, y2))
kwarg_rec = {"color": bbox_color, "thickness": thickness}
arg_rec_text = (img, box_coords[0], box_coords[1])
kwarg_rec_text = {"color": bbox_color, "thickness": cv2.FILLED}
# draw a bounding box rectangle and label on the img
fdraw_bbox(*arg_rec, **kwarg_rec) # type: ignore
if draw_text:
fdraw_bbox(*arg_rec_text, **kwarg_rec_text) # type: ignore
fdraw_text(*arg_text, **kwarg_text) # type: ignore # text have to put in front of rec_text
return np.array(img) if ori_img_type is np.ndarray and is_vnese else img

View File

@ -0,0 +1,673 @@
from builtins import dict
from .dto import Word, Line, Word_group, Box
import numpy as np
from typing import Optional, List, Tuple, Union
MIN_IOU_HEIGHT = 0.7
MIN_WIDTH_LINE_RATIO = 0.05
def resize_to_original(
boundingbox, scale
): # resize coordinates to match size of original image
left, top, right, bottom = boundingbox
left *= scale[1]
right *= scale[1]
top *= scale[0]
bottom *= scale[0]
return [left, top, right, bottom]
def check_iomin(word: Word, word_group: Word_group):
min_height = min(
word.boundingbox[3] - word.boundingbox[1],
word_group.boundingbox[3] - word_group.boundingbox[1],
)
intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max(
word.boundingbox[1], word_group.boundingbox[1]
)
if intersect / min_height > 0.7:
return True
return False
def prepare_line(words):
lines = []
visited = [False] * len(words)
for id_word, word in enumerate(words):
if word.invalid_size() == 0:
continue
new_line = True
for i in range(len(lines)):
if (
lines[i].in_same_line(word) and not visited[id_word]
): # check if word is in the same line with lines[i]
lines[i].merge_word(word)
new_line = False
visited[id_word] = True
if new_line == True:
new_line = Line()
new_line.merge_word(word)
lines.append(new_line)
# print(len(lines))
# sort line from top to bottom according top coordinate
lines.sort(key=lambda x: x.boundingbox[1])
return lines
def __create_word_group(word, word_group_id):
new_word_group_ = Word_group()
new_word_group_.list_words = list()
new_word_group_.word_group_id = word_group_id
new_word_group_.add_word(word)
return new_word_group_
def __sort_line(line):
line.list_word_groups.sort(
key=lambda x: x.boundingbox[0]
) # sort word in lines from left to right
return line
def __merge_text_for_line(line):
line.text = ""
for word in line.list_word_groups:
line.text += " " + word.text
return line
def __update_list_word_groups(line, word_group_id, word_id, line_width):
old_list_word_group = line.list_word_groups
list_word_groups = []
inital_word_group = __create_word_group(
old_list_word_group[0], word_group_id)
old_list_word_group[0].word_id = word_id
list_word_groups.append(inital_word_group)
word_group_id += 1
word_id += 1
for word in old_list_word_group[1:]:
check_word_group = True
word.word_id = word_id
word_id += 1
if (
(not list_word_groups[-1].text.endswith(":"))
and (
(word.boundingbox[0] - list_word_groups[-1].boundingbox[2])
/ line_width
< MIN_WIDTH_LINE_RATIO
)
and check_iomin(word, list_word_groups[-1])
):
list_word_groups[-1].add_word(word)
check_word_group = False
if check_word_group:
new_word_group = __create_word_group(word, word_group_id)
list_word_groups.append(new_word_group)
word_group_id += 1
line.list_word_groups = list_word_groups
return line, word_group_id, word_id
def construct_word_groups_in_each_line(lines):
line_id = 0
word_group_id = 0
word_id = 0
for i in range(len(lines)):
if len(lines[i].list_word_groups) == 0:
continue
# left, top ,right, bottom
line_width = lines[i].boundingbox[2] - \
lines[i].boundingbox[0] # right - left
line_width = 1 # TODO: to remove
lines[i] = __sort_line(lines[i])
# update text for lines after sorting
lines[i] = __merge_text_for_line(lines[i])
lines[i], word_group_id, word_id = __update_list_word_groups(
lines[i],
word_group_id,
word_id,
line_width)
lines[i].update_line_id(line_id)
line_id += 1
return lines
def words_to_lines(words, check_special_lines=True): # words is list of Word instance
# sort word by top
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
# words.sort(key=lambda x: (sum(x.bbox[:])))
number_of_word = len(words)
# print(number_of_word)
# sort list words to list lines, which have not contained word_group yet
lines = prepare_line(words)
# construct word_groups in each line
lines = construct_word_groups_in_each_line(lines)
return lines, number_of_word
def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8):
"""Check if two boxes are on the same line by their y-axis coordinates.
Two boxes are on the same line if they overlap vertically, and the length
of the overlapping line segment is greater than min_y_overlap_ratio * the
height of either of the boxes.
Args:
box_a (list), box_b (list): Two bounding boxes to be checked
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
allowed for boxes in the same line
Returns:
The bool flag indicating if they are on the same line
"""
a_y_min = np.min(box_a[1::2])
b_y_min = np.min(box_b[1::2])
a_y_max = np.max(box_a[1::2])
b_y_max = np.max(box_b[1::2])
# Make sure that box a is always the box above another
if a_y_min > b_y_min:
a_y_min, b_y_min = b_y_min, a_y_min
a_y_max, b_y_max = b_y_max, a_y_max
if b_y_min <= a_y_max:
if min_y_overlap_ratio is not None:
sorted_y = sorted([b_y_min, b_y_max, a_y_max])
overlap = sorted_y[1] - sorted_y[0]
min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio
min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio
return overlap >= min_a_overlap or \
overlap >= min_b_overlap
else:
return True
return False
def merge_bboxes_to_group(bboxes_group, x_sorted_boxes):
merged_bboxes = []
for box_group in bboxes_group:
merged_box = {}
merged_box['text'] = ' '.join(
[x_sorted_boxes[idx]['text'] for idx in box_group])
x_min, y_min = float('inf'), float('inf')
x_max, y_max = float('-inf'), float('-inf')
for idx in box_group:
x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max)
x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min)
y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max)
y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min)
merged_box['box'] = [
x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max
]
merged_box['list_words'] = [x_sorted_boxes[idx]['word']
for idx in box_group]
merged_bboxes.append(merged_box)
return merged_bboxes
def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.3):
"""Stitch fragmented boxes of words into lines.
Note: part of its logic is inspired by @Johndirr
(https://github.com/faustomorales/keras-ocr/issues/22)
Args:
boxes (list): List of ocr results to be stitched
max_x_dist (int): The maximum horizontal distance between the closest
edges of neighboring boxes in the same line
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
allowed for any pairs of neighboring boxes in the same line
Returns:
merged_boxes(List[dict]): List of merged boxes and texts
"""
if len(boxes) <= 1:
if len(boxes) == 1:
boxes[0]["list_words"] = [boxes[0]["word"]]
return boxes
# merged_groups = []
merged_lines = []
# sort groups based on the x_min coordinate of boxes
x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2]))
# store indexes of boxes which are already parts of other lines
skip_idxs = set()
i = 0
# locate lines of boxes starting from the leftmost one
for i in range(len(x_sorted_boxes)):
if i in skip_idxs:
continue
# the rightmost box in the current line
rightmost_box_idx = i
line = [rightmost_box_idx]
for j in range(i + 1, len(x_sorted_boxes)):
if j in skip_idxs:
continue
if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'],
x_sorted_boxes[j]['box'], min_y_overlap_ratio):
line.append(j)
skip_idxs.add(j)
rightmost_box_idx = j
# split line into lines if the distance between two neighboring
# sub-lines' is greater than max_x_dist
# groups = []
# line_idx = 0
# groups.append([line[0]])
# for k in range(1, len(line)):
# curr_box = x_sorted_boxes[line[k]]
# prev_box = x_sorted_boxes[line[k - 1]]
# dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2])
# if dist > max_x_dist:
# line_idx += 1
# groups.append([])
# groups[line_idx].append(line[k])
# # Get merged boxes
merged_line = merge_bboxes_to_group([line], x_sorted_boxes)
merged_lines.extend(merged_line)
# merged_group = merge_bboxes_to_group(groups,x_sorted_boxes)
# merged_groups.extend(merged_group)
merged_lines = sorted(merged_lines, key=lambda x: np.min(x['box'][1::2]))
# merged_groups = sorted(merged_groups, key=lambda x: np.min(x['box'][1::2]))
return merged_lines # , merged_groups
# REFERENCE
# https://vigneshgig.medium.com/bounding-box-sorting-algorithm-for-text-detection-and-object-detection-from-left-to-right-and-top-cf2c523c8a85
# https://huggingface.co/spaces/tomofi/MMOCR/blame/main/mmocr/utils/box_util.py
def words_to_lines_mmocr(words: List[Word], *args) -> Tuple[List[Line], Optional[int]]:
bboxes = [{"box": [w.bbox[0], w.bbox[1], w.bbox[2], w.bbox[1], w.bbox[2], w.bbox[3], w.bbox[0], w.bbox[3]],
"text":w.text, "word":w} for w in words]
merged_lines = stitch_boxes_into_lines(bboxes)
merged_groups = merged_lines # TODO: fix code to return both word group and line
lwords_groups = [Word_group(list_words_=merged_box["list_words"],
text=merged_box["text"],
boundingbox=[merged_box["box"][i] for i in [0, 1, 2, -1]])
for merged_box in merged_groups]
llines = [Line(text=word_group.text, list_word_groups=[word_group], boundingbox=word_group.boundingbox)
for word_group in lwords_groups]
return llines, None # same format with the origin words_to_lines
# lines = [Line() for merged]
# def most_overlapping_row(rows, top, bottom, y_shift):
# max_overlap = -1
# max_overlap_idx = -1
# for i, row in enumerate(rows):
# row_top, row_bottom = row
# overlap = min(top + y_shift, row_top) - max(bottom + y_shift, row_bottom)
# if overlap > max_overlap:
# max_overlap = overlap
# max_overlap_idx = i
# return max_overlap_idx
def most_overlapping_row(rows, row_words, top, bottom, y_shift, max_row_size, y_overlap_threshold=0.5):
max_overlap = -1
max_overlap_idx = -1
overlapping_rows = []
for i, row in enumerate(rows):
row_top, row_bottom = row
overlap = min(top - y_shift[i], row_top) - \
max(bottom - y_shift[i], row_bottom)
if overlap > max_overlap:
max_overlap = overlap
max_overlap_idx = i
# if at least overlap 1 pixel and not (overlap too much and overlap too little)
if (row_bottom <= top and row_top >= bottom) and not (top - bottom - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold):
overlapping_rows.append(i)
# Merge overlapping rows if necessary
if len(overlapping_rows) > 1:
merge_top = max(rows[i][0] for i in overlapping_rows)
merge_bottom = min(rows[i][1] for i in overlapping_rows)
if merge_top - merge_bottom <= max_row_size:
# Merge rows
merged_row = (merge_top, merge_bottom)
merged_words = []
# Remove other overlapping rows
for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2
merged_words.extend(row_words[row_idx])
del rows[row_idx]
del row_words[row_idx]
rows[overlapping_rows[0]] = merged_row
row_words[overlapping_rows[0]].extend(merged_words[::-1])
max_overlap_idx = overlapping_rows[0]
if top - bottom - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold:
max_overlap_idx = -1
return max_overlap_idx
def stitch_boxes_into_lines_tesseract(words: list[Word],
gradient: float, y_overlap_threshold: float) -> Tuple[list[list[Word]],
float]:
sorted_words = sorted(words, key=lambda x: x.bbox[0])
rows = []
row_words = []
max_row_size = sorted_words[0].height
running_y_shift = []
for _i, word in enumerate(sorted_words):
# if word.bbox[1] > 340 and word.bbox[3] < 450:
# print("DEBUG")
# if word.text == "Lực":
# print("DEBUG")
bbox, _text = word.bbox[:], word.text
_x1, y1, _x2, y2 = bbox
top, bottom = y2, y1
max_row_size = max(max_row_size, top - bottom)
overlap_row_idx = most_overlapping_row(
rows, row_words, top, bottom, running_y_shift, max_row_size, y_overlap_threshold)
if overlap_row_idx == -1: # No overlapping row found
new_row = (top, bottom)
rows.append(new_row)
row_words.append([word])
running_y_shift.append(0)
else: # Overlapping row found
row_top, row_bottom = rows[overlap_row_idx]
new_top = max(row_top, top)
new_bottom = min(row_bottom, bottom)
rows[overlap_row_idx] = (new_top, new_bottom)
row_words[overlap_row_idx].append(word)
new_shift = (bottom + top) / 2 - (row_bottom + row_top) / 2
running_y_shift[overlap_row_idx] = gradient * \
running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift
# Sort rows and row_texts based on the top y-coordinate
sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][0])
_sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data)
# /_|<- the perpendicular line of the horizontal line and the skew line of the page
page_skew_dist = sum(running_y_shift) / len(running_y_shift)
return sorted_row_words, page_skew_dist
def construct_word_groups_tesseract(sorted_row_words: list[list[Word]],
max_x_dist: int, page_skew_dist: float) -> list[list[list[Word]]]:
# approximate page_skew_angle by page_skew_dist
corrected_max_x_dist = max_x_dist * abs(np.cos(page_skew_dist / 180 * 3.14))
constructed_row_word_groups = []
for row_words in sorted_row_words:
lword_groups = []
line_idx = 0
lword_groups.append([row_words[0]])
for k in range(1, len(row_words)):
curr_box = row_words[k].bbox[:]
prev_box = row_words[k - 1].bbox[:]
dist = curr_box[0] - prev_box[2]
if dist > corrected_max_x_dist:
line_idx += 1
lword_groups.append([])
lword_groups[line_idx].append(row_words[k])
constructed_row_word_groups.append(lword_groups)
return constructed_row_word_groups
def group_bbox_and_text(lwords: list[Word]) -> tuple[Box, tuple[str, float]]:
text = ' '.join([word.text for word in lwords])
x_min, y_min = float('inf'), float('inf')
x_max, y_max = float('-inf'), float('-inf')
conf_det = 0
conf_cls = 0
for word in lwords:
x_max = max(np.max(word.bbox[::2]), x_max)
x_min = min(np.min(word.bbox[::2]), x_min)
y_max = max(np.max(word.bbox[1::2]), y_max)
y_min = min(np.min(word.bbox[1::2]), y_min)
conf_det += word.conf_detect
conf_cls += word.conf_cls
bbox = Box(x_min, y_min, x_max, y_max, conf=conf_det / len(lwords))
return bbox, (text, conf_cls / len(lwords))
def words_to_lines_tesseract(words: List[Word],
gradient: float, max_x_dist: int, y_overlap_threshold: float) -> Tuple[List[Line],
Optional[int]]:
sorted_row_words, page_skew_dist = stitch_boxes_into_lines_tesseract(
words, gradient, y_overlap_threshold)
constructed_row_word_groups = construct_word_groups_tesseract(
sorted_row_words, max_x_dist, page_skew_dist)
llines = []
for row in constructed_row_word_groups:
lwords_row = []
lword_groups = []
for word_group in row:
bbox_word_group, text_word_group = group_bbox_and_text(word_group)
lwords_row.extend(word_group)
lword_groups.append(
Word_group(
list_words_=word_group, text=text_word_group[0],
conf_cls=text_word_group[1],
boundingbox=bbox_word_group))
bbox_line, text_line = group_bbox_and_text(lwords_row)
llines.append(
Line(
list_word_groups=lword_groups, text=text_line[0],
boundingbox=bbox_line, conf_cls=text_line[1]))
return llines, None
def near(word_group1: Word_group, word_group2: Word_group):
min_height = min(
word_group1.boundingbox[3] - word_group1.boundingbox[1],
word_group2.boundingbox[3] - word_group2.boundingbox[1],
)
overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max(
word_group1.boundingbox[1], word_group2.boundingbox[1]
)
if overlap > 0:
return True
if abs(overlap / min_height) < 1.5:
print("near enough", abs(overlap / min_height), overlap, min_height)
return True
return False
def calculate_iou_and_near(wg1: Word_group, wg2: Word_group):
min_height = min(
wg1.boundingbox[3] -
wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1]
)
overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max(
wg1.boundingbox[1], wg2.boundingbox[1]
)
iou = overlap / min_height
distance = min(
abs(wg1.boundingbox[0] - wg2.boundingbox[2]),
abs(wg1.boundingbox[2] - wg2.boundingbox[0]),
)
if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]):
return True
return False
def construct_word_groups_to_kie_label(list_word_groups: list):
kie_dict = dict()
for wg in list_word_groups:
if wg.kie_label == "other":
continue
if wg.kie_label not in kie_dict:
kie_dict[wg.kie_label] = [wg]
else:
kie_dict[wg.kie_label].append(wg)
new_dict = dict()
for key, value in kie_dict.items():
if len(value) == 1:
new_dict[key] = value
continue
value.sort(key=lambda x: x.boundingbox[1])
new_dict[key] = value
return new_dict
def invoice_construct_word_groups_to_kie_label(list_word_groups: list):
kie_dict = dict()
for wg in list_word_groups:
if wg.kie_label == "other":
continue
if wg.kie_label not in kie_dict:
kie_dict[wg.kie_label] = [wg]
else:
kie_dict[wg.kie_label].append(wg)
return kie_dict
def postprocess_total_value(kie_dict):
if "total_in_words_value" not in kie_dict:
return kie_dict
for k, value in kie_dict.items():
if k == "total_in_words_value":
continue
l = []
for v in value:
if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]:
l.append(v)
if len(l) != 0:
kie_dict[k] = l
return kie_dict
def postprocess_tax_code_value(kie_dict):
if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict:
return kie_dict
kie_dict["buyer_tax_code_value"] = []
for v in kie_dict["seller_tax_code_value"]:
if "buyer_name_key" in kie_dict and (
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
or near(v, kie_dict["buyer_name_key"][0])
):
kie_dict["buyer_tax_code_value"].append(v)
continue
if "buyer_name_value" in kie_dict and (
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
or near(v, kie_dict["buyer_name_value"][0])
):
kie_dict["buyer_tax_code_value"].append(v)
continue
if "buyer_address_value" in kie_dict and near(
kie_dict["buyer_address_value"][0], v
):
kie_dict["buyer_tax_code_value"].append(v)
return kie_dict
def postprocess_tax_code_key(kie_dict):
if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict:
return kie_dict
kie_dict["buyer_tax_code_key"] = []
for v in kie_dict["seller_tax_code_key"]:
if "buyer_name_key" in kie_dict and (
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
or near(v, kie_dict["buyer_name_key"][0])
):
kie_dict["buyer_tax_code_key"].append(v)
continue
if "buyer_name_value" in kie_dict and (
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
or near(v, kie_dict["buyer_name_value"][0])
):
kie_dict["buyer_tax_code_key"].append(v)
continue
if "buyer_address_value" in kie_dict and near(
kie_dict["buyer_address_value"][0], v
):
kie_dict["buyer_tax_code_key"].append(v)
return kie_dict
def invoice_postprocess(kie_dict: dict):
# all keys or values which are below total_in_words_value will be thrown away
kie_dict = postprocess_total_value(kie_dict)
kie_dict = postprocess_tax_code_value(kie_dict)
kie_dict = postprocess_tax_code_key(kie_dict)
return kie_dict
def throw_overlapping_words(list_words):
new_list = [list_words[0]]
for word in list_words:
overlap = False
area = (word.boundingbox[2] - word.boundingbox[0]) * (
word.boundingbox[3] - word.boundingbox[1]
)
for word2 in new_list:
area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * (
word2.boundingbox[3] - word2.boundingbox[1]
)
xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0])
xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2])
ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1])
ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3])
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
continue
area_intersect = (xmax_intersect - xmin_intersect) * (
ymax_intersect - ymin_intersect
)
if area_intersect / area > 0.7 or area_intersect / area2 > 0.7:
overlap = True
if overlap == False:
new_list.append(word)
return new_list
def check_iou(box1: Word, box2: Box, threshold=0.9):
area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * (
box1.boundingbox[3] - box1.boundingbox[1]
)
area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin)
xmin_intersect = max(box1.boundingbox[0], box2.xmin)
ymin_intersect = max(box1.boundingbox[1], box2.ymin)
xmax_intersect = min(box1.boundingbox[2], box2.xmax)
ymax_intersect = min(box1.boundingbox[3], box2.ymax)
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
area_intersect = 0
else:
area_intersect = (xmax_intersect - xmin_intersect) * (
ymax_intersect - ymin_intersect
)
union = area1 + area2 - area_intersect
iou = area_intersect / union
if iou > threshold:
return True
return False

View File

@ -0,0 +1,230 @@
from omegaconf import OmegaConf
import os
import cv2
import torch
# from functions import get_colormap, visualize
import sys
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ???????
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
from model import get_model
from utils import load_model_weight
class KVUPredictor:
def __init__(self, configs, class_names, dummy_idx, mode=0):
cfg_path = configs['cfg']
ckpt_path = configs['ckpt']
self.class_names = class_names
self.dummy_idx = dummy_idx
self.mode = mode
print('[INFO] Loading Key-Value Understanding model ...')
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
print("[INFO] Loaded model")
if mode == 3:
self.max_window_count = cfg.train.max_window_count
self.window_size = cfg.train.window_size
self.slice_interval = 0
self.dummy_idx = dummy_idx * self.max_window_count
else:
self.slice_interval = cfg.train.slice_interval
self.window_size = cfg.train.max_num_words
self.device = 'cuda'
def _load_model(self, cfg_path, ckpt_path):
cfg = OmegaConf.load(cfg_path)
cfg.stage = self.mode
backbone_type = cfg.model.backbone
print('[INFO] Checkpoint:', ckpt_path)
net = get_model(cfg)
load_model_weight(net, ckpt_path)
net.to('cuda')
net.eval()
return net, cfg, backbone_type
def predict(self, input_sample):
if self.mode == 0:
if len(input_sample['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 1:
if len(input_sample['documents']['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 2:
if len(input_sample['windows'][0]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
for window in input_sample['windows']:
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
bbox.append(_bbox)
lwords.append(_lwords)
pr_class_words.append(_pr_class_words)
pr_relations.append(_pr_relations)
return bbox, lwords, pr_class_words, pr_relations
elif self.mode == 3:
if len(input_sample["documents"]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
else:
raise ValueError(
f"Not supported mode: {self.mode}"
)
def doc_predict(self, input_sample):
lwords = input_sample['documents']['words']
for idx, window in enumerate(input_sample['windows']):
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = input_sample['documents']
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def combined_predict(self, input_sample):
lwords = input_sample['words']
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def cat_predict(self, input_sample):
lwords = input_sample['documents']['words']
inputs = []
for window in input_sample['windows']:
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
input_sample['windows'] = inputs
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
bbox = input_sample['documents']['bbox']
dummy_idx = input_sample['documents']['bbox'].shape[0]
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def get_ground_truth_label(self, ground_truth):
# ground_truth = self.preprocessor.load_ground_truth(json_file)
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
gt_el_label = ground_truth['el_labels'].squeeze(0)
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
lwords = ground_truth["words"]
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
attention_mask = ground_truth['attention_mask'].squeeze(0)
bbox = ground_truth['bbox'].squeeze(0)
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, self.class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
)
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
gt_relations = gt_relations_from_header | gt_relations_from_key
return bbox, lwords, gt_class_words, gt_relations

View File

@ -0,0 +1,601 @@
import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
list_layoutxlm_tokens = []
list_bbs = []
list_words = []
lwords = [''] * max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
len_overlap_tokens = 0
len_non_overlap_tokens = 0
len_valid_tokens = 0
for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word))
this_box_token_indices = []
len_valid_tokens += len(layoutxlm_tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(layoutxlm_tokens)
if len(layoutxlm_tokens) == 0:
layoutxlm_tokens.append(self.unk_token_id)
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
break
list_layoutxlm_tokens += layoutxlm_tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(layoutxlm_tokens))]
texts = [word for _ in range(len(layoutxlm_tokens))]
for _ in layoutxlm_tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_layoutxlm_tokens = (
[self.cls_token_id_layoutxlm]
+ list_layoutxlm_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP']
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
lwords[:len_list_layoutxlm_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids_layoutxlm = input_ids_layoutxlm[:ntokens]
attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm)
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": list_words,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids_layoutxlm,
"attention_mask_layoutxlm": attention_mask_layoutxlm,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def load_ground_truth(self, json_file):
json_obj = read_json(json_file)
width = json_obj["meta"]["imageSize"]["width"]
height = json_obj["meta"]["imageSize"]["height"]
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
itc_labels = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
# stc_labels stores the index of the previous token.
# A stored index of max_seq_length (512) indicates that
# this token is the initial token of a word box.
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
list_tokens = []
list_bbs = []
list_words = []
box2token_span_map = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for word_idx, word in enumerate(json_obj["words"]):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ###
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
# Label
classes_dic = json_obj["parse"]["class"]
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= self.max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
# Label
relations = json_obj["parse"]["relations"]
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= self.max_seq_length
or box2token_span_map[relation[1]][0] >= self.max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
# if self.second_relations == 1:
# if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
# el_labels[word_to] = word_from # pair of (header, key) or (header-value)
# else:
#### 1st relation => ['key, 'value']
#### 2st relation => ['header', 'key'or'value']
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
itc_labels = torch.from_numpy(itc_labels)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
stc_labels = torch.from_numpy(stc_labels)
el_labels = torch.from_numpy(el_labels)
el_labels_from_key = torch.from_numpy(el_labels_from_key)
return_dict = {
# "image": feature_maps,
"input_ids": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask": attention_mask,
"itc_labels": itc_labels,
"are_box_first_tokens": are_box_first_tokens,
"stc_labels": stc_labels,
"el_labels": el_labels,
"el_labels_from_key": el_labels_from_key
}
return return_dict
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
ocr_path = "tmp.txt"
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path})
return output
def preprocess(self, bounding_boxes, words, feature_maps):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_tokens = []
list_bbs = []
list_words = []
lwords = [''] * self.max_seq_length
box_to_token_indices = []
cum_token_idx = 0
cls_bbs = [0.0] * 8
for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])):
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))
this_box_token_indices = []
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
break
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [word for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ###
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id]
+ list_tokens[: self.max_seq_length - 2]
+ [self.sep_token_id]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
if len(list_words) < 510:
list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words ###
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < self.max_seq_length
]
are_box_first_tokens[st_indices] = True
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids": input_ids,
"bbox": bbox,
"words": list_words,
"attention_mask": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts

View File

@ -0,0 +1,19 @@
nptyping==1.4.2
numpy==1.20.3
pytorch-lightning==1.5.6
omegaconf
pillow
six
overrides==4.1.2
transformers==4.11.3
seqeval==0.0.12
imagesize
pandas==2.0.1
xmltodict
dicttoxml
tensorboard>=2.2.0
# code-style
isort==5.9.3
black==21.9b0

View File

@ -0,0 +1 @@
python anyKeyValue.py --img_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0

View File

@ -0,0 +1,106 @@
1113 773 1220 825 BEST
1243 759 1378 808 DENKI
1410 752 1487 799 (S)
1430 707 1515 748 TAX
1511 745 1598 790 PTE
1542 700 1725 740 TNVOICE
1618 742 1706 783 LTD
1783 725 1920 773 FUNAN
1943 723 2054 767 MALL
1434 797 1576 843 WORTH
1599 785 1760 831 BRIDGE
1784 778 1846 822 RD
1277 846 1632 897 #02-16/#03-1
1655 832 1795 877 FUNAN
1817 822 1931 869 MALL
1272 897 1518 956 S(179105)
1548 890 1655 943 TEL:
1686 877 1911 928 69046183
1247 1011 1334 1068 GST
1358 1006 1447 1059 REG
1360 1063 1449 1115 RCB
1473 1003 1575 1055 NO.:
1474 1059 1555 1110 NO.
1595 1042 1868 1096 198202199E
1607 985 1944 1040 M2-0053813-7
1056 1134 1254 1194 Opening
1276 1127 1391 1181 Hrs:
1425 1112 1647 1170 10:00:00
1672 1102 1735 1161 AN
1755 1101 1819 1157 to
1846 1090 2067 1147 10:00:00
2090 1080 2156 1141 PH
1061 1308 1228 1366 Staff:
1258 1300 1378 1357 3296
1710 1283 1880 1337 Trans:
1936 1266 2192 1322 262152554
1060 1372 1201 1429 Date:
1260 1358 1494 1419 22-03-23
1540 1344 1664 1409 9:05
1712 1339 1856 1407 Slip:
1917 1328 2196 1387 2000130286
1124 1487 1439 1545 SALESPERSON
1465 1477 1601 1537 CODE.
1633 1471 1752 1530 6043
1777 1462 2004 1519 HUHAHHAD
2032 1451 2177 1509 RAZIH
1070 1558 1187 1617 Item
1211 1554 1276 1615 No
1439 1542 1585 1601 Price
1750 1530 1841 1597 Qty
1951 1517 2120 1579 Amount
1076 1683 1276 1741 ANDROID
1304 1673 1477 1733 TABLET
1080 1746 1280 1804 2105976
1509 1729 1705 1784 SAMSUNG
1734 1719 1931 1776 SH-P613
1964 1709 2101 1768 128GB
1082 1809 1285 1869 SM-P613
1316 1802 1454 1860 12838
1429 1859 1600 1919 518.00
1481 1794 1596 1855 WIFI
1622 1790 1656 1850 G
1797 1845 1824 1904 1
1993 1832 2165 1892 518.00
1088 1935 1347 1995 PROMOTION
1091 2000 1294 2062 2105664
1520 1983 1717 2039 SAMSUNG
1743 1963 2106 2030 F-Sam-Redeen
1439 2111 1557 2173 0.00
1806 2095 1832 2156 1
2053 2081 2174 2144 0.00
1106 2248 1250 2312 Total
1974 2206 2146 2266 518.00
1107 2312 1204 2377 UOB
1448 2291 1567 2355 CARD
1978 2268 2147 2327 518.00
1253 2424 1375 2497 GST%
1456 2411 1655 2475 Net.Amt
1818 2393 1912 2460 GST
2023 2387 2192 2445 Amount
1106 2494 1231 2560 GST8
1486 2472 1661 2537 479.63
1770 2458 1916 2523 38.37
2027 2448 2203 2511 518.00
1553 2601 1699 2666 THANK
1721 2592 1821 2661 YOU
1436 2678 1616 2749 please
1644 2682 1764 2732 come
1790 2660 1942 2729 again
1191 2862 1391 2931 Those
1426 2870 2018 2945 facebook.com
1565 2809 1690 2884 join
1709 2816 1777 2870 us
1799 2811 1868 2865 on
1838 2946 2024 3003 com .89
1533 3006 2070 3088 ar.com/askbe
1300 3326 1659 3446 That's
1696 3308 1905 3424 not
1937 3289 2131 3408 all!
1450 3511 1633 3573 SCAN
1392 3589 1489 3645 QR
1509 3577 1698 3635 CODE
1321 3656 1370 3714 &
1517 3638 1768 3699 updates
1643 3882 1769 3932 Scan
1789 3868 1859 3926 Me

View File

@ -0,0 +1,127 @@
import os
import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from utils.ema_callbacks import EMA
def _update_config(cfg):
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
# set per-gpu batch size
num_devices = torch.cuda.device_count()
print('No. devices:', num_devices)
for mode in ["train", "val"]:
new_batch_size = cfg[mode].batch_size // num_devices
cfg[mode].batch_size = new_batch_size
def _get_config_from_cli():
cfg_cli = OmegaConf.from_cli()
cli_keys = list(cfg_cli.keys())
for cli_key in cli_keys:
if "--" in cli_key:
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
del cfg_cli[cli_key]
return cfg_cli
def get_callbacks(cfg):
callback_list = []
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
filename='best_model',
save_last=True,
save_top_k=1,
save_weights_only=True,
verbose=True,
monitor='val_f1', mode='max')
checkpoint_callback.FILE_EXTENSION = ".pth"
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
callback_list.append(checkpoint_callback)
if cfg.callbacks.ema.decay != -1:
ema_callback = EMA(decay=0.9999)
callback_list.append(ema_callback)
return callback_list if len(callback_list) > 1 else checkpoint_callback
def get_plugins(cfg):
plugins = []
if cfg.train.strategy.type == "ddp":
plugins.append(DDPPlugin())
return plugins
def get_loggers(cfg):
loggers = []
loggers.append(
TensorBoardLogger(
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
)
)
return loggers
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
for key, val in cfg.items():
if isinstance(val, DictConfig):
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
else:
hparam_dict[parent_str + key] = str(val)
return hparam_dict
def get_specific_pl_logger(pl_loggers, logger_type):
for pl_logger in pl_loggers:
if isinstance(pl_logger, logger_type):
return pl_logger
return None
def get_class_names(dataset_root_path):
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
class_names = (
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
)
return class_names
def create_exp_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Experiment dir : {}'.format(save_dir))
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Save dir : {}'.format(save_dir))
def load_checkpoint(ckpt_path, model, key_include):
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
for key in list(state_dict.keys()):
if f'.{key_include}.' not in key:
del state_dict[key]
else:
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
del state_dict[key]
model.load_state_dict(state_dict, strict=True)
print(f"Load checkpoint at {ckpt_path}")
return model
def load_model_weight(net, pretrained_model_file):
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
"state_dict"
]
new_state_dict = {}
for k, v in pretrained_model_state_dict.items():
new_k = k
if new_k.startswith("net."):
new_k = new_k[len("net.") :]
new_state_dict[new_k] = v
net.load_state_dict(new_state_dict)

View File

@ -0,0 +1,346 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import os
import threading
from typing import Any, Dict, Iterable
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_info
class EMA(Callback):
"""
Implements Exponential Moving Averaging (EMA).
When training a model, this callback will maintain moving averages of the trained parameters.
When evaluating, we use the moving averages copy of the trained parameters.
When saving, we save an additional set of parameters with the prefix `ema`.
Args:
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
every_n_steps: Apply EMA every N steps.
cpu_offload: Offload weights to CPU.
"""
def __init__(
self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False,
):
if not (0 <= decay <= 1):
raise MisconfigurationException("EMA decay value must be between 0 and 1")
self.decay = decay
self.validate_original_weights = validate_original_weights
self.every_n_steps = every_n_steps
self.cpu_offload = cpu_offload
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
device = pl_module.device if not self.cpu_offload else torch.device('cpu')
trainer.optimizers = [
EMAOptimizer(
optim,
device=device,
decay=self.decay,
every_n_steps=self.every_n_steps,
current_step=trainer.global_step,
)
for optim in trainer.optimizers
if not isinstance(optim, EMAOptimizer)
]
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
return not self.validate_original_weights and self._ema_initialized(trainer)
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.switch_main_parameter_weights(saving_ema_model)
@contextlib.contextmanager
def save_ema_model(self, trainer: "pl.Trainer"):
"""
Saves an EMA copy of the model + EMA optimizer states for resume.
"""
self.swap_model_weights(trainer, saving_ema_model=True)
try:
yield
finally:
self.swap_model_weights(trainer, saving_ema_model=False)
@contextlib.contextmanager
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.save_original_optimizer_state = True
try:
yield
finally:
for optimizer in trainer.optimizers:
optimizer.save_original_optimizer_state = False
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
checkpoint_callback = trainer.checkpoint_callback
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
connector = trainer._checkpoint_connector
ckpt_path = connector.resume_checkpoint_path
if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
ext = checkpoint_callback.FILE_EXTENSION
if ckpt_path.endswith(f'-EMA{ext}'):
rank_zero_info(
"loading EMA based weights. "
"The callback will treat the loaded EMA weights as the main weights"
" and create a new EMA copy when training."
)
return
ema_path = ckpt_path.replace(ext, f'-EMA{ext}')
if os.path.exists(ema_path):
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
del ema_state_dict
rank_zero_info("EMA state has been restored.")
else:
raise MisconfigurationException(
"Unable to find the associated EMA weights when re-loading, "
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
)
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
)
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
if pre_sync_stream is not None:
pre_sync_stream.synchronize()
ema_update(ema_model_tuple, current_model_tuple, decay)
class EMAOptimizer(torch.optim.Optimizer):
r"""
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
Exponential Moving Average of parameters registered in the optimizer.
EMA parameters are automatically updated after every step of the optimizer
with the following formula:
ema_weight = decay * ema_weight + (1 - decay) * training_weight
To access EMA parameters, use ``swap_ema_weights()`` context manager to
perform a temporary in-place swap of regular parameters with EMA
parameters.
Notes:
- EMAOptimizer is not compatible with APEX AMP O2.
Args:
optimizer (torch.optim.Optimizer): optimizer to wrap
device (torch.device): device for EMA parameters
decay (float): decay factor
Returns:
returns an instance of torch.optim.Optimizer that computes EMA of
parameters
Example:
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
device: torch.device,
decay: float = 0.9999,
every_n_steps: int = 1,
current_step: int = 0,
):
self.optimizer = optimizer
self.decay = decay
self.device = device
self.current_step = current_step
self.every_n_steps = every_n_steps
self.save_original_optimizer_state = False
self.first_iteration = True
self.rebuild_ema_params = True
self.stream = None
self.thread = None
self.ema_params = ()
self.in_saving_ema_model_context = False
def all_parameters(self) -> Iterable[torch.Tensor]:
return (param for group in self.param_groups for param in group['params'])
def step(self, closure=None, **kwargs):
self.join()
if self.first_iteration:
if any(p.is_cuda for p in self.all_parameters()):
self.stream = torch.cuda.Stream()
self.first_iteration = False
if self.rebuild_ema_params:
opt_params = list(self.all_parameters())
self.ema_params += tuple(
copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :]
)
self.rebuild_ema_params = False
loss = self.optimizer.step(closure)
if self._should_update_at_step():
self.update()
self.current_step += 1
return loss
def _should_update_at_step(self) -> bool:
return self.current_step % self.every_n_steps == 0
@torch.no_grad()
def update(self):
if self.stream is not None:
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
current_model_state = tuple(
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
)
if self.device.type == 'cuda':
ema_update(self.ema_params, current_model_state, self.decay)
if self.device.type == 'cpu':
self.thread = threading.Thread(
target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,),
)
self.thread.start()
def swap_tensors(self, tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
self.join()
self.in_saving_ema_model_context = saving_ema_model
for param, ema_param in zip(self.all_parameters(), self.ema_params):
self.swap_tensors(param.data, ema_param)
@contextlib.contextmanager
def swap_ema_weights(self, enabled: bool = True):
r"""
A context manager to in-place swap regular parameters with EMA
parameters.
It swaps back to the original regular parameters on context manager
exit.
Args:
enabled (bool): whether the swap should be performed
"""
if enabled:
self.switch_main_parameter_weights()
try:
yield
finally:
if enabled:
self.switch_main_parameter_weights()
def __getattr__(self, name):
return getattr(self.optimizer, name)
def join(self):
if self.stream is not None:
self.stream.synchronize()
if self.thread is not None:
self.thread.join()
def state_dict(self):
self.join()
if self.save_original_optimizer_state:
return self.optimizer.state_dict()
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters())
state_dict = {
'opt': self.optimizer.state_dict(),
'ema': ema_params,
'current_step': self.current_step,
'decay': self.decay,
'every_n_steps': self.every_n_steps,
}
return state_dict
def load_state_dict(self, state_dict):
self.join()
self.optimizer.load_state_dict(state_dict['opt'])
self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
self.current_step = state_dict['current_step']
self.decay = state_dict['decay']
self.every_n_steps = state_dict['every_n_steps']
self.rebuild_ema_params = False
def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)
self.rebuild_ema_params = True

View File

@ -0,0 +1,459 @@
import os
import cv2
import json
import torch
import glob
import re
import numpy as np
from tqdm import tqdm
from pdf2image import convert_from_path
from dicttoxml import dicttoxml
from word_preprocess import vat_standardizer, get_string, ap_standardizer
from kvu_dictionary import vat_dictionary, ap_dictionary
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Save dir : {}'.format(save_dir))
def pdf2image(pdf_dir, save_dir):
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
print('No. pdf files:', len(pdf_files))
for file in tqdm(pdf_files):
pages = convert_from_path(file, 500)
for i, page in enumerate(pages):
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
print('Done!!!')
def xyxy2xywh(bbox):
return [
float(bbox[0]),
float(bbox[1]),
float(bbox[2]) - float(bbox[0]),
float(bbox[3]) - float(bbox[1]),
]
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def read_xml(file_path):
with open(file_path, 'r') as xml_file:
return xml_file.read()
def write_to_xml(file_path, content):
with open(file_path, mode="w", encoding='utf8') as f:
f.write(content)
def write_to_xml_from_dict(file_path, content):
xml = dicttoxml(content)
xml = content
xml_decode = xml.decode()
with open(file_path, mode="w") as f:
f.write(xml_decode)
def load_ocr_result(ocr_path):
with open(ocr_path, 'r') as f:
lines = f.read().splitlines()
preds = []
for line in lines:
preds.append(line.split('\t'))
return preds
def post_process_basic_ocr(lwords: list) -> list:
pp_lwords = []
for word in lwords:
pp_lwords.append(word.replace("", " "))
return pp_lwords
def read_ocr_result_from_txt(file_path: str):
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
word_info = line.split("\t")
if len(word_info) == 6:
x1, y1, x2, y2, text, _ = word_info
elif len(word_info) == 5:
x1, y1, x2, y2, text = word_info
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_colormap():
return {
'others': (0, 0, 255), # others: red
'title': (0, 255, 255), # title: yellow
'key': (255, 0, 0), # key: blue
'value': (0, 255, 0), # value: green
'header': (233, 197, 15), # header
'group': (0, 128, 128), # group
'relation': (0, 0, 255)# (128, 128, 128), # relation
}
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
exif = image._getexif()
orientation = None
if exif is not None:
orientation = exif.get(0x0112)
# Convert the PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Rotate the image in OpenCV if necessary
if orientation == 3:
image = cv2.rotate(image, cv2.ROTATE_180)
elif orientation == 6:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
elif orientation == 8:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
image = np.asarray(image)
if len(image.shape) == 2:
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
assert len(image.shape) == 3
if orientation is not None and orientation == 6:
width, height, _ = image.shape
else:
height, width, _ = image.shape
if len(pr_class_words) > 0:
id2label = {k: labels[k] for k in range(len(labels))}
for lb, groups in enumerate(pr_class_words):
if lb == 0:
continue
for group_id, group in enumerate(groups):
for i, word_id in enumerate(group):
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
if i == 0:
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
else:
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
x_center0, y_center0 = x_center1, y_center1
if len(pr_relations) > 0:
for pair in pr_relations:
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
return image
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
outputs = {}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] in (rel_from, rel_to):
is_rel[element['class']]['status'] = 1
is_rel[element['class']]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
return outputs
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
list_keys = list(header_key_pairs.keys())
relations = {k: [] for k in list_keys}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] == rel_from and element['group_id'] in list_keys:
is_rel[rel_from]['status'] = 1
is_rel[rel_from]['value'] = element
if element['class'] == rel_to:
is_rel[rel_to]['status'] = 1
is_rel[rel_to]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
return relations
def get_key2values_relations(key_value_pairs: dict):
triple_linkings = {}
for value_group_id, key_value_pair in key_value_pairs.items():
key_group_id = key_value_pair[0]
if key_group_id not in list(triple_linkings.keys()):
triple_linkings[key_group_id] = []
triple_linkings[key_group_id].append(value_group_id)
return triple_linkings
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
word_groups = {}
id2class = {i: labels[i] for i in range(len(labels))}
for class_id, lwgroups_in_class in enumerate(class_words):
for ltokens_in_wgroup in lwgroups_in_class:
group_id = ltokens_in_wgroup[0]
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
text_string = get_string(ltokens_to_ltexts)
word_groups[group_id] = {
'group_id': group_id,
'text': text_string,
'class': id2class[class_id],
'tokens': ltokens_in_wgroup
}
return word_groups
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
if linking_id not in list(word_groups):
for wg_id, _word_group in word_groups.items():
if linking_id in _word_group['tokens']:
return wg_id
return linking_id
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
outputs = []
for pair in lrelations:
wg_from = verify_linking_id(word_groups, pair[0])
wg_to = verify_linking_id(word_groups, pair[1])
try:
outputs.append([word_groups[wg_from], word_groups[wg_to]])
except:
print('Not valid pair:', wg_from, wg_to)
return outputs
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
triplet_pairs = []
single_pairs = []
table = []
# print('key2values_relations', key2values_relations)
for key_group_id, list_value_group_ids in key2values_relations.items():
if len(list_value_group_ids) == 0: continue
elif len(list_value_group_ids) == 1:
value_group_id = list_value_group_ids[0]
single_pairs.append({word_groups[key_group_id]['text']: {
'text': word_groups[value_group_id]['text'],
'id': value_group_id,
'class': "value"
}})
else:
item = []
for value_group_id in list_value_group_ids:
if value_group_id not in header_value.keys():
header_name_for_value = "non-header"
else:
header_group_id = header_value[value_group_id][0]
header_name_for_value = word_groups[header_group_id]['text']
item.append({
'text': word_groups[value_group_id]['text'],
'header': header_name_for_value,
'id': value_group_id,
'class': 'value'
})
if key_group_id not in list(header_key.keys()):
triplet_pairs.append({
word_groups[key_group_id]['text']: item
})
else:
header_group_id = header_key[key_group_id][0]
header_name_for_key = word_groups[header_group_id]['text']
item.append({
'text': word_groups[key_group_id]['text'],
'header': header_name_for_key,
'id': key_group_id,
'class': 'key'
})
table.append({key_group_id: item})
if len(table) > 0:
table = sorted(table, key=lambda x: list(x.keys())[0])
table = [v for item in table for k, v in item.items()]
outputs = {}
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
outputs['triplet'] = triplet_pairs
outputs['table'] = table
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
write_to_json(file_path, outputs)
return outputs
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
vat_outputs = {}
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
if header_name in list(item.keys()):
# item[header_name] = value['text']
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
# VAT Information
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
# print('='*10, file_path)
# print(vat_info)
# Combine VAT information and table
vat_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
if len(list_potential_value) == 1:
vat_outputs[key_name] = list_potential_value[0]['content']
else:
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
for value in list_potential_value:
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
else:
if len(list_potential_value) == 0: continue
if key_name in ("Mã số thuế người bán"):
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
vat_outputs[key_name] = selected_value['content'].replace(' ', '')
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
vat_outputs[key_name] = selected_value['content']
vat_outputs['table'] = table
write_to_json(file_path, vat_outputs)
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
if header_name in list(item.keys()):
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
triplet_pairs = []
for single_item in outputs['triplet']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
is_item_valid = 0
for key_name, list_value in single_item.items():
for value in list_value:
if value['header'] == "non-header":
continue
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
if header_name in list(item.keys()):
is_item_valid = 1
item[header_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
if is_item_valid == 1:
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item['productname'] = key_name
# triplet_pairs.append({key_name: new_item})
triplet_pairs.append(item)
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
ap_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if len(list_potential_value) == 0: continue
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
ap_outputs[key_name] = selected_value['content']
table = table + triplet_pairs
ap_outputs['table'] = table
# ap_outputs['triplet'] = triplet_pairs
write_to_json(file_path, ap_outputs)

View File

@ -0,0 +1,138 @@
DKVU2XML = {
"Ký hiệu mẫu hóa đơn": "form_no",
"Ký hiệu hóa đơn": "serial_no",
"Số hóa đơn": "invoice_no",
"Ngày, tháng, năm lập hóa đơn": "issue_date",
"Tên người bán": "seller_name",
"Mã số thuế người bán": "seller_tax_code",
"Thuế suất": "tax_rate",
"Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount",
"Mặt hàng": "item",
"Đơn vị tính": "unit",
"Số lượng": "quantity",
"Đơn giá": "unit_price",
"Doanh số mua chưa có thuế": "amount"
}
def ap_dictionary(header: bool):
header_dictionary = {
'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'],
'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'],
'qty': ['quantity', 'invoicequantity']
}
key_dictionary = {
'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'],
'retailername': ['retailer', 'retailername', 'ownedoperatedby'],
'serial_number': ['serialnumber', 'serialno'],
'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2']
}
return header_dictionary if header else key_dictionary
def vat_dictionary(header: bool):
header_dictionary = {
'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'],
'Đơn vị tính': ['dvt', 'donvitinh'],
'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'],
'Đơn giá': ['dongia'],
'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'],
# 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber']
}
key_dictionary = {
'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'],
'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'],
'Số hóa đơn': ['soinvoiceno', 'invoiceno'],
'Ngày, tháng, năm lập hóa đơn': [],
'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'],
'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'],
'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'],
'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'],
# 'Ghi chú': [],
# 'Ngày': ['ngayday', 'ngay', 'day'],
# 'Tháng': ['thangmonth', 'thang', 'month'],
# 'Năm': ['namyear', 'nam', 'year']
}
# exact_dictionary = {
# 'Số hóa đơn': ['sono', 'so'],
# 'Mã số thuế người bán': ['mst'],
# 'Tên người bán': ['kyboi'],
# 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
# }
return header_dictionary if header else key_dictionary
def manulife_dictionary(type: str):
key_dict = {
"Document type": ["documenttype", "loaichungtu"],
"Document name": ["documentname", "tenchungtu"],
"Patient Name": ["patientname", "tenbenhnhan"],
"Date of Birth/Year of birth": [
"dateofbirth",
"yearofbirth",
"ngaythangnamsinh",
"namsinh",
],
"Age": ["age", "tuoi"],
"Gender": ["gender", "gioitinh"],
"Social insurance card No.": ["socialinsurancecardno", "sothebhyt"],
"Medical service provider name": ["medicalserviceprovidername", "tencosoyte"],
"Department name": ["departmentname", "tenkhoadieutri"],
"Diagnosis description": ["diagnosisdescription", "motachandoan"],
"Diagnosis code": ["diagnosiscode", "machandoan"],
"Admission date": ["admissiondate", "ngaynhapvien"],
"Discharge date": ["dischargedate", "ngayxuatvien"],
"Treatment method": ["treatmentmethod", "phuongphapdieutri"],
"Treatment date": ["treatmentdate", "ngaydieutri", "ngaykham"],
"Follow up treatment date": ["followuptreatmentdate", "ngaytaikham"],
# "Name of precribed medicine": [],
# "Quantity of prescribed medicine": [],
# "Dosage for each medicine": []
"Medical expense": ["Medical expense", "chiphiyte"],
"Invoice No.": ["invoiceno", "sohoadon"],
}
title_dict = {
"Chứng từ y tế": [
"giayravien",
"giaychungnhanphauthuat",
"cachthucphauthuat",
"phauthuat",
"tomtathosobenhan",
"donthuoc",
"toathuoc",
"donbosung",
"ketquaconghuongtu"
"ketqua",
"phieuchidinh",
"phieudangkykham",
"giayhenkhamlai",
"phieukhambenh",
"phieukhambenhvaovien",
"phieuxetnghiem",
"phieuketquaxetnghiem",
"phieuchidinhxetnghiem",
"ketquasieuam",
"phieuchidinhxetnghiem"
],
"Chứng từ thanh toán": [
"hoadon",
"hoadongiatrigiatang",
"hoadongiatrigiatangchuyendoituhoadondientu",
"bangkechiphibaohiem",
"bienlaithutien",
"bangkechiphidieutrinoitru"
],
}
if type == "key":
return key_dict
elif type == "title":
return title_dict
else:
raise ValueError(f"[ERROR] Dictionary type of {type} is not supported")

View File

@ -0,0 +1,30 @@
import numpy as np
from pathlib import Path
from typing import Union, Tuple, List
import sys, os
cur_dir = os.path.dirname(__file__)
sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine"))
from src.ocr import OcrEngine
def load_ocr_engine() -> OcrEngine:
print("[INFO] Loading engine...")
engine = OcrEngine()
print("[INFO] Engine loaded")
return engine
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
save_dir_or_path = Path(save_dir_or_path)
if isinstance(img, np.ndarray):
if save_dir_or_path.is_dir():
raise ValueError("numpy array input require a save path, not a save dir")
page = engine(img)
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
page.write_to_file('word', save_path)
if export_img:
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
def read_img(img: Union[str, np.ndarray], engine: OcrEngine):
page = engine(img)
return ' '.join([f.text for f in page.llines])

View File

@ -0,0 +1,808 @@
import os
import cv2
import json
import random
import glob
import re
import numpy as np
from tqdm import tqdm
from pdf2image import convert_from_path
from dicttoxml import dicttoxml
from word_preprocess import (
vat_standardizer,
ap_standardizer,
get_string_with_word2line,
split_key_value_by_colon,
normalize_kvu_output,
normalize_kvu_output_for_manulife,
manulife_standardizer
)
from utils.kvu_dictionary import (
vat_dictionary,
ap_dictionary,
manulife_dictionary
)
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
# else:
# print("DIR already existed.")
# print('Save dir : {}'.format(save_dir))
def convert_pdf2img(pdf_dir, save_dir):
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
print('No. pdf files:', len(pdf_files))
print(pdf_files)
for file in tqdm(pdf_files):
pdf2img(file, save_dir, n_pages=-1, return_fname=False)
# pages = convert_from_path(file, 500)
# for i, page in enumerate(pages):
# page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
print('Done!!!')
def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False):
file_names = []
pages = convert_from_path(pdf_path)
if n_pages != -1:
pages = pages[:n_pages]
for i, page in enumerate(pages):
_save_path = os.path.join(save_dir, os.path.basename(pdf_path).replace('.pdf', f'_{i}.jpg'))
page.save(_save_path, 'JPEG')
file_names.append(_save_path)
if return_fname:
return file_names
def xyxy2xywh(bbox):
return [
float(bbox[0]),
float(bbox[1]),
float(bbox[2]) - float(bbox[0]),
float(bbox[3]) - float(bbox[1]),
]
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def read_xml(file_path):
with open(file_path, 'r') as xml_file:
return xml_file.read()
def write_to_xml(file_path, content):
with open(file_path, mode="w", encoding='utf8') as f:
f.write(content)
def write_to_xml_from_dict(file_path, content):
xml = dicttoxml(content)
xml = content
xml_decode = xml.decode()
with open(file_path, mode="w") as f:
f.write(xml_decode)
def load_ocr_result(ocr_path):
with open(ocr_path, 'r') as f:
lines = f.read().splitlines()
preds = []
for line in lines:
preds.append(line.split('\t'))
return preds
def post_process_basic_ocr(lwords: list) -> list:
pp_lwords = []
for word in lwords:
pp_lwords.append(word.replace("", " "))
return pp_lwords
def read_ocr_result_from_txt(file_path: str):
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
word_info = line.split("\t")
if len(word_info) == 6:
x1, y1, x2, y2, text, _ = word_info
elif len(word_info) == 5:
x1, y1, x2, y2, text = word_info
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_colormap():
return {
'others': (0, 0, 255), # others: red
'title': (0, 255, 255), # title: yellow
'key': (255, 0, 0), # key: blue
'value': (0, 255, 0), # value: green
'header': (233, 197, 15), # header
'group': (0, 128, 128), # group
'relation': (0, 0, 255)# (128, 128, 128), # relation
}
def convert_image(image):
exif = image._getexif()
orientation = None
if exif is not None:
orientation = exif.get(0x0112)
# Convert the PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Rotate the image in OpenCV if necessary
if orientation == 3:
image = cv2.rotate(image, cv2.ROTATE_180)
elif orientation == 6:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
elif orientation == 8:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
image = np.asarray(image)
if len(image.shape) == 2:
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
assert len(image.shape) == 3
return image, orientation
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
image, orientation = convert_image(image)
# if orientation is not None and orientation == 6:
# width, height, _ = image.shape
# else:
# height, width, _ = image.shape
if len(pr_class_words) > 0:
id2label = {k: labels[k] for k in range(len(labels))}
for lb, groups in enumerate(pr_class_words):
if lb == 0:
continue
for group_id, group in enumerate(groups):
for i, word_id in enumerate(group):
# x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
# x0, y0, x1, y1 = revert_box(bbox[word_id], width, height)
x0, y0, x1, y1 = bbox[word_id]
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
if i == 0:
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
else:
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
x_center0, y_center0 = x_center1, y_center1
if len(pr_relations) > 0:
for pair in pr_relations:
# xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
# xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
# xyxy0 = revert_box(bbox[pair[0]], width, height)
# xyxy1 = revert_box(bbox[pair[1]], width, height)
xyxy0 = bbox[pair[0]]
xyxy1 = bbox[pair[1]]
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
return image
def revert_box(box, width, height):
return [
int((box[0] / 1000) * width),
int((box[1] / 1000) * height),
int((box[2] / 1000) * width),
int((box[3] / 1000) * height)
]
def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list:
points = [lbbox[i] for i in lword_ids]
x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1]
x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3]
return [x_min, y_min, x_max, y_max]
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
outputs = {}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] in (rel_from, rel_to):
is_rel[element['class']]['status'] = 1
is_rel[element['class']]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
return outputs
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
list_keys = list(header_key_pairs.keys())
relations = {k: [] for k in list_keys}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] == rel_from and element['group_id'] in list_keys:
is_rel[rel_from]['status'] = 1
is_rel[rel_from]['value'] = element
if element['class'] == rel_to:
is_rel[rel_to]['status'] = 1
is_rel[rel_to]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
return relations
def get_key2values_relations(key_value_pairs: dict):
triple_linkings = {}
for value_group_id, key_value_pair in key_value_pairs.items():
key_group_id = key_value_pair[0]
if key_group_id not in list(triple_linkings.keys()):
triple_linkings[key_group_id] = []
triple_linkings[key_group_id].append(value_group_id)
return triple_linkings
def merged_token_to_wordgroup(class_words: list, lwords: list, lbboxes: list, labels: list) -> dict:
word_groups = {}
id2class = {i: labels[i] for i in range(len(labels))}
for class_id, lwgroups_in_class in enumerate(class_words):
for ltokens_in_wgroup in lwgroups_in_class:
group_id = ltokens_in_wgroup[0]
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
ltokens_to_lbboxes = [lbboxes[token] for token in ltokens_in_wgroup]
# text_string = get_string(ltokens_to_ltexts)
# text_string= get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes)
text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes)
group_bbox = get_wordgroup_bbox(lbboxes, ltokens_in_wgroup)
word_groups[group_id] = {
'group_id': group_id,
'text': text_string,
'class': id2class[class_id],
'tokens': ltokens_in_wgroup,
'bbox': group_bbox
}
return word_groups
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
if linking_id not in list(word_groups):
for wg_id, _word_group in word_groups.items():
if linking_id in _word_group['tokens']:
return wg_id
return linking_id
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
outputs = []
for pair in lrelations:
wg_from = verify_linking_id(word_groups, pair[0])
wg_to = verify_linking_id(word_groups, pair[1])
try:
outputs.append([word_groups[wg_from], word_groups[wg_to]])
except:
print('Not valid pair:', wg_from, wg_to)
return outputs
def get_single_entity(word_groups: dict, lrelations: list) -> list:
single_entity = {'title': [], 'key': [], 'value': [], 'header': []}
list_linked_ids = []
for pair in lrelations:
list_linked_ids.extend(pair)
for word_group_id, word_group in word_groups.items():
if word_group_id not in list_linked_ids:
single_entity[word_group['class']].append(word_group)
return single_entity
def export_kvu_outputs(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels)
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
single_entity = get_single_entity(word_groups, lrelations)
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
triplet_pairs = []
single_pairs = []
table = []
# print('key2values_relations', key2values_relations)
for key_group_id, list_value_group_ids in key2values_relations.items():
if len(list_value_group_ids) == 0: continue
elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())):
value_group_id = list_value_group_ids[0]
single_pairs.append({word_groups[key_group_id]['text']: {
'text': word_groups[value_group_id]['text'],
'id': value_group_id,
'class': "value",
'bbox': word_groups[value_group_id]['bbox'],
'key_bbox': word_groups[key_group_id]['bbox']
}})
else:
item = []
for value_group_id in list_value_group_ids:
if value_group_id not in header_value.keys():
header_group_id = -1 # temp
header_name_for_value = "non-header"
else:
header_group_id = header_value[value_group_id][0]
header_name_for_value = word_groups[header_group_id]['text']
item.append({
'text': word_groups[value_group_id]['text'],
'header': header_name_for_value,
'id': value_group_id,
"key_id": key_group_id,
"header_id": header_group_id,
'class': 'value',
'bbox': word_groups[value_group_id]['bbox'],
'key_bbox': word_groups[key_group_id]['bbox'],
'header_bbox': word_groups[header_group_id]['bbox'] if header_group_id != -1 else [0, 0, 0, 0],
})
if key_group_id not in list(header_key.keys()):
triplet_pairs.append({
word_groups[key_group_id]['text']: item
})
else:
header_group_id = header_key[key_group_id][0]
header_name_for_key = word_groups[header_group_id]['text']
item.append({
'text': word_groups[key_group_id]['text'],
'header': header_name_for_key,
'id': key_group_id,
"key_id": key_group_id,
"header_id": header_group_id,
'class': 'key',
'bbox': word_groups[value_group_id]['bbox'],
'key_bbox': word_groups[key_group_id]['bbox'],
'header_bbox': word_groups[header_group_id]['bbox'],
})
table.append({key_group_id: item})
single_entity_dict = {}
for class_name, single_items in single_entity.items():
single_entity_dict[class_name] = []
for single_item in single_items:
single_entity_dict[class_name].append({
'text': single_item['text'],
'id': single_item['group_id'],
'class': class_name,
'bbox': single_item['bbox']
})
if len(table) > 0:
table = sorted(table, key=lambda x: list(x.keys())[0])
table = [v for item in table for k, v in item.items()]
outputs = {}
outputs['title'] = single_entity_dict['title']
outputs['key'] = single_entity_dict['key']
outputs['value'] = single_entity_dict['value']
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
outputs['triplet'] = triplet_pairs
outputs['table'] = table
create_dir(os.path.join(os.path.dirname(file_path), 'kvu_results'))
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
write_to_json(file_path, outputs)
return outputs
def export_kvu_for_all(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']) -> dict:
raw_outputs = export_kvu_outputs(
file_path, lwords, lbboxes, class_words, lrelations, labels
)
outputs = {}
# Title
outputs["title"] = (
raw_outputs["title"][0]["text"] if len(raw_outputs["title"]) > 0 else None
)
# Pairs of key-value
for pair in raw_outputs["single"]:
for key, values in pair.items():
# outputs[key] = values["text"]
elements = split_key_value_by_colon(key, values["text"])
outputs[elements[0]] = elements[1]
# Only key fields
for key in raw_outputs["key"]:
# outputs[key["text"]] = None
elements = split_key_value_by_colon(key["text"], None)
outputs[elements[0]] = elements[1]
# Triplet data
for triplet in raw_outputs["triplet"]:
for key, list_value in triplet.items():
outputs[key] = [value["text"] for value in list_value]
# Table data
table = []
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
if header_list:
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
print("Header_list:", header_list.keys())
for row in raw_outputs["table"]:
item = {header: None for header in list(header_list.keys())}
for cell in row:
item[cell["header"]] = cell["text"]
table.append(item)
outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}]
else:
outputs["tables"] = []
outputs = normalize_kvu_output(outputs)
# write_to_json(file_path, outputs)
return outputs
def export_kvu_for_manulife(
file_path,
lwords,
lbboxes,
class_words,
lrelations,
labels=["others", "title", "key", "value", "header"],
) -> dict:
raw_outputs = export_kvu_outputs(
file_path, lwords, lbboxes, class_words, lrelations, labels
)
outputs = {}
# Title
title_list = []
for title in raw_outputs["title"]:
is_match, title_name, score, proceessed_text = manulife_standardizer(title["text"], threshold=0.6, type_dict="title")
title_list.append({
'documment_type': title_name if is_match else None,
'content': title['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': title['id']
})
if len(title_list) > 0:
selected_element = max(title_list, key=lambda x: x['lcs_score'])
outputs["title"] = selected_element['content'].upper()
outputs["class_doc"] = selected_element['documment_type']
outputs["Loại chứng từ"] = selected_element['documment_type']
outputs["Tên chứng từ"] = selected_element['content']
else:
outputs["title"] = None
outputs["class_doc"] = None
outputs["Loại chứng từ"] = None
outputs["Tên chứng từ"] = None
# Pairs of key-value
for pair in raw_outputs["single"]:
for key, values in pair.items():
# outputs[key] = values["text"]
elements = split_key_value_by_colon(key, values["text"])
outputs[elements[0]] = elements[1]
# Only key fields
for key in raw_outputs["key"]:
# outputs[key["text"]] = None
elements = split_key_value_by_colon(key["text"], None)
outputs[elements[0]] = elements[1]
# Triplet data
for triplet in raw_outputs["triplet"]:
for key, list_value in triplet.items():
outputs[key] = [value["text"] for value in list_value]
# Table data
table = []
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
if header_list:
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
# print("Header_list:", header_list.keys())
for row in raw_outputs["table"]:
item = {header: None for header in list(header_list.keys())}
for cell in row:
item[cell["header"]] = cell["text"]
table.append(item)
outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}]
else:
outputs["tables"] = []
outputs = normalize_kvu_output_for_manulife(outputs)
# write_to_json(file_path, outputs)
return outputs
# For FI-VAT project
def get_vat_table_information(outputs):
table = []
for single_item in outputs['table']:
headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item]
item = {k: [] for k in headers}
print(item)
for cell in single_item:
# header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
# if header_name in list(item.keys()):
# item[header_name] = value['text']
item[cell['header']].append({
'content': cell['text'],
'processed_key_name': cell['header'],
'lcs_score': random.uniform(0.75, 1.0),
'token_id': cell['id']
})
# for header_name, value in item.items():
# if len(value) == 0:
# if header_name in ("Số lượng", "Doanh số mua chưa có thuế"):
# item[header_name] = '0'
# else:
# item[header_name] = None
# continue
# item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
# item = post_process_for_item(item)
# if item["Mặt hàng"] == None:
# continue
table.append(item)
print(table)
return table
def get_vat_information(outputs):
# VAT Information
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
for pair in outputs['single']:
for raw_key_name, value in pair.items():
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id'],
})
for triplet in outputs['triplet']:
for key, value_list in triplet.items():
if len(value_list) == 1:
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value_list[0]['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value_list[0]['id']
})
for pair in value_list:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
for table_row in outputs['table']:
for pair in table_row:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
return single_pairs
def post_process_vat_information(single_pairs):
vat_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
if len(list_potential_value) == 1:
vat_outputs[key_name] = list_potential_value[0]['content']
else:
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
for value in list_potential_value:
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
else:
if len(list_potential_value) == 0: continue
if key_name in ("Mã số thuế người bán"):
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
# tax_code_raw = selected_value['content'].replace(' ', '')
tax_code_raw = selected_value['content']
if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated
tax_code_raw = tax_code_raw.split(' ')
tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0]
vat_outputs[key_name] = tax_code_raw.replace(' ', '')
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
vat_outputs[key_name] = selected_value['content']
return vat_outputs
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
vat_outputs = {}
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_vat_table_information(outputs)
# table = outputs["table"]
for pair in outputs['single']:
for raw_key_name, value in pair.items():
vat_outputs[raw_key_name] = value['text']
# VAT Information
# single_pairs = get_vat_information(outputs)
# vat_outputs = post_process_vat_information(single_pairs)
# Combine VAT information and table
vat_outputs['table'] = table
write_to_json(file_path, vat_outputs)
print(vat_outputs)
return vat_outputs
# For SBT project
def get_ap_table_information(outputs):
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
if header_name in list(item.keys()):
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
return table
def get_ap_triplet_information(outputs):
triplet_pairs = []
for single_item in outputs['triplet']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
is_item_valid = 0
for key_name, list_value in single_item.items():
for value in list_value:
if value['header'] == "non-header":
continue
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
if header_name in list(item.keys()):
is_item_valid = 1
item[header_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
if is_item_valid == 1:
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item['productname'] = key_name
# triplet_pairs.append({key_name: new_item})
triplet_pairs.append(item)
return triplet_pairs
def get_ap_information(outputs):
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
for pair in outputs['single']:
for raw_key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
## Get single_pair if it in a table (Product Information)
is_product_info = False
for table_row in outputs['table']:
pair = {"key": None, 'value': None}
for cell in table_row:
_, _, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=False)
if any(txt in proceessed_text for txt in ['product', 'information', 'productinformation']):
is_product_info = True
if cell['class'] in pair:
pair[cell['class']] = cell
if all(v is not None for k, v in pair.items()) and is_product_info == True:
key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False)
# print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': pair['value']['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['value']['id']
})
## end_block
ap_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if len(list_potential_value) == 0: continue
if key_name == "imei_number":
# print('list_potential_value', list_potential_value)
# ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5]
ap_outputs[key_name] = []
for v in list_potential_value:
imei = v['content'].replace(' ', '')
if imei.isdigit() and len(imei) > 5: # imei is number and have more 5 digits
ap_outputs[key_name].append(imei)
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
ap_outputs[key_name] = selected_value['content']
return ap_outputs
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_ap_table_information(outputs)
triplet_pairs = get_ap_triplet_information(outputs)
table = table + triplet_pairs
ap_outputs = get_ap_information(outputs)
ap_outputs['table'] = table
# ap_outputs['triplet'] = triplet_pairs
write_to_json(file_path, ap_outputs)
return ap_outputs

View File

@ -0,0 +1,226 @@
class Word():
def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""):
self.type = "word"
self.text =text
self.image = image
self.conf_detect = conf_detect
self.conf_cls = conf_cls
self.boundingbox = bndbox # [left, top,right,bot] coordinate of top-left and bottom-right point
self.word_id = 0 # id of word
self.word_group_id = 0 # id of word_group which instance belongs to
self.line_id = 0 #id of line which instance belongs to
self.paragraph_id = 0 #id of line which instance belongs to
self.kie_label = kie_label
def invalid_size(self):
return (self.boundingbox[2] - self.boundingbox[0]) * (self.boundingbox[3] - self.boundingbox[1]) > 0
def is_special_word(self):
left, top, right, bottom = self.boundingbox
width, height = right - left, bottom - top
text = self.text
if text is None:
return True
# if len(text) > 7:
# return True
if len(text) >= 7:
no_digits = sum(c.isdigit() for c in text)
return no_digits / len(text) >= 0.3
return False
class Word_group():
def __init__(self):
self.type = "word_group"
self.list_words = [] # dict of word instances
self.word_group_id = 0 # word group id
self.line_id = 0 #id of line which instance belongs to
self.paragraph_id = 0# id of paragraph which instance belongs to
self.text =""
self.boundingbox = [-1,-1,-1,-1]
self.kie_label =""
def add_word(self, word:Word): #add a word instance to the word_group
if word.text != "":
for w in self.list_words:
if word.word_id == w.word_id:
print("Word id collision")
return False
word.word_group_id = self.word_group_id #
word.line_id = self.line_id
word.paragraph_id = self.paragraph_id
self.list_words.append(word)
self.text += ' '+ word.text
if self.boundingbox == [-1,-1,-1,-1]:
self.boundingbox = word.boundingbox
else:
self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]),
min(self.boundingbox[1], word.boundingbox[1]),
max(self.boundingbox[2], word.boundingbox[2]),
max(self.boundingbox[3], word.boundingbox[3])]
return True
else:
return False
def update_word_group_id(self, new_word_group_id):
self.word_group_id = new_word_group_id
for i in range(len(self.list_words)):
self.list_words[i].word_group_id = new_word_group_id
def update_kie_label(self):
list_kie_label = [word.kie_label for word in self.list_words]
dict_kie = dict()
for label in list_kie_label:
if label not in dict_kie:
dict_kie[label]=1
else:
dict_kie[label]+=1
total = len(list(dict_kie.values()))
max_value = max(list(dict_kie.values()))
list_keys = list(dict_kie.keys())
list_values = list(dict_kie.values())
self.kie_label = list_keys[list_values.index(max_value)]
class Line():
def __init__(self):
self.type = "line"
self.list_word_groups = [] # list of Word_group instances in the line
self.line_id = 0 #id of line in the paragraph
self.paragraph_id = 0 # id of paragraph which instance belongs to
self.text = ""
self.boundingbox = [-1,-1,-1,-1]
def add_group(self, word_group:Word_group): # add a word_group instance
if word_group.list_words is not None:
for wg in self.list_word_groups:
if word_group.word_group_id == wg.word_group_id:
print("Word_group id collision")
return False
self.list_word_groups.append(word_group)
self.text += word_group.text
word_group.paragraph_id = self.paragraph_id
word_group.line_id = self.line_id
for i in range(len(word_group.list_words)):
word_group.list_words[i].paragraph_id = self.paragraph_id #set paragraph_id for word
word_group.list_words[i].line_id = self.line_id #set line_id for word
return True
return False
def update_line_id(self, new_line_id):
self.line_id = new_line_id
for i in range(len(self.list_word_groups)):
self.list_word_groups[i].line_id = new_line_id
for j in range(len(self.list_word_groups[i].list_words)):
self.list_word_groups[i].list_words[j].line_id = new_line_id
def merge_word(self, word): # word can be a Word instance or a Word_group instance
if word.text != "":
if self.boundingbox == [-1,-1,-1,-1]:
self.boundingbox = word.boundingbox
else:
self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]),
min(self.boundingbox[1], word.boundingbox[1]),
max(self.boundingbox[2], word.boundingbox[2]),
max(self.boundingbox[3], word.boundingbox[3])]
self.list_word_groups.append(word)
self.text += ' ' + word.text
return True
return False
def in_same_line(self, input_line, thresh=0.7):
# calculate iou in vertical direction
left1, top1, right1, bottom1 = self.boundingbox
left2, top2, right2, bottom2 = input_line.boundingbox
sorted_vals = sorted([top1, bottom1, top2, bottom2])
intersection = sorted_vals[2] - sorted_vals[1]
union = sorted_vals[3]-sorted_vals[0]
min_height = min(bottom1-top1, bottom2-top2)
if min_height==0:
return False
ratio = intersection / min_height
height1, height2 = top1-bottom1, top2-bottom2
ratio_height = float(max(height1, height2))/float(min(height1, height2))
# height_diff = (float(top1-bottom1))/(float(top2-bottom2))
if (top1 in range(top2, bottom2) or top2 in range(top1, bottom1)) and ratio >= thresh and (ratio_height<2):
return True
return False
def check_iomin(word:Word, word_group:Word_group):
min_height = min(word.boundingbox[3]-word.boundingbox[1],word_group.boundingbox[3]-word_group.boundingbox[1])
intersect = min(word.boundingbox[3],word_group.boundingbox[3]) - max(word.boundingbox[1],word_group.boundingbox[1])
if intersect/min_height > 0.7:
return True
return False
def words_to_lines(words, check_special_lines=True): #words is list of Word instance
#sort word by top
words.sort(key = lambda x: (x.boundingbox[1], x.boundingbox[0]))
number_of_word = len(words)
#sort list words to list lines, which have not contained word_group yet
lines = []
for i, word in enumerate(words):
if word.invalid_size()==0:
continue
new_line = True
for i in range(len(lines)):
if lines[i].in_same_line(word): #check if word is in the same line with lines[i]
lines[i].merge_word(word)
new_line = False
if new_line ==True:
new_line = Line()
new_line.merge_word(word)
lines.append(new_line)
# print(len(lines))
#sort line from top to bottom according top coordinate
lines.sort(key = lambda x: x.boundingbox[1])
#construct word_groups in each line
line_id = 0
word_group_id =0
word_id = 0
for i in range(len(lines)):
if len(lines[i].list_word_groups)==0:
continue
#left, top ,right, bottom
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
# print("line_width",line_width)
lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right
#update text for lines after sorting
lines[i].text =""
for word in lines[i].list_word_groups:
lines[i].text += " "+word.text
list_word_groups=[]
inital_word_group = Word_group()
inital_word_group.word_group_id= word_group_id
word_group_id +=1
lines[i].list_word_groups[0].word_id=word_id
inital_word_group.add_word(lines[i].list_word_groups[0])
word_id+=1
list_word_groups.append(inital_word_group)
for word in lines[i].list_word_groups[1:]: #iterate through each word object in list_word_groups (has not been construted to word_group yet)
check_word_group= True
#set id for each word in each line
word.word_id = word_id
word_id+=1
if (not list_word_groups[-1].text.endswith(':')) and ((word.boundingbox[0]-list_word_groups[-1].boundingbox[2])/line_width <0.05) and check_iomin(word, list_word_groups[-1]):
list_word_groups[-1].add_word(word)
check_word_group=False
if check_word_group ==True:
new_word_group = Word_group()
new_word_group.word_group_id= word_group_id
word_group_id +=1
new_word_group.add_word(word)
list_word_groups.append(new_word_group)
lines[i].list_word_groups = list_word_groups
# set id for lines from top to bottom
lines[i].update_line_id(line_id)
line_id +=1
return lines, number_of_word

View File

@ -0,0 +1,388 @@
import nltk
import re
import string
import copy
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML
from word2line import Word, words_to_lines
nltk.download('words')
words = set(nltk.corpus.words.words())
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
# def clean_text(text):
# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text)
def get_string(lwords: list):
unique_list = []
for item in lwords:
if item.isdigit() and len(item) == 1:
unique_list.append(item)
elif item not in unique_list:
unique_list.append(item)
return ' '.join(unique_list)
def remove_english_words(text):
_word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words]
return ' '.join(_word)
def remove_punctuation(text):
return text.translate(str.maketrans(" ", " ", string.punctuation))
def remove_accents(input_str, s0, s1):
s = ''
# print input_str.encode('utf-8')
for c in input_str:
if c in s1:
s += s0[s1.index(c)]
else:
s += c
return s
def remove_spaces(text):
return text.replace(' ', '')
def preprocessing(text: str):
# text = remove_english_words(text) if table else text
text = remove_punctuation(text)
text = remove_accents(text, s0, s1)
text = remove_spaces(text)
return text.lower()
def vat_standardize_outputs(vat_outputs: dict) -> dict:
outputs = {}
for key, value in vat_outputs.items():
if key != "table":
outputs[DKVU2XML[key]] = value
else:
list_items = []
for item in value:
list_items.append({
DKVU2XML[item_key]: item_value for item_key, item_value in item.items()
})
outputs['table'] = list_items
return outputs
def vat_standardizer(text: str, threshold: float, header: bool):
dictionary = vat_dictionary(header)
processed_text = preprocessing(text)
for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]:
if any([processed_text in txt for txt in candidates]):
processed_text = candidates[-1]
return "Ngày, tháng, năm lập hóa đơn", 5, processed_text
_dictionary = copy.deepcopy(dictionary)
if not header:
exact_dictionary = {
'Số hóa đơn': ['sono', 'so'],
'Mã số thuế người bán': ['mst'],
'Tên người bán': ['kyboi'],
'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
}
for k, v in exact_dictionary.items():
_dictionary[k] = dictionary[k] + exact_dictionary[k]
for k, v in dictionary.items():
# if k in ("Ngày, tháng, năm lập hóa đơn"):
# continue
# Prioritize match completely
if k in ('Tên người bán') and processed_text == "kyboi":
return k, 8, processed_text
if any([processed_text == key for key in _dictionary[k]]):
return k, 10, processed_text
scores = {k: 0.0 for k in dictionary}
for k, v in dictionary.items():
if k in ("Ngày, tháng, năm lập hóa đơn"):
continue
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
key, score = max(scores.items(), key=lambda x: x[1])
return key if score > threshold else text, score, processed_text
def ap_standardizer(text: str, threshold: float, header: bool):
dictionary = ap_dictionary(header)
processed_text = preprocessing(text)
# Prioritize match completely
_dictionary = copy.deepcopy(dictionary)
if not header:
_dictionary['serial_number'] = dictionary['serial_number'] + ['sn']
_dictionary['imei_number'] = dictionary['imei_number'] + ['imel', 'imed', 'ime'] # text recog error
else:
_dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei']
_dictionary['qty'] = dictionary['qty'] + ['qty']
for k, v in dictionary.items():
if any([processed_text == key for key in _dictionary[k]]):
return k, 10, processed_text
scores = {k: 0.0 for k in dictionary}
for k, v in dictionary.items():
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
key, score = max(scores.items(), key=lambda x: x[1])
return key if score >= threshold else text, score, processed_text
def manulife_standardizer(text: str, threshold: float, type_dict: str):
dictionary = manulife_dictionary(type=type_dict)
processed_text = preprocessing(text)
for key, candidates in dictionary.items():
if any([txt == processed_text for txt in candidates]):
return True, key, 5 * (1 + len(processed_text)), processed_text
if any([txt in processed_text for txt in candidates]):
return True, key, 5, processed_text
scores = {k: 0.0 for k in dictionary}
for k, v in dictionary.items():
if len(v) == 0:
continue
scores[k] = max(
[
longestCommonSubsequence(processed_text, key) / len(key)
for key in dictionary[k]
]
)
key, score = max(scores.items(), key=lambda x: x[1])
return score > threshold, key if score > threshold else text, score, processed_text
def convert_format_number(s: str) -> float:
s = s.replace(' ', '').replace('O', '0').replace('o', '0')
if s.endswith(",00") or s.endswith(".00"):
s = s[:-3]
if all([delimiter in s for delimiter in [',', '.']]):
s = s.replace('.', '').split(',')
remain_value = s[1].split('0')[0]
return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value))
else:
s = s.replace(',', '').replace('.', '')
return int(s)
def post_process_for_item(item: dict) -> dict:
check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế']
mis_key = []
for key in check_keys:
if item[key] in (None, '0'):
mis_key.append(key)
if len(mis_key) == 1:
try:
if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0:
item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__()
elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0:
item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__()
elif mis_key[0] == check_keys[2]:
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
except Exception as e:
print("Cannot post process this item with error:", e)
return item
def longestCommonSubsequence(text1: str, text2: str) -> int:
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
for i, c in enumerate(text1):
for j, d in enumerate(text2):
dp[i + 1][j + 1] = 1 + \
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
return dp[-1][-1]
def longest_common_subsequence_with_idx(X, Y):
"""
This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS.
The longest_common_subsequence function takes two strings as input, and returns a tuple with three values:
the length of the LCS,
the index of the first character of the LCS in the first string,
and the index of the last character of the LCS in the first string.
"""
m, n = len(X), len(Y)
L = [[0 for i in range(n + 1)] for j in range(m + 1)]
# Following steps build L[m+1][n+1] in bottom up fashion. Note
# that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
right_idx = 0
max_lcs = 0
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
if L[i][j] > max_lcs:
max_lcs = L[i][j]
right_idx = i
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
# Create a string variable to store the lcs string
lcs = L[i][j]
# Start from the right-most-bottom-most corner and
# one by one store characters in lcs[]
i = m
j = n
# right_idx = 0
while i > 0 and j > 0:
# If current character in X[] and Y are same, then
# current character is part of LCS
if X[i - 1] == Y[j - 1]:
i -= 1
j -= 1
# If not same, then find the larger of two and
# go in the direction of larger value
elif L[i - 1][j] > L[i][j - 1]:
# right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs
i -= 1
else:
j -= 1
return lcs, i, max(i + lcs, right_idx)
def get_string_by_deduplicate_bbox(lwords: list, lbboxes: list):
unique_list = []
prev_bbox = [-1, -1, -1, -1]
for word, bbox in zip(lwords, lbboxes):
if bbox != prev_bbox:
unique_list.append(word)
prev_bbox = bbox
return ' '.join(unique_list)
def get_string_with_word2line(lwords: list, lbboxes: list):
list_words = []
unique_list = []
list_sorted_words = []
prev_bbox = [-1, -1, -1, -1]
for word, bbox in zip(lwords, lbboxes):
if bbox != prev_bbox:
prev_bbox = bbox
list_words.append(Word(image=None, text=word, conf_cls=-1, bndbox=bbox, conf_detect=-1))
unique_list.append(word)
llines = words_to_lines(list_words)[0]
for line in llines:
for _word_group in line.list_word_groups:
for _word in _word_group.list_words:
list_sorted_words.append(_word.text)
string_from_model = ' '.join(unique_list)
string_after_word2line = ' '.join(list_sorted_words)
if string_from_model != string_after_word2line:
print("[Warning] Word group from model is different with word2line module")
print("Model: ", ' '.join(unique_list))
print("Word2line: ", ' '.join(list_sorted_words))
return string_after_word2line
def remove_bullet_points_and_punctuation(text):
# Remove bullet points (e.g., • or -)
text = re.sub(r'^\s*[\\-\*]\s*', '', text, flags=re.MULTILINE)
text = re.sub("^\d+\s*", "", text)
text = text.strip()
# # Remove end-of-sentence punctuation (e.g., ., !, ?)
# text = re.sub(r'[.!?]', '', text)
if len(text) > 0 and text[0] in (',', '.', ':', ';', '?', '!'):
text = text[1:]
if len(text) > 0 and text[-1] in (',', '.', ':', ';', '?', '!'):
text = text[:-1]
return text.strip()
def split_key_value_by_colon(key: str, value: str) -> list:
text_string = key + " " + value if value is not None else key
elements = text_string.split(':')
if len(elements) > 1:
return elements[0], text_string[len(elements[0]):]
return key, value
# def normalize_kvu_output(raw_outputs: dict) -> dict:
# outputs = {}
# for key, values in raw_outputs.items():
# if key == "table":
# table = []
# for row in values:
# item = {}
# for k, v in row.items():
# k = remove_bullet_points_and_punctuation(k)
# if v is not None and len(v) > 0:
# v = remove_bullet_points_and_punctuation(v)
# item[k] = v
# table.append(item)
# outputs[key] = table
# else:
# key = remove_bullet_points_and_punctuation(key)
# if isinstance(values, list):
# values = [remove_bullet_points_and_punctuation(v) for v in values]
# elif values is not None and len(values) > 0:
# values = remove_bullet_points_and_punctuation(values)
# outputs[key] = values
# return outputs
def normalize_kvu_output(raw_outputs: dict) -> dict:
outputs = {}
for key, values in raw_outputs.items():
if key == "tables" and len(values) > 0:
table_list = []
for table in values:
headers, data = [], []
headers = [remove_bullet_points_and_punctuation(header).upper() for header in table['headers']]
for row in table['data']:
item = []
for k, v in row.items():
if v is not None and len(v) > 0:
item.append(remove_bullet_points_and_punctuation(v))
else:
item.append(v)
data.append(item)
table_list.append({"headers": headers, "data": data})
outputs[key] = table_list
else:
key = remove_bullet_points_and_punctuation(key)
if isinstance(values, list):
values = [remove_bullet_points_and_punctuation(v) for v in values]
elif values is not None and len(values) > 0:
values = remove_bullet_points_and_punctuation(values)
outputs[key] = values
return outputs
def normalize_kvu_output_for_manulife(raw_outputs: dict) -> dict:
outputs = {}
for key, values in raw_outputs.items():
if key == "tables" and len(values) > 0:
table_list = []
for table in values:
headers, data = [], []
headers = [
remove_bullet_points_and_punctuation(header).upper()
for header in table["headers"]
]
for row in table["data"]:
item = []
for k, v in row.items():
if v is not None and len(v) > 0:
item.append(remove_bullet_points_and_punctuation(v))
else:
item.append(v)
data.append(item)
table_list.append({"headers": headers, "data": data})
outputs[key] = table_list
else:
if key not in ("title", "tables", "class_doc"):
key = remove_bullet_points_and_punctuation(key).capitalize()
if isinstance(values, list):
values = [remove_bullet_points_and_punctuation(v) for v in values]
elif values is not None and len(values) > 0:
values = remove_bullet_points_and_punctuation(values)
outputs[key] = values
return outputs

View File

@ -0,0 +1,58 @@
from sdsvkie import Predictor
import cv2
import numpy as np
import urllib
model = Predictor(
cfg = "/ai-core/models/Kie_invoice_ap/config.yaml",
device = "cuda:0",
weights = "/ai-core/models/Kie_invoice_ap/ep21"
)
def predict(image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
out = model(img)
output = out["end2end_results"]
output_dict = {
"document_type": "invoice",
"fields": []
}
for key in output.keys():
field = {
"label": key if key != "id" else "Receipt Number",
"value": output[key]['value'] if output[key]['value'] else "",
"box": output[key]['box'],
"confidence": output[key]['conf']
}
output_dict['fields'].append(field)
print(output_dict)
return output_dict
if __name__ == "__main__":
image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg"
output = predict(image_url)
print(output)

View File

@ -0,0 +1,77 @@
import os, sys
cur_dir = os.path.dirname(__file__)
KIE_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvkie")
TD_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtd")
TR_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtr")
sys.path.append(KIE_PATH)
sys.path.append(TD_PATH)
sys.path.append(TR_PATH)
from sdsvkie import Predictor
from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU
import cv2
import numpy as np
import urllib
model = Predictor(
cfg = "/models/Kie_invoice_ap/06062023/config.yaml", # TODO: Better be scalable
device = "cuda:0",
weights = "/models/Kie_invoice_ap/06062023/best" # TODO: Better be scalable
)
class_names = ['others', 'title', 'key', 'value', 'header']
save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test")
predictor, processor = load_engine(exp_dir="/models/Kie_invoice_fi/key_value_understanding-20230627-164536",
class_names=class_names,
mode=3)
def predict_fi(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
# img = cv2.imread(image_url)
# Phan cua LeHoang
out = model(img)
output = out["end2end_results"]
output_kie = {
field_name: field_item['value'] for field_name, field_item in output.items()
}
# print("Hoangggggggggggggggggggggggggggggggggggggggggggggg")
# print(output_kie)
#Phan cua Tuan
kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor)
# print("TuanNnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn")
# print(kvu_result)
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
return kvu_result, output_kie
if __name__ == "__main__":
image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg"
output = predict_fi(0, image_url)
print(output)

View File

@ -0,0 +1,146 @@
from sdsvkie import Predictor
import sys, os
cur_dir = os.path.dirname(__file__)
# sys.path.append("cope2n-ai-fi/Kie_Invoice_AP") # Better be relative
from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU
import cv2
import numpy as np
import urllib
import random
# model = Predictor(
# cfg = "/cope2n-ai-fi/models/Kie_invoice_ap/config.yaml",
# device = "cuda:0",
# weights = "/cope2n-ai-fi/models/Kie_invoice_ap/best"
# )
class_names = ['others', 'title', 'key', 'value', 'header']
save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test")
predictor, processor = load_engine(exp_dir=os.path.join(cur_dir, "AnyKey_Value/experiments/key_value_understanding-20231003-171748"),
class_names=class_names,
mode=3)
def predict(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
# img = cv2.imread(image_url)
# Phan cua LeHoang
# out = model(img)
# output = out["end2end_results"]
#Phan cua Tuan
kvu_result = Predictor_KVU(image_url, save_dir, predictor, processor)
output_dict = {
"document_type": "invoice",
"fields": []
}
for key in kvu_result.keys():
field = {
"label": key,
"value": kvu_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
print(output_dict)
return output_dict
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
# output_dict = {
# "document_type": "invoice",
# "fields": []
# }
# for key in output.keys():
# field = {
# "label": key if key != "id" else "Receipt Number",
# "value": output[key]['value'] if output[key]['value'] else "",
# "box": output[key]['box'],
# "confidence": output[key]['conf'],
# "page": page_numb
# }
# output_dict['fields'].append(field)
# table = kvu_result['table']
# field_table = {
# "label": "table",
# "value": table,
# "box": [0, 0, 0, 0],
# "confidence": 0.98,
# "page": page_numb
# }
# output_dict['fields'].append(field_table)
# return output_dict
# else:
# output_dict = {
# "document_type": "KSU",
# "fields": []
# }
# # for key in output.keys():
# # field = {
# # "label": key if key != "id" else "Receipt Number",
# # "value": output[key]['value'] if output[key]['value'] else "",
# # "box": output[key]['box'],
# # "confidence": output[key]['conf'],
# # "page": page_numb
# # }
# # output_dict['fields'].append(field)
# # Serial Number
# serial_number = kvu_result['serial_number']
# field_serial = {
# "label" : "serial_number",
# "value": serial_number,
# "box": [0, 0, 0, 0],
# "confidence": 0.98,
# "page": page_numb
# }
# output_dict['fields'].append(field_serial)
# # IMEI Number
# imei_number = kvu_result['imei_number']
# if imei_number == None:
# return output_dict
# if imei_number != None:
# for i in range(len(imei_number)):
# field_imei = {
# "label": "imei_number_{}".format(i+1),
# "value": imei_number[i],
# "box": [0, 0, 0, 0],
# "confidence": 0.98,
# "page": page_numb
# }
# output_dict['fields'].append(field_imei)
# return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
print(output)

View File

@ -0,0 +1,106 @@
1113 773 1220 825 BEST
1243 759 1378 808 DENKI
1410 752 1487 799 (S)
1430 707 1515 748 TAX
1511 745 1598 790 PTE
1542 700 1725 740 TNVOICE
1618 742 1706 783 LTD
1783 725 1920 773 FUNAN
1943 723 2054 767 MALL
1434 797 1576 843 WORTH
1599 785 1760 831 BRIDGE
1784 778 1846 822 RD
1277 846 1632 897 #02-16/#03-1
1655 832 1795 877 FUNAN
1817 822 1931 869 MALL
1272 897 1518 956 S(179105)
1548 890 1655 943 TEL:
1686 877 1911 928 69046183
1247 1011 1334 1068 GST
1358 1006 1447 1059 REG
1360 1063 1449 1115 RCB
1473 1003 1575 1055 NO.:
1474 1059 1555 1110 NO.
1595 1042 1868 1096 198202199E
1607 985 1944 1040 M2-0053813-7
1056 1134 1254 1194 Opening
1276 1127 1391 1181 Hrs:
1425 1112 1647 1170 10:00:00
1672 1102 1735 1161 AN
1755 1101 1819 1157 to
1846 1090 2067 1147 10:00:00
2090 1080 2156 1141 PH
1061 1308 1228 1366 Staff:
1258 1300 1378 1357 3296
1710 1283 1880 1337 Trans:
1936 1266 2192 1322 262152554
1060 1372 1201 1429 Date:
1260 1358 1494 1419 22-03-23
1540 1344 1664 1409 9:05
1712 1339 1856 1407 Slip:
1917 1328 2196 1387 2000130286
1124 1487 1439 1545 SALESPERSON
1465 1477 1601 1537 CODE.
1633 1471 1752 1530 6043
1777 1462 2004 1519 HUHAHHAD
2032 1451 2177 1509 RAZIH
1070 1558 1187 1617 Item
1211 1554 1276 1615 No
1439 1542 1585 1601 Price
1750 1530 1841 1597 Qty
1951 1517 2120 1579 Amount
1076 1683 1276 1741 ANDROID
1304 1673 1477 1733 TABLET
1080 1746 1280 1804 2105976
1509 1729 1705 1784 SAMSUNG
1734 1719 1931 1776 SH-P613
1964 1709 2101 1768 128GB
1082 1809 1285 1869 SM-P613
1316 1802 1454 1860 12838
1429 1859 1600 1919 518.00
1481 1794 1596 1855 WIFI
1622 1790 1656 1850 G
1797 1845 1824 1904 1
1993 1832 2165 1892 518.00
1088 1935 1347 1995 PROMOTION
1091 2000 1294 2062 2105664
1520 1983 1717 2039 SAMSUNG
1743 1963 2106 2030 F-Sam-Redeen
1439 2111 1557 2173 0.00
1806 2095 1832 2156 1
2053 2081 2174 2144 0.00
1106 2248 1250 2312 Total
1974 2206 2146 2266 518.00
1107 2312 1204 2377 UOB
1448 2291 1567 2355 CARD
1978 2268 2147 2327 518.00
1253 2424 1375 2497 GST%
1456 2411 1655 2475 Net.Amt
1818 2393 1912 2460 GST
2023 2387 2192 2445 Amount
1106 2494 1231 2560 GST8
1486 2472 1661 2537 479.63
1770 2458 1916 2523 38.37
2027 2448 2203 2511 518.00
1553 2601 1699 2666 THANK
1721 2592 1821 2661 YOU
1436 2678 1616 2749 please
1644 2682 1764 2732 come
1790 2660 1942 2729 again
1191 2862 1391 2931 Those
1426 2870 2018 2945 facebook.com
1565 2809 1690 2884 join
1709 2816 1777 2870 us
1799 2811 1868 2865 on
1838 2946 2024 3003 com .89
1533 3006 2070 3088 ar.com/askbe
1300 3326 1659 3446 That's
1696 3308 1905 3424 not
1937 3289 2131 3408 all!
1450 3511 1633 3573 SCAN
1392 3589 1489 3645 QR
1509 3577 1698 3635 CODE
1321 3656 1370 3714 &
1517 3638 1768 3699 updates
1643 3882 1769 3932 Scan
1789 3868 1859 3926 Me

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -0,0 +1,155 @@
from OCRBase.text_recognition import ocr_predict
import cv2
from shapely.geometry import Polygon
import urllib
import numpy as np
def check_percent_overlap_bbox(boxA, boxB):
"""check percent box A in boxB
Args:
boxA (_type_): _description_
boxB (_type_): _description_
Returns:
Float: percent overlap bbox
"""
# determine the (x, y)-coordinates of the intersection rectangle
box_shape_1 = [
[boxA[0], boxA[1]],
[boxA[2], boxA[1]],
[boxA[2], boxA[3]],
[boxA[0], boxA[3]],
]
# Give dimensions of shape 2
box_shape_2 = [
[boxB[0], boxB[1]],
[boxB[2], boxB[1]],
[boxB[2], boxB[3]],
[boxB[0], boxB[3]],
]
# Draw polygon 1 from shape 1
# dimensions
polygon_1 = Polygon(box_shape_1)
# Draw polygon 2 from shape 2
# dimensions
polygon_2 = Polygon(box_shape_2)
# Calculate the intersection of
# bounding boxes
intersect = polygon_1.intersection(polygon_2).area / polygon_1.area
return intersect
def check_box_in_box(boxA, boxB):
"""check boxA in boxB
Args:
boxA (_type_): _description_
boxB (_type_): _description_
Returns:
Boolean: True if boxA in boxB
"""
if (
boxA[0] >= boxB[0]
and boxA[1] >= boxB[1]
and boxA[2] <= boxB[2]
and boxA[3] <= boxB[3]
):
return True
else:
return False
def word_to_line_image_origin(list_words, bbox):
"""use for predict image with bbox selected
Args:
list_words (_type_): _description_
bbox (_type_): _description_
Returns:
_type_: _description_
"""
texts, boundingboxes = [], []
for line in list_words:
if line.text == "":
continue
else:
# convert to bbox image original
boundingbox = line.boundingbox
boundingbox = list(boundingbox)
boundingbox[0] = boundingbox[0] + bbox[0]
boundingbox[1] = boundingbox[1] + bbox[1]
boundingbox[2] = boundingbox[2] + bbox[0]
boundingbox[3] = boundingbox[3] + bbox[1]
texts.append(line.text)
boundingboxes.append(boundingbox)
return texts, boundingboxes
def word_to_line(list_words):
"""use for predict full image
Args:
list_words (_type_): _description_
"""
texts, boundingboxes = [], []
for line in list_words:
print(line.text)
if line.text == "":
continue
else:
boundingbox = line.boundingbox
boundingbox = list(boundingbox)
texts.append(line.text)
boundingboxes.append(boundingbox)
return texts, boundingboxes
def predict(page_numb, image_url):
"""predict text from image
Args:
image_path (String): path image to predict
list_id (List): List id of bbox selected
list_bbox (List): List bbox selected
Returns:
Dict: Dict result of prediction
"""
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
image = cv2.imdecode(arr, -1)
list_lines = ocr_predict(image)
texts, boundingboxes = word_to_line(list_lines)
result = {}
texts_replace = []
for text in texts:
if "" in text:
text = text.replace("", " ")
texts_replace.append(text)
else:
texts_replace.append(text)
result["texts"] = texts_replace
result["boundingboxes"] = boundingboxes
output_dict = {
"document_type": "ocr-base",
"fields": []
}
field = {
"label": "Text",
"value": result["texts"],
"box": result["boundingboxes"],
"confidence": 0.98,
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict

View File

@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from distutils.command.config import config
from mmdet.apis import inference_detector, init_detector
from PIL import Image
from io import BytesIO
config = "model/yolox_s_8x8_300e_cocotext_1280.py"
checkpoint = "model/best_bbox_mAP_epoch_294.pth"
device = "cpu"
def read_imagefile(file) -> Image.Image:
image = Image.open(BytesIO(file))
return image
def detection_predict(image):
model = init_detector(config, checkpoint, device=device)
# test a single image
result = inference_detector(model, image)
return result

View File

@ -0,0 +1,39 @@
from common.utils.ocr_yolox import OcrEngineForYoloX_Invoice
from common.utils.word_formation import Word, words_to_lines
det_ckpt = "/models/sdsvtd/hub/wild_receipt_finetune_weights_c_lite.pth"
cls_ckpt = "satrn-lite-general-pretrain-20230106"
engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt)
def ocr_predict(img):
"""Predict text from image
Args:
image_path (str): _description_
Returns:
list: list of words
"""
try:
lbboxes, lwords = engine.run_image(img)
lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)]
list_lines, _ = words_to_lines(lWords)
return list_lines
# return lbboxes, lwords
except AssertionError as e:
print(e)
list_lines = []
return list_lines
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
args = parser.parse_args()
list_lines = ocr_predict(args.image)

View File

@ -0,0 +1,98 @@
import cv2
import urllib
import random
import numpy as np
from pathlib import Path
import sys, os
cur_dir = str(Path(__file__).parents[2])
sys.path.append(cur_dir)
from modules.sdsvkvu import load_engine, process_img
from modules.ocr_engine import OcrEngine
from configs.manulife import device, ocr_cfg, kvu_cfg
def load_ocr_engine(opt) -> OcrEngine:
print("[INFO] Loading engine...")
engine = OcrEngine(**opt)
print("[INFO] Engine loaded")
return engine
print("OCR engine configfs: \n", ocr_cfg)
print("KVU configfs: \n", kvu_cfg)
ocr_engine = load_ocr_engine(ocr_cfg)
kvu_cfg['ocr_engine'] = ocr_engine
option = kvu_cfg['option']
kvu_cfg.pop("option") # pop option
manulife_engine = load_engine(kvu_cfg)
def manulife_predict(image_url, engine) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
save_dir = "./tmp_results"
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
image_path = os.path.join(save_dir, "abc.jpg")
cv2.imwrite(image_path, img)
outputs = process_img(img_path=image_path,
save_dir=save_dir,
engine=engine,
export_all=False,
option=option)
return outputs
def predict(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
kvu_result = manulife_predict(image_url, engine=manulife_engine)
output_dict = {
"document_type": kvu_result['title'] if kvu_result['title'] is not None else "unknown",
"document_class": kvu_result['class_doc'] if kvu_result['class_doc'] is not None else "unknown",
"page_number": page_numb,
"fields": []
}
for key in kvu_result.keys():
if key in ("title", "class_doc"):
continue
field = {
"label": key,
"value": kvu_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
print(output_dict)
return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
print(output)

View File

@ -0,0 +1,94 @@
import cv2
import urllib
import random
import numpy as np
from pathlib import Path
import sys, os
cur_dir = str(Path(__file__).parents[2])
sys.path.append(cur_dir)
from modules.sdsvkvu import load_engine, process_img
from modules.ocr_engine import OcrEngine
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
def load_ocr_engine(opt) -> OcrEngine:
print("[INFO] Loading engine...")
engine = OcrEngine(**opt)
print("[INFO] Engine loaded")
return engine
print("OCR engine configfs: \n", ocr_cfg)
print("KVU configfs: \n", kvu_cfg)
ocr_engine = load_ocr_engine(ocr_cfg)
kvu_cfg['ocr_engine'] = ocr_engine
option = kvu_cfg['option']
kvu_cfg.pop("option") # pop option
sbt_engine = load_engine(kvu_cfg)
def sbt_predict(image_url, engine) -> None:
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
save_dir = "./tmp_results"
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
image_path = os.path.join(save_dir, "abc.jpg")
cv2.imwrite(image_path, img)
outputs = process_img(img_path=image_path,
save_dir=save_dir,
engine=engine,
export_all=False,
option=option)
return outputs
def predict(page_numb, image_url):
"""
module predict function
Args:
image_url (str): image url
Returns:
example output:
"data": {
"document_type": "invoice",
"fields": [
{
"label": "Invoice Number",
"value": "INV-12345",
"box": [0, 0, 0, 0],
"confidence": 0.98
},
...
]
}
dict: output of model
"""
sbt_result = sbt_predict(image_url, engine=sbt_engine)
output_dict = {
"document_type": "invoice",
"document_class": " ",
"page_number": page_numb,
"fields": []
}
for key in sbt_result.keys():
field = {
"label": key,
"value": sbt_result[key],
"box": [0, 0, 0, 0],
"confidence": random.uniform(0.9, 1.0),
"page": page_numb
}
output_dict['fields'].append(field)
return output_dict
if __name__ == "__main__":
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
output = predict(0, image_url)
print(output)

View File

View File

@ -0,0 +1,81 @@
from celery import Celery
import base64
import environ
env = environ.Env(
DEBUG=(bool, True)
)
class CeleryConnector:
task_routes = {
"process_id_result": {"queue": "id_card_rs"},
"process_driver_license_result": {"queue": "driver_license_rs"},
"process_invoice_result": {"queue": "invoice_rs"},
"process_ocr_with_box_result": {"queue": "ocr_with_box_rs"},
"process_template_matching_result": {"queue": "template_matching_rs"},
# mock task
"process_id": {"queue": "id_card"},
"process_driver_license": {"queue": "driver_license"},
"process_invoice": {"queue": "invoice"},
"process_ocr_with_box": {"queue": "ocr_with_box"},
"process_template_matching": {"queue": "template_matching"},
}
app = Celery(
"postman",
broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
# backend="rpc://",
)
def process_id_result(self, args):
return self.send_task("process_id_result", args)
def process_driver_license_result(self, args):
return self.send_task("process_driver_license_result", args)
def process_invoice_result(self, args):
return self.send_task("process_invoice_result", args)
def process_ocr_with_box_result(self, args):
return self.send_task("process_ocr_with_box_result", args)
def process_template_matching_result(self, args):
return self.send_task("process_template_matching_result", args)
def process_id(self, args):
return self.send_task("process_id", args)
def process_driver_license(self, args):
return self.send_task("process_driver_license", args)
def process_invoice(self, args):
return self.send_task("process_invoice", args)
def process_ocr_with_box(self, args):
return self.send_task("process_ocr_with_box", args)
def process_template_matching(self, args):
return self.send_task("process_template_matching", args)
def send_task(self, name=None, args=None):
if name not in self.task_routes or "queue" not in self.task_routes[name]:
return self.app.send_task(name, args)
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
def main():
rq_id = 345
file_names = "abc.jpg"
list_data = []
with open("/home/sds/thucpd/aicr-2022/abc.jpg", "rb") as fs:
encoded_string = base64.b64encode(fs.read()).decode("utf-8")
list_data.append(encoded_string)
c_connector = CeleryConnector()
a = c_connector.process_id(args=(rq_id, list_data, file_names))
print(a)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,56 @@
from celery import Celery
import environ
env = environ.Env(
DEBUG=(bool, True)
)
class CeleryConnector:
task_routes = {
'process_fi_invoice_result': {'queue': 'invoice_fi_rs'},
'process_sap_invoice_result': {'queue': 'invoice_sap_rs'},
'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'},
'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'},
# mock task
'process_fi_invoice': {'queue': "invoice_fi"},
'process_sap_invoice': {'queue': "invoice_sap"},
'process_manulife_invoice': {'queue': "invoice_manulife"},
'process_sbt_invoice': {'queue': "invoice_sbt"},
}
app = Celery(
"postman",
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
)
# mock task for FI
def process_fi_invoice_result(self, args):
return self.send_task("process_fi_invoice_result", args)
def process_fi_invoice(self, args):
return self.send_task("process_fi_invoice", args)
# mock task for SAP
def process_sap_invoice_result(self, args):
return self.send_task("process_sap_invoice_result", args)
def process_sap_invoice(self, args):
return self.send_task("process_sap_invoice", args)
# mock task for manulife
def process_manulife_invoice_result(self, args):
return self.send_task("process_manulife_invoice_result", args)
def process_manulife_invoice(self, args):
return self.send_task("process_manulife_invoice", args)
# mock task for manulife
def process_sbt_invoice_result(self, args):
return self.send_task("process_sbt_invoice_result", args)
def process_sbt_invoice(self, args):
return self.send_task("process_sbt_invoice", args)
def send_task(self, name=None, args=None):
if name not in self.task_routes or "queue" not in self.task_routes[name]:
return self.app.send_task(name, args)
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])

View File

@ -0,0 +1,220 @@
from celery_worker.worker import app
import numpy as np
import cv2
@app.task(name="process_id")
def process_id(rq_id, sub_id, folder_name, list_url, user_id):
from common.serve_model import predict
from celery_worker.client_connector import CeleryConnector
c_connector = CeleryConnector()
try:
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="id_card")
print(result)
result = {
"status": 200,
"content": result,
"message": "Success",
}
c_connector.process_id_result((rq_id, result))
return {"rq_id": rq_id}
# if image_croped is not None:
# if result["data"] == []:
# result = {
# "status": 404,
# "content": {},
# }
# c_connector.process_id_result((rq_id, result, None))
# return {"rq_id": rq_id}
# else:
# result = {
# "status": 200,
# "content": result,
# "message": "Success",
# }
# c_connector.process_id_result((rq_id, result))
# return {"rq_id": rq_id}
# elif image_croped is None:
# result = {
# "status": 404,
# "content": {},
# }
# c_connector.process_id_result((rq_id, result, None))
# return {"rq_id": rq_id}
except Exception as e:
print(e)
result = {
"status": 404,
"content": {},
}
c_connector.process_id_result((rq_id, result, None))
return {"rq_id": rq_id}
@app.task(name="process_driver_license")
def process_driver_license(rq_id, sub_id, folder_name, list_url, user_id):
from common.serve_model import predict
from celery_worker.client_connector import CeleryConnector
c_connector = CeleryConnector()
try:
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="driving_license")
result = {
"status": 200,
"content": result,
"message": "Success",
}
c_connector.process_driver_license_result((rq_id, result))
return {"rq_id": rq_id}
# result, image_croped = predict(str(url), "driving_license")
# if image_croped is not None:
# if result["data"] == []:
# result = {
# "status": 404,
# "content": {},
# }
# c_connector.process_driver_license_result((rq_id, result, None))
# return {"rq_id": rq_id}
# else:
# result = {
# "status": 200,
# "content": result,
# "message": "Success",
# }
# path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
# cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_croped)
# c_connector.process_driver_license_result((rq_id, result, path_image_croped))
# return {"rq_id": rq_id}
# elif image_croped is None:
# result = {
# "status": 404,
# "content": {},
# }
# c_connector.process_driver_license_result((rq_id, result, None))
# return {"rq_id": rq_id}
except Exception as e:
print(e)
result = {
"status": 404,
"content": {},
}
c_connector.process_driver_license_result((rq_id, result, None))
return {"rq_id": rq_id}
@app.task(name="process_template_matching")
def process_template_matching(rq_id, sub_id, folder_name, url, tmp_json, user_id):
from TemplateMatching.src.ocr_master import Extractor
from celery_worker.client_connector import CeleryConnector
import urllib
c_connector = CeleryConnector()
extractor = Extractor()
try:
req = urllib.request.urlopen(url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
imgs = [img]
image_aliged = extractor.image_alige(imgs, tmp_json)
if image_aliged is None:
result = {
"status": 401,
"content": "Image is not match with template",
}
c_connector.process_template_matching_result(
(rq_id, result, None)
)
return {"rq_id": rq_id}
else:
output = extractor.extract_information(
image_aliged, tmp_json
)
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_aliged)
if output == {}:
result = {"status": 404, "content": {}}
c_connector.process_template_matching_result((rq_id, result, None))
return {"rq_id": rq_id}
else:
result = {
"document_type": "template_matching",
"fields": []
}
print(output)
for field in tmp_json["fields"]:
print(field["label"])
field_value = {
"label": field["label"],
"value": output[field["label"]],
"box": [float(num) for num in field["box"]],
"confidence": 0.98 #TODO confidence
}
result["fields"].append(field_value)
print(result)
result = {"status": 200, "content": result}
c_connector.process_template_matching_result(
(rq_id, result, path_image_croped)
)
return {"rq_id": rq_id}
except Exception as e:
print(e)
result = {"status": 404, "content": {}}
c_connector.process_template_matching_result((rq_id, result, None))
return {"rq_id": rq_id}
# @app.task(name="process_invoice")
# def process_invoice(rq_id, url):
# from celery_worker.client_connector import CeleryConnector
# from Kie_Hoanglv.prediction import predict
# c_connector = CeleryConnector()
# try:
# print(url)
# result = predict(str(url))
# hoadon = {"status": 200, "content": result, "message": "Success"}
# c_connector.process_invoice_result((rq_id, hoadon))
# return {"rq_id": rq_id}
# except Exception as e:
# print(e)
# hoadon = {"status": 404, "content": {}}
# c_connector.process_invoice_result((rq_id, hoadon))
# return {"rq_id": rq_id}
@app.task(name="process_invoice")
def process_invoice(rq_id, list_url):
from celery_worker.client_connector import CeleryConnector
from common.process_pdf import compile_output
c_connector = CeleryConnector()
try:
result = compile_output(list_url)
hoadon = {"status": 200, "content": result, "message": "Success"}
c_connector.process_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
except Exception as e:
print(e)
hoadon = {"status": 404, "content": {}}
c_connector.process_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
@app.task(name="process_ocr_with_box")
def process_ocr_with_box(rq_id, list_url):
from celery_worker.client_connector import CeleryConnector
from common.process_pdf import compile_output_ocr_base
c_connector = CeleryConnector()
try:
result = compile_output_ocr_base(list_url)
result = {"status": 200, "content": result, "message": "Success"}
c_connector.process_ocr_with_box_result((rq_id, result))
return {"rq_id": rq_id}
except Exception as e:
print(e)
result = {"status": 404, "content": {}}
c_connector.process_ocr_with_box_result((rq_id, result))
return {"rq_id": rq_id}

View File

@ -0,0 +1,74 @@
from celery_worker.worker_fi import app
@app.task(name="process_fi_invoice")
def process_invoice(rq_id, list_url):
from celery_worker.client_connector_fi import CeleryConnector
from common.process_pdf import compile_output_fi
c_connector = CeleryConnector()
try:
result = compile_output_fi(list_url)
hoadon = {"status": 200, "content": result, "message": "Success"}
print(hoadon)
c_connector.process_fi_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
except Exception as e:
print(e)
hoadon = {"status": 404, "content": {}}
c_connector.process_fi_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
@app.task(name="process_sap_invoice")
def process_sap_invoice(rq_id, list_url):
from celery_worker.client_connector_fi import CeleryConnector
from common.process_pdf import compile_output
print(list_url)
c_connector = CeleryConnector()
try:
result = compile_output(list_url)
hoadon = {"status": 200, "content": result, "message": "Success"}
c_connector.process_sap_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
except Exception as e:
print(e)
hoadon = {"status": 404, "content": {}}
c_connector.process_sap_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
@app.task(name="process_manulife_invoice")
def process_manulife_invoice(rq_id, list_url):
from celery_worker.client_connector_fi import CeleryConnector
from common.process_pdf import compile_output_manulife
# TODO: simply returning 200 and 404 doesn't make any sense
c_connector = CeleryConnector()
try:
result = compile_output_manulife(list_url)
hoadon = {"status": 200, "content": result, "message": "Success"}
print(hoadon)
c_connector.process_manulife_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
except Exception as e:
print(e)
hoadon = {"status": 404, "content": {}}
c_connector.process_manulife_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
@app.task(name="process_sbt_invoice")
def process_sbt_invoice(rq_id, list_url):
from celery_worker.client_connector_fi import CeleryConnector
from common.process_pdf import compile_output_sbt
# TODO: simply returning 200 and 404 doesn't make any sense
c_connector = CeleryConnector()
try:
result = compile_output_sbt(list_url)
hoadon = {"status": 200, "content": result, "message": "Success"}
print(hoadon)
c_connector.process_sbt_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}
except Exception as e:
print(e)
hoadon = {"status": 404, "content": {}}
c_connector.process_sbt_invoice_result((rq_id, hoadon))
return {"rq_id": rq_id}

View File

@ -0,0 +1,40 @@
from celery import Celery
from kombu import Queue, Exchange
import environ
env = environ.Env(
DEBUG=(bool, True)
)
app: Celery = Celery(
"postman",
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
# backend="rpc://",
include=[
"celery_worker.mock_process_tasks",
],
)
task_exchange = Exchange("default", type="direct")
task_create_missing_queues = False
app.conf.update(
{
"result_expires": 3600,
"task_queues": [
Queue("id_card"),
Queue("driver_license"),
Queue("invoice"),
Queue("ocr_with_box"),
Queue("template_matching"),
],
"task_routes": {
"process_id": {"queue": "id_card"},
"process_driver_license": {"queue": "driver_license"},
"process_invoice": {"queue": "invoice"},
"process_ocr_with_box": {"queue": "ocr_with_box"},
"process_template_matching": {"queue": "template_matching"},
},
}
)
if __name__ == "__main__":
argv = ["celery_worker.worker", "--loglevel=INFO", "--pool=solo"] # Window opts
app.worker_main(argv)

View File

@ -0,0 +1,37 @@
from celery import Celery
from kombu import Queue, Exchange
import environ
env = environ.Env(
DEBUG=(bool, True)
)
app: Celery = Celery(
"postman",
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
include=[
"celery_worker.mock_process_tasks_fi",
],
)
task_exchange = Exchange("default", type="direct")
task_create_missing_queues = False
app.conf.update(
{
"result_expires": 3600,
"task_queues": [
Queue("invoice_fi"),
Queue("invoice_sap"),
Queue("invoice_manulife"),
Queue("invoice_sbt"),
],
"task_routes": {
'process_fi_invoice': {'queue': "invoice_fi"},
'process_fi_invoice_result': {'queue': 'invoice_fi_rs'},
'process_sap_invoice': {'queue': "invoice_sap"},
'process_sap_invoice_result': {'queue': 'invoice_sap_rs'},
'process_manulife_invoice': {'queue': 'invoice_manulife'},
'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'},
'process_sbt_invoice': {'queue': 'invoice_sbt'},
'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'},
},
}
)

View File

@ -0,0 +1,101 @@
import os
import glob
import cv2
import argparse
from tqdm import tqdm
from datetime import datetime
# from omegaconf import OmegaConf
import sys
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/common/AnyKey_Value') # TODO: ????
from predictor import KVUPredictor
from preprocess import KVUProcess, DocumentKVUProcess
from utils.utils import create_dir, visualize, get_colormap, export_kvu_for_VAT_invoice, export_kvu_outputs
def get_args():
args = argparse.ArgumentParser(description='Main file')
args.add_argument('--img_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
help='Input image directory')
args.add_argument('--save_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
help='Save directory')
args.add_argument('--exp_dir', default='/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
help='Checkpoint and config of model')
args.add_argument('--export_img', default=0, type=int,
help='Save visualize on image')
args.add_argument('--mode', default=3, type=int,
help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
args.add_argument('--dir_level', default=0, type=int,
help='Number of subfolders contains image')
return args.parse_args()
def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
configs = {
'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
}
dummy_idx = 512
predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
# processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
# predictor.net.feature_extractor, predictor.backbone_type, class_names,
# predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode)
processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor,
predictor.backbone_type, class_names,
predictor.max_window_count, predictor.slice_interval, predictor.window_size,
run_ocr=1, mode=mode)
return predictor, processor
def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
fname = os.path.basename(img_path)
img_ext = os.path.splitext(img_path)[1]
output_ext = ".json"
inputs = processor(img_path, ocr_path='')
bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
slide_window = False if len(bbox) == 1 else True
if len(bbox) == 0:
vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords, pr_class_words, pr_relations, predictor.class_names)
else:
for i in range(len(bbox)):
if not slide_window:
save_path = os.path.join(save_dir, 'kvu_results')
create_dir(save_path)
# export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
return vat_outputs
def Predictor_KVU(img: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
# req = urllib.request.urlopen(image_url)
# arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
# img = cv2.imdecode(arr, -1)
curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime)
cv2.imwrite(image_path, img)
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
return vat_outputs
if __name__ == "__main__":
args = get_args()
class_names = ['others', 'title', 'key', 'value', 'header']
predict_mode = {
'normal': 0,
'full_tokens': 1,
'sliding': 2,
'document': 3
}
predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
create_dir(args.save_dir)
image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
print('[INFO] Done')

View File

@ -0,0 +1,133 @@
import time
import torch
import torch.utils.data
from overrides import overrides
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import LambdaLR
from lightning_modules.schedulers import (
cosine_scheduler,
linear_scheduler,
multistep_scheduler,
)
from model import get_model
from utils import cfg_to_hparams, get_specific_pl_logger
class ClassifierModule(LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.net = get_model(self.cfg)
self.ignore_index = -100
self.time_tracker = None
self.optimizer_types = {
"sgd": SGD,
"adam": Adam,
"adamw": AdamW,
}
@overrides
def setup(self, stage):
self.time_tracker = time.time()
@overrides
def configure_optimizers(self):
optimizer = self._get_optimizer()
scheduler = self._get_lr_scheduler(optimizer)
scheduler = {
"scheduler": scheduler,
"name": "learning_rate",
"interval": "step",
}
return [optimizer], [scheduler]
def _get_lr_scheduler(self, optimizer):
cfg_train = self.cfg.train
lr_schedule_method = cfg_train.optimizer.lr_schedule.method
lr_schedule_params = cfg_train.optimizer.lr_schedule.params
if lr_schedule_method is None:
scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
elif lr_schedule_method == "step":
scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
elif lr_schedule_method == "cosine":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = cosine_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
elif lr_schedule_method == "linear":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = linear_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
else:
raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")
return scheduler
def _get_optimizer(self):
opt_cfg = self.cfg.train.optimizer
method = opt_cfg.method.lower()
if method not in self.optimizer_types:
raise ValueError(f"Unknown optimizer method={method}")
kwargs = dict(opt_cfg.params)
kwargs["params"] = self.net.parameters()
optimizer = self.optimizer_types[method](**kwargs)
return optimizer
@rank_zero_only
@overrides
def on_fit_end(self):
hparam_dict = cfg_to_hparams(self.cfg, {})
metric_dict = {"metric/dummy": 0}
tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger)
if tb_logger:
tb_logger.log_hyperparams(hparam_dict, metric_dict)
@overrides
def training_epoch_end(self, training_step_outputs):
avg_loss = torch.tensor(0.0).to(self.device)
for step_out in training_step_outputs:
avg_loss += step_out["loss"]
log_dict = {"train_loss": avg_loss}
self._log_shell(log_dict, prefix="train ")
def _log_shell(self, log_info, prefix=""):
log_info_shell = {}
for k, v in log_info.items():
new_v = v
if type(new_v) is torch.Tensor:
new_v = new_v.item()
log_info_shell[k] = new_v
out_str = prefix.upper()
if prefix.upper().strip() in ["TRAIN", "VAL"]:
out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]"
if self.training:
lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"]
log_info_shell["lr"] = lr
for key, value in log_info_shell.items():
out_str += f" || {key}: {round(value, 5)}"
out_str += f" || time: {round(time.time() - self.time_tracker, 1)}"
out_str += " secs."
# self.print(out_str)
self.time_tracker = time.time()

View File

@ -0,0 +1,390 @@
import numpy as np
import torch
import torch.utils.data
from overrides import overrides
from lightning_modules.classifier import ClassifierModule
from utils import get_class_names
class KVUClassifierModule(ClassifierModule):
def __init__(self, cfg):
super().__init__(cfg)
class_names = get_class_names(self.cfg.dataset_root_path)
self.window_size = cfg.train.max_num_words
self.slice_interval = cfg.train.slice_interval
self.eval_kwargs = {
"class_names": class_names,
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
}
self.stage = cfg.stage
@overrides
def training_step(self, batch, batch_idx, *args):
if self.stage == 1:
_, loss = self.net(batch['windows'])
elif self.stage == 2:
_, loss = self.net(batch)
else:
raise ValueError(
f"Not supported stage: {self.stage}"
)
log_dict_input = {"train_loss": loss}
self.log_dict(log_dict_input, sync_dist=True)
return loss
@torch.no_grad()
@overrides
def validation_step(self, batch, batch_idx, *args):
if self.stage == 1:
step_out_total = {
"loss": 0,
"ee":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el_from_key":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
}}
for window in batch['windows']:
head_outputs, loss = self.net(window)
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
for key in step_out_total:
if key == 'loss':
step_out_total[key] += step_out[key]
else:
for subkey in step_out_total[key]:
step_out_total[key][subkey] += step_out[key][subkey]
return step_out_total
elif self.stage == 2:
head_outputs, loss = self.net(batch)
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
return step_out
@torch.no_grad()
@overrides
def validation_epoch_end(self, validation_step_outputs):
scores = do_eval_epoch_end(validation_step_outputs)
self.print(
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
)
self.print(
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
)
self.print(
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
)
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
'val_f1_el_from_key': scores['el_from_key']['f1'],}
return {'log': tensorboard_logs}
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
class_names = eval_kwargs["class_names"]
dummy_idx = eval_kwargs["dummy_idx"]
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_labels = torch.argmax(itc_outputs, -1)
pr_stc_labels = torch.argmax(stc_outputs, -1)
pr_el_labels = torch.argmax(el_outputs, -1)
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
(
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
) = eval_ee_spade_batch(
pr_itc_labels,
batch["itc_labels"],
batch["are_box_first_tokens"],
pr_stc_labels,
batch["stc_labels"],
batch["attention_mask_layoutxlm"],
class_names,
dummy_idx,
)
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
pr_el_labels,
batch["el_labels"],
batch["are_box_first_tokens"],
dummy_idx,
)
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
pr_el_labels_from_key,
batch["el_labels_from_key"],
batch["are_box_first_tokens"],
dummy_idx,
)
step_out = {
"loss": loss,
"ee":{
"n_batch_gt": n_batch_gt_classes,
"n_batch_pr": n_batch_pr_classes,
"n_batch_correct": n_batch_correct_classes,
},
"el":{
"n_batch_gt": n_batch_gt_rel,
"n_batch_pr": n_batch_pr_rel,
"n_batch_correct": n_batch_correct_rel,
},
"el_from_key":{
"n_batch_gt": n_batch_gt_rel_from_key,
"n_batch_pr": n_batch_pr_rel_from_key,
"n_batch_correct": n_batch_correct_rel_from_key,
}
}
return step_out
def eval_ee_spade_batch(
pr_itc_labels,
gt_itc_labels,
are_box_first_tokens,
pr_stc_labels,
gt_stc_labels,
attention_mask,
class_names,
dummy_idx,
):
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
bsz = pr_itc_labels.shape[0]
for example_idx in range(bsz):
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
pr_itc_labels[example_idx],
gt_itc_labels[example_idx],
are_box_first_tokens[example_idx],
pr_stc_labels[example_idx],
gt_stc_labels[example_idx],
attention_mask[example_idx],
class_names,
dummy_idx,
)
n_batch_gt_classes += n_gt_classes
n_batch_pr_classes += n_pr_classes
n_batch_correct_classes += n_correct_classes
return (
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
)
def eval_ee_spade_example(
pr_itc_label,
gt_itc_label,
box_first_token_mask,
pr_stc_label,
gt_stc_label,
attention_mask,
class_names,
dummy_idx,
):
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, dummy_idx
)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
for class_idx in range(len(class_names)):
# Evaluate by ID
gt_parse = set(gt_class_words[class_idx])
pr_parse = set(pr_class_words[class_idx])
n_gt_classes += len(gt_parse)
n_pr_classes += len(pr_parse)
n_correct_classes += len(gt_parse & pr_parse)
return n_gt_classes, n_pr_classes, n_correct_classes
def parse_initial_words(itc_label, box_first_token_mask, class_names):
itc_label_np = itc_label.cpu().numpy()
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
outputs = [[] for _ in range(len(class_names))]
for token_idx, label in enumerate(itc_label_np):
if box_first_token_mask_np[token_idx] and label != 0:
outputs[label].append(token_idx)
return outputs
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
max_connections = 50
valid_stc_label = stc_label * attention_mask.bool()
valid_stc_label = valid_stc_label.cpu().numpy()
stc_label_np = stc_label.cpu().numpy()
valid_token_indices = np.where(
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
)
next_token_idx_dict = {}
for token_idx in valid_token_indices[0]:
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
outputs = []
for init_token_indices in init_words:
sub_outputs = []
for init_token_idx in init_token_indices:
cur_token_indices = [init_token_idx]
for _ in range(max_connections):
if cur_token_indices[-1] in next_token_idx_dict:
if (
next_token_idx_dict[cur_token_indices[-1]]
not in init_token_indices
):
cur_token_indices.append(
next_token_idx_dict[cur_token_indices[-1]]
)
else:
break
else:
break
sub_outputs.append(tuple(cur_token_indices))
outputs.append(sub_outputs)
return outputs
def eval_el_spade_batch(
pr_el_labels,
gt_el_labels,
are_box_first_tokens,
dummy_idx,
):
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
bsz = pr_el_labels.shape[0]
for example_idx in range(bsz):
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
pr_el_labels[example_idx],
gt_el_labels[example_idx],
are_box_first_tokens[example_idx],
dummy_idx,
)
n_batch_gt_rel += n_gt_rel
n_batch_pr_rel += n_pr_rel
n_batch_correct_rel += n_correct_rel
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
gt_relations = set(gt_relations)
pr_relations = set(pr_relations)
n_gt_rel = len(gt_relations)
n_pr_rel = len(pr_relations)
n_correct_rel = len(gt_relations & pr_relations)
return n_gt_rel, n_pr_rel, n_correct_rel
def parse_relations(el_label, box_first_token_mask, dummy_idx):
valid_el_labels = el_label * box_first_token_mask
valid_el_labels = valid_el_labels.cpu().numpy()
el_label_np = el_label.cpu().numpy()
max_token = box_first_token_mask.shape[0] - 1
valid_token_indices = np.where(
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
)
link_map_tuples = []
for token_idx in valid_token_indices[0]:
link_map_tuples.append((el_label_np[token_idx], token_idx))
return set(link_map_tuples)
def do_eval_epoch_end(step_outputs):
scores = {}
for task in ['ee', 'el', 'el_from_key']:
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
for step_out in step_outputs:
n_total_gt_classes += step_out[task]["n_batch_gt"]
n_total_pr_classes += step_out[task]["n_batch_pr"]
n_total_correct_classes += step_out[task]["n_batch_correct"]
precision = (
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
)
recall = (
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
)
f1 = (
0.0
if recall * precision == 0
else 2.0 * recall * precision / (recall + precision)
)
scores[task] = {
"precision": precision,
"recall": recall,
"f1": f1,
}
return scores
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
class_names = get_class_names(dataset_root_path)
dummy_idx = max_seq_length
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
return eval_kwargs
def get_eval_kwargs_spade_rel(max_seq_length):
dummy_idx = max_seq_length
eval_kwargs = {"dummy_idx": dummy_idx}
return eval_kwargs

View File

@ -0,0 +1,53 @@
"""
BROS
Copyright 2022-present NAVER Corp.
Apache License v2.0
"""
import math
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1):
"""linear_scheduler with warmup from huggingface"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0,
float(training_steps - current_step)
/ float(max(1, training_steps - warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def cosine_scheduler(
optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1
):
"""Cosine LR scheduler with warmup from huggingface"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return current_step / max(1, warmup_steps)
progress = current_step - warmup_steps
progress /= max(1, training_steps - warmup_steps)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1):
def lr_lambda(current_step):
if current_step < warmup_steps:
# calculate a warmup ratio
return current_step / max(1, warmup_steps)
else:
# calculate a multistep lr scaling ratio
idx = np.searchsorted(milestones, current_step)
return gamma ** idx
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@ -0,0 +1,161 @@
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import math
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
element_windows = []
if len(elements) > window_size:
max_step = math.ceil((len(elements) - window_size)/slice_interval)
for i in range(0, max_step + 1):
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
if (i*slice_interval+window_size) >= len(elements):
_window = copy.deepcopy(elements[i*slice_interval:])
else:
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
element_windows.append(_window)
return element_windows
else:
return [elements]
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
word_windows = []
parse_class_windows = []
parse_relation_windows = []
if len(lwords) > window_size:
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
for i in range(0, max_step+1):
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
if (i*slice_interval+window_size) >= len(lwords):
_word_window = copy.deepcopy(lwords[i*slice_interval:])
else:
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
if len(_word_window) < 2:
continue
first_word_id = _word_window[0]['word_id']
last_word_id = _word_window[-1]['word_id']
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
# word list
for _word in _word_window:
_word['word_id'] -= first_word_id
# Entity extraction
_class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
# Entity Linking
_relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
word_windows.append(_word_window)
parse_class_windows.append(_class_window)
parse_relation_windows.append(_relation_window)
return word_windows, parse_class_windows, parse_relation_windows
else:
return [lwords], [parse_class], [parse_relation]
def entity_extraction_by_words(parse_class, first_word_id, last_word_id):
_class_window = {k: [] for k in list(parse_class.keys())}
for class_name, _parse_class in parse_class.items():
for group in _parse_class:
tmp = []
for idw in group:
idw -= first_word_id
if 0 <= idw <= (last_word_id - first_word_id):
tmp.append(idw)
_class_window[class_name].append(tmp)
return _class_window
def entity_linking_by_words(parse_relation, first_word_id, last_word_id):
_relation_window = []
for pair in parse_relation:
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
_relation_window.append([idw - first_word_id for idw in pair])
return _relation_window
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = copy.deepcopy(loverlaps[i-1])
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
if overlap_gap != 0:
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
cls_token = lpatches[0][:, :1, ...]
sep_token = lpatches[0][:, -1:, ...]
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = loverlaps[i-1]
window = lpatches[i][:, start_pos:end_pos, ...]
if overlap_gap != 0:
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
curr_overlap = window[:, :overlap_gap, ...]
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)

View File

@ -0,0 +1,15 @@
from model.combined_model import CombinedKVUModel
from model.kvu_model import KVUModel
from model.document_kvu_model import DocumentKVUModel
def get_model(cfg):
if cfg.stage == 1:
model = CombinedKVUModel(cfg=cfg)
elif cfg.stage == 2:
model = KVUModel(cfg=cfg)
elif cfg.stage == 3:
model = DocumentKVUModel(cfg=cfg)
else:
raise Exception('[ERROR] Trainging stage is wrong')
return model

View File

@ -0,0 +1,76 @@
import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from transformers import AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from model.kvu_model import KVUModel
from utils import load_checkpoint
class CombinedKVUModel(KVUModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if os.path.exists(self.model_cfg.ckpt_model_file):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
if self.finetune_only == 'EE':
for name, param in self.named_parameters():
if 'itc_layer' not in name and 'stc_layer' not in name:
param.requires_grad = False
if self.finetune_only == 'EL':
for name, param in self.named_parameters():
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
param.requires_grad = False
if self.finetune_only == 'ELK':
for name, param in self.named_parameters():
if 'relation_layer_from_key' not in name:
param.requires_grad = False
def forward(self, batch):
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
loss = 0.0
if any(['labels' in key for key in batch.keys()]):
loss = self._get_loss(head_outputs, batch)
return head_outputs, loss

View File

@ -0,0 +1,185 @@
import torch
from torch import nn
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2Config, LayoutLMv2Model
from transformers import LayoutXLMTokenizer
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
from model.relation_extractor import RelationExtractor
from model.kvu_model import KVUModel
from utils import load_checkpoint
class DocumentKVUModel(KVUModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.train_cfg = cfg.train
self._get_backbones(self.model_cfg.backbone)
# if 'pth' in self.model_cfg.ckpt_model_file:
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone)
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
def _create_head(self):
self.backbone_hidden_size = self.backbone_config.hidden_size
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
self.repr_hiddent_size = self.backbone_hidden_size
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# Classfication Layer for whole document
# (1) Initial token classification
self.itc_layer_document = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.repr_hiddent_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key_document = RelationExtractor(
n_relations=1,
backbone_hidden_size=self.repr_hiddent_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
self.relation_layer_from_key.apply(self._init_weight)
self.itc_layer_document.apply(self._init_weight)
self.stc_layer_document.apply(self._init_weight)
self.relation_layer_document.apply(self._init_weight)
self.relation_layer_from_key_document.apply(self._init_weight)
def _get_backbones(self, config_type):
configs = {
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
}
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
if config_type != 'xlm-roberta':
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
else:
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
def forward(self, batches):
head_outputs_list = []
loss = 0.0
for batch in batches["windows"]:
image = batch["image"]
input_ids = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask = batch["attention_mask_layoutxlm"]
if self.freeze:
for param in self.backbone.parameters():
param.requires_grad = False
if self.model_cfg.backbone == 'layoutxlm':
backbone_outputs = self.backbone(
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
)
else:
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
window_repr = last_hidden_states.transpose(0, 1).contiguous()
head_outputs = {"window_repr": window_repr,
"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
head_outputs_list.append(head_outputs)
batch = batches["documents"]
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
document_repr = document_repr.transpose(0, 1).contiguous()
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs,
"stc_outputs": stc_outputs,
"el_outputs": el_outputs,
"el_outputs_from_key": el_outputs_from_key}
if any(['labels' in key for key in batch.keys()]):
loss += self._get_loss(head_outputs, batch)
return head_outputs, loss

View File

@ -0,0 +1,248 @@
import os
import torch
from torch import nn
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
from transformers import LayoutXLMTokenizer
from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2
from model.relation_extractor import RelationExtractor
from utils import load_checkpoint
class KVUModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.device = 'cuda'
self.model_cfg = cfg.model
self.freeze = cfg.train.freeze
self.finetune_only = cfg.train.finetune_only
# if cfg.stage == 2:
# self.freeze = True
self._get_backbones(self.model_cfg.backbone)
self._create_head()
if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
self._create_head()
self.loss_func = nn.CrossEntropyLoss()
if self.freeze:
for name, param in self.named_parameters():
if 'backbone' in name:
param.requires_grad = False
def _create_head(self):
self.backbone_hidden_size = 768
self.head_hidden_size = self.model_cfg.head_hidden_size
self.head_p_dropout = self.model_cfg.head_p_dropout
self.n_classes = self.model_cfg.n_classes + 1
# (1) Initial token classification
self.itc_layer = nn.Sequential(
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
nn.Dropout(self.head_p_dropout),
nn.Linear(self.backbone_hidden_size, self.n_classes),
)
# (2) Subsequent token classification
self.stc_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (3) Linking token classification
self.relation_layer = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
# (4) Linking token classification
self.relation_layer_from_key = RelationExtractor(
n_relations=1, #1
backbone_hidden_size=self.backbone_hidden_size,
head_hidden_size=self.head_hidden_size,
head_p_dropout=self.head_p_dropout,
)
self.itc_layer.apply(self._init_weight)
self.stc_layer.apply(self._init_weight)
self.relation_layer.apply(self._init_weight)
def _get_backbones(self, config_type):
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
@staticmethod
def _init_weight(module):
init_std = 0.02
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight, 1.0, init_std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
# def forward(self, inputs):
# token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda()
# itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
# stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
# el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
# head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
# "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
# loss = self._get_loss(head_outputs, inputs)
# return head_outputs, loss
# def forward_single_doccument(self, lbatches):
def forward(self, lbatches):
windows = lbatches['windows']
token_embeddings_windows = []
lvalids = []
loverlaps = []
for i, batch in enumerate(windows):
batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')}
image = batch["image"]
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
bbox = batch["bbox"]
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
lvalids.append(batch['len_valid_tokens'])
loverlaps.append(batch['len_overlap_tokens'])
token_embeddings_windows.append(last_hidden_states_layoutxlm)
token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False)
# token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True)
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key,
'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()}
loss = 0.0
if any(['labels' in key for key in lbatches.keys()]):
labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')}
loss = self._get_loss(head_outputs, labels)
return head_outputs, loss
def _get_loss(self, head_outputs, batch):
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
itc_loss = self._get_itc_loss(itc_outputs, batch)
stc_loss = self._get_stc_loss(stc_outputs, batch)
el_loss = self._get_el_loss(el_outputs, batch)
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
return loss
def _get_itc_loss(self, itc_outputs, batch):
itc_mask = batch["are_box_first_tokens"].view(-1)
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
itc_logits = itc_logits[itc_mask]
itc_labels = batch["itc_labels"].view(-1)
itc_labels = itc_labels[itc_mask]
itc_loss = self.loss_func(itc_logits, itc_labels)
return itc_loss
def _get_stc_loss(self, stc_outputs, batch):
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
bsz, max_seq_length = inv_attention_mask.shape
device = inv_attention_mask.device
invalid_token_mask = torch.cat(
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
).bool()
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
stc_logits = stc_logits[stc_mask]
stc_labels = batch["stc_labels"].view(-1)
stc_labels = stc_labels[stc_mask]
stc_loss = self.loss_func(stc_logits, stc_labels)
return stc_loss
def _get_el_loss(self, el_outputs, batch, from_key=False):
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
device = batch["attention_mask_layoutxlm"].device
self_token_mask = (
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
)
box_first_token_mask = torch.cat(
[
(batch["are_box_first_tokens"] == False),
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
],
axis=1,
)
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
mask = batch["are_box_first_tokens"].view(-1)
logits = el_outputs.view(-1, max_seq_length + 1)
logits = logits[mask]
if from_key:
el_labels = batch["el_labels_from_key"]
else:
el_labels = batch["el_labels"]
labels = el_labels.view(-1)
labels = labels[mask]
loss = self.loss_func(logits, labels)
return loss

View File

@ -0,0 +1,48 @@
import torch
from torch import nn
class RelationExtractor(nn.Module):
def __init__(
self,
n_relations,
backbone_hidden_size,
head_hidden_size,
head_p_dropout=0.1,
):
super().__init__()
self.n_relations = n_relations
self.backbone_hidden_size = backbone_hidden_size
self.head_hidden_size = head_hidden_size
self.head_p_dropout = head_p_dropout
self.drop = nn.Dropout(head_p_dropout)
self.q_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.k_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
nn.init.normal_(self.dummy_node)
def forward(self, h_q, h_k):
h_q = self.q_net(self.drop(h_q))
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
h_k = torch.cat([h_k, dummy_vec], axis=0)
h_k = self.k_net(self.drop(h_k))
head_q = h_q.view(
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
)
head_k = h_k.view(
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
)
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
return relation_score

View File

@ -0,0 +1,228 @@
from omegaconf import OmegaConf
import torch
# from functions import get_colormap, visualize
import sys
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
from model import get_model
from utils import load_model_weight
class KVUPredictor:
def __init__(self, configs, class_names, dummy_idx, mode=0):
cfg_path = configs['cfg']
ckpt_path = configs['ckpt']
self.class_names = class_names
self.dummy_idx = dummy_idx
self.mode = mode
print('[INFO] Loading Key-Value Understanding model ...')
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
print("[INFO] Loaded model")
if mode == 3:
self.max_window_count = cfg.train.max_window_count
self.window_size = cfg.train.window_size
self.slice_interval = 0
self.dummy_idx = dummy_idx * self.max_window_count
else:
self.slice_interval = cfg.train.slice_interval
self.window_size = cfg.train.max_num_words
self.device = 'cuda'
def _load_model(self, cfg_path, ckpt_path):
cfg = OmegaConf.load(cfg_path)
cfg.stage = self.mode
backbone_type = cfg.model.backbone
print('[INFO] Checkpoint:', ckpt_path)
net = get_model(cfg)
load_model_weight(net, ckpt_path)
net.to('cuda')
net.eval()
return net, cfg, backbone_type
def predict(self, input_sample):
if self.mode == 0:
if len(input_sample['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 1:
if len(input_sample['documents']['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
elif self.mode == 2:
if len(input_sample['windows'][0]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
for window in input_sample['windows']:
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
bbox.append(_bbox)
lwords.append(_lwords)
pr_class_words.append(_pr_class_words)
pr_relations.append(_pr_relations)
return bbox, lwords, pr_class_words, pr_relations
elif self.mode == 3:
if len(input_sample["documents"]['words']) == 0:
return [], [], [], []
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
return [bbox], [lwords], [pr_class_words], [pr_relations]
else:
raise ValueError(
f"Not supported mode: {self.mode}"
)
def doc_predict(self, input_sample):
lwords = input_sample['documents']['words']
for idx, window in enumerate(input_sample['windows']):
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = input_sample['documents']
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def combined_predict(self, input_sample):
lwords = input_sample['words']
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
bbox = input_sample['bbox'].squeeze(0)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def cat_predict(self, input_sample):
lwords = input_sample['documents']['words']
inputs = []
for window in input_sample['windows']:
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
input_sample['windows'] = inputs
with torch.no_grad():
head_outputs, _ = self.net(input_sample)
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
bbox = input_sample['documents']['bbox']
dummy_idx = input_sample['documents']['bbox'].shape[0]
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox, lwords, pr_class_words, pr_relations
def get_ground_truth_label(self, ground_truth):
# ground_truth = self.preprocessor.load_ground_truth(json_file)
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
gt_el_label = ground_truth['el_labels'].squeeze(0)
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
lwords = ground_truth["words"]
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
attention_mask = ground_truth['attention_mask'].squeeze(0)
bbox = ground_truth['bbox'].squeeze(0)
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, self.class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
)
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
gt_relations = gt_relations_from_header | gt_relations_from_key
return bbox, lwords, gt_class_words, gt_relations

View File

@ -0,0 +1,456 @@
import os
from typing import Any
import numpy as np
import pandas as pd
import imagesize
import itertools
from PIL import Image
import argparse
import torch
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
from utils.run_ocr import load_ocr_engine, process_img
from lightning_modules.utils import sliding_windows
class KVUProcess:
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
self.tokenizer_layoutxlm = tokenizer_layoutxlm
self.feature_extractor = feature_extractor
self.max_seq_length = max_seq_length
self.backbone_type = backbone_type
self.class_names = class_names
self.slice_interval = slice_interval
self.window_size = window_size
self.run_ocr = run_ocr
self.mode = mode
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
self.class_idx_dic = dict(
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
)
self.ocr_engine = None
if self.run_ocr == 1:
self.ocr_engine = load_ocr_engine()
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
if self.mode == 0:
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
elif self.mode == 1:
output = {}
windows = []
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
elif self.mode == 2:
output = {}
windows = []
output['doduments'] = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=2048)
for i in range(len(bbox_windows)):
_words = word_windows[i]
_bboxes = bbox_windows[i]
windows.append(
self.preprocess(
_bboxes, _words,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
max_seq_length=self.max_seq_length)
)
output['windows'] = windows
else:
raise ValueError(
f"Not supported mode: {self.mode }"
)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
list_word_objects = []
for bb, text in zip(bounding_boxes, words):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
assert len_list_tokens == len_valid_tokens + 2
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
input_ids = input_ids[:ntokens]
attention_mask = attention_mask[:ntokens]
bbox = bbox[:ntokens]
are_box_first_tokens = are_box_first_tokens[:ntokens]
input_ids = torch.from_numpy(input_ids)
attention_mask = torch.from_numpy(attention_mask)
bbox = torch.from_numpy(bbox)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
len_valid_tokens = torch.tensor(len_valid_tokens)
len_overlap_tokens = torch.tensor(len_overlap_tokens)
return_dict = {
"img_path": feature_maps['img_path'],
"words": lwords,
"len_overlap_tokens": len_overlap_tokens,
'len_valid_tokens': len_valid_tokens,
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
"bbox": bbox,
}
return return_dict
def parser_words(self, words, max_seq_length, width, height):
list_bbs = []
list_words = []
list_tokens = []
cls_bbs = [0.0] * 8
box2token_span_map = []
box_to_token_indices = []
lwords = [''] * max_seq_length
cum_token_idx = 0
len_valid_tokens = 0
len_non_overlap_tokens = 0
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
for word_idx, word in enumerate(words):
this_box_token_indices = []
tokens = word["layoutxlm_tokens"]
bb = word["boundingBox"]
text = word["text"]
len_valid_tokens += len(tokens)
if word_idx < self.slice_interval:
len_non_overlap_tokens += len(tokens)
if len(tokens) == 0:
tokens.append(self.unk_token_id)
if len(list_tokens) + len(tokens) > max_seq_length - 2:
break
box2token_span_map.append(
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
) # including st_idx
list_tokens += tokens
# min, max clipping
for coord_idx in range(4):
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
bb = list(itertools.chain(*bb))
bbs = [bb for _ in range(len(tokens))]
texts = [text for _ in range(len(tokens))]
for _ in tokens:
cum_token_idx += 1
this_box_token_indices.append(cum_token_idx)
list_bbs.extend(bbs)
list_words.extend(texts) ####
box_to_token_indices.append(this_box_token_indices)
sep_bbs = [width, height] * 4
# For [CLS] and [SEP]
list_tokens = (
[self.cls_token_id_layoutxlm]
+ list_tokens[: max_seq_length - 2]
+ [self.sep_token_id_layoutxlm]
)
if len(list_bbs) == 0:
# When len(json_obj["words"]) == 0 (no OCR result)
list_bbs = [cls_bbs] + [sep_bbs]
else: # len(list_bbs) > 0
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
# if len(list_words) < 510:
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
len_list_tokens = len(list_tokens)
input_ids[:len_list_tokens] = list_tokens
attention_mask[:len_list_tokens] = 1
bbox[:len_list_tokens, :] = list_bbs
lwords[:len_list_tokens] = list_words
# Normalize bbox -> 0 ~ 1
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
if self.backbone_type in ("layoutlm", "layoutxlm"):
bbox = bbox[:, [0, 1, 4, 5]]
bbox = bbox * 1000
bbox = bbox.astype(int)
else:
assert False
st_indices = [
indices[0]
for indices in box_to_token_indices
if indices[0] < max_seq_length
]
are_box_first_tokens[st_indices] = True
return (
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_tokens
)
def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length):
itc_labels = np.zeros(max_seq_length, dtype=int)
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
classes_dic = parse_class
for class_name in self.class_names:
if class_name == "others":
continue
if class_name not in classes_dic:
continue
for word_list in classes_dic[class_name]:
is_first, last_word_idx = True, -1
for word_idx in word_list:
if word_idx >= len(box_to_token_indices):
break
box2token_list = box_to_token_indices[word_idx]
for converted_word_idx in box2token_list:
if converted_word_idx >= max_seq_length:
break # out of idx
if is_first:
itc_labels[converted_word_idx] = self.class_idx_dic[
class_name
]
is_first, last_word_idx = False, converted_word_idx
else:
stc_labels[converted_word_idx] = last_word_idx
last_word_idx = converted_word_idx
return itc_labels, stc_labels
def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length):
el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length
el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length
relations = parse_relation
for relation in relations:
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
box2token_span_map
):
continue
if (
box2token_span_map[relation[0]][0] >= max_seq_length
or box2token_span_map[relation[1]][0] >= max_seq_length
):
continue
word_from = box2token_span_map[relation[0]][0]
word_to = box2token_span_map[relation[1]][0]
# el_labels[word_to] = word_from
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
continue
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
el_labels_from_key[word_to] = word_from # pair of (key-value)
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
return el_labels, el_labels_from_key
class DocumentKVUProcess(KVUProcess):
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
self.max_window_count = max_window_count
self.pad_token_id = self.pad_token_id_layoutxlm
self.cls_token_id = self.cls_token_id_layoutxlm
self.sep_token_id = self.sep_token_id_layoutxlm
self.unk_token_id = self.unk_token_id_layoutxlm
self.tokenizer = self.tokenizer_layoutxlm
def __call__(self, img_path: str, ocr_path: str) -> list:
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
ocr_path = "tmp.txt"
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
lwords = post_process_basic_ocr(lwords)
width, height = imagesize.get(img_path)
images = [Image.open(img_path).convert("RGB")]
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
output = self.preprocess(lbboxes, lwords,
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
self.max_seq_length)
return output
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
n_words = len(words)
output_dicts = {'windows': [], 'documents': []}
n_empty_windows = 0
for i in range(self.max_window_count):
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
attention_mask = np.zeros(self.max_seq_length, dtype=int)
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
if n_words == 0:
n_empty_windows += 1
output_dicts['windows'].append({
"image": feature_maps['image'],
"input_ids_layoutxlm": torch.from_numpy(input_ids),
"bbox": torch.from_numpy(bbox),
"words": [],
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
})
continue
start_word_idx = i * self.window_size
stop_word_idx = min(n_words, (i+1)*self.window_size)
if start_word_idx >= stop_word_idx:
n_empty_windows += 1
output_dicts['windows'].append(output_dicts['windows'][-1])
continue
list_word_objects = []
for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]):
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
list_word_objects.append({
"layoutxlm_tokens": tokens,
"boundingBox": boundingBox,
"text": text
})
(
bbox,
input_ids,
attention_mask,
are_box_first_tokens,
box_to_token_indices,
box2token_span_map,
lwords,
len_valid_tokens,
len_non_overlap_tokens,
len_list_layoutxlm_tokens
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
input_ids = torch.from_numpy(input_ids)
bbox = torch.from_numpy(bbox)
attention_mask = torch.from_numpy(attention_mask)
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
return_dict = {
"image": feature_maps['image'],
"input_ids_layoutxlm": input_ids,
"bbox": bbox,
"words": lwords,
"attention_mask_layoutxlm": attention_mask,
"are_box_first_tokens": are_box_first_tokens,
}
output_dicts["windows"].append(return_dict)
attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]])
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
if n_empty_windows > 0:
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
words = []
for o in output_dicts['windows']:
words.extend(o['words'])
return_dict = {
"attention_mask_layoutxlm": attention_mask,
"bbox": bbox,
"are_box_first_tokens": are_box_first_tokens,
"n_empty_windows": n_empty_windows,
"words": words
}
output_dicts['documents'] = return_dict
return output_dicts

View File

@ -0,0 +1,30 @@
nptyping==1.4.2
numpy==1.20.3
opencv-python-headless==4.5.4.60
pytorch-lightning==1.5.6
omegaconf
# pillow
six
overrides==4.1.2
# transformers==4.11.3
seqeval==0.0.12
imagesize
pandas==2.0.1
xmltodict
dicttoxml
tensorboard>=2.2.0
# code-style
isort==5.9.3
black==21.9b0
# # pytorch
# --find-links https://download.pytorch.org/whl/torch_stable.html
# torch==1.9.1+cu102
# torchvision==0.10.1+cu102
# pytorch
# --find-links https://download.pytorch.org/whl/torch_stable.html
# torch==1.10.0+cu113
# torchvision==0.11.1+cu113

View File

@ -0,0 +1 @@
python anyKeyValue.py --img_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0

View File

@ -0,0 +1,393 @@
550 14 644 53 HÓA
654 20 745 53 ĐƠN
755 13 833 53 GIÁ
842 20 916 60 TRỊ
925 22 1001 53 GIA
1012 14 1130 54 TĂNG
208 63 363 87 CHUNGHO
752 76 815 101 (VAT
821 76 940 101 INVOICE)
1193 80 1230 109 Ký
1235 80 1286 109 hiệu
1293 83 1390 108 (Serial):
1398 81 1519 105 1C23TCV
135 130 330 171 ChungHo
339 131 436 165 Vina
600 150 662 178 (BẢN
667 152 719 177 THỂ
678 119 743 147 Ngày
724 152 788 177 HIÊN
749 119 779 142 01
786 119 854 147 tháng
794 152 847 176 CỦA
853 151 908 176 HÓA
860 119 891 142 03
897 121 949 142 năm
915 156 973 176 ĐƠN
956 120 1014 142 2023
979 154 1043 179 ĐIỆN
1049 151 1093 181 TỬ)
1193 118 1228 148 Số
1234 125 1303 151 (No.):
1309 116 1372 148 829
187 179 270 197 HEALTH
276 179 383 197 SOLUTION
664 187 795 210 (EINVOICE
802 188 908 209 DISPLAY
915 188 1032 213 VERSION)
465 240 507 264 Mã
512 241 554 264 của
560 241 626 266 CQT:
655 239 1186 266 PURIBUZANAMOIC99C9CTROINS
90 286 141 308 Đơn
147 285 172 312 vị
177 285 221 308 bán
226 285 285 313 hàng
293 286 387 312 (Seller):
395 279 495 307 CÔNG
504 280 549 307 TY
558 280 657 307 TNHH
664 279 844 307 CHUNGHO
852 279 937 307 VINA
945 279 1091 309 HEALTH
1100 280 1274 309 SOLUTION
90 326 130 348 Mã
135 322 165 349 số
170 322 223 349 thuế
229 327 281 351 (Tax
288 327 358 351 code):
363 324 383 347 0
390 324 406 348 1
411 321 605 350 08118215
89 366 132 392 Địa
138 366 176 389 chỉ
182 368 298 393 (Address):
303 362 334 389 Số
340 367 373 392 11,
378 368 401 389 lô
407 368 477 392 N04A,
484 367 526 389 khu
530 367 560 389 đô
564 367 595 392 thị
599 367 643 390 mới
648 367 702 392 Dịch
707 368 773 393 Vọng,
779 367 863 394 phường
867 367 922 392 Dịch
927 368 993 394 Vọng,
999 368 1053 394 quận
1057 363 1103 390 Cầu
1108 364 1167 394 Giấy,
1173 368 1233 390 thành
1238 364 1281 394 phố
1285 367 1318 390 Hà
1323 368 1371 393 Nội,
1377 367 1424 393 Việt
1430 368 1483 390 Nam
89 406 146 431 Điện
152 407 211 432 thoại
218 407 279 431 (Tel):
286 407 331 428 024
339 406 397 428 7300
404 406 460 428 0891
826 406 878 430 Fax:
89 440 121 467 Số
127 444 158 467 tài
165 445 236 466 khoản
245 447 348 470 (Account
355 446 413 472 No.):
421 445 616 469 700-010-446490
622 445 653 471 tại
659 445 722 471 Ngân
728 445 792 472 Hàng
799 444 893 468 Shinhan
898 449 911 467 -
916 444 960 468 Chi
965 445 1036 468 nhánh
1043 445 1078 469 Hà
1084 444 1129 472 Nội
89 486 126 513 Họ
132 486 169 510 tên
174 486 244 513 người
251 493 302 509 mua
308 486 367 514 hàng
374 488 469 514 (Buyer):
90 529 136 550 Tên
142 529 188 551 đơn
193 529 219 555 vị
226 530 342 556 (Company
349 531 428 555 name):
435 524 518 552 CÔNG
525 529 562 552 TY
570 530 650 551 TNHH
657 529 796 552 SAMSUNG
803 529 856 552 SDS
862 527 930 556 VIỆT
937 529 1004 552 NAM
89 570 129 592 Mã
136 564 165 593 số
170 565 223 592 thuế
229 571 281 595 (Tax
287 571 358 596 code):
365 569 510 592 2300680991
88 611 133 638 Địa
138 612 175 634 chỉ
182 613 297 638 (Address):
303 611 339 633 Lô
346 612 424 636 CN05,
432 611 514 639 đường
521 612 583 638 YP6,
589 612 645 634 Khu
652 611 713 640 công
719 612 803 641 nghiệp
810 611 862 635 Yên
869 612 956 640 Phong,
962 612 1000 635 Xã
1006 612 1057 635 Yên
1064 612 1152 640 Trung,
1159 611 1242 640 Huyện
1249 611 1300 635 Yên
1307 612 1394 640 Phong,
1402 612 1463 635 Tỉnh
1470 605 1518 636 Bắc
89 654 158 678 Ninh,
165 654 219 679 Việt
225 654 286 676 Nam
89 681 122 705 Số
127 684 157 705 tài
164 682 236 705 khoản
244 685 348 709 (Account
354 684 413 709 No.):
90 724 149 747 Hình
157 724 208 747 thức
216 724 282 746 thanh
289 724 341 747 toán
349 726 457 749 (Payment
464 726 563 748 method):
572 725 663 747 TM/CK
98 789 148 812 STT
163 770 212 793 Tên
218 770 281 797 hàng
287 770 341 796 hóa,
348 769 405 796 dịch
436 770 491 792 Đơn
498 770 525 798 vị
569 783 603 811 Số
610 789 684 817 lượng
747 789 802 811 Đơn
808 788 850 818 giá
979 788 1063 812 Thành
1070 782 1119 812 tiền
1214 765 1282 793 Thuế
1287 764 1344 793 suất
1393 764 1452 792 Tiền
1459 765 1518 792 thuế
94 825 153 850 (No.)
206 844 356 870 (Description)
266 813 299 836 vụ
448 845 515 868 (Unit)
454 808 507 830 tính
568 825 685 852 (Quantity)
733 826 793 848 (Unit
798 826 865 851 price)
993 824 1103 851 (Amount)
1223 807 1287 830 (VAT
1262 843 1295 869 %)
1290 809 1335 829 rate
1378 843 1542 869 (VAT Amount)
1417 807 1491 830 GTGT
116 891 128 913 1
273 890 290 913 2
472 890 488 912 3
617 889 635 912 4
790 889 807 914 5
992 889 1103 916 6=4x5
1270 889 1287 913 7
1399 889 1510 914 8=6X7
158 939 200 961 Phí
207 939 259 961 thuê
266 939 316 966 máy
323 938 361 965 lọc
159 977 219 998 nước
225 977 285 1002 nóng
292 976 345 1001 lạnh
161 1014 236 1040 Digital
244 1014 307 1035 CHP-
114 1032 127 1055 1
159 1052 267 1073 3800ST1
276 1052 312 1076 (từ
318 1051 377 1078 ngày
453 1032 507 1060 Máy
678 1035 697 1059 4
795 1031 893 1057 800.000
1074 1031 1195 1057 3,200,000
1297 1031 1353 1057 10%
1452 1031 1551 1058 320,000
158 1089 292 1111 01/02/2023
301 1083 343 1110 đến
351 1084 388 1110 hết
159 1125 304 1151 28/02/2023)
159 1173 200 1195 Phí
207 1173 258 1195 thuê
265 1173 316 1200 máy
323 1173 361 1199 lọc
158 1213 218 1234 nước
225 1212 285 1238 nóng
292 1211 344 1237 lạnh
161 1249 236 1274 Digital
243 1248 307 1270 CHP-
112 1267 129 1291 2
160 1287 267 1309 3800ST1
276 1287 312 1312 (từ
318 1287 377 1313 ngày
454 1267 506 1293 Máy
664 1266 697 1292 20
793 1265 896 1294 876,800
1062 1265 1198 1297 17,536,000
1430 1266 1552 1295 1,753,600
159 1323 292 1346 29/01/2023
301 1319 343 1345 đến
351 1319 388 1345 hết
160 1360 304 1387 28/02/2023)
160 1409 200 1431 Phí
207 1409 258 1431 thuê
265 1409 316 1436 máy
323 1408 361 1435 lọc
158 1447 218 1468 nước
225 1446 285 1473 nóng
292 1446 344 1472 lạnh
113 1503 128 1527 3
160 1484 236 1511 Digital
243 1484 307 1505 CHP-
452 1503 506 1531 Máy
795 1502 897 1529 544,000
1074 1502 1197 1529 2,176,000
1450 1502 1550 1528 217,600
160 1522 267 1544 3800ST1
276 1522 312 1546 (từ
318 1522 377 1549 ngày
162 1559 292 1581 10/02/2023
301 1555 343 1581 đến
351 1555 388 1581 hết
159 1596 304 1623 28/02/2023)
160 1645 200 1667 Phí
207 1644 259 1667 thuê
265 1645 316 1672 máy
324 1645 362 1671 lọc
159 1683 219 1706 nước
226 1683 286 1709 nóng
293 1683 345 1708 lạnh
160 1720 237 1746 Digital
245 1720 307 1742 CHP-
112 1738 129 1760 4
159 1758 268 1780 3800ST1
276 1758 313 1783 (từ
319 1758 377 1785 ngày
453 1737 508 1767 Máy
677 1737 699 1764 4
795 1737 895 1764 256,000
1077 1737 1197 1764 1,024,000
1297 1737 1354 1762 10%
1453 1737 1552 1764 102,400
158 1795 293 1817 20/02/2023
301 1791 344 1817 đến
350 1791 388 1817 hết
158 1831 304 1859 28/02/2023)
93 2004 134 2027 Giá
139 2006 169 2031 trị
173 2007 221 2026 theo
226 2006 276 2026 mức
281 2002 331 2026 thuế
337 2005 412 2026 GTGT
676 1986 747 2008 Thành
752 1982 795 2008 tiền
800 1986 859 2008 trước
865 1983 915 2008 thuế
920 1986 992 2007 GTGT
1025 1983 1074 2008 Tiền
1079 1983 1127 2008 thuế
1132 1986 1201 2008 GTGT
1234 1985 1303 2008 Thành
1308 1982 1352 2008 tiền
1357 1985 1406 2012 gồm
1411 1982 1460 2008 thuế
1466 1985 1538 2008 GTGT
718 2023 813 2047 (Amount
819 2023 889 2048 before
895 2022 952 2048 VAT)
1035 2022 1096 2046 (VAT
1100 2023 1195 2048 Amount)
1248 2023 1345 2047 (Amount
1351 2023 1459 2049 including
1465 2022 1524 2048 VAT)
92 2071 132 2093 Giá
137 2072 162 2096 trị
166 2072 202 2092 với
207 2069 253 2093 thuế
259 2072 328 2092 GTGT
334 2071 371 2092 0%
378 2072 472 2094 (Amount
478 2074 502 2093 at
507 2071 545 2093 0%
552 2070 610 2097 VAT)
93 2119 132 2141 Giá
137 2120 162 2145 trị
167 2120 202 2141 với
207 2117 253 2141 thuế
259 2121 328 2140 GTGT
335 2120 371 2140 5%
378 2121 471 2141 (Amount
477 2122 502 2141 at
507 2119 546 2141 5%
552 2119 609 2145 VAT)
94 2167 132 2189 Giá
137 2168 162 2192 trị
166 2167 202 2189 với
207 2164 253 2189 thuế
259 2169 328 2188 GTGT
335 2168 372 2188 8%
379 2168 472 2189 (Amount
478 2171 502 2189 at
507 2168 545 2189 8%
552 2167 609 2193 VAT)
92 2216 132 2237 Giá
137 2217 162 2241 trị
167 2215 202 2236 với
207 2212 253 2237 thuế
259 2216 329 2236 GTGT
337 2216 384 2236 10%
391 2217 486 2237 (Amount
491 2219 516 2236 at
523 2216 572 2237 10%
579 2214 636 2240 VAT)
891 2213 1005 2239 23,936,000
1111 2213 1215 2239 2,393,600
1435 2213 1552 2240 26,329,600
94 2263 170 2285 TỔNG
176 2263 252 2288 CỘNG
260 2264 361 2287 (GRAND
368 2263 461 2288 TOTAL)
874 2262 1007 2288 23,936,000
1097 2262 1215 2288 2,393,600
1416 2261 1549 2288 26,329,600
87 2307 119 2333 Số
125 2307 171 2332 tiền
178 2304 223 2332 viết
230 2308 289 2337 bằng
295 2308 340 2332 chữ
347 2311 443 2332 (Amount
449 2310 473 2332 in
479 2309 564 2336 words):
571 2308 616 2331 Hai
621 2310 684 2331 mươi
690 2308 731 2331 sáu
736 2308 793 2336 triệu
797 2308 828 2331 ba
833 2309 888 2331 trăm
894 2308 932 2331 hai
938 2309 1001 2331 mươi
1007 2308 1059 2331 chín
1065 2308 1133 2336 nghìn
1139 2307 1180 2331 sáu
1186 2309 1241 2331 trăm
1247 2305 1308 2336 đồng

View File

@ -0,0 +1,127 @@
import os
import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from utils.ema_callbacks import EMA
def _update_config(cfg):
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
# set per-gpu batch size
num_devices = torch.cuda.device_count()
print('No. devices:', num_devices)
for mode in ["train", "val"]:
new_batch_size = cfg[mode].batch_size // num_devices
cfg[mode].batch_size = new_batch_size
def _get_config_from_cli():
cfg_cli = OmegaConf.from_cli()
cli_keys = list(cfg_cli.keys())
for cli_key in cli_keys:
if "--" in cli_key:
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
del cfg_cli[cli_key]
return cfg_cli
def get_callbacks(cfg):
callback_list = []
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
filename='best_model',
save_last=True,
save_top_k=1,
save_weights_only=True,
verbose=True,
monitor='val_f1', mode='max')
checkpoint_callback.FILE_EXTENSION = ".pth"
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
callback_list.append(checkpoint_callback)
if cfg.callbacks.ema.decay != -1:
ema_callback = EMA(decay=0.9999)
callback_list.append(ema_callback)
return callback_list if len(callback_list) > 1 else checkpoint_callback
def get_plugins(cfg):
plugins = []
if cfg.train.strategy.type == "ddp":
plugins.append(DDPPlugin())
return plugins
def get_loggers(cfg):
loggers = []
loggers.append(
TensorBoardLogger(
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
)
)
return loggers
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
for key, val in cfg.items():
if isinstance(val, DictConfig):
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
else:
hparam_dict[parent_str + key] = str(val)
return hparam_dict
def get_specific_pl_logger(pl_loggers, logger_type):
for pl_logger in pl_loggers:
if isinstance(pl_logger, logger_type):
return pl_logger
return None
def get_class_names(dataset_root_path):
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
class_names = (
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
)
return class_names
def create_exp_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Experiment dir : {}'.format(save_dir))
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Save dir : {}'.format(save_dir))
def load_checkpoint(ckpt_path, model, key_include):
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
for key in list(state_dict.keys()):
if f'.{key_include}.' not in key:
del state_dict[key]
else:
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
del state_dict[key]
model.load_state_dict(state_dict, strict=True)
print(f"Load checkpoint at {ckpt_path}")
return model
def load_model_weight(net, pretrained_model_file):
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
"state_dict"
]
new_state_dict = {}
for k, v in pretrained_model_state_dict.items():
new_k = k
if new_k.startswith("net."):
new_k = new_k[len("net.") :]
new_state_dict[new_k] = v
net.load_state_dict(new_state_dict)

View File

@ -0,0 +1,346 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import os
import threading
from typing import Any, Dict, Iterable
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_info
class EMA(Callback):
"""
Implements Exponential Moving Averaging (EMA).
When training a model, this callback will maintain moving averages of the trained parameters.
When evaluating, we use the moving averages copy of the trained parameters.
When saving, we save an additional set of parameters with the prefix `ema`.
Args:
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
every_n_steps: Apply EMA every N steps.
cpu_offload: Offload weights to CPU.
"""
def __init__(
self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False,
):
if not (0 <= decay <= 1):
raise MisconfigurationException("EMA decay value must be between 0 and 1")
self.decay = decay
self.validate_original_weights = validate_original_weights
self.every_n_steps = every_n_steps
self.cpu_offload = cpu_offload
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
device = pl_module.device if not self.cpu_offload else torch.device('cpu')
trainer.optimizers = [
EMAOptimizer(
optim,
device=device,
decay=self.decay,
every_n_steps=self.every_n_steps,
current_step=trainer.global_step,
)
for optim in trainer.optimizers
if not isinstance(optim, EMAOptimizer)
]
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._should_validate_ema_weights(trainer):
self.swap_model_weights(trainer)
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
return not self.validate_original_weights and self._ema_initialized(trainer)
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.switch_main_parameter_weights(saving_ema_model)
@contextlib.contextmanager
def save_ema_model(self, trainer: "pl.Trainer"):
"""
Saves an EMA copy of the model + EMA optimizer states for resume.
"""
self.swap_model_weights(trainer, saving_ema_model=True)
try:
yield
finally:
self.swap_model_weights(trainer, saving_ema_model=False)
@contextlib.contextmanager
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.save_original_optimizer_state = True
try:
yield
finally:
for optimizer in trainer.optimizers:
optimizer.save_original_optimizer_state = False
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
checkpoint_callback = trainer.checkpoint_callback
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
connector = trainer._checkpoint_connector
ckpt_path = connector.resume_checkpoint_path
if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
ext = checkpoint_callback.FILE_EXTENSION
if ckpt_path.endswith(f'-EMA{ext}'):
rank_zero_info(
"loading EMA based weights. "
"The callback will treat the loaded EMA weights as the main weights"
" and create a new EMA copy when training."
)
return
ema_path = ckpt_path.replace(ext, f'-EMA{ext}')
if os.path.exists(ema_path):
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
del ema_state_dict
rank_zero_info("EMA state has been restored.")
else:
raise MisconfigurationException(
"Unable to find the associated EMA weights when re-loading, "
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
)
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
)
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
if pre_sync_stream is not None:
pre_sync_stream.synchronize()
ema_update(ema_model_tuple, current_model_tuple, decay)
class EMAOptimizer(torch.optim.Optimizer):
r"""
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
Exponential Moving Average of parameters registered in the optimizer.
EMA parameters are automatically updated after every step of the optimizer
with the following formula:
ema_weight = decay * ema_weight + (1 - decay) * training_weight
To access EMA parameters, use ``swap_ema_weights()`` context manager to
perform a temporary in-place swap of regular parameters with EMA
parameters.
Notes:
- EMAOptimizer is not compatible with APEX AMP O2.
Args:
optimizer (torch.optim.Optimizer): optimizer to wrap
device (torch.device): device for EMA parameters
decay (float): decay factor
Returns:
returns an instance of torch.optim.Optimizer that computes EMA of
parameters
Example:
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
device: torch.device,
decay: float = 0.9999,
every_n_steps: int = 1,
current_step: int = 0,
):
self.optimizer = optimizer
self.decay = decay
self.device = device
self.current_step = current_step
self.every_n_steps = every_n_steps
self.save_original_optimizer_state = False
self.first_iteration = True
self.rebuild_ema_params = True
self.stream = None
self.thread = None
self.ema_params = ()
self.in_saving_ema_model_context = False
def all_parameters(self) -> Iterable[torch.Tensor]:
return (param for group in self.param_groups for param in group['params'])
def step(self, closure=None, **kwargs):
self.join()
if self.first_iteration:
if any(p.is_cuda for p in self.all_parameters()):
self.stream = torch.cuda.Stream()
self.first_iteration = False
if self.rebuild_ema_params:
opt_params = list(self.all_parameters())
self.ema_params += tuple(
copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :]
)
self.rebuild_ema_params = False
loss = self.optimizer.step(closure)
if self._should_update_at_step():
self.update()
self.current_step += 1
return loss
def _should_update_at_step(self) -> bool:
return self.current_step % self.every_n_steps == 0
@torch.no_grad()
def update(self):
if self.stream is not None:
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
current_model_state = tuple(
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
)
if self.device.type == 'cuda':
ema_update(self.ema_params, current_model_state, self.decay)
if self.device.type == 'cpu':
self.thread = threading.Thread(
target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,),
)
self.thread.start()
def swap_tensors(self, tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
self.join()
self.in_saving_ema_model_context = saving_ema_model
for param, ema_param in zip(self.all_parameters(), self.ema_params):
self.swap_tensors(param.data, ema_param)
@contextlib.contextmanager
def swap_ema_weights(self, enabled: bool = True):
r"""
A context manager to in-place swap regular parameters with EMA
parameters.
It swaps back to the original regular parameters on context manager
exit.
Args:
enabled (bool): whether the swap should be performed
"""
if enabled:
self.switch_main_parameter_weights()
try:
yield
finally:
if enabled:
self.switch_main_parameter_weights()
def __getattr__(self, name):
return getattr(self.optimizer, name)
def join(self):
if self.stream is not None:
self.stream.synchronize()
if self.thread is not None:
self.thread.join()
def state_dict(self):
self.join()
if self.save_original_optimizer_state:
return self.optimizer.state_dict()
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters())
state_dict = {
'opt': self.optimizer.state_dict(),
'ema': ema_params,
'current_step': self.current_step,
'decay': self.decay,
'every_n_steps': self.every_n_steps,
}
return state_dict
def load_state_dict(self, state_dict):
self.join()
self.optimizer.load_state_dict(state_dict['opt'])
self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
self.current_step = state_dict['current_step']
self.decay = state_dict['decay']
self.every_n_steps = state_dict['every_n_steps']
self.rebuild_ema_params = False
def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)
self.rebuild_ema_params = True

View File

@ -0,0 +1,68 @@
DKVU2XML = {
"Ký hiệu mẫu hóa đơn": "form_no",
"Ký hiệu hóa đơn": "serial_no",
"Số hóa đơn": "invoice_no",
"Ngày, tháng, năm lập hóa đơn": "issue_date",
"Tên người bán": "seller_name",
"Mã số thuế người bán": "seller_tax_code",
"Thuế suất": "tax_rate",
"Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount",
"Mặt hàng": "item",
"Đơn vị tính": "unit",
"Số lượng": "quantity",
"Đơn giá": "unit_price",
"Doanh số mua chưa có thuế": "amount"
}
def ap_dictionary(header: bool):
header_dictionary = {
'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'],
'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'],
'qty': ['quantity', 'invoicequantity']
}
key_dictionary = {
'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'],
'retailername': ['retailer', 'retailername', 'ownedoperatedby'],
'serial_number': ['serialnumber', 'serialno'],
'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2']
}
return header_dictionary if header else key_dictionary
def vat_dictionary(header: bool):
header_dictionary = {
'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'],
'Đơn vị tính': ['dvt', 'donvitinh'],
'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'],
'Đơn giá': ['dongia'],
'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'],
# 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber']
}
key_dictionary = {
'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'],
'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'],
'Số hóa đơn': ['soinvoiceno', 'invoiceno'],
'Ngày, tháng, năm lập hóa đơn': [],
'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'],
'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'],
'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'],
'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'],
# 'Ghi chú': [],
# 'Ngày': ['ngayday', 'ngay', 'day'],
# 'Tháng': ['thangmonth', 'thang', 'month'],
# 'Năm': ['namyear', 'nam', 'year']
}
# exact_dictionary = {
# 'Số hóa đơn': ['sono', 'so'],
# 'Mã số thuế người bán': ['mst'],
# 'Tên người bán': ['kyboi'],
# 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
# }
return header_dictionary if header else key_dictionary

View File

@ -0,0 +1,33 @@
import numpy as np
from pathlib import Path
from typing import Union, Tuple, List
import sys
# sys.path.append('/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/ocr-engine')
# from src.ocr import OcrEngine
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ??????
import serve_model
# def load_ocr_engine() -> OcrEngine:
def load_ocr_engine() -> OcrEngine:
print("[INFO] Loading engine...")
# engine = OcrEngine()
engine = serve_model.engine
print("[INFO] Engine loaded")
return engine
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
save_dir_or_path = Path(save_dir_or_path)
if isinstance(img, np.ndarray):
if save_dir_or_path.is_dir():
raise ValueError("numpy array input require a save path, not a save dir")
page = engine(img)
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
page.write_to_file('word', save_path)
if export_img:
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
def read_img(img: Union[str, np.ndarray], engine: OcrEngine):
page = engine(img)
return ' '.join([f.text for f in page.llines])

View File

@ -0,0 +1,101 @@
import os
import glob
import json
from tqdm import tqdm
def longestCommonSubsequence(text1: str, text2: str) -> int:
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
for i, c in enumerate(text1):
for j, d in enumerate(text2):
dp[i + 1][j + 1] = 1 + \
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
return dp[-1][-1]
def write_to_json(file_path, content):
with open(file_path, mode="w", encoding="utf8") as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, "r") as f:
return json.load(f)
def check_label_exists(array, target_label):
for obj in array:
if obj["label"] == target_label:
return True # Label exists in the array
return False # Label does not exist in the array
def merged_kvu_outputs(loutputs: list) -> dict:
compiled = []
for output_model in loutputs:
for field in output_model:
if field['value'] != "" and not check_label_exists(compiled, field['label']):
element = {
'label': field['label'],
'value': field['value'],
}
compiled.append(element)
elif field['label'] == 'table' and check_label_exists(compiled, "table"):
for index, obj in enumerate(compiled):
if obj['label'] == 'table' and len(field['value']) > 0:
compiled[index]['value'].append(field['value'])
return compiled
def split_docs(doc_data: list, threshold: float=0.6) -> list:
num_pages = len(doc_data)
outputs = []
kvu_content = []
doc_data = sorted(doc_data, key=lambda x: int(x['page_number']))
for data in doc_data:
page_id = int(data['page_number'])
doc_type = data['document_type']
doc_class = data['document_class']
fields = data['fields']
if page_id == 0:
prev_title = doc_type
start_page_id = page_id
prev_class = doc_class
curr_title = doc_type if doc_type != "unknown" else prev_title
curr_class = doc_class if doc_class != "unknown" else "other"
kvu_content.append(fields)
similarity_score = longestCommonSubsequence(curr_title, prev_title) / len(prev_title)
if similarity_score < threshold:
end_page_id = page_id - 1
outputs.append({
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
"start_page": start_page_id,
"end_page": end_page_id,
"content": merged_kvu_outputs(kvu_content[:-1])
})
prev_title = curr_title
prev_class = curr_class
start_page_id = page_id
kvu_content = kvu_content[-1:]
if page_id == num_pages - 1: # end_page
outputs.append({
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
"start_page": start_page_id,
"end_page": page_id,
"content": merged_kvu_outputs(kvu_content)
})
elif page_id == num_pages - 1: # end_page
outputs.append({
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
"start_page": start_page_id,
"end_page": page_id,
"content": merged_kvu_outputs(kvu_content)
})
return outputs
if __name__ == "__main__":
threshold = 0.9
json_path = "/home/sds/tuanlv/02-KVU/02-KVU_test/visualize/manulife_v2/json_outputs/HS_YCBT_No_IP_HMTD.json"
doc_data = read_json(json_path)
outputs = split_docs(doc_data, threshold)
write_to_json(os.path.join(os.path.dirname(json_path), "splited_doc.json"), outputs)

View File

@ -0,0 +1,548 @@
import os
import cv2
import json
import torch
import glob
import re
import numpy as np
from tqdm import tqdm
from pdf2image import convert_from_path
from dicttoxml import dicttoxml
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
print("DIR already existed.")
print('Save dir : {}'.format(save_dir))
def pdf2image(pdf_dir, save_dir):
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
print('No. pdf files:', len(pdf_files))
for file in tqdm(pdf_files):
pages = convert_from_path(file, 500)
for i, page in enumerate(pages):
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
print('Done!!!')
def xyxy2xywh(bbox):
return [
float(bbox[0]),
float(bbox[1]),
float(bbox[2]) - float(bbox[0]),
float(bbox[3]) - float(bbox[1]),
]
def write_to_json(file_path, content):
with open(file_path, mode='w', encoding='utf8') as f:
json.dump(content, f, ensure_ascii=False)
def read_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def read_xml(file_path):
with open(file_path, 'r') as xml_file:
return xml_file.read()
def write_to_xml(file_path, content):
with open(file_path, mode="w", encoding='utf8') as f:
f.write(content)
def write_to_xml_from_dict(file_path, content):
xml = dicttoxml(content)
xml = content
xml_decode = xml.decode()
with open(file_path, mode="w") as f:
f.write(xml_decode)
def load_ocr_result(ocr_path):
with open(ocr_path, 'r') as f:
lines = f.read().splitlines()
preds = []
for line in lines:
preds.append(line.split('\t'))
return preds
def post_process_basic_ocr(lwords: list) -> list:
pp_lwords = []
for word in lwords:
pp_lwords.append(word.replace("", " "))
return pp_lwords
def read_ocr_result_from_txt(file_path: str):
'''
return list of bounding boxes, list of words
'''
with open(file_path, 'r') as f:
lines = f.read().splitlines()
boxes, words = [], []
for line in lines:
if line == "":
continue
word_info = line.split("\t")
if len(word_info) == 6:
x1, y1, x2, y2, text, _ = word_info
elif len(word_info) == 5:
x1, y1, x2, y2, text = word_info
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
if text and text != " ":
words.append(text)
boxes.append((x1, y1, x2, y2))
return boxes, words
def get_colormap():
return {
'others': (0, 0, 255), # others: red
'title': (0, 255, 255), # title: yellow
'key': (255, 0, 0), # key: blue
'value': (0, 255, 0), # value: green
'header': (233, 197, 15), # header
'group': (0, 128, 128), # group
'relation': (0, 0, 255)# (128, 128, 128), # relation
}
def convert_image(image):
exif = image._getexif()
orientation = None
if exif is not None:
orientation = exif.get(0x0112)
# Convert the PIL image to OpenCV format
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Rotate the image in OpenCV if necessary
if orientation == 3:
image = cv2.rotate(image, cv2.ROTATE_180)
elif orientation == 6:
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
elif orientation == 8:
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
image = np.asarray(image)
if len(image.shape) == 2:
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
assert len(image.shape) == 3
return image, orientation
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
image, orientation = convert_image(image)
if orientation is not None and orientation == 6:
width, height, _ = image.shape
else:
height, width, _ = image.shape
if len(pr_class_words) > 0:
id2label = {k: labels[k] for k in range(len(labels))}
for lb, groups in enumerate(pr_class_words):
if lb == 0:
continue
for group_id, group in enumerate(groups):
for i, word_id in enumerate(group):
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
if i == 0:
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
else:
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
x_center0, y_center0 = x_center1, y_center1
if len(pr_relations) > 0:
for pair in pr_relations:
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
return image
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
outputs = {}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] in (rel_from, rel_to):
is_rel[element['class']]['status'] = 1
is_rel[element['class']]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
return outputs
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
list_keys = list(header_key_pairs.keys())
relations = {k: [] for k in list_keys}
for pair in json:
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
for element in pair:
if element['class'] == rel_from and element['group_id'] in list_keys:
is_rel[rel_from]['status'] = 1
is_rel[rel_from]['value'] = element
if element['class'] == rel_to:
is_rel[rel_to]['status'] = 1
is_rel[rel_to]['value'] = element
if all([v['status'] == 1 for _, v in is_rel.items()]):
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
return relations
def get_key2values_relations(key_value_pairs: dict):
triple_linkings = {}
for value_group_id, key_value_pair in key_value_pairs.items():
key_group_id = key_value_pair[0]
if key_group_id not in list(triple_linkings.keys()):
triple_linkings[key_group_id] = []
triple_linkings[key_group_id].append(value_group_id)
return triple_linkings
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
word_groups = {}
id2class = {i: labels[i] for i in range(len(labels))}
for class_id, lwgroups_in_class in enumerate(class_words):
for ltokens_in_wgroup in lwgroups_in_class:
group_id = ltokens_in_wgroup[0]
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
text_string = get_string(ltokens_to_ltexts)
word_groups[group_id] = {
'group_id': group_id,
'text': text_string,
'class': id2class[class_id],
'tokens': ltokens_in_wgroup
}
return word_groups
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
if linking_id not in list(word_groups):
for wg_id, _word_group in word_groups.items():
if linking_id in _word_group['tokens']:
return wg_id
return linking_id
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
outputs = []
for pair in lrelations:
wg_from = verify_linking_id(word_groups, pair[0])
wg_to = verify_linking_id(word_groups, pair[1])
try:
outputs.append([word_groups[wg_from], word_groups[wg_to]])
except Exception as e:
print('Not valid pair:', wg_from, wg_to)
return outputs
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
triplet_pairs = []
single_pairs = []
table = []
# print('key2values_relations', key2values_relations)
for key_group_id, list_value_group_ids in key2values_relations.items():
if len(list_value_group_ids) == 0: continue
elif len(list_value_group_ids) == 1:
value_group_id = list_value_group_ids[0]
single_pairs.append({word_groups[key_group_id]['text']: {
'text': word_groups[value_group_id]['text'],
'id': value_group_id,
'class': "value"
}})
else:
item = []
for value_group_id in list_value_group_ids:
if value_group_id not in header_value.keys():
header_name_for_value = "non-header"
else:
header_group_id = header_value[value_group_id][0]
header_name_for_value = word_groups[header_group_id]['text']
item.append({
'text': word_groups[value_group_id]['text'],
'header': header_name_for_value,
'id': value_group_id,
'class': 'value'
})
if key_group_id not in list(header_key.keys()):
triplet_pairs.append({
word_groups[key_group_id]['text']: item
})
else:
header_group_id = header_key[key_group_id][0]
header_name_for_key = word_groups[header_group_id]['text']
item.append({
'text': word_groups[key_group_id]['text'],
'header': header_name_for_key,
'id': key_group_id,
'class': 'key'
})
table.append({key_group_id: item})
if len(table) > 0:
table = sorted(table, key=lambda x: list(x.keys())[0])
table = [v for item in table for k, v in item.items()]
outputs = {}
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
outputs['triplet'] = triplet_pairs
outputs['table'] = table
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
write_to_json(file_path, outputs)
return outputs
# For FI-VAT project
def get_vat_table_information(outputs):
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
if header_name in list(item.keys()):
# item[header_name] = value['text']
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
if header_name in ("Số lượng", "Doanh số mua chưa có thuế"):
item[header_name] = '0'
else:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item = post_process_for_item(item)
if item["Mặt hàng"] == None:
continue
table.append(item)
return table
def get_vat_information(outputs):
# VAT Information
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
for pair in outputs['single']:
for raw_key_name, value in pair.items():
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id'],
})
for triplet in outputs['triplet']:
for key, value_list in triplet.items():
if len(value_list) == 1:
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': value_list[0]['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value_list[0]['id']
})
for pair in value_list:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
for table_row in outputs['table']:
for pair in table_row:
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs.keys()):
single_pairs[key_name].append({
'content': pair['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': pair['id']
})
return single_pairs
def post_process_vat_information(single_pairs):
vat_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
if len(list_potential_value) == 1:
vat_outputs[key_name] = list_potential_value[0]['content']
else:
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
for value in list_potential_value:
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
else:
if len(list_potential_value) == 0: continue
if key_name in ("Mã số thuế người bán"):
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
# tax_code_raw = selected_value['content'].replace(' ', '')
tax_code_raw = selected_value['content']
if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated
tax_code_raw = tax_code_raw.split(' ')
tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0]
vat_outputs[key_name] = tax_code_raw.replace(' ', '')
else:
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
vat_outputs[key_name] = selected_value['content']
return vat_outputs
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
vat_outputs = {}
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_vat_table_information(outputs)
# VAT Information
single_pairs = get_vat_information(outputs)
vat_outputs = post_process_vat_information(single_pairs)
# Combine VAT information and table
vat_outputs['table'] = table
write_to_json(file_path, vat_outputs)
return vat_outputs
# For SBT project
def get_ap_table_information(outputs):
table = []
for single_item in outputs['table']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
for cell in single_item:
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
if header_name in list(item.keys()):
item[header_name].append({
'content': cell['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': cell['id']
})
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
table.append(item)
return table
def get_ap_triplet_information(outputs):
triplet_pairs = []
for single_item in outputs['triplet']:
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
is_item_valid = 0
for key_name, list_value in single_item.items():
for value in list_value:
if value['header'] == "non-header":
continue
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
if header_name in list(item.keys()):
is_item_valid = 1
item[header_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
if is_item_valid == 1:
for header_name, value in item.items():
if len(value) == 0:
item[header_name] = None
continue
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
item['productname'] = key_name
# triplet_pairs.append({key_name: new_item})
triplet_pairs.append(item)
return triplet_pairs
def get_ap_information(outputs):
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
for pair in outputs['single']:
for key_name, value in pair.items():
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
if key_name in list(single_pairs):
single_pairs[key_name].append({
'content': value['text'],
'processed_key_name': proceessed_text,
'lcs_score': score,
'token_id': value['id']
})
ap_outputs = {k: None for k in list(single_pairs)}
for key_name, list_potential_value in single_pairs.items():
if len(list_potential_value) == 0: continue
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
ap_outputs[key_name] = selected_value['content']
return ap_outputs
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
# List of items in table
table = get_ap_table_information(outputs)
triplet_pairs = get_ap_triplet_information(outputs)
table = table + triplet_pairs
ap_outputs = get_ap_information(outputs)
ap_outputs['table'] = table
# ap_outputs['triplet'] = triplet_pairs
write_to_json(file_path, ap_outputs)

View File

@ -0,0 +1,224 @@
import nltk
import re
import string
import copy
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
nltk.download('words')
words = set(nltk.corpus.words.words())
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
# def clean_text(text):
# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text)
def get_string(lwords: list):
unique_list = []
for item in lwords:
if item.isdigit() and len(item) == 1:
unique_list.append(item)
elif item not in unique_list:
unique_list.append(item)
return ' '.join(unique_list)
def remove_english_words(text):
_word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words]
return ' '.join(_word)
def remove_punctuation(text):
return text.translate(str.maketrans(" ", " ", string.punctuation))
def remove_accents(input_str, s0, s1):
s = ''
# print input_str.encode('utf-8')
for c in input_str:
if c in s1:
s += s0[s1.index(c)]
else:
s += c
return s
def remove_spaces(text):
return text.replace(' ', '')
def preprocessing(text: str):
# text = remove_english_words(text) if table else text
text = remove_punctuation(text)
text = remove_accents(text, s0, s1)
text = remove_spaces(text)
return text.lower()
def vat_standardize_outputs(vat_outputs: dict) -> dict:
outputs = {}
for key, value in vat_outputs.items():
if key != "table":
outputs[DKVU2XML[key]] = value
else:
list_items = []
for item in value:
list_items.append({
DKVU2XML[item_key]: item_value for item_key, item_value in item.items()
})
outputs['table'] = list_items
return outputs
def vat_standardizer(text: str, threshold: float, header: bool):
dictionary = vat_dictionary(header)
processed_text = preprocessing(text)
for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]:
if any([processed_text in txt for txt in candidates]):
processed_text = candidates[-1]
return "Ngày, tháng, năm lập hóa đơn", 5, processed_text
_dictionary = copy.deepcopy(dictionary)
if not header:
exact_dictionary = {
'Số hóa đơn': ['sono', 'so'],
'Mã số thuế người bán': ['mst'],
'Tên người bán': ['kyboi'],
'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
}
for k, v in exact_dictionary.items():
_dictionary[k] = dictionary[k] + exact_dictionary[k]
for k, v in dictionary.items():
# if k in ("Ngày, tháng, năm lập hóa đơn"):
# continue
# Prioritize match completely
if k in ('Tên người bán') and processed_text == "kyboi":
return k, 8, processed_text
if any([processed_text == key for key in _dictionary[k]]):
return k, 10, processed_text
scores = {k: 0.0 for k in dictionary}
for k, v in dictionary.items():
if k in ("Ngày, tháng, năm lập hóa đơn"):
continue
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
key, score = max(scores.items(), key=lambda x: x[1])
return key if score > threshold else text, score, processed_text
def ap_standardizer(text: str, threshold: float, header: bool):
dictionary = ap_dictionary(header)
processed_text = preprocessing(text)
# Prioritize match completely
_dictionary = copy.deepcopy(dictionary)
if not header:
_dictionary['serial_number'] = dictionary['serial_number'] + ['sn']
_dictionary['imei_number'] = dictionary['imei_number'] + ['imel']
else:
_dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei']
_dictionary['qty'] = dictionary['qty'] + ['qty']
for k, v in dictionary.items():
if any([processed_text == key for key in _dictionary[k]]):
return k, 10, processed_text
scores = {k: 0.0 for k in dictionary}
for k, v in dictionary.items():
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
key, score = max(scores.items(), key=lambda x: x[1])
return key if score >= threshold else text, score, processed_text
def convert_format_number(s: str) -> float:
s = s.replace(' ', '').replace('O', '0').replace('o', '0')
if s.endswith(",00") or s.endswith(".00"):
s = s[:-3]
if all([delimiter in s for delimiter in [',', '.']]):
s = s.replace('.', '').split(',')
remain_value = s[1].split('0')[0]
return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value))
else:
s = s.replace(',', '').replace('.', '')
return int(s)
def post_process_for_item(item: dict) -> dict:
check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế']
mis_key = []
for key in check_keys:
if item[key] in (None, '0'):
mis_key.append(key)
if len(mis_key) == 1:
try:
if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0:
item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__()
elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0:
item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__()
elif mis_key[0] == check_keys[2]:
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
except Exception as e:
print("Cannot post process this item with error:", e)
return item
def longestCommonSubsequence(text1: str, text2: str) -> int:
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
for i, c in enumerate(text1):
for j, d in enumerate(text2):
dp[i + 1][j + 1] = 1 + \
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
return dp[-1][-1]
def longest_common_subsequence_with_idx(X, Y):
"""
This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS.
The longest_common_subsequence function takes two strings as input, and returns a tuple with three values:
the length of the LCS,
the index of the first character of the LCS in the first string,
and the index of the last character of the LCS in the first string.
"""
m, n = len(X), len(Y)
L = [[0 for i in range(n + 1)] for j in range(m + 1)]
# Following steps build L[m+1][n+1] in bottom up fashion. Note
# that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
right_idx = 0
max_lcs = 0
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
if L[i][j] > max_lcs:
max_lcs = L[i][j]
right_idx = i
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
# Create a string variable to store the lcs string
lcs = L[i][j]
# Start from the right-most-bottom-most corner and
# one by one store characters in lcs[]
i = m
j = n
# right_idx = 0
while i > 0 and j > 0:
# If current character in X[] and Y are same, then
# current character is part of LCS
if X[i - 1] == Y[j - 1]:
i -= 1
j -= 1
# If not same, then find the larger of two and
# go in the direction of larger value
elif L[i - 1][j] > L[i][j - 1]:
# right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs
i -= 1
else:
j -= 1
return lcs, i, max(i + lcs, right_idx)

View File

@ -0,0 +1,172 @@
from mmdet.apis import inference_detector, init_detector
import cv2
import numpy as np
import urllib
def get_center(box):
xmin, ymin, xmax, ymax = box
x_center = int((xmin + xmax) / 2)
y_center = int((ymin + ymax) / 2)
return [x_center, y_center]
def cal_euclidean_dist(p1, p2):
return np.linalg.norm(p1 - p2)
def bbox_to_four_poinst(bbox):
"""convert one bouding box to 4 corner poinst
Args:
bbox (_type_): _description_
"""
xmin, ymin, xmax, ymax = bbox
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
return poinst
def find_closest_point(src_point, point_list):
"""
Args:
point (list): point format xy
point_list (list[list]): list of point xy
"""
point_list = np.array(point_list)
dist_list = np.array(
cal_euclidean_dist(src_point, target_point) for target_point in point_list
)
index_closest_point = np.argmin(dist_list)
return index_closest_point
def crop_align_card(img_src, corner_box_list):
"""Dewarp image based on four courners
Args:
corner_list (list): four points of corners
"""
img = img_src.copy()
if isinstance(corner_box_list[0], list):
poinst = [get_center(box) for box in corner_box_list]
else:
# print(corner_box_list)
xmin, ymin, xmax, ymax = corner_box_list
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
return dewarp(img, poinst)
def dewarp(image, poinst):
if isinstance(poinst, list):
poinst = np.array(poinst, dtype="float32")
(tl, tr, br, bl) = poinst
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array(
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
dtype="float32",
)
M = cv2.getPerspectiveTransform(poinst, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
class MdetPredictor:
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
self.model = init_detector(config, checkpoint, device=device)
self.class_names = self.model.CLASSES
def infer(self, image, threshold=0.2):
bbox_result = inference_detector(self.model, image)
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
res_bboxes = []
res_labels = []
for idx, box in enumerate(bboxes):
score = box[-1]
if score >= threshold:
label = labels[idx]
res_bboxes.append(box.tolist()[:4])
res_labels.append(self.class_names[label])
return res_bboxes, res_labels
class ImageTransformer:
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
self.corner_detect_model = MdetPredictor(config, checkpoint, device)
def __call__(self, image, threshold=0.2):
"""
Args:
image (np.ndarray): BGR image
"""
corner_result = self.corner_detect_model.infer(image)
corners_dict = self.__extract_corners(corner_result)
card_image = self.__crop_image_based_on_corners(image, corners_dict)
return card_image
def __extract_corners(self, corner_result):
bboxes, labels = corner_result
# convert bbox to int
bboxes = [[int(x) for x in box] for box in bboxes]
output = {k: bboxes[labels.index(k)] for k in labels}
# print(output)
return output
def __crop_image_based_on_corners(self, image, corners_dict):
"""
Args:
corners_dict (_type_): _description_
"""
if "card" in corners_dict.keys():
if len(corners_dict.keys()) == 5:
points = [
corners_dict["top_left"],
corners_dict["top_right"],
corners_dict["bottom_right"],
corners_dict["bottom_left"],
]
else:
points = corners_dict["card"]
card_image = crop_align_card(image, points)
else:
card_image = None
return card_image
def crop_location(image_url):
transform_module = ImageTransformer(
config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py",
checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth",
device="cuda:0",
)
req = urllib.request.urlopen(image_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
card_image = transform_module(img)
if card_image is not None:
return card_image
else:
return img

View File

@ -0,0 +1,622 @@
{
"20221027_154840.json": {
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
"pred": "ngày date 05 tháng month 04 năm/year 2016"
},
"7ba0f6b2f2ff34a16dee18.json": {
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
"pred": "ngày date 01 tháng Amount 04 năm year 2022"
},
"20221027_155646.json": {
"label": "ngày /date 01 tháng /month 04 năm/year✪2022",
"pred": "ngày date or tháng month 04 năm/year2022"
},
"799c679b63d6a588fcc717.json": {
"label": "ngày /date 30 tháng /month 07 năm/year 2015",
"pred": "ngày lone 30 thể 2 work 07 ndmyvar 2015"
},
"150094748_1728953773944481_6269983404281027305_n.json": {
"label": "ngày/da e 15✪tháng # #th 04 năm /year 2020",
"pred": "ngày/da - 15tháng % with 04 năm (year 2020"
},
"20221027_155711.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "ngày date 16 tháng month 06 năm 'year 2020"
},
"20221027_155638.json": {
"label": "ngày /date 05 tháng /month 04 năm/year✪2016",
"pred": "ngày /date 05 tháng month 04 năm/year2016"
},
"20221027_155754.json": {
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
"pred": "ngày/date 19 tháng month 03 năm/year 2018"
},
"201417393_4052045311588809_501501345369021923_n.json": {
"label": "ngày /date 27 tháng /month 07 năm/year 2020",
"pred": "ngày date 27 tháng month 07 năm/year 2020"
},
"c50de8e0e2ad24f37dbc28.json": {
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
"pred": "ngày date 24 tháng month 05 năm/year 2016"
},
"178033599_360318769019204_7688552975060615249_n.json": {
"label": "ngày /date 13 tháng month 08 năm /year 2020",
"pred": "ngày date 73 tháng month 08 năm year 2020"
},
"20221027_154755.json": {
"label": "ngày /date 22 tháng /month 04 năm/year 2022",
"pred": "ngày date 22 tháng month 04 năm/year 2022"
},
"20221027_154701.json": {
"label": "ngày/date 27 tháng /month 10 năm/year 2014",
"pred": "ngày/date 27 tháng month 10 năm/year 2014"
},
"20221027_154617.json": {
"label": "ngày/date 10 tháng /month 03 năm/year 2022",
"pred": "ngày/date 10 than Wmonth 03 năm/year 2022"
},
"20221027_155429.json": {
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
"pred": "ngày date 29 tháng month 10 năm/year 2020"
},
"38949066_1826046370807755_4672633932229902336_n.json": {
"label": "ngày /date 03 tháng /month 07 năm /year 2017",
"pred": "ngày date 03 tháng month 0Z năm year 2017"
},
"174353602_3093780914182491_6316781286680210887_n.json": {
"label": "ngày /date 09 tháng /month 09 năm/year 2019",
"pred": "ngày date 09 tháng month 09 năm/voce 2019"
},
"135575662_779633739578177_65454671165731184_n.json": {
"label": "ngày /date 02 tháng /month 10 năm /year 2016",
"pred": "ngày 'date 07 tháng month 10 năm Sear 2016"
},
"198291731_4210067705720978_7154894655460708366_n.json": {
"label": "ngày /date 05 tháng month 05 năm/year 2014",
"pred": "ngày date 05 tháng month 05 ndmivar 2014"
},
"20221027_155325.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
},
"20221027_155526.json": {
"label": "ngày/date 14 tháng /month 01 năm/year✪2019",
"pred": "ngày/date 14 tháng month 01 năm/year2019"
},
"20221027_155759.json": {
"label": "ngày/date 24 tháng /month 05 năm/year 2016",
"pred": "ngày/date 24 tháng month 05 năm/year 2016"
},
"f40789388d754b2b126416.json": {
"label": "ngày /date 10 tháng/month 09 năm /year 2020",
"pred": "ngày /date 10thing month 09 năm year 2020"
},
"20221027_155453.json": {
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
"pred": "ngày date 27 tháng month 10 năm/year 2014"
},
"88bb0caa03e7c5b99cf64.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "ngày dg h 6AR ag M ath 06 năm hear 2020"
},
"20221027_154408.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
"pred": "ngày date 28 tháng month 05 năm year 2019"
},
"200064855_1976213642526776_4676665588314498194_n.json": {
"label": "ngày /date 04 tháng /month 01 năm /year 2018",
"pred": "ngày date 04 tháng /month 01 năm year 2018"
},
"61c259385775912bc86410.json": {
"label": "ngày /date 20 tháng/month 09 năm/year 2017",
"pred": "ngày /de l 20 th and 4 onth 09 năm/year 2017"
},
"20221027_155630.json": {
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
"pred": "ngày date 10 tháng month 09 năm year 2020"
},
"20221027_155342.json": {
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
"pred": "ngày/date 19 tháng month 03 năm/year 2018"
},
"165824171_1804443136392377_8891768953420682785_n.json": {
"label": "ngày /date 03 tháng month 04 năm /year 2018",
"pred": "ngày date 03 tháng month 04 năm year 2018"
},
"107005164_1003706713393058_3039921363490738151_n.json": {
"label": "ngày /date 1 8 tháng month 11 năm /year 2019",
"pred": "ngày 'date I 8 tháng month 11 năm vear 2019"
},
"20221027_155658.json": {
"label": "ngày/date 14 tháng /month 01 năm/year 2019",
"pred": "ngày/date 14 tháng month 01 năm/year 2019"
},
"20221027_154654.json": {
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
"pred": "ngày date 08 tháng month 12 năm/year 2015"
},
"f2badabfd1f217ac4ee327.json": {
"label": "ngày /date 19 tháng /month 03 năm/year 2018",
"pred": "ngày /date 19 their almonth 03 năm/year 2018"
},
"20221027_155501.json": {
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
"pred": "ngày/date 1 l tháng month 06 năm/year. 2014"
},
"74179950_806804509749947_7322741604127604736_n.json": {
"label": "ngày /date 11 tháng /month 06 năm /year 2019",
"pred": "ngày date 11 tháng month 06 năm \\/gar 2019"
},
"197892118_1197184050744529_3186157591590303981_n.json": {
"label": "ngày /date 15 tháng /month 05 năm year 2015",
"pred": "ngày date IS tháng month 05 năm year 2015"
},
"8265ab8ba0c666983fd722.json": {
"label": "ngày /date 08 tháng /month 12 năm /year 2015",
"pred": "ngày /date 08 tháng month 12 năm year 2015"
},
"185421881_360318465685901_6968676669094190049_n.json": {
"label": "ngày /date 13 tháng month 08 năm /year 2020",
"pred": "ngày date 73 tháng month 08 năm year 2020"
},
"20221027_155142.json": {
"label": "ngày/date 30 tháng/month 07 năm/year 2015",
"pred": "ngày/date 30 tháng/month 07 năm/year 2015"
},
"39441751_326131871285503_8401816317220356096_n.json": {
"label": "ngày /date 03 tháng /month 08 năm/year✪2017",
"pred": "ngày 'date 03 tháng month 08 năm✪year✪2017"
},
"20221027_154912.json": {
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
},
"199235658_1994072707411372_221206969024509405_n.json": {
"label": "ngày /date 05 tháng month 06 năm year 2020",
"pred": "ngày date 05 tháng month 06 năm year 2020"
},
"168434217_1698404580364474_5439600436729489777_n.json": {
"label": "ngày /date 05 tháng /month 11 năm/year 2020",
"pred": "ngày date 05 tháng month 11 ndmhear 2020"
},
"20221027_154805.json": {
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
"pred": "ngày date 14 tháng month 01 năm/year 2019"
},
"e043761b7256b408ed4712.json": {
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
"pred": "ngày date 27 tháng month 10 năm/year 2014"
},
"20221027_155401.json": {
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
"pred": "ngày date 29 tháng /uponth 10 năm/year2013"
},
"20221027_155316.json": {
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
"pred": "ngày/date 20 tháng month 09 năm/year 2017"
},
"193926583_1626674417674644_309549447428666454_n.json": {
"label": "ngày /date 29 tháng /month 07 năm/year 2020",
"pred": "ngày (date 29 tháng /month 07 năm✪year 2020"
},
"20221027_154823.json": {
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
"pred": "ngày date or tháng month 04 năm year 2022"
},
"138942425_242694630705291_5683978028617807264_n.json": {
"label": "ngày /date 30 tháng/month 01 năm/year✪2013",
"pred": "ngày date 30 tháng/month 01 năm/year2013"
},
"20221027_155334.json": {
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
"pred": "ngày date 24 tháng month 05 năm/year 2016"
},
"187421917_1668044206721318_779901369147309116_n.json": {
"label": "ngày date 04 # # 05 # 2022",
"pred": "ngày date 4 them Please 05 advise 2021"
},
"20221027_154716.json": {
"label": "ngày /date 11 tháng /month 06 năm/year 2014",
"pred": "ngày date 11 tháng month 06 năm/year 2014"
},
"5b1116c61d8bdbd5829a23.json": {
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
"pred": "ngày date 29 tháng /month 10 năm/year 2020"
},
"20221027_154850.json": {
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
"pred": "ngày date 10 tháng month 09 năm year 2020"
},
"1489864e88034e5d17122.json": {
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
"pred": "ngày date 08 hàng month 12 năm/year 2015"
},
"199990418_1443812262638998_8173300652488821384_n.json": {
"label": "ngày/date 10✪tháng /month 03 năm/year 2021",
"pred": "ngày/date 10.50m Noun th 03 ndm/year 2021"
},
"139073668_833869180522062_7998364448555134241_n.json": {
"label": "ngày/date 26 tháng /month 07 năm/year 2017",
"pred": "ngày/date 26 tháng /month 07 năm/year2 2017"
},
"20221027_154528.json": {
"label": "ngày date 19 tháng /month 03 năm/year 2018",
"pred": "ngày date 19 tháng month 03 năm/year 2018"
},
"20221027_154423.json": {
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
"pred": "ngày/date 20 tháng month 09 năm/yew 2017"
},
"20221027_154722.json": {
"label": "ngày /date 11 tháng /month 06 năm/year 2014",
"pred": "ngày date 1 I tháng month 06 năm/year 2014"
},
"144003628_4199201026774245_6202264670231940239_n.json": {
"label": "ngày date 07 tháng month 03 ######## ### #",
"pred": "ngày date0 thôn Normmm 05 adortear 7"
},
"20221027_154434.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
},
"20221027_154630.json": {
"label": "ngày/date 29 tháng /month 10 năm/year 2020",
"pred": "ngày/date 29 tháng month 10 năm/year 2020"
},
"20221027_155552.json": {
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
"pred": "ngày date 10 tháng month 09 năm 'year 2020"
},
"131114177_1027132027767399_411142190418396877_n.json": {
"label": "ngày /date 17 tháng /month 01 năm/year✪2018",
"pred": "ngày date 17 tháng /month 01 năm/year2018"
},
"164703297_455738728964651_5260332814645460915_n.json": {
"label": "ngày date 03 tháng month 11 năm/year 2020",
"pred": "ngày date 03 tháng month I I năm/year 2020"
},
"20221027_155926.json": {
"label": "ngày/date 20 tháng /month 09 năm /year 2017",
"pred": "ngày/date 20 tháng month 09 năm year 2017"
},
"20221027_154832.json": {
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
"pred": "ngày date 05 tháng month 04 năm/year 2016 6"
},
"dd09b9b6b2fb74a52dea24.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "123) d 2010 Way Yount 06 năm year 2020"
},
"20221027_154646.json": {
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
"pred": "ngày date 08 2 tháng month 12 năm/year2 2015"
},
"180534342_1213803569050037_4381710158357942629_n.json": {
"label": "ngày /date 20 tháng /month 10 năm/year 21",
"pred": "ngày date 20 tháng month 10 năm/year 2"
},
"20221027_155443.json": {
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
"pred": "ngày date 08 tháng month 12 năm/year 2015"
},
"25a9717c7b31bd6fe42029.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 tháng month 07 năm hear 2019"
},
"a0a8281f2652e00cb9435.json": {
"label": "ngày/date 10 tháng /month 03 năm/year 2022",
"pred": "ngày/date 10 that abroath 03 năm/year 2022"
},
"48793dfd37b0f1eea8a130.json": {
"label": "tháng/month 09 năm/year\n ngày/date 20 tháng/month\n 163",
"pred": "ngày/date 20 chăn knowin năm/ya\n 163"
},
"20221027_154730.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "ngày date to tháng month 06 năm year 2020"
},
"174102242_893537741194123_1381062036549019974_n.json": {
"label": "ngày /date 11 tháng /month 11 năm/year 2019",
"pred": "ngày date 17 tháng Thuonth 11 năm/year 2019"
},
"20221027_154541.json": {
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
"pred": "ngày date 29 tháng month 10 năm/year2013"
},
"20221027_155939.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
"pred": "ngày date 28 tháng month 05 năm 'year 2019"
},
"20221027_154452.json": {
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
"pred": "ngày date 24 tháng month 05 năm/year 2016"
},
"104353445_990772771353119_6131582365614146594_n.json": {
"label": "ngày date 18 tháng /month 03 năm /year 2019",
"pred": "ngày date If tháng month 03 năm year 2019"
},
"20221027_155418.json": {
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
"pred": "ngày/date 10than dmonth 03 năm/year 2022"
},
"20221027_155916.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
},
"195887607_545276056640128_7265052621888807786_n.json": {
"label": "ngày /date 12 tháng /month 01 năm /year 2017",
"pred": "ngày /dote 12 tháng month 01 năm hear 201 7"
},
"168303942_358282189193092_4968412916165104911_n.json": {
"label": "ngày/date 19 tháng month 03 năm year 2015",
"pred": "ngày/date 19 tháng month 03 năm your 1019"
},
"6a70d15bd51613484a0713.json": {
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
"pred": "ngày date n háu Z with 04 năm year 2022"
},
"20221027_155302.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
"pred": "ngày date 28 tháng month 05 năm 'year 2019"
},
"20221027_154815.json": {
"label": "ngày /date 01 tháng /month 04 năm/year 2022",
"pred": "ngày date or tháng month 04 năm/year 2022"
},
"c87b81298e64483a11757.json": {
"label": "ngày /date 19 tháng /month 03 năm/year 201",
"pred": "ngày date 19 thớ almonth 03 năm/year 201"
},
"20221027_155620.json": {
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
},
"20221027_155511.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "ngày date to tháng month 06 năm 'year 2020"
},
"745c5d6b52269478cd379.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 thân g worth 07 năm hear 2019"
},
"9329511f5552930cca4315.json": {
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
"pred": "ngày/date 11 tháng month 06 năm/year 2014"
},
"158882925_262850065433814_5526034984745996835_n.json": {
"label": "ngày /date 16 tháng month 12 năm/year 2015",
"pred": "ngày date 16 tháng month 12 năm/year 2015"
},
"140687263_3755683421155059_7637736837539526203_n.json": {
"label": "năm/year 2017\n ngày /date # # # # 08 năm/year",
"pred": "năm 2017\n ngày 'date 422 ins 1.1 ma 08 năm"
},
"20221027_155717.json": {
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
"pred": "ngày/date 11 tháng month 06 năm/year 2014"
},
"148919455_2877898732481724_2579276238538203411_n.json": {
"label": "ngày /date 05 tháng /month 09 năm/year✪2018",
"pred": "ngày date 05 thán g /month 09 năm/year2018"
},
"c8b0dc9cd8d11e8f47c014.json": {
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
"pred": "ngày /date 0.5 tháng g/month 04 năm/year 2016"
},
"175913976_2827333254262221_2873818403698028020_n.json": {
"label": "ngày date 05 tháng month 01 năm year 2018",
"pred": "ngày dan us thông month 01 năm year 2018"
},
"20221027_155029.json": {
"label": "ngày/date 30 tháng/month\n 07 năm/year 2015",
"pred": "ngày/date\n năm/year 2015"
},
"196165776_1160925321042008_58817602967276351_n.json": {
"label": "ngày date 01 tháng month 09 năm year 2020",
"pred": "ngày date or tháng month 09 năm year 2020"
},
"162820484_451115505943597_8326385834717580925_n.json": {
"label": "ngày /date 27 tháng /month 08 thường 2014",
"pred": "ngày/ date 27 tháng furonth 08 năm/year 2014"
},
"41594446_1316256461838413_661251515624718336_n.json": {
"label": "ngày /date 21 tháng /month 06 năm/year✪2017",
"pred": "ngày /date 27 tháng month 06 năm/year201 7"
},
"20221027_155728.json": {
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
"pred": "ngày date 08 tháng month 12 năm/year 2015"
},
"142090201_2826170774286852_1233962294093312865_n.json": {
"label": "ngày /date 29 tháng month 07 năm /year 2019",
"pred": "ngày date 29 tháng month 07 năm year 2019"
},
"0dd2fe8bf5c633986ad726.json": {
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
"pred": "march date 29 tháng month 10✪năm✪volur✪2013"
},
"190919701_1913643422140446_6855763478065892825_n.json": {
"label": "ngày date 24 tháng /month 07 năm/year 2017",
"pred": "ngày /date 24 tháng month 07 năm/year 2017"
},
"147585615_2757487791230682_5515346433540820516_n.json": {
"label": "ngày /date # 9✪tháng /month 05 năm /year 2019",
"pred": "ngày date 2 Danate month 05 năm year 2019"
},
"5fcae09ee4d3228d7bc211.json": {
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
"pred": "ngày 'date 14 tháng month 0.1 năm/year 2019"
},
"20221027_154417.json": {
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
"pred": "ngày/date 20 thân ghi onth 09 năm/year 2017"
},
"20221027_154802.json": {
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
"pred": "ngày date 14 tháng month 01 năm/year 2019"
},
"20221027_154749.json": {
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
"pred": "ngày date 22 tháng month 04 năm year 2022"
},
"20221027_154900.json": {
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
"pred": "ngày date 10 tháng month 09 năm Year 2020"
},
"4b2b35453e08f856a11925.json": {
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
"pred": "ngày/date 1 in 03 năm/year 2022"
},
"20221027_155734.json": {
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
"pred": "ngày date 29 tháng month 10 năm/year 2020"
},
"198876248_2931797967062357_4287721016641237281_n.json": {
"label": "ngày /date 29 áng /month 10 năm /year 2019",
"pred": "ngày Warm 204 đón Quanth 10 năm vear 2019"
},
"20221027_154613.json": {
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
"pred": "ngày/date 10than Imonth 03 năm/year 2022"
},
"20221027_154600.json": {
"label": "ngày/date 29 tháng /month 10 năm/year✪2013",
"pred": "ngày/date 29 tháng month 10 năm/year2013"
},
"191389634_910736173104716_4923402486196996972_n.json": {
"label": "ngày /date 24 tháng /month 02 năm/year 2021",
"pred": "ngày date 24 tháng month 02 năm/year 2021"
},
"20221027_155723.json": {
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
"pred": "ngày /date 27 tháng month 10 năm/year 2014"
},
"184606042_1586323798376373_2179113485447088932_n.json": {
"label": "ngày /date 29 tháng /month 07 năm/year 2020",
"pred": "ngày (date 29 tháng /month 07 năm✪year 2020"
},
"6fb460d06a9dacc3f58c31.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
"pred": "ngày date 28 tháng month 05 năm year 2019"
},
"7e810ce502a8c4f69db93.json": {
"label": "ngày /date 29 tháng /month 10 năm /year 2020",
"pred": "ngày de it 29 thing month 10 năm Year 2020"
},
"20221027_154400.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
"pred": "ngày date 28 tháng month 05 năm year 2019"
},
"139579296_107198731349013_7325819456715999953_n.json": {
"label": "ngày /date 10tháng month 05 năm /year 2018",
"pred": "ngày date 1 month 05 năm hear 2018"
},
"20221027_155748.json": {
"label": "ngày/date 29 tháng/month 10 năm/year✪2013",
"pred": "ngày/date 29 tháng/month 10 năm/year2013"
},
"164359233_2788848161366629_6843431895499380423_n.json": {
"label": "ngày /date 25 tháng /month 12 năm/year 2015",
"pred": "ngày dute 25 tháng month 12 năm/year 2015"
},
"962650ff5eb298ecc1a36.json": {
"label": "ngày /date 29 tháng/month 10 năm/year 2013",
"pred": "10 năm/year 2013\n ngày the 29thanginonth\n TL"
},
"20221027_155600.json": {
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
},
"951970367f7bb925e06a1.json": {
"label": "ngày /date 28 tháng /month 05 năm /year 2019 -",
"pred": "ngày late 28 hàng worth 05 năm hear 2019"
},
"20221027_154627.json": {
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
"pred": "ngày date 29 tháng month 10 năm/year. 2020"
},
"20221027_155535.json": {
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
"pred": "ngày date 01 tháng month 04 năm 'year 2022"
},
"106402928_1000018507095212_5438034148254460378_n.json": {
"label": "ngày date 04 tháng month 09 năm /year 2019",
"pred": "ngày dan 04 tháng month 09 năm Year 2019"
},
"20221027_154458.json": {
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
"pred": "ngày date 24 tháng month 03 năm/year 2016"
},
"190841312_3057702594458479_8551202571498845435_n.json": {
"label": "ngày /date 14 tháng month 12 năm /year 2018",
"pred": "ngày date 14 tháng month 12 năm year 2018"
},
"20221027_154907.json": {
"label": "ngày/date 30 tháng /month 07 năm/year✪2015",
"pred": "ngày/date 30 tháng month 07 năm/year2015"
},
"136098726_3413968628702123_4090292519699106839_n.json": {
"label": "ngày /date 20tháng /month 11 năm /year 2020",
"pred": "ngày date 70tháng month 11 năm year 2020"
},
"20221027_154440.json": {
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
"pred": "ngày date 09 tháng mc each 07 năm year 2019"
},
"07966c1b6256a408fd478.json": {
"label": "ngày/date 24 tháng /month 05 năm/year 2016",
"pred": "ngày/d te 24 thán day anth 05 năm/year 2016"
},
"20221027_155740.json": {
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
"pred": "ngày/date 10than almonth 03 năm/year 2022"
},
"130727988_1377512982599930_6481917606912865462_n.json": {
"label": "ngày/date 30 tháng /month 05 năm/year 2016",
"pred": "ngày/date 30 tháng month 05 năm/year 2016"
},
"194993073_539716147195165_8525378287933192246_n.json": {
"label": "ngày /date 12 tháng /month 06 năm /year 2020",
"pred": "ngày date 12 tháng /month 06 năm vear 2020"
},
"20221027_154521.json": {
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
"pred": "ngày/dec 19 tháng month 03 năm/year 2018"
},
"174247511_900878714088455_7516565117828455890_n.json": {
"label": "ngày/date 1 0✪tháng /month 05 năm/year 2018",
"pred": "ngày/date 4 Other month 05 năm/year 2018"
},
"20221027_155519.json": {
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
"pred": "ngày date 22 tháng month 04 năm year 2022"
},
"20221027_154706.json": {
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
"pred": "ngày date 27 tháng month 10 năm/year. 2014"
},
"a88ae66aed272b79723621.json": {
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
"pred": ""
},
"4378298-95583fbb1edb703a6c5bbc1744246058-1-1.json": {
"label": "ngày/d 24 tháng /month 1 năm/year 2015",
"pred": "ngày/d / 24 than chuonth 1 1 năm/year 2015"
},
"20221027_155704.json": {
"label": "ngày /date 22 tháng /month 04 năm/year 2022",
"pred": "ngày date 22 tháng month 04 năm hear 2022"
},
"179642128_1945301335636182_5557211235870766646_n.json": {
"label": "ngày/date #✪thá# # onth 10 năm/year 201",
"pred": "ngowdan (Ký - touth 10 năm/year 201"
},
"20221027_154734.json": {
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
"pred": "ngày date to tháng month 06 năm year 2020"
},
"20221027_155542.json": {
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
"pred": "ngày /date 05 tháng month 04 năm/year 2016"
}
}

196
cope2n-ai-fi/common/json2xml.py Executable file
View File

@ -0,0 +1,196 @@
import xml.etree.ElementTree as ET
from datetime import datetime
ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#")
xml_template3 = """
<HDon>
<DLHDon>
<TTChung>
<PBan>None</PBan>
<THDon>None</THDon>
<KHMSHDon>None</KHMSHDon>
<KHHDon>None</KHHDon>
<SHDon>None</SHDon>
<NLap>None</NLap>
<DVTTe>None</DVTTe>
<TGia>None</TGia>
<HTTToan>None</HTTToan>
<MSTTCGP>None</MSTTCGP>
</TTChung>
<NDHDon>
<NBan>
<Ten>None</Ten>
<MST>None</MST>
<DChi>None</DChi>
<SDThoai>None</SDThoai>
</NBan>
<NMua>
<Ten>None</Ten>
<MST>None</MST>
<DChi>None</DChi>
<SDThoai>None</SDThoai>
<HVTNMHang>None</HVTNMHang>
</NMua>
<DSHHDVu>
<HHDVu>
<TChat>None</TChat>
<STT>None</STT>
<THHDVu>None</THHDVu>
<DVTinh>None</DVTinh>
<SLuong>None</SLuong>
<DGia>None</DGia>
<TLCKhau>None</TLCKhau>
<STCKhau>None</STCKhau>
<ThTien>None</ThTien>
<TSuat>None</TSuat>
</HHDVu>
</DSHHDVu>
<TToan>
<THTTLTSuat>
<LTSuat>
<TSuat>None</TSuat>
<ThTien>None</ThTien>
<TThue>None</TThue>
</LTSuat>
</THTTLTSuat>
<TgTCThue>None</TgTCThue>
<TgTThue>None</TgTThue>
<TTCKTMai>None</TTCKTMai>
<TgTTTBSo>None</TgTTTBSo>
<TgTTTBChu>None</TgTTTBChu>
</TToan>
</NDHDon>
</DLHDon>
<MCCQT Id="None">None</MCCQT>
<DLQRCode>None</DLQRCode>
</HDon>
"""
def replace_xml_values(xml_str, replacement_dict):
""" replace xml values
Args:
xml_str (_type_): _description_
replacement_dict (_type_): _description_
Returns:
_type_: _description_
"""
try:
root = ET.fromstring(xml_str)
for key, value in replacement_dict.items():
if not value:
continue
if key == "TToan":
ttoan_element = root.find(".//TToan")
tsuat_element = ttoan_element.find(".//TgTThue")
tthue_element = ttoan_element.find(".//TgTTTBSo")
tthuebchu_element = ttoan_element.find(".//TgTTTBChu")
if value["TgTThue"]:
tsuat_element.text = value["TgTThue"]
if value["TgTTTBSo"]:
tthue_element.text = value["TgTTTBSo"]
if value["TgTTTBChu"]:
tthuebchu_element.text = value["TgTTTBChu"]
elif key == "NMua":
nmua_element = root.find(".//NMua")
for key_ in ["DChi", "SDThoai", "MST", "Ten", "HVTNMHang"]:
if value.get(key_, None):
nmua_element_key = nmua_element.find(f".//{key_}")
nmua_element_key.text = value[key_]
elif key == "NBan":
nban_element = root.find(".//NBan")
for key_ in ["DChi", "SDThoai", "MST", "Ten"]:
if value.get(key_, None):
nban_element_key = nban_element.find(f".//{key_}")
nban_element_key.text = value[key_]
elif key == "HHDVu":
dshhdvu_element = root.find(".//DSHHDVu")
hhdvu_template = root.find(".//HHDVu")
if hhdvu_template is not None and dshhdvu_element is not None:
dshhdvu_element.remove(hhdvu_template) # Remove the template
for hhdvu_data in value:
hhdvu_element = ET.SubElement(dshhdvu_element, "HHDVu")
for h_key, h_value in hhdvu_data.items():
h_element = ET.SubElement(hhdvu_element, h_key)
h_element.text = h_value if h_value is not None else "None"
elif key == "NLap":
nlap_element = root.find(".//NLap")
if nlap_element is not None:
# Convert the date to yyyy-mm-dd format
try:
date_obj = datetime.strptime(value, "%d/%m/%Y")
formatted_date = date_obj.strftime("%Y-%m-%d")
nlap_element.text = formatted_date
except ValueError:
print(f"Invalid date format for {key}: {value}")
nlap_element.text = value
else:
element = root.find(f".//{key}")
if element is not None:
element.text = value
ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#")
return ET.tostring(root, encoding="unicode")
except ET.ParseError as e:
print(f"Error parsing XML: {e}")
return None
def convert_key_names(original_dict):
"""_summary_
Args:
original_dict (_type_): _description_
Returns:
_type_: _description_
"""
key_mapping = {
"table": "HHDVu",
"Mặt hàng": "THHDVu",
"Đơn vị tính": "DVTinh",
"Số lượng": "SLuong",
"Đơn giá": "DGia",
"Doanh số mua chưa có thuế": "ThTien",
"buyer_address_value": "NMua.DChi",
'buyer_company_name_value': 'NMua.Ten',
'buyer_personal_name_value': 'NMua.HVTNMHang',
'buyer_tax_code_value': 'NMua.MST',
'buyer_tel_value': 'NMua.SDThoai',
'seller_address_value': 'NBan.DChi',
'seller_company_name_value': 'NBan.Ten',
'seller_tax_code_value': 'NBan.MST',
'seller_tel_value': 'NBan.SDThoai',
'date_value': 'NLap',
'form_value': 'KHMSHDon',
'no_value': 'SHDon',
'serial_value': 'KHHDon',
'tax_amount_value': 'TToan.TgTThue',
'total_in_words_value': 'TToan.TgTTTBChu',
'total_value': 'TToan.TgTTTBSo'
}
converted_dict = {}
for key, value in original_dict.items():
new_key = key_mapping.get(key, key)
if "." in new_key:
parts = new_key.split(".")
current_dict = converted_dict
for i, part in enumerate(parts):
if i == len(parts) - 1:
current_dict[part] = value
else:
current_dict.setdefault(part, {})
current_dict = current_dict[part]
else:
if key == "table":
lconverted_table_values = []
for table_value in value:
converted_table_value = convert_key_names(table_value)
lconverted_table_values.append(converted_table_value)
converted_dict[new_key] = lconverted_table_values
else:
converted_dict[new_key] = value
return converted_dict

37
cope2n-ai-fi/common/ocr.py Executable file
View File

@ -0,0 +1,37 @@
from common.utils.ocr_yolox import OcrEngineForYoloX_ID_Driving
from common.utils.word_formation import Word, words_to_lines
det_ckpt = "yolox-s-general-text-pretrain-20221226"
cls_ckpt = "satrn-lite-general-pretrain-20230106"
engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt)
def ocr_predict(image):
"""Predict text from image
Args:
image_path (str): _description_
Returns:
list: list of words
"""
try:
lbboxes, lwords = engine.run_image(image)
lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)]
list_lines, _ = words_to_lines(lWords)
return list_lines
except AssertionError as e:
print(e)
list_lines = []
return list_lines
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
args = parser.parse_args()
list_lines = ocr_predict(args.image)

Some files were not shown because too many files have changed in this diff Show More