Add everything
This commit is contained in:
parent
4e83776907
commit
7e9a8e2d4b
4
.gitignore
vendored
4
.gitignore
vendored
@ -11,4 +11,6 @@ backup/
|
||||
*.sqlite3
|
||||
*.log
|
||||
__pycache__
|
||||
migrations/
|
||||
migrations/
|
||||
test/
|
||||
._git/
|
||||
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "cope2n-ai-fi/modules/sdsvkvu"]
|
||||
path = cope2n-ai-fi/modules/sdsvkvu
|
||||
url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu
|
6
cope2n-ai-fi/._gitmodules
Normal file
6
cope2n-ai-fi/._gitmodules
Normal file
@ -0,0 +1,6 @@
|
||||
[submodule "modules/sdsvkvu"]
|
||||
path = modules/sdsvkvu
|
||||
url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu.git
|
||||
[submodule "modules/ocr_engine"]
|
||||
path = modules/ocr_engine
|
||||
url = https://code.sdsdev.co.kr/tuanlv/IDP-BasicOCR.git
|
7
cope2n-ai-fi/.dockerignore
Executable file
7
cope2n-ai-fi/.dockerignore
Executable file
@ -0,0 +1,7 @@
|
||||
.github
|
||||
.git
|
||||
.vscode
|
||||
__pycache__
|
||||
DataBase/image_temp/
|
||||
DataBase/json_temp/
|
||||
DataBase/template.db
|
21
cope2n-ai-fi/.gitignore
vendored
Executable file
21
cope2n-ai-fi/.gitignore
vendored
Executable file
@ -0,0 +1,21 @@
|
||||
.vscode
|
||||
__pycache__
|
||||
DataBase/image_temp/
|
||||
DataBase/json_temp/
|
||||
DataBase/template.db
|
||||
sdsvtd/
|
||||
sdsvtr/
|
||||
sdsvkie/
|
||||
detectron2/
|
||||
output/
|
||||
data/
|
||||
models/
|
||||
server/
|
||||
image_logs/
|
||||
experiments/
|
||||
weights/
|
||||
packages/
|
||||
tmp_results/
|
||||
.env
|
||||
.zip
|
||||
.json
|
43
cope2n-ai-fi/Dockerfile
Executable file
43
cope2n-ai-fi/Dockerfile
Executable file
@ -0,0 +1,43 @@
|
||||
# FROM thucpd2408/env-cope2n:v1
|
||||
FROM thucpd2408/env-deskew
|
||||
|
||||
COPY ./packages/cudnn-linux*.tar.xz /tmp/cudnn-linux*.tar.xz
|
||||
|
||||
RUN tar -xvf /tmp/cudnn-linux*.tar.xz -C /tmp/ \
|
||||
&& cp /tmp/cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include \
|
||||
&& cp -P /tmp/cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 \
|
||||
&& chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn* \
|
||||
&& rm -rf /tmp/cudnn-*-archive
|
||||
|
||||
RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6
|
||||
# RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
COPY ./modules/ocr_engine/externals/ /workspace/cope2n-ai-fi/modules/ocr_engine/externals/
|
||||
COPY ./modules/ocr_engine/requirements.txt /workspace/cope2n-ai-fi/modules/ocr_engine/requirements.txt
|
||||
COPY ./modules/sdsvkie/ /workspace/cope2n-ai-fi/modules/sdsvkie/
|
||||
COPY ./modules/sdsvkvu/ /workspace/cope2n-ai-fi/modules/sdsvkvu/
|
||||
COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt
|
||||
|
||||
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtd && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtr && pip3 install -v -e .
|
||||
|
||||
# RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/ && pip3 install -r requirements.txt
|
||||
RUN cd /workspace/cope2n-ai-fi/modules/sdsvkie && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi/modules/sdsvkvu && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt
|
||||
|
||||
RUN rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \
|
||||
rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \
|
||||
rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11 && \
|
||||
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublasLt.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \
|
||||
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \
|
||||
ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvblas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11
|
||||
|
||||
ENV PYTHONPATH="."
|
||||
|
||||
CMD [ "sh", "run.sh"]
|
||||
# CMD ["tail -f > /dev/null"]
|
21
cope2n-ai-fi/Dockerfile-dev
Executable file
21
cope2n-ai-fi/Dockerfile-dev
Executable file
@ -0,0 +1,21 @@
|
||||
FROM thucpd2408/env-cope2n:v1
|
||||
|
||||
RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt
|
||||
COPY ./sdsvkie/ /workspace/cope2n-ai-fi/sdsvkie/
|
||||
COPY ./sdsvtd /workspace/cope2n-ai-fi/sdsvtd/
|
||||
COPY ./sdsvtr/ /workspace/cope2n-ai-fi/sdsvtr/
|
||||
COPY ./models/ /models
|
||||
|
||||
RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt
|
||||
RUN cd /workspace/cope2n-ai-fi/sdsvkie && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi/sdsvtd && pip3 install -v -e .
|
||||
RUN cd /workspace/cope2n-ai-fi/sdsvtr && pip3 install -v -e .
|
||||
|
||||
ENV PYTHONPATH="."
|
||||
|
||||
CMD [ "sh", "run.sh"]
|
||||
# CMD ["tail -f > /dev/null"]
|
8
cope2n-ai-fi/Dockerfile_fwd
Normal file
8
cope2n-ai-fi/Dockerfile_fwd
Normal file
@ -0,0 +1,8 @@
|
||||
FROM hisiter/fwd_env:1.0.0
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl-dev \
|
||||
libxml2-dev \
|
||||
libxslt-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
|
||||
RUN pip install paddleocr>=2.0.1
|
674
cope2n-ai-fi/LICENSE
Executable file
674
cope2n-ai-fi/LICENSE
Executable file
@ -0,0 +1,674 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
7
cope2n-ai-fi/NOTE.md
Executable file
7
cope2n-ai-fi/NOTE.md
Executable file
@ -0,0 +1,7 @@
|
||||
#### Environment
|
||||
```
|
||||
docker run -itd --privileged --name=TannedCungnoCope2n-ai-fi \
|
||||
-v /mnt/hdd2T/dxtan/TannedCung/OCR/cope2n-ai-fi:/workspace \
|
||||
tannedcung/mmocr:latest \
|
||||
tail -f > /dev/null
|
||||
```
|
36
cope2n-ai-fi/README.md
Executable file
36
cope2n-ai-fi/README.md
Executable file
@ -0,0 +1,36 @@
|
||||
# AI-core
|
||||
|
||||
## Add your files
|
||||
|
||||
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
|
||||
- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
|
||||
|
||||
```bash
|
||||
cd existing_repo
|
||||
git remote add origin http://code.sdsrv.vn/c-ope2n/ai-core.git
|
||||
git branch -M main
|
||||
git push -uf origin main
|
||||
```
|
||||
|
||||
## Develop
|
||||
|
||||
Assume you are at root folder with struct:
|
||||
|
||||
```bash
|
||||
.
|
||||
├── cope2n-ai-fi
|
||||
├── cope2n-api
|
||||
├── cope2n-fe
|
||||
├── .env
|
||||
└── docker-compose-dev.yml
|
||||
```
|
||||
|
||||
Run: `docker-compose -f docker-compose-dev.yml up --build -d` to bring the project alive
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Variable | Default | Usage | Note
|
||||
------------- | ------- | ----- | ----
|
||||
CELERY_BROKER | $250 | | |
|
||||
SAP_KIE_MODEL | $80 | | |
|
||||
FI_KIE_MODEL | $420 | | |
|
26
cope2n-ai-fi/TODO.md
Normal file
26
cope2n-ai-fi/TODO.md
Normal file
@ -0,0 +1,26 @@
|
||||
## Bring abs path to relative
|
||||
|
||||
- [x] save_dir `Kie_Invoice_AP/prediction_sap.py:18`
|
||||
- [x] detector `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor)
|
||||
- [x] rotator_version `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor)
|
||||
- [x] cfg `Kie_Invoice_AP/prediction_fi.py`
|
||||
- [x] weight `Kie_Invoice_AP/prediction_fi.py`
|
||||
- [x] save_dir `Kie_Invoice_AP/prediction_fi.py:18`
|
||||
|
||||
## Bring abs path to .env
|
||||
|
||||
- [x] CELERY_BROKER:
|
||||
- [ ] SAP_KIE_MODEL: `Kie_Invoice_AP/prediction_sap.py:20` [_NEED_REFACTOR_](#refactor)
|
||||
- [ ] FI_KIE_MODEL: `Kie_Invoice_AP/prediction_fi.py:20` [_NEED_REFACTOR_](#refactor)
|
||||
|
||||
## Possible logic confict
|
||||
|
||||
### Refactor
|
||||
|
||||
- [ ] Each model should be loaded in a docker container and serve as a service
|
||||
- [ ] Some files (weights, ...) should be mounted in container in a format for endurability
|
||||
- [ ] `Kie_Invoice_AP/prediction_fi.py` and `Kie_Invoice_AP/prediction_fi.py` should be merged into a single file as it shared resources with different logic
|
||||
- [ ] `Kie_Invoice_AP/prediction.py` seems to be the base function, this should act as a proxy which import all other `predict_{anything else}` functions
|
||||
- [ ] There should be a unique folder to keep all models with different versions then mount as /models in container. Currently, `fi` is loading from `/models/Kie_invoice_fi` while `sap` is loading from `Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20231003-171748`. Another model weight is at `sdsvtd/hub` for unknown reason
|
||||
- [ ] Env variables should have its description in README
|
||||
- [ ]
|
149
cope2n-ai-fi/api/Kie_AHung/prediction.py
Executable file
149
cope2n-ai-fi/api/Kie_AHung/prediction.py
Executable file
@ -0,0 +1,149 @@
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
from transformers import (
|
||||
LayoutXLMTokenizer,
|
||||
LayoutLMv2FeatureExtractor,
|
||||
LayoutXLMProcessor,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
)
|
||||
|
||||
from common.utils.word_formation import *
|
||||
|
||||
from common.utils.global_variables import *
|
||||
from common.utils.process_label import *
|
||||
import ssl
|
||||
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
os.environ["CURL_CA_BUNDLE"] = ""
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
# config
|
||||
IGNORE_KIE_LABEL = "others"
|
||||
KIE_LABELS = [
|
||||
"Number",
|
||||
"Name",
|
||||
"Birthday",
|
||||
"Home Town",
|
||||
"Address",
|
||||
"Sex",
|
||||
"Nationality",
|
||||
"Expiry Date",
|
||||
"Nation",
|
||||
"Religion",
|
||||
"Date Range",
|
||||
"Issued By",
|
||||
IGNORE_KIE_LABEL,
|
||||
"Rank"
|
||||
]
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
|
||||
|
||||
# tokenizer = LayoutXLMTokenizer.from_pretrained(
|
||||
# "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
|
||||
# )
|
||||
|
||||
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
|
||||
|
||||
model = LayoutLMv2ForTokenClassification.from_pretrained(
|
||||
"Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True
|
||||
).to(
|
||||
DEVICE
|
||||
) # TODO FIX this hard code
|
||||
|
||||
|
||||
def load_ocr_labels(list_lines):
|
||||
words, boxes, labels = [], [], []
|
||||
for line in list_lines:
|
||||
for word_group in line.list_word_groups:
|
||||
for word in word_group.list_words:
|
||||
xmin, ymin, xmax, ymax = (
|
||||
word.boundingbox[0],
|
||||
word.boundingbox[1],
|
||||
word.boundingbox[2],
|
||||
word.boundingbox[3],
|
||||
)
|
||||
text = word.text
|
||||
label = "seller_name_value" # TODO ??? fix this
|
||||
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
|
||||
if text != " ":
|
||||
words.append(text)
|
||||
boxes.append([x1, y1, x2, y2])
|
||||
labels.append(label)
|
||||
return words, boxes, labels
|
||||
|
||||
|
||||
def _normalize_box(box, width, height):
|
||||
return [
|
||||
int(1000 * (box[0] / width)),
|
||||
int(1000 * (box[1] / height)),
|
||||
int(1000 * (box[2] / width)),
|
||||
int(1000 * (box[3] / height)),
|
||||
]
|
||||
|
||||
|
||||
def infer_driving_license(image_crop, list_lines, max_n_words, processor):
|
||||
# Load inputs
|
||||
# image = Image.open(image_path)
|
||||
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
|
||||
image = Image.fromarray(image)
|
||||
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
|
||||
batch_preds, batch_true_boxes = [], []
|
||||
list_words = []
|
||||
for i in range(0, len(batch_words), max_n_words):
|
||||
words = batch_words[i : i + max_n_words]
|
||||
boxes = batch_boxes[i : i + max_n_words]
|
||||
boxes_norm = [
|
||||
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
|
||||
]
|
||||
|
||||
# Preprocess
|
||||
dummy_word_labels = [0] * len(words)
|
||||
encoding = processor(
|
||||
image,
|
||||
text=words,
|
||||
boxes=boxes_norm,
|
||||
word_labels=dummy_word_labels,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
)
|
||||
|
||||
# Run model
|
||||
for k, v in encoding.items():
|
||||
encoding[k] = v.to(DEVICE)
|
||||
outputs = model(**encoding)
|
||||
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
||||
|
||||
# Postprocess
|
||||
is_subword = (
|
||||
(encoding["labels"] == -100).detach().cpu().numpy()[0]
|
||||
) # remove padding
|
||||
true_predictions = [
|
||||
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
|
||||
]
|
||||
true_boxes = (
|
||||
boxes # TODO check assumption that layourlm do not change box order
|
||||
)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
bndbox = [int(j) for j in true_boxes[i]]
|
||||
list_words.append(
|
||||
Word(
|
||||
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
|
||||
)
|
||||
)
|
||||
|
||||
batch_preds.extend(true_predictions)
|
||||
batch_true_boxes.extend(true_boxes)
|
||||
|
||||
batch_preds = np.array(batch_preds)
|
||||
batch_true_boxes = np.array(batch_true_boxes)
|
||||
return batch_words, batch_preds, batch_true_boxes, list_words
|
145
cope2n-ai-fi/api/Kie_AHung_ID/prediction.py
Executable file
145
cope2n-ai-fi/api/Kie_AHung_ID/prediction.py
Executable file
@ -0,0 +1,145 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from transformers import LayoutLMv2ForTokenClassification
|
||||
|
||||
from common.utils.word_formation import *
|
||||
|
||||
from common.utils.global_variables import *
|
||||
from common.utils.process_label import *
|
||||
import ssl
|
||||
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
os.environ["CURL_CA_BUNDLE"] = ""
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
# config
|
||||
IGNORE_KIE_LABEL = "others"
|
||||
KIE_LABELS = [
|
||||
"Number",
|
||||
"Name",
|
||||
"Birthday",
|
||||
"Home Town",
|
||||
"Address",
|
||||
"Sex",
|
||||
"Nationality",
|
||||
"Expiry Date",
|
||||
"Nation",
|
||||
"Religion",
|
||||
"Date Range",
|
||||
"Issued By",
|
||||
IGNORE_KIE_LABEL
|
||||
]
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
|
||||
|
||||
# tokenizer = LayoutXLMTokenizer.from_pretrained(
|
||||
# "Kie_AHung_ID/model/pretrained/layoutxlm-base/tokenizer",
|
||||
# model_max_length=MAX_SEQ_LENGTH,
|
||||
# )
|
||||
|
||||
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
|
||||
|
||||
model = LayoutLMv2ForTokenClassification.from_pretrained(
|
||||
"Kie_AHung_ID/model/ID_CARD_145_train_300_val_0.02_char_0.06_word",
|
||||
num_labels=len(KIE_LABELS),
|
||||
local_files_only=True,
|
||||
).to(
|
||||
DEVICE
|
||||
) # TODO FIX this hard code
|
||||
|
||||
|
||||
def load_ocr_labels(list_lines):
|
||||
words, boxes, labels = [], [], []
|
||||
for line in list_lines:
|
||||
for word_group in line.list_word_groups:
|
||||
for word in word_group.list_words:
|
||||
xmin, ymin, xmax, ymax = (
|
||||
word.boundingbox[0],
|
||||
word.boundingbox[1],
|
||||
word.boundingbox[2],
|
||||
word.boundingbox[3],
|
||||
)
|
||||
text = word.text
|
||||
label = "seller_name_value" # TODO ??? fix this
|
||||
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
|
||||
if text != " ":
|
||||
words.append(text)
|
||||
boxes.append([x1, y1, x2, y2])
|
||||
labels.append(label)
|
||||
return words, boxes, labels
|
||||
|
||||
|
||||
def _normalize_box(box, width, height):
|
||||
return [
|
||||
int(1000 * (box[0] / width)),
|
||||
int(1000 * (box[1] / height)),
|
||||
int(1000 * (box[2] / width)),
|
||||
int(1000 * (box[3] / height)),
|
||||
]
|
||||
|
||||
|
||||
def infer_id_card(image_crop, list_lines, max_n_words, processor):
|
||||
# Load inputs
|
||||
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
|
||||
image = Image.fromarray(image)
|
||||
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
|
||||
batch_preds, batch_true_boxes = [], []
|
||||
list_words = []
|
||||
for i in range(0, len(batch_words), max_n_words):
|
||||
words = batch_words[i : i + max_n_words]
|
||||
boxes = batch_boxes[i : i + max_n_words]
|
||||
boxes_norm = [
|
||||
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
|
||||
]
|
||||
|
||||
# Preprocess
|
||||
dummy_word_labels = [0] * len(words)
|
||||
encoding = processor(
|
||||
image,
|
||||
text=words,
|
||||
boxes=boxes_norm,
|
||||
word_labels=dummy_word_labels,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
)
|
||||
|
||||
# Run model
|
||||
for k, v in encoding.items():
|
||||
encoding[k] = v.to(DEVICE)
|
||||
outputs = model(**encoding)
|
||||
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
||||
|
||||
# Postprocess
|
||||
is_subword = (
|
||||
(encoding["labels"] == -100).detach().cpu().numpy()[0]
|
||||
) # remove padding
|
||||
true_predictions = [
|
||||
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
|
||||
]
|
||||
true_boxes = (
|
||||
boxes # TODO check assumption that layourlm do not change box order
|
||||
)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
bndbox = [int(j) for j in true_boxes[i]]
|
||||
list_words.append(
|
||||
Word(
|
||||
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
|
||||
)
|
||||
)
|
||||
|
||||
batch_preds.extend(true_predictions)
|
||||
batch_true_boxes.extend(true_boxes)
|
||||
|
||||
batch_preds = np.array(batch_preds)
|
||||
batch_true_boxes = np.array(batch_true_boxes)
|
||||
return batch_words, batch_preds, batch_true_boxes, list_words
|
57
cope2n-ai-fi/api/Kie_Hoanglv/prediction.py
Executable file
57
cope2n-ai-fi/api/Kie_Hoanglv/prediction.py
Executable file
@ -0,0 +1,57 @@
|
||||
from sdsvkie import Predictor
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib
|
||||
from common import serve_model
|
||||
from common import ocr
|
||||
|
||||
model = Predictor(
|
||||
cfg = "./models/kie_invoice/config.yaml",
|
||||
device = "cuda:0",
|
||||
weights = "./models/models/kie_invoice/last",
|
||||
proccessor = serve_model.processor,
|
||||
ocr_engine = ocr.engine
|
||||
)
|
||||
|
||||
def predict(page_numb, image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
image = cv2.imdecode(arr, -1)
|
||||
out = model(image)
|
||||
output = out["end2end_results"]
|
||||
output_dict = {
|
||||
"document_type": "invoice",
|
||||
"fields": []
|
||||
}
|
||||
for key in output.keys():
|
||||
field = {
|
||||
"label": key,
|
||||
"value": output[key]['value'] if output[key]['value'] else "",
|
||||
"box": output[key]['box'],
|
||||
"confidence": output[key]['conf'],
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
return output_dict
|
83
cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py
Executable file
83
cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py
Executable file
@ -0,0 +1,83 @@
|
||||
from common.utils_invoice.run_ocr import ocr_predict
|
||||
import os
|
||||
from Kie_Hoanglv.prediction2 import KIEInvoiceInfer
|
||||
from configs.config_invoice.layoutxlm_base_invoice import *
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
model = KIEInvoiceInfer(
|
||||
weight_dir=TRAINED_DIR,
|
||||
tokenizer_dir=TOKENIZER_DIR,
|
||||
max_seq_len=MAX_SEQ_LENGTH,
|
||||
classes=KIE_LABELS,
|
||||
device=DEVICE,
|
||||
outdir_visualize=VISUALIZE_DIR,
|
||||
)
|
||||
|
||||
def format_result(result):
|
||||
"""
|
||||
return:
|
||||
[
|
||||
{
|
||||
key: 'name',
|
||||
value: 'Nguyen Hoang Hiep',
|
||||
true_box: [
|
||||
373,
|
||||
113,
|
||||
700,
|
||||
420
|
||||
]
|
||||
},
|
||||
{
|
||||
key: 'name',
|
||||
value: 'Nguyen Hoang Hiep 1',
|
||||
true_box: [
|
||||
10,
|
||||
10,
|
||||
20,
|
||||
20,
|
||||
]
|
||||
},
|
||||
]
|
||||
"""
|
||||
new_result = []
|
||||
for i, item in enumerate(result[0]):
|
||||
new_result.append(
|
||||
{
|
||||
"key": item,
|
||||
"value": result[0][item],
|
||||
"true_box": result[1][i],
|
||||
}
|
||||
)
|
||||
return new_result
|
||||
|
||||
def predict(image_url):
|
||||
if not os.path.exists(PRED_DIR):
|
||||
os.makedirs(PRED_DIR, exist_ok=True)
|
||||
|
||||
if not os.path.exists(VISUALIZE_DIR):
|
||||
os.makedirs(VISUALIZE_DIR, exist_ok=True)
|
||||
|
||||
|
||||
response = requests.get(image_url)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
|
||||
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
bboxes, texts = ocr_predict(cv_image)
|
||||
|
||||
texts_replaced = []
|
||||
for text in texts:
|
||||
if "✪" in text:
|
||||
text_replaced = text.replace("✪", " ")
|
||||
texts_replaced.append(text_replaced)
|
||||
else:
|
||||
texts_replaced.append(text)
|
||||
inputs = model.prepare_kie_inputs(image, ocr_info=[bboxes, texts_replaced])
|
||||
result = model(inputs)
|
||||
result = format_result(result)
|
||||
return result
|
0
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/__init__.py
Executable file
0
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/__init__.py
Executable file
137
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/anyKeyValue.py
Executable file
137
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/anyKeyValue.py
Executable file
@ -0,0 +1,137 @@
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import urllib
|
||||
import numpy as np
|
||||
import imagesize
|
||||
# from omegaconf import OmegaConf
|
||||
import sys
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
sys.path.append(cur_dir)
|
||||
# sys.path.append('/cope2n-ai-fi/Kie_Invoice_AP/AnyKey_Value/')
|
||||
from predictor import KVUPredictor
|
||||
from preprocess import KVUProcess, DocumentKVUProcess
|
||||
from utils.utils import create_dir, visualize, get_colormap, export_kvu_outputs, export_kvu_for_manulife
|
||||
|
||||
|
||||
def get_args():
|
||||
args = argparse.ArgumentParser(description='Main file')
|
||||
args.add_argument('--img_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
||||
help='Input image directory')
|
||||
args.add_argument('--save_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
||||
help='Save directory')
|
||||
# args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
|
||||
# help='Checkpoint and config of model')
|
||||
args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
|
||||
help='Checkpoint and config of model')
|
||||
args.add_argument('--export_img', default=0, type=int,
|
||||
help='Save visualize on image')
|
||||
args.add_argument('--mode', default=3, type=int,
|
||||
help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
|
||||
args.add_argument('--dir_level', default=0, type=int,
|
||||
help='Number of subfolders contains image')
|
||||
|
||||
return args.parse_args()
|
||||
|
||||
|
||||
def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
|
||||
configs = {
|
||||
'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
|
||||
'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
|
||||
}
|
||||
dummy_idx = 512
|
||||
predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
|
||||
|
||||
# processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
|
||||
# predictor.net.feature_extractor, predictor.backbone_type, class_names,
|
||||
# predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode)
|
||||
|
||||
processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor,
|
||||
predictor.backbone_type, class_names,
|
||||
predictor.max_window_count, predictor.slice_interval, predictor.window_size,
|
||||
run_ocr=1, mode=mode)
|
||||
return predictor, processor
|
||||
|
||||
def revert_box(box, width, height):
|
||||
return [
|
||||
int((box[0] / 1000) * width),
|
||||
int((box[1] / 1000) * height),
|
||||
int((box[2] / 1000) * width),
|
||||
int((box[3] / 1000) * height)
|
||||
]
|
||||
|
||||
def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
||||
fname = os.path.basename(img_path)
|
||||
img_ext = img_path.split('.')[-1]
|
||||
inputs = processor(img_path, ocr_path='')
|
||||
width, height = imagesize.get(img_path)
|
||||
|
||||
bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
|
||||
# slide_window = False if len(bbox) == 1 else True
|
||||
|
||||
if len(bbox) == 0:
|
||||
bbox, lwords, pr_class_words, pr_relations = [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
for i in range(len(bbox)):
|
||||
bbox[i] = [revert_box(bb, width, height) for bb in bbox[i]]
|
||||
# vat_outputs_invoice = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
# vat_outputs_receipt = export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
# vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
|
||||
print(vat_outputs_invoice)
|
||||
return vat_outputs_invoice
|
||||
|
||||
|
||||
def load_groundtruth(img_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor: KVUProcess, export_img: int) -> None:
|
||||
fname = os.path.basename(img_path)
|
||||
img_ext = img_path.split('.')[-1]
|
||||
inputs = processor.load_ground_truth(os.path.join(json_dir, fname.replace(f".{img_ext}", ".json")))
|
||||
bbox, lwords, pr_class_words, pr_relations = predictor.get_ground_truth_label(inputs)
|
||||
|
||||
export_kvu_outputs(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords, pr_class_words, pr_relations, predictor.class_names)
|
||||
|
||||
if export_img == 1:
|
||||
save_path = os.path.join(save_dir, 'kvu_results')
|
||||
create_dir(save_path)
|
||||
color_map = get_colormap()
|
||||
image = cv2.imread(img_path)
|
||||
image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness=1)
|
||||
cv2.imwrite(os.path.join(save_path, fname), image)
|
||||
|
||||
def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor, export_img: int) -> None:
|
||||
list_images = []
|
||||
for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']:
|
||||
list_images += glob.glob(os.path.join(dir_path, f'*.{ext}'))
|
||||
print('No. images:', len(list_images))
|
||||
for img_path in tqdm(list_images):
|
||||
load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img)
|
||||
|
||||
def Predictor_KVU(image_url: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
||||
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
image_path = "./Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
||||
cv2.imwrite(image_path, img)
|
||||
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
||||
return vat_outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
class_names = ['others', 'title', 'key', 'value', 'header']
|
||||
predict_mode = {
|
||||
'normal': 0,
|
||||
'full_tokens': 1,
|
||||
'sliding': 2,
|
||||
'document': 3
|
||||
}
|
||||
predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
|
||||
create_dir(args.save_dir)
|
||||
image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test"
|
||||
predict_image(image_path, save_dir, predictor, processor)
|
||||
print('[INFO] Done')
|
133
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier.py
Executable file
133
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier.py
Executable file
@ -0,0 +1,133 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from overrides import overrides
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.optim import SGD, Adam, AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lightning_modules.schedulers import (
|
||||
cosine_scheduler,
|
||||
linear_scheduler,
|
||||
multistep_scheduler,
|
||||
)
|
||||
from model import get_model
|
||||
from utils import cfg_to_hparams, get_specific_pl_logger
|
||||
|
||||
|
||||
class ClassifierModule(LightningModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.net = get_model(self.cfg)
|
||||
self.ignore_index = -100
|
||||
|
||||
self.time_tracker = None
|
||||
|
||||
self.optimizer_types = {
|
||||
"sgd": SGD,
|
||||
"adam": Adam,
|
||||
"adamw": AdamW,
|
||||
}
|
||||
|
||||
@overrides
|
||||
def setup(self, stage):
|
||||
self.time_tracker = time.time()
|
||||
|
||||
@overrides
|
||||
def configure_optimizers(self):
|
||||
optimizer = self._get_optimizer()
|
||||
scheduler = self._get_lr_scheduler(optimizer)
|
||||
scheduler = {
|
||||
"scheduler": scheduler,
|
||||
"name": "learning_rate",
|
||||
"interval": "step",
|
||||
}
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def _get_lr_scheduler(self, optimizer):
|
||||
cfg_train = self.cfg.train
|
||||
lr_schedule_method = cfg_train.optimizer.lr_schedule.method
|
||||
lr_schedule_params = cfg_train.optimizer.lr_schedule.params
|
||||
|
||||
if lr_schedule_method is None:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
|
||||
elif lr_schedule_method == "step":
|
||||
scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
|
||||
elif lr_schedule_method == "cosine":
|
||||
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
|
||||
total_batch_size = cfg_train.batch_size * self.trainer.world_size
|
||||
max_iter = total_samples / total_batch_size
|
||||
scheduler = cosine_scheduler(
|
||||
optimizer, training_steps=max_iter, **lr_schedule_params
|
||||
)
|
||||
elif lr_schedule_method == "linear":
|
||||
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
|
||||
total_batch_size = cfg_train.batch_size * self.trainer.world_size
|
||||
max_iter = total_samples / total_batch_size
|
||||
scheduler = linear_scheduler(
|
||||
optimizer, training_steps=max_iter, **lr_schedule_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")
|
||||
|
||||
return scheduler
|
||||
|
||||
def _get_optimizer(self):
|
||||
opt_cfg = self.cfg.train.optimizer
|
||||
method = opt_cfg.method.lower()
|
||||
|
||||
if method not in self.optimizer_types:
|
||||
raise ValueError(f"Unknown optimizer method={method}")
|
||||
|
||||
kwargs = dict(opt_cfg.params)
|
||||
kwargs["params"] = self.net.parameters()
|
||||
optimizer = self.optimizer_types[method](**kwargs)
|
||||
|
||||
return optimizer
|
||||
|
||||
@rank_zero_only
|
||||
@overrides
|
||||
def on_fit_end(self):
|
||||
hparam_dict = cfg_to_hparams(self.cfg, {})
|
||||
metric_dict = {"metric/dummy": 0}
|
||||
|
||||
tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger)
|
||||
|
||||
if tb_logger:
|
||||
tb_logger.log_hyperparams(hparam_dict, metric_dict)
|
||||
|
||||
@overrides
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
avg_loss = torch.tensor(0.0).to(self.device)
|
||||
for step_out in training_step_outputs:
|
||||
avg_loss += step_out["loss"]
|
||||
|
||||
log_dict = {"train_loss": avg_loss}
|
||||
self._log_shell(log_dict, prefix="train ")
|
||||
|
||||
def _log_shell(self, log_info, prefix=""):
|
||||
log_info_shell = {}
|
||||
for k, v in log_info.items():
|
||||
new_v = v
|
||||
if type(new_v) is torch.Tensor:
|
||||
new_v = new_v.item()
|
||||
log_info_shell[k] = new_v
|
||||
|
||||
out_str = prefix.upper()
|
||||
if prefix.upper().strip() in ["TRAIN", "VAL"]:
|
||||
out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]"
|
||||
|
||||
if self.training:
|
||||
lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"]
|
||||
log_info_shell["lr"] = lr
|
||||
|
||||
for key, value in log_info_shell.items():
|
||||
out_str += f" || {key}: {round(value, 5)}"
|
||||
out_str += f" || time: {round(time.time() - self.time_tracker, 1)}"
|
||||
out_str += " secs."
|
||||
self.print(out_str)
|
||||
self.time_tracker = time.time()
|
@ -0,0 +1,390 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from overrides import overrides
|
||||
|
||||
from lightning_modules.classifier import ClassifierModule
|
||||
from utils import get_class_names
|
||||
|
||||
|
||||
class KVUClassifierModule(ClassifierModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
class_names = get_class_names(self.cfg.dataset_root_path)
|
||||
|
||||
self.window_size = cfg.train.max_num_words
|
||||
self.slice_interval = cfg.train.slice_interval
|
||||
self.eval_kwargs = {
|
||||
"class_names": class_names,
|
||||
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
|
||||
}
|
||||
self.stage = cfg.stage
|
||||
|
||||
@overrides
|
||||
def training_step(self, batch, batch_idx, *args):
|
||||
if self.stage == 1:
|
||||
_, loss = self.net(batch['windows'])
|
||||
elif self.stage == 2:
|
||||
_, loss = self.net(batch)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported stage: {self.stage}"
|
||||
)
|
||||
|
||||
log_dict_input = {"train_loss": loss}
|
||||
self.log_dict(log_dict_input, sync_dist=True)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@overrides
|
||||
def validation_step(self, batch, batch_idx, *args):
|
||||
if self.stage == 1:
|
||||
step_out_total = {
|
||||
"loss": 0,
|
||||
"ee":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
},
|
||||
"el":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
},
|
||||
"el_from_key":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
}}
|
||||
for window in batch['windows']:
|
||||
head_outputs, loss = self.net(window)
|
||||
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
|
||||
for key in step_out_total:
|
||||
if key == 'loss':
|
||||
step_out_total[key] += step_out[key]
|
||||
else:
|
||||
for subkey in step_out_total[key]:
|
||||
step_out_total[key][subkey] += step_out[key][subkey]
|
||||
return step_out_total
|
||||
|
||||
elif self.stage == 2:
|
||||
head_outputs, loss = self.net(batch)
|
||||
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
|
||||
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
|
||||
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
|
||||
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
|
||||
return step_out
|
||||
|
||||
@torch.no_grad()
|
||||
@overrides
|
||||
def validation_epoch_end(self, validation_step_outputs):
|
||||
scores = do_eval_epoch_end(validation_step_outputs)
|
||||
self.print(
|
||||
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
|
||||
)
|
||||
self.print(
|
||||
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
|
||||
)
|
||||
self.print(
|
||||
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
|
||||
)
|
||||
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
|
||||
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
|
||||
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
|
||||
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
|
||||
'val_f1_el_from_key': scores['el_from_key']['f1'],}
|
||||
return {'log': tensorboard_logs}
|
||||
|
||||
|
||||
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
|
||||
class_names = eval_kwargs["class_names"]
|
||||
dummy_idx = eval_kwargs["dummy_idx"]
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_labels = torch.argmax(itc_outputs, -1)
|
||||
pr_stc_labels = torch.argmax(stc_outputs, -1)
|
||||
pr_el_labels = torch.argmax(el_outputs, -1)
|
||||
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
|
||||
|
||||
(
|
||||
n_batch_gt_classes,
|
||||
n_batch_pr_classes,
|
||||
n_batch_correct_classes,
|
||||
) = eval_ee_spade_batch(
|
||||
pr_itc_labels,
|
||||
batch["itc_labels"],
|
||||
batch["are_box_first_tokens"],
|
||||
pr_stc_labels,
|
||||
batch["stc_labels"],
|
||||
batch["attention_mask_layoutxlm"],
|
||||
class_names,
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
|
||||
pr_el_labels,
|
||||
batch["el_labels"],
|
||||
batch["are_box_first_tokens"],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
|
||||
pr_el_labels_from_key,
|
||||
batch["el_labels_from_key"],
|
||||
batch["are_box_first_tokens"],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
step_out = {
|
||||
"loss": loss,
|
||||
"ee":{
|
||||
"n_batch_gt": n_batch_gt_classes,
|
||||
"n_batch_pr": n_batch_pr_classes,
|
||||
"n_batch_correct": n_batch_correct_classes,
|
||||
},
|
||||
"el":{
|
||||
"n_batch_gt": n_batch_gt_rel,
|
||||
"n_batch_pr": n_batch_pr_rel,
|
||||
"n_batch_correct": n_batch_correct_rel,
|
||||
},
|
||||
"el_from_key":{
|
||||
"n_batch_gt": n_batch_gt_rel_from_key,
|
||||
"n_batch_pr": n_batch_pr_rel_from_key,
|
||||
"n_batch_correct": n_batch_correct_rel_from_key,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return step_out
|
||||
|
||||
|
||||
def eval_ee_spade_batch(
|
||||
pr_itc_labels,
|
||||
gt_itc_labels,
|
||||
are_box_first_tokens,
|
||||
pr_stc_labels,
|
||||
gt_stc_labels,
|
||||
attention_mask,
|
||||
class_names,
|
||||
dummy_idx,
|
||||
):
|
||||
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
|
||||
|
||||
bsz = pr_itc_labels.shape[0]
|
||||
for example_idx in range(bsz):
|
||||
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
|
||||
pr_itc_labels[example_idx],
|
||||
gt_itc_labels[example_idx],
|
||||
are_box_first_tokens[example_idx],
|
||||
pr_stc_labels[example_idx],
|
||||
gt_stc_labels[example_idx],
|
||||
attention_mask[example_idx],
|
||||
class_names,
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_classes += n_gt_classes
|
||||
n_batch_pr_classes += n_pr_classes
|
||||
n_batch_correct_classes += n_correct_classes
|
||||
|
||||
return (
|
||||
n_batch_gt_classes,
|
||||
n_batch_pr_classes,
|
||||
n_batch_correct_classes,
|
||||
)
|
||||
|
||||
|
||||
def eval_ee_spade_example(
|
||||
pr_itc_label,
|
||||
gt_itc_label,
|
||||
box_first_token_mask,
|
||||
pr_stc_label,
|
||||
gt_stc_label,
|
||||
attention_mask,
|
||||
class_names,
|
||||
dummy_idx,
|
||||
):
|
||||
gt_first_words = parse_initial_words(
|
||||
gt_itc_label, box_first_token_mask, class_names
|
||||
)
|
||||
gt_class_words = parse_subsequent_words(
|
||||
gt_stc_label, attention_mask, gt_first_words, dummy_idx
|
||||
)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
||||
)
|
||||
|
||||
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
|
||||
for class_idx in range(len(class_names)):
|
||||
# Evaluate by ID
|
||||
gt_parse = set(gt_class_words[class_idx])
|
||||
pr_parse = set(pr_class_words[class_idx])
|
||||
|
||||
n_gt_classes += len(gt_parse)
|
||||
n_pr_classes += len(pr_parse)
|
||||
n_correct_classes += len(gt_parse & pr_parse)
|
||||
|
||||
return n_gt_classes, n_pr_classes, n_correct_classes
|
||||
|
||||
|
||||
def parse_initial_words(itc_label, box_first_token_mask, class_names):
|
||||
itc_label_np = itc_label.cpu().numpy()
|
||||
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
|
||||
|
||||
outputs = [[] for _ in range(len(class_names))]
|
||||
|
||||
for token_idx, label in enumerate(itc_label_np):
|
||||
if box_first_token_mask_np[token_idx] and label != 0:
|
||||
outputs[label].append(token_idx)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
|
||||
max_connections = 50
|
||||
|
||||
valid_stc_label = stc_label * attention_mask.bool()
|
||||
valid_stc_label = valid_stc_label.cpu().numpy()
|
||||
stc_label_np = stc_label.cpu().numpy()
|
||||
|
||||
valid_token_indices = np.where(
|
||||
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
|
||||
)
|
||||
|
||||
next_token_idx_dict = {}
|
||||
for token_idx in valid_token_indices[0]:
|
||||
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
|
||||
|
||||
outputs = []
|
||||
for init_token_indices in init_words:
|
||||
sub_outputs = []
|
||||
for init_token_idx in init_token_indices:
|
||||
cur_token_indices = [init_token_idx]
|
||||
for _ in range(max_connections):
|
||||
if cur_token_indices[-1] in next_token_idx_dict:
|
||||
if (
|
||||
next_token_idx_dict[cur_token_indices[-1]]
|
||||
not in init_token_indices
|
||||
):
|
||||
cur_token_indices.append(
|
||||
next_token_idx_dict[cur_token_indices[-1]]
|
||||
)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
sub_outputs.append(tuple(cur_token_indices))
|
||||
|
||||
outputs.append(sub_outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def eval_el_spade_batch(
|
||||
pr_el_labels,
|
||||
gt_el_labels,
|
||||
are_box_first_tokens,
|
||||
dummy_idx,
|
||||
):
|
||||
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
|
||||
|
||||
bsz = pr_el_labels.shape[0]
|
||||
for example_idx in range(bsz):
|
||||
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
|
||||
pr_el_labels[example_idx],
|
||||
gt_el_labels[example_idx],
|
||||
are_box_first_tokens[example_idx],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel += n_gt_rel
|
||||
n_batch_pr_rel += n_pr_rel
|
||||
n_batch_correct_rel += n_correct_rel
|
||||
|
||||
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
|
||||
|
||||
|
||||
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
|
||||
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
|
||||
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
||||
|
||||
gt_relations = set(gt_relations)
|
||||
pr_relations = set(pr_relations)
|
||||
|
||||
n_gt_rel = len(gt_relations)
|
||||
n_pr_rel = len(pr_relations)
|
||||
n_correct_rel = len(gt_relations & pr_relations)
|
||||
|
||||
return n_gt_rel, n_pr_rel, n_correct_rel
|
||||
|
||||
|
||||
def parse_relations(el_label, box_first_token_mask, dummy_idx):
|
||||
valid_el_labels = el_label * box_first_token_mask
|
||||
valid_el_labels = valid_el_labels.cpu().numpy()
|
||||
el_label_np = el_label.cpu().numpy()
|
||||
|
||||
max_token = box_first_token_mask.shape[0] - 1
|
||||
|
||||
valid_token_indices = np.where(
|
||||
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
|
||||
)
|
||||
|
||||
link_map_tuples = []
|
||||
for token_idx in valid_token_indices[0]:
|
||||
link_map_tuples.append((el_label_np[token_idx], token_idx))
|
||||
|
||||
return set(link_map_tuples)
|
||||
|
||||
def do_eval_epoch_end(step_outputs):
|
||||
scores = {}
|
||||
for task in ['ee', 'el', 'el_from_key']:
|
||||
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
|
||||
|
||||
for step_out in step_outputs:
|
||||
n_total_gt_classes += step_out[task]["n_batch_gt"]
|
||||
n_total_pr_classes += step_out[task]["n_batch_pr"]
|
||||
n_total_correct_classes += step_out[task]["n_batch_correct"]
|
||||
|
||||
precision = (
|
||||
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
|
||||
)
|
||||
recall = (
|
||||
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
|
||||
)
|
||||
f1 = (
|
||||
0.0
|
||||
if recall * precision == 0
|
||||
else 2.0 * recall * precision / (recall + precision)
|
||||
)
|
||||
|
||||
scores[task] = {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
}
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
|
||||
class_names = get_class_names(dataset_root_path)
|
||||
dummy_idx = max_seq_length
|
||||
|
||||
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
|
||||
|
||||
return eval_kwargs
|
||||
|
||||
|
||||
def get_eval_kwargs_spade_rel(max_seq_length):
|
||||
dummy_idx = max_seq_length
|
||||
|
||||
eval_kwargs = {"dummy_idx": dummy_idx}
|
||||
|
||||
return eval_kwargs
|
@ -0,0 +1,135 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from overrides import overrides
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset
|
||||
from lightning_modules.utils import _get_number_samples
|
||||
|
||||
class KVUDataModule(pl.LightningDataModule):
|
||||
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.train_loader = None
|
||||
self.val_loader = None
|
||||
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
||||
self.feature_extractor = feature_extractor
|
||||
self.collate_fn = None
|
||||
|
||||
self.backbone_type = self.cfg.model.backbone
|
||||
self.num_samples_per_epoch = _get_number_samples(cfg.dataset_root_path)
|
||||
|
||||
|
||||
@overrides
|
||||
def setup(self, stage=None):
|
||||
self.train_loader = self._get_train_loader()
|
||||
self.val_loader = self._get_val_loaders()
|
||||
|
||||
@overrides
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
@overrides
|
||||
def val_dataloader(self):
|
||||
return self.val_loader
|
||||
|
||||
def _get_train_loader(self):
|
||||
start_time = time.time()
|
||||
|
||||
if self.cfg.stage == 1:
|
||||
dataset = KVUDataset(
|
||||
self.cfg,
|
||||
self.tokenizer_layoutxlm,
|
||||
self.feature_extractor,
|
||||
mode="train",
|
||||
)
|
||||
elif self.cfg.stage == 2:
|
||||
# dataset = KVUEmbeddingDataset(
|
||||
# self.cfg,
|
||||
# mode="train",
|
||||
# )
|
||||
dataset = KVUDataset(
|
||||
self.cfg,
|
||||
self.tokenizer_layoutxlm,
|
||||
self.feature_extractor,
|
||||
mode="train",
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported stage: {self.cfg.stage}"
|
||||
)
|
||||
|
||||
print('No. training samples:', len(dataset))
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.cfg.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.cfg.train.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"Elapsed time for loading training data: {elapsed_time}")
|
||||
|
||||
return data_loader
|
||||
|
||||
def _get_val_loaders(self):
|
||||
|
||||
if self.cfg.stage == 1:
|
||||
dataset = KVUDataset(
|
||||
self.cfg,
|
||||
self.tokenizer_layoutxlm,
|
||||
self.feature_extractor,
|
||||
mode="val",
|
||||
)
|
||||
elif self.cfg.stage == 2:
|
||||
# dataset = KVUEmbeddingDataset(
|
||||
# self.cfg,
|
||||
# mode="val",
|
||||
# )
|
||||
dataset = KVUDataset(
|
||||
self.cfg,
|
||||
self.tokenizer_layoutxlm,
|
||||
self.feature_extractor,
|
||||
mode="val",
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported stage: {self.cfg.stage}"
|
||||
)
|
||||
|
||||
print('No. validation samples:', len(dataset))
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.cfg.val.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.cfg.val.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
return data_loader
|
||||
|
||||
@overrides
|
||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
if isinstance(batch, list):
|
||||
for sub_batch in batch:
|
||||
for k in sub_batch.keys():
|
||||
if isinstance(sub_batch[k], torch.Tensor):
|
||||
sub_batch[k] = sub_batch[k].to(device)
|
||||
else:
|
||||
for k in batch.keys():
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].to(device)
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,728 @@
|
||||
import os
|
||||
import json
|
||||
import omegaconf
|
||||
import itertools
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from utils import get_class_names
|
||||
from lightning_modules.utils import sliding_windows_by_words
|
||||
|
||||
class KVUDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer_layoutxlm,
|
||||
feature_extractor,
|
||||
mode=None,
|
||||
):
|
||||
super(KVUDataset, self).__init__()
|
||||
|
||||
self.dataset_root_path = cfg.dataset_root_path
|
||||
if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig):
|
||||
self.dataset_root_path = [self.dataset_root_path]
|
||||
|
||||
self.backbone_type = cfg.model.backbone
|
||||
self.max_seq_length = cfg.train.max_seq_length
|
||||
self.window_size = cfg.train.max_num_words
|
||||
self.slice_interval = cfg.train.slice_interval
|
||||
|
||||
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
||||
self.feature_extractor = feature_extractor
|
||||
|
||||
self.stage = cfg.stage
|
||||
self.mode = mode
|
||||
|
||||
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._pad_token)
|
||||
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._cls_token)
|
||||
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._sep_token)
|
||||
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
|
||||
|
||||
self.examples = self._load_examples()
|
||||
|
||||
self.class_names = get_class_names(self.dataset_root_path)
|
||||
self.class_idx_dic = dict(
|
||||
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
|
||||
)
|
||||
|
||||
def _load_examples(self):
|
||||
examples = []
|
||||
for dataset_dir in self.dataset_root_path:
|
||||
with open(
|
||||
os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"),
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fp:
|
||||
for line in fp.readlines():
|
||||
preprocessed_file = os.path.join(dataset_dir, line.strip())
|
||||
examples.append(
|
||||
json.load(open(preprocessed_file, "r", encoding="utf-8"))
|
||||
)
|
||||
|
||||
return examples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
json_obj = self.examples[index]
|
||||
|
||||
width = json_obj["meta"]["imageSize"]["width"]
|
||||
height = json_obj["meta"]["imageSize"]["height"]
|
||||
img_path = json_obj["meta"]["image_path"]
|
||||
|
||||
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
|
||||
word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words(
|
||||
json_obj["words"],
|
||||
json_obj['parse']['class'],
|
||||
json_obj['parse']['relations'],
|
||||
self.window_size, self.slice_interval)
|
||||
outputs = {}
|
||||
if self.stage == 1:
|
||||
if self.mode == 'train':
|
||||
i = np.random.randint(0, len(word_windows), 1)[0]
|
||||
outputs['windows'] = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
else:
|
||||
outputs['windows'] = []
|
||||
for i in range(len(word_windows)):
|
||||
single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
outputs['windows'].append(single_window)
|
||||
|
||||
elif self.stage == 2:
|
||||
outputs['documents'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'],
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=2048)
|
||||
|
||||
windows = []
|
||||
for i in range(len(word_windows)):
|
||||
_words = word_windows[i]
|
||||
_parse_class = parse_class_windows[i]
|
||||
_parse_relation = parse_relation_windows[i]
|
||||
windows.append(
|
||||
self.preprocess(_words, _parse_class, _parse_relation,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
self.max_seq_length)
|
||||
)
|
||||
outputs['windows'] = windows
|
||||
return outputs
|
||||
|
||||
def preprocess(self, words, parse_class, parse_relation, feature_maps, max_seq_length):
|
||||
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
||||
|
||||
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
|
||||
|
||||
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
||||
|
||||
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
||||
|
||||
itc_labels = np.zeros(max_seq_length, dtype=int)
|
||||
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
|
||||
el_labels = np.ones((max_seq_length,), dtype=int) * max_seq_length
|
||||
el_labels_from_key = np.ones((max_seq_length,), dtype=int) * max_seq_length
|
||||
list_layoutxlm_tokens = []
|
||||
|
||||
list_bbs = []
|
||||
box2token_span_map = []
|
||||
|
||||
|
||||
box_to_token_indices = []
|
||||
cum_token_idx = 0
|
||||
|
||||
cls_bbs = [0.0] * 8
|
||||
len_overlap_tokens = 0
|
||||
len_non_overlap_tokens = 0
|
||||
len_valid_tokens = 0
|
||||
|
||||
for word_idx, word in enumerate(words):
|
||||
this_box_token_indices = []
|
||||
|
||||
layoutxlm_tokens = word["layoutxlm_tokens"]
|
||||
bb = word["boundingBox"]
|
||||
len_valid_tokens += len(layoutxlm_tokens)
|
||||
if word_idx < self.slice_interval:
|
||||
len_non_overlap_tokens += len(layoutxlm_tokens)
|
||||
# print(word_idx, layoutxlm_tokens, non_overlap_tokens)
|
||||
|
||||
if len(layoutxlm_tokens) == 0:
|
||||
layoutxlm_tokens.append(self.unk_token_id_layoutxlm)
|
||||
|
||||
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
|
||||
break
|
||||
|
||||
box2token_span_map.append(
|
||||
[len(list_layoutxlm_tokens) + 1, len(list_layoutxlm_tokens) + len(layoutxlm_tokens) + 1]
|
||||
) # including st_idx
|
||||
list_layoutxlm_tokens += layoutxlm_tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(layoutxlm_tokens))]
|
||||
|
||||
for _ in layoutxlm_tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_layoutxlm_tokens = (
|
||||
[self.cls_token_id_layoutxlm]
|
||||
+ list_layoutxlm_tokens[: max_seq_length - 2]
|
||||
+ [self.sep_token_id_layoutxlm]
|
||||
)
|
||||
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
|
||||
|
||||
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
|
||||
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
|
||||
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
|
||||
|
||||
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
|
||||
|
||||
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
else:
|
||||
assert False
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
# Label for entity extraction
|
||||
classes_dic = parse_class
|
||||
for class_name in self.class_names:
|
||||
if class_name == "others":
|
||||
continue
|
||||
if class_name not in classes_dic:
|
||||
continue
|
||||
|
||||
for word_list in classes_dic[class_name]:
|
||||
is_first, last_word_idx = True, -1
|
||||
for word_idx in word_list:
|
||||
if word_idx >= len(box_to_token_indices):
|
||||
break
|
||||
box2token_list = box_to_token_indices[word_idx]
|
||||
for converted_word_idx in box2token_list:
|
||||
if converted_word_idx >= max_seq_length:
|
||||
break # out of idx
|
||||
|
||||
if is_first:
|
||||
itc_labels[converted_word_idx] = self.class_idx_dic[
|
||||
class_name
|
||||
]
|
||||
is_first, last_word_idx = False, converted_word_idx
|
||||
else:
|
||||
stc_labels[converted_word_idx] = last_word_idx
|
||||
last_word_idx = converted_word_idx
|
||||
|
||||
|
||||
# Label for entity linking
|
||||
relations = parse_relation
|
||||
for relation in relations:
|
||||
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
|
||||
box2token_span_map
|
||||
):
|
||||
continue
|
||||
if (
|
||||
box2token_span_map[relation[0]][0] >= max_seq_length
|
||||
or box2token_span_map[relation[1]][0] >= max_seq_length
|
||||
):
|
||||
continue
|
||||
|
||||
word_from = box2token_span_map[relation[0]][0]
|
||||
word_to = box2token_span_map[relation[1]][0]
|
||||
# el_labels[word_to] = word_from
|
||||
|
||||
|
||||
#### 1st relation => ['key, 'value']
|
||||
#### 2st relation => ['header', 'key'or'value']
|
||||
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
||||
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
||||
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
|
||||
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
|
||||
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
|
||||
# overlap_tokens = max_seq_length - non_overlap_tokens
|
||||
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
|
||||
|
||||
# ntokens = max_seq_length
|
||||
|
||||
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm[:ntokens])
|
||||
|
||||
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm[:ntokens])
|
||||
|
||||
bbox = torch.from_numpy(bbox[:ntokens])
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens[:ntokens])
|
||||
|
||||
itc_labels = itc_labels[:ntokens]
|
||||
stc_labels = stc_labels[:ntokens]
|
||||
el_labels = el_labels[:ntokens]
|
||||
el_labels_from_key = el_labels_from_key[:ntokens]
|
||||
|
||||
itc_labels = np.where(itc_labels != max_seq_length, itc_labels, ntokens)
|
||||
stc_labels = np.where(stc_labels != max_seq_length, stc_labels, ntokens)
|
||||
el_labels = np.where(el_labels != max_seq_length, el_labels, ntokens)
|
||||
el_labels_from_key = np.where(el_labels_from_key != max_seq_length, el_labels_from_key, ntokens)
|
||||
|
||||
itc_labels = torch.from_numpy(itc_labels)
|
||||
stc_labels = torch.from_numpy(stc_labels)
|
||||
el_labels = torch.from_numpy(el_labels)
|
||||
el_labels_from_key = torch.from_numpy(el_labels_from_key)
|
||||
|
||||
|
||||
return_dict = {
|
||||
"img_path": feature_maps['img_path'],
|
||||
"len_overlap_tokens": len_overlap_tokens,
|
||||
'len_valid_tokens': len_valid_tokens,
|
||||
"image": feature_maps['image'],
|
||||
"input_ids_layoutxlm": input_ids_layoutxlm,
|
||||
"attention_mask_layoutxlm": attention_mask_layoutxlm,
|
||||
"bbox": bbox,
|
||||
"itc_labels": itc_labels,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"stc_labels": stc_labels,
|
||||
"el_labels": el_labels,
|
||||
"el_labels_from_key": el_labels_from_key,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
||||
|
||||
class KVUPredefinedDataset(KVUDataset):
|
||||
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None):
|
||||
super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode)
|
||||
self.max_windows = cfg.train.max_windows
|
||||
|
||||
def __getitem__(self, index):
|
||||
json_obj = self.examples[index]
|
||||
|
||||
width = json_obj["meta"]["imageSize"]["width"]
|
||||
height = json_obj["meta"]["imageSize"]["height"]
|
||||
img_path = json_obj["meta"]["image_path"]
|
||||
|
||||
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
|
||||
|
||||
word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words(
|
||||
json_obj["words"],
|
||||
json_obj['parse']['class'],
|
||||
json_obj['parse']['relations'],
|
||||
self.window_size, self.slice_interval)
|
||||
|
||||
word_windows = word_windows[: self.max_windows] if len(word_windows) >= self.max_windows else word_windows + [[]] * (self.max_windows - len(word_windows))
|
||||
parse_class_windows = parse_class_windows[: self.max_windows] if len(parse_class_windows) >= self.max_windows else parse_class_windows + [[]] * (self.max_windows - len(parse_class_windows))
|
||||
parse_relation_windows = parse_relation_windows[: self.max_windows] if len(parse_relation_windows) >= self.max_windows else parse_relation_windows + [[]] * (self.max_windows - len(parse_relation_windows))
|
||||
|
||||
|
||||
outputs = {}
|
||||
# outputs['labels'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'],
|
||||
# {'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
# max_seq_length=self.max_seq_length*self.max_windows)
|
||||
|
||||
outputs['windows'] = []
|
||||
for i in range(len(self.max_windows)):
|
||||
single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i],
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
outputs['windows'].append(single_window)
|
||||
|
||||
outputs['windows'] = torch.cat(outputs["windows"], dim=0)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class KVUEmbeddingDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
mode=None,
|
||||
):
|
||||
super(KVUEmbeddingDataset, self).__init__()
|
||||
|
||||
self.dataset_root_path = cfg.dataset_root_path
|
||||
if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig):
|
||||
self.dataset_root_path = [self.dataset_root_path]
|
||||
|
||||
self.stage = cfg.stage
|
||||
self.mode = mode
|
||||
|
||||
self.examples = self._load_examples()
|
||||
|
||||
def _load_examples(self):
|
||||
examples = []
|
||||
for dataset_dir in self.dataset_root_path:
|
||||
with open(
|
||||
os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"),
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fp:
|
||||
for line in fp.readlines():
|
||||
preprocessed_file = os.path.join(dataset_dir, line.strip())
|
||||
examples.append(
|
||||
preprocessed_file.replace('preprocessed', 'embedding_matrix').replace('.json', '.npz')
|
||||
)
|
||||
return examples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
input_embeddings = np.load(self.examples[index])
|
||||
|
||||
return_dict = {
|
||||
"embeddings": torch.from_numpy(input_embeddings["embeddings"]).type(torch.HalfTensor),
|
||||
"attention_mask_layoutxlm": torch.from_numpy(input_embeddings["attention_mask_layoutxlm"]),
|
||||
"are_box_first_tokens": torch.from_numpy(input_embeddings["are_box_first_tokens"]),
|
||||
"bbox": torch.from_numpy(input_embeddings["bbox"]),
|
||||
"itc_labels": torch.from_numpy(input_embeddings["itc_labels"]),
|
||||
"stc_labels": torch.from_numpy(input_embeddings["stc_labels"]),
|
||||
"el_labels": torch.from_numpy(input_embeddings["el_labels"]),
|
||||
"el_labels_from_key": torch.from_numpy(input_embeddings["el_labels_from_key"]),
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
||||
class DocumentKVUDataset(KVUDataset):
|
||||
def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None):
|
||||
super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode)
|
||||
self.self.max_window_count = cfg.train.max_window_count
|
||||
|
||||
def __getitem__(self, idx):
|
||||
json_obj = self.examples[idx]
|
||||
|
||||
width = json_obj["meta"]["imageSize"]["width"]
|
||||
height = json_obj["meta"]["imageSize"]["height"]
|
||||
|
||||
images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")]
|
||||
feature_maps = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
|
||||
n_words = len(json_obj['words'])
|
||||
output_dicts = {'windows': [], 'documents': []}
|
||||
box_to_token_indices_document = []
|
||||
box2token_span_map_document = []
|
||||
n_empty_windows = 0
|
||||
|
||||
for i in range(self.max_window_count):
|
||||
input_ids = np.ones(self.max_seq_length, dtype=int) * 1
|
||||
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
||||
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
||||
|
||||
itc_labels = np.zeros(self.max_seq_length, dtype=int)
|
||||
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
||||
|
||||
# stc_labels stores the index of the previous token.
|
||||
# A stored index of max_seq_length (512) indicates that
|
||||
# this token is the initial token of a word box.
|
||||
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
|
||||
el_labels = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length
|
||||
el_labels_from_key = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length
|
||||
|
||||
start_word_idx = i * self.window_size
|
||||
stop_word_idx = min(n_words, (i+1)*self.window_size)
|
||||
|
||||
if start_word_idx >= stop_word_idx:
|
||||
n_empty_windows += 1
|
||||
output_dicts.append(output_dicts[-1])
|
||||
|
||||
box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices)
|
||||
for i_box in range(len(box_to_token_indices_to_mod)):
|
||||
for j in range(len(box_to_token_indices_to_mod[i_box])):
|
||||
box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length
|
||||
for element in box_to_token_indices_to_mod:
|
||||
box_to_token_indices_document.append(element)
|
||||
|
||||
box2token_span_map_to_mod = copy.deepcopy(box2token_span_map)
|
||||
for i_box in range(len(box2token_span_map_to_mod)):
|
||||
for j in range(len(box2token_span_map_to_mod[i_box])):
|
||||
box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length
|
||||
for element in box2token_span_map_to_mod:
|
||||
box2token_span_map_document.append(element)
|
||||
|
||||
continue
|
||||
|
||||
list_tokens = []
|
||||
list_bbs = []
|
||||
box2token_span_map = []
|
||||
|
||||
box_to_token_indices = []
|
||||
cum_token_idx = 0
|
||||
|
||||
cls_bbs = [0.0] * 8
|
||||
|
||||
# Parse words
|
||||
for word_idx, word in enumerate(json_obj["words"][start_word_idx:stop_word_idx]):
|
||||
this_box_token_indices = []
|
||||
tokens = word["tokens"]
|
||||
bb = word["boundingBox"]
|
||||
if len(tokens) == 0:
|
||||
tokens.append(self.unk_token_id)
|
||||
|
||||
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
|
||||
break ### be able to apply sliding window here
|
||||
|
||||
box2token_span_map.append(
|
||||
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
|
||||
) # including st_idx
|
||||
list_tokens += tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(tokens))]
|
||||
|
||||
for _ in tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [width, height] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_tokens = (
|
||||
[self.cls_token_id]
|
||||
+ list_tokens[: self.max_seq_length - 2]
|
||||
+ [self.sep_token_id]
|
||||
)
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
|
||||
|
||||
len_list_tokens = len(list_tokens)
|
||||
input_ids[:len_list_tokens] = list_tokens
|
||||
attention_mask[:len_list_tokens] = 1
|
||||
|
||||
bbox[:len_list_tokens, :] = list_bbs
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
|
||||
|
||||
if self.backbone_type in ('layoutlm', 'layoutxlm', 'xlm-roberta'):
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
else:
|
||||
assert False
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < self.max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
# Parse word groups
|
||||
classes_dic = json_obj["parse"]["class"]
|
||||
for class_name in self.class_names:
|
||||
if class_name == "others":
|
||||
continue
|
||||
if class_name not in classes_dic:
|
||||
continue
|
||||
|
||||
for word_list in classes_dic[class_name]:
|
||||
word_list = [w for w in word_list if w >= start_word_idx and w < stop_word_idx]
|
||||
if len(word_list) == 0:
|
||||
continue # no more word left
|
||||
word_list = [w - start_word_idx for w in word_list]
|
||||
|
||||
is_first, last_word_idx = True, -1
|
||||
for word_idx in word_list:
|
||||
if word_idx >= len(box_to_token_indices):
|
||||
break
|
||||
box2token_list = box_to_token_indices[word_idx]
|
||||
for converted_word_idx in box2token_list:
|
||||
if converted_word_idx >= self.max_seq_length:
|
||||
break # out of idx
|
||||
|
||||
if is_first:
|
||||
itc_labels[converted_word_idx] = self.class_idx_dic[
|
||||
class_name
|
||||
]
|
||||
is_first, last_word_idx = False, converted_word_idx
|
||||
else:
|
||||
stc_labels[converted_word_idx] = last_word_idx
|
||||
last_word_idx = converted_word_idx
|
||||
|
||||
# Parse relation
|
||||
relations = json_obj["parse"]["relations"]
|
||||
for relation in relations:
|
||||
relation = [r for r in relation if r >= start_word_idx and r < stop_word_idx]
|
||||
if len(relation) != 2:
|
||||
continue # relation popped due to window inconsistent
|
||||
relation[0] -= start_word_idx
|
||||
relation[1] -= start_word_idx
|
||||
|
||||
word_from = box2token_span_map[relation[0]][0]
|
||||
word_to = box2token_span_map[relation[1]][0]
|
||||
|
||||
#### 1st relation => ['key, 'value']
|
||||
#### 2st relation => ['header', 'key'or'value']
|
||||
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
||||
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
||||
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
|
||||
input_ids = torch.from_numpy(input_ids)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
attention_mask = torch.from_numpy(attention_mask)
|
||||
|
||||
itc_labels = torch.from_numpy(itc_labels)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
stc_labels = torch.from_numpy(stc_labels)
|
||||
el_labels = torch.from_numpy(el_labels)
|
||||
el_labels_from_key = torch.from_numpy(el_labels_from_key)
|
||||
|
||||
box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices)
|
||||
for i_box in range(len(box_to_token_indices_to_mod)):
|
||||
for j in range(len(box_to_token_indices_to_mod[i_box])):
|
||||
box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length
|
||||
for element in box_to_token_indices_to_mod:
|
||||
box_to_token_indices_document.append(element)
|
||||
|
||||
box2token_span_map_to_mod = copy.deepcopy(box2token_span_map)
|
||||
for i_box in range(len(box2token_span_map_to_mod)):
|
||||
for j in range(len(box2token_span_map_to_mod[i_box])):
|
||||
box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length
|
||||
for element in box2token_span_map_to_mod:
|
||||
box2token_span_map_document.append(element)
|
||||
|
||||
return_dict = {
|
||||
"image": feature_maps,
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"attention_mask": attention_mask,
|
||||
"itc_labels": itc_labels,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"stc_labels": stc_labels,
|
||||
"el_labels": el_labels,
|
||||
"el_labels_from_key": el_labels_from_key,
|
||||
}
|
||||
|
||||
output_dicts["windows"].append(return_dict)
|
||||
|
||||
|
||||
# Parse whole document labels
|
||||
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts])
|
||||
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts])
|
||||
if n_empty_windows > 0:
|
||||
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
|
||||
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
|
||||
bbox = torch.cat([o['bbox'] for o in output_dicts])
|
||||
|
||||
self.max_seq_length_document = self.max_seq_length * self.max_window_count
|
||||
itc_labels = np.zeros(self.max_seq_length_document, dtype=int)
|
||||
stc_labels = np.ones(self.max_seq_length_document, dtype=np.int64) * self.max_seq_length_document
|
||||
el_labels = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document
|
||||
el_labels_from_key = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document
|
||||
|
||||
# Parse word groups
|
||||
classes_dic = json_obj["parse"]["class"]
|
||||
for class_name in self.class_names:
|
||||
if class_name == "others":
|
||||
continue
|
||||
if class_name not in classes_dic:
|
||||
continue
|
||||
|
||||
word_lists = classes_dic[class_name]
|
||||
|
||||
for word_list in word_lists:
|
||||
is_first, last_word_idx = True, -1
|
||||
for word_idx in word_list:
|
||||
if word_idx >= len(box_to_token_indices_document):
|
||||
break
|
||||
box2token_list = box_to_token_indices_document[word_idx]
|
||||
for converted_word_idx in box2token_list:
|
||||
if converted_word_idx >= self.max_seq_length_document:
|
||||
break # out of idx
|
||||
|
||||
if is_first:
|
||||
itc_labels[converted_word_idx] = self.class_idx_dic[
|
||||
class_name
|
||||
]
|
||||
is_first, last_word_idx = False, converted_word_idx
|
||||
else:
|
||||
stc_labels[converted_word_idx] = last_word_idx
|
||||
last_word_idx = converted_word_idx
|
||||
|
||||
# Parse relation
|
||||
relations = json_obj["parse"]["relations"]
|
||||
|
||||
for relation in relations:
|
||||
if relation[0] >= len(box2token_span_map_document) or relation[1] >= len(
|
||||
box2token_span_map_document
|
||||
):
|
||||
continue
|
||||
if (
|
||||
box2token_span_map_document[relation[0]][0] >= self.max_seq_length_document
|
||||
or box2token_span_map_document[relation[1]][0] >= self.max_seq_length_document
|
||||
):
|
||||
continue
|
||||
|
||||
word_from = box2token_span_map_document[relation[0]][0]
|
||||
word_to = box2token_span_map_document[relation[1]][0]
|
||||
|
||||
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
||||
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
||||
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
|
||||
itc_labels = torch.from_numpy(itc_labels)
|
||||
stc_labels = torch.from_numpy(stc_labels)
|
||||
el_labels = torch.from_numpy(el_labels)
|
||||
el_labels_from_key = torch.from_numpy(el_labels_from_key)
|
||||
|
||||
return_dict = {
|
||||
"img_path": json_obj["meta"]["image_path"],
|
||||
"attention_mask": attention_mask,
|
||||
"bbox": bbox,
|
||||
"itc_labels": itc_labels,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"stc_labels": stc_labels,
|
||||
"el_labels": el_labels,
|
||||
"el_labels_from_key": el_labels_from_key,
|
||||
"n_empty_windows": n_empty_windows
|
||||
}
|
||||
|
||||
output_dicts['documents'] = return_dict
|
||||
return output_dicts
|
53
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/schedulers.py
Executable file
53
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/schedulers.py
Executable file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
BROS
|
||||
Copyright 2022-present NAVER Corp.
|
||||
Apache License v2.0
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
|
||||
def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1):
|
||||
"""linear_scheduler with warmup from huggingface"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
float(training_steps - current_step)
|
||||
/ float(max(1, training_steps - warmup_steps)),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def cosine_scheduler(
|
||||
optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1
|
||||
):
|
||||
"""Cosine LR scheduler with warmup from huggingface"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return current_step / max(1, warmup_steps)
|
||||
progress = current_step - warmup_steps
|
||||
progress /= max(1, training_steps - warmup_steps)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1):
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
# calculate a warmup ratio
|
||||
return current_step / max(1, warmup_steps)
|
||||
else:
|
||||
# calculate a multistep lr scaling ratio
|
||||
idx = np.searchsorted(milestones, current_step)
|
||||
return gamma ** idx
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
218
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/utils.py
Executable file
218
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/utils.py
Executable file
@ -0,0 +1,218 @@
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
def _get_number_samples(dataset_root_path):
|
||||
n_samples = 0
|
||||
for dataset_dir in dataset_root_path:
|
||||
with open(
|
||||
os.path.join(dataset_dir, f"preprocessed_files_train.txt"),
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fp:
|
||||
for line in fp.readlines():
|
||||
n_samples += 1
|
||||
return n_samples
|
||||
|
||||
|
||||
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
|
||||
element_windows = []
|
||||
|
||||
if len(elements) > window_size:
|
||||
max_step = math.ceil((len(elements) - window_size)/slice_interval)
|
||||
|
||||
for i in range(0, max_step + 1):
|
||||
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
|
||||
if (i*slice_interval+window_size) >= len(elements):
|
||||
_window = copy.deepcopy(elements[i*slice_interval:])
|
||||
else:
|
||||
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
|
||||
element_windows.append(_window)
|
||||
return element_windows
|
||||
else:
|
||||
return [elements]
|
||||
|
||||
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
|
||||
word_windows = []
|
||||
parse_class_windows = []
|
||||
parse_relation_windows = []
|
||||
|
||||
if len(lwords) > window_size:
|
||||
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
|
||||
for i in range(0, max_step+1):
|
||||
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
|
||||
if (i*slice_interval+window_size) >= len(lwords):
|
||||
_word_window = copy.deepcopy(lwords[i*slice_interval:])
|
||||
else:
|
||||
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
|
||||
|
||||
if len(_word_window) < 2:
|
||||
continue
|
||||
|
||||
first_word_id = _word_window[0]['word_id']
|
||||
last_word_id = _word_window[-1]['word_id']
|
||||
|
||||
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
|
||||
# word list
|
||||
for _word in _word_window:
|
||||
_word['word_id'] -= first_word_id
|
||||
|
||||
|
||||
# Entity extraction
|
||||
_class_window = {k: [] for k in list(parse_class.keys())}
|
||||
for class_name, _parse_class in parse_class.items():
|
||||
for group in _parse_class:
|
||||
tmp = []
|
||||
for idw in group:
|
||||
idw -= first_word_id
|
||||
if 0 <= idw <= (last_word_id - first_word_id):
|
||||
tmp.append(idw)
|
||||
_class_window[class_name].append(tmp)
|
||||
|
||||
# Entity Linking
|
||||
_relation_window = []
|
||||
for pair in parse_relation:
|
||||
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
|
||||
_relation_window.append([idw - first_word_id for idw in pair])
|
||||
|
||||
word_windows.append(_word_window)
|
||||
parse_class_windows.append(_class_window)
|
||||
parse_relation_windows.append(_relation_window)
|
||||
|
||||
return word_windows, parse_class_windows, parse_relation_windows
|
||||
else:
|
||||
return [lwords], [parse_class], [parse_relation]
|
||||
|
||||
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[0]
|
||||
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
|
||||
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
|
||||
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
|
||||
|
||||
for i in range(1, len(lpatches)):
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[i]
|
||||
|
||||
overlap_gap = copy.deepcopy(loverlaps[i-1])
|
||||
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
|
||||
|
||||
if overlap_gap != 0:
|
||||
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
|
||||
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
|
||||
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||||
|
||||
if average:
|
||||
avg_overlap = (
|
||||
prev_overlap + curr_overlap
|
||||
) / 2.
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens, window], dim=1
|
||||
)
|
||||
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
||||
|
||||
|
||||
|
||||
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[0]
|
||||
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
|
||||
cls_token = lpatches[0][:, :1, ...]
|
||||
sep_token = lpatches[0][:, -1:, ...]
|
||||
|
||||
for i in range(1, len(lpatches)):
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[i]
|
||||
|
||||
overlap_gap = loverlaps[i-1]
|
||||
window = lpatches[i][:, start_pos:end_pos, ...]
|
||||
|
||||
if overlap_gap != 0:
|
||||
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
|
||||
curr_overlap = window[:, :overlap_gap, ...]
|
||||
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||||
|
||||
if average:
|
||||
avg_overlap = (
|
||||
prev_overlap + curr_overlap
|
||||
) / 2.
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens, window], dim=1
|
||||
)
|
||||
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
||||
|
||||
|
||||
|
||||
# def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||||
# # start_pos = 1
|
||||
# # end_pos = start_pos + lvalids[0]
|
||||
# # embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
|
||||
# embedding_tokens = np.zeros((1, 2046, 768))
|
||||
# cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
|
||||
# sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
|
||||
|
||||
# token_apperance_count = np.zeros((2046,))
|
||||
# start_idx = 1
|
||||
# for loverlap, lvalid in zip(loverlaps, lvalids):
|
||||
# token_apperance_count[start_idx:start_idx+lvalid] += 1
|
||||
# start_idx = start_idx + lvalid - loverlap
|
||||
|
||||
# embedding_matrix_spos = 0
|
||||
# for i in range(0, len(lpatches)):
|
||||
# embedding_matrix_epos = embedding_matrix_spos + int(lvalids[i])
|
||||
|
||||
# # assert embedding_matrix_epos - embedding_matrix_spos == end_pos - 1, (embedding_matrix_spos, embedding_matrix_epos, lvalid[i], end_pos)
|
||||
|
||||
# overlap_gap = copy.deepcopy(loverlaps[i]).cpu().numpy()
|
||||
# window = copy.deepcopy(lpatches[i][:, 1:int(lvalids[i])+1, ...]).cpu().numpy()
|
||||
|
||||
# # embedding_tokens[:,embedding_matrix_spos:embedding_matrix_epos:] = window * token_apperance_count[embedding_matrix_spos:embedding_matrix_epos]
|
||||
# for i in range(len(window)):
|
||||
# window[:,i,:] *= token_apperance_count[embedding_matrix_spos + i]
|
||||
# embedding_tokens[: ,int(embedding_matrix_spos): int(embedding_matrix_epos), :] += window
|
||||
# embedding_matrix_spos -= overlap_gap
|
||||
|
||||
|
||||
# embedding_tokens = torch.tensor(embedding_tokens).type(torch.HalfTensor).cuda()
|
||||
|
||||
# # if overlap_gap != 0:
|
||||
# # prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
|
||||
# # curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
|
||||
# # assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||||
|
||||
# # if average:
|
||||
# # avg_overlap = (
|
||||
# # prev_overlap + curr_overlap
|
||||
# # ) / 2.
|
||||
# # embedding_tokens = torch.cat(
|
||||
# # [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
# # )
|
||||
# # else:
|
||||
# # embedding_tokens = torch.cat(
|
||||
# # [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
# # )
|
||||
# # else:
|
||||
# # embedding_tokens = torch.cat(
|
||||
# # [embedding_tokens, window], dim=1
|
||||
# # )
|
||||
# return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
15
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/__init__.py
Executable file
15
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/__init__.py
Executable file
@ -0,0 +1,15 @@
|
||||
|
||||
from model.combined_model import CombinedKVUModel
|
||||
from model.kvu_model import KVUModel
|
||||
from model.document_kvu_model import DocumentKVUModel
|
||||
|
||||
def get_model(cfg):
|
||||
if cfg.stage == 1:
|
||||
model = CombinedKVUModel(cfg=cfg)
|
||||
elif cfg.stage == 2:
|
||||
model = KVUModel(cfg=cfg)
|
||||
elif cfg.stage == 3:
|
||||
model = DocumentKVUModel(cfg=cfg)
|
||||
else:
|
||||
AssertionError('[ERROR] Trainging stage is wrong')
|
||||
return model
|
224
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/combined_model.py
Executable file
224
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/combined_model.py
Executable file
@ -0,0 +1,224 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from transformers import AutoTokenizer, XLMRobertaModel
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class CombinedKVUModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.finetune_only = cfg.train.finetune_only
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
|
||||
self._create_head()
|
||||
|
||||
if os.path.exists(self.model_cfg.ckpt_model_file):
|
||||
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
||||
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
|
||||
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
|
||||
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
|
||||
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
|
||||
|
||||
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
if self.freeze:
|
||||
for name, param in self.named_parameters():
|
||||
if 'backbone' in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'EE':
|
||||
for name, param in self.named_parameters():
|
||||
if 'itc_layer' not in name and 'stc_layer' not in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'EL':
|
||||
for name, param in self.named_parameters():
|
||||
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'ELK':
|
||||
for name, param in self.named_parameters():
|
||||
if 'relation_layer_from_key' not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_head(self):
|
||||
self.backbone_hidden_size = 768
|
||||
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||||
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||||
self.n_classes = self.model_cfg.n_classes + 1
|
||||
|
||||
# (1) Initial token classification
|
||||
self.itc_layer = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (3) Linking token classification
|
||||
self.relation_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
self.itc_layer.apply(self._init_weight)
|
||||
self.stc_layer.apply(self._init_weight)
|
||||
self.relation_layer.apply(self._init_weight)
|
||||
|
||||
|
||||
def _get_backbones(self, config_type):
|
||||
|
||||
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
|
||||
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(module):
|
||||
init_std = 0.02
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.normal_(module.weight, 1.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
def forward(self, batch):
|
||||
image = batch["image"]
|
||||
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
||||
|
||||
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
||||
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
|
||||
|
||||
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
|
||||
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
loss = 0.0
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss = self._get_loss(head_outputs, batch)
|
||||
|
||||
return head_outputs, loss
|
||||
|
||||
def _get_loss(self, head_outputs, batch):
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
||||
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
||||
el_loss = self._get_el_loss(el_outputs, batch)
|
||||
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
|
||||
|
||||
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
|
||||
|
||||
return loss
|
||||
|
||||
def _get_itc_loss(self, itc_outputs, batch):
|
||||
itc_mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
|
||||
itc_logits = itc_logits[itc_mask]
|
||||
|
||||
itc_labels = batch["itc_labels"].view(-1)
|
||||
itc_labels = itc_labels[itc_mask]
|
||||
|
||||
itc_loss = self.loss_func(itc_logits, itc_labels)
|
||||
|
||||
return itc_loss
|
||||
|
||||
def _get_stc_loss(self, stc_outputs, batch):
|
||||
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
|
||||
|
||||
bsz, max_seq_length = inv_attention_mask.shape
|
||||
device = inv_attention_mask.device
|
||||
|
||||
invalid_token_mask = torch.cat(
|
||||
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
|
||||
).bool()
|
||||
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
|
||||
|
||||
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
|
||||
stc_logits = stc_logits[stc_mask]
|
||||
|
||||
stc_labels = batch["stc_labels"].view(-1)
|
||||
stc_labels = stc_labels[stc_mask]
|
||||
|
||||
stc_loss = self.loss_func(stc_logits, stc_labels)
|
||||
|
||||
return stc_loss
|
||||
|
||||
def _get_el_loss(self, el_outputs, batch, from_key=False):
|
||||
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
|
||||
device = batch["attention_mask_layoutxlm"].device
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
|
||||
box_first_token_mask = torch.cat(
|
||||
[
|
||||
(batch["are_box_first_tokens"] == False),
|
||||
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
|
||||
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
logits = el_outputs.view(-1, max_seq_length + 1)
|
||||
logits = logits[mask]
|
||||
|
||||
if from_key:
|
||||
el_labels = batch["el_labels_from_key"]
|
||||
else:
|
||||
el_labels = batch["el_labels"]
|
||||
labels = el_labels.view(-1)
|
||||
labels = labels[mask]
|
||||
|
||||
loss = self.loss_func(logits, labels)
|
||||
return loss
|
285
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/document_kvu_model.py
Executable file
285
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/document_kvu_model.py
Executable file
@ -0,0 +1,285 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutLMv2Config, LayoutLMv2Model
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class DocumentKVUModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.train_cfg = cfg.train
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
# if 'pth' in self.model_cfg.ckpt_model_file:
|
||||
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone, 'backbone')
|
||||
|
||||
self._create_head()
|
||||
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
def _create_head(self):
|
||||
self.backbone_hidden_size = self.backbone_config.hidden_size
|
||||
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||||
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||||
self.n_classes = self.model_cfg.n_classes + 1
|
||||
self.relations = self.model_cfg.n_relations
|
||||
# self.repr_hiddent_size = self.backbone_hidden_size + self.n_classes + (self.train_cfg.max_seq_length + 1) * 3
|
||||
self.repr_hiddent_size = self.backbone_hidden_size
|
||||
|
||||
# (1) Initial token classification
|
||||
self.itc_layer = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (3) Linking token classification
|
||||
self.relation_layer = RelationExtractor(
|
||||
n_relations=self.relations, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key = RelationExtractor(
|
||||
n_relations=self.relations, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# Classfication Layer for whole document
|
||||
# (1) Initial token classification
|
||||
self.itc_layer_document = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.repr_hiddent_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer_document = RelationExtractor(
|
||||
n_relations=1,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
# (3) Linking token classification
|
||||
self.relation_layer_document = RelationExtractor(
|
||||
n_relations=self.relations,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key_document = RelationExtractor(
|
||||
n_relations=self.relations,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
self.itc_layer.apply(self._init_weight)
|
||||
self.stc_layer.apply(self._init_weight)
|
||||
self.relation_layer.apply(self._init_weight)
|
||||
self.relation_layer_from_key.apply(self._init_weight)
|
||||
|
||||
self.itc_layer_document.apply(self._init_weight)
|
||||
self.stc_layer_document.apply(self._init_weight)
|
||||
self.relation_layer_document.apply(self._init_weight)
|
||||
self.relation_layer_from_key_document.apply(self._init_weight)
|
||||
|
||||
|
||||
def _get_backbones(self, config_type):
|
||||
configs = {
|
||||
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
}
|
||||
|
||||
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
if config_type != 'xlm-roberta':
|
||||
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
else:
|
||||
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
|
||||
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
|
||||
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(module):
|
||||
init_std = 0.02
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.normal_(module.weight, 1.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
def forward(self, batches):
|
||||
head_outputs_list = []
|
||||
loss = 0.0
|
||||
for batch in batches["windows"]:
|
||||
image = batch["image"]
|
||||
input_ids = batch["input_ids"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
|
||||
if self.freeze:
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.model_cfg.backbone == 'layoutxlm':
|
||||
backbone_outputs = self.backbone(
|
||||
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
|
||||
|
||||
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
|
||||
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
|
||||
window_repr = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
head_outputs = {"window_repr": window_repr,
|
||||
"itc_outputs": itc_outputs,
|
||||
"stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs,
|
||||
"el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss += self._get_loss(head_outputs, batch)
|
||||
|
||||
head_outputs_list.append(head_outputs)
|
||||
|
||||
batch = batches["documents"]
|
||||
|
||||
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
|
||||
document_repr = document_repr.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
|
||||
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
|
||||
|
||||
head_outputs = {"itc_outputs": itc_outputs,
|
||||
"stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs,
|
||||
"el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss += self._get_loss(head_outputs, batch)
|
||||
|
||||
return head_outputs, loss
|
||||
|
||||
def _get_loss(self, head_outputs, batch):
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
||||
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
||||
el_loss = self._get_el_loss(el_outputs, batch)
|
||||
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
|
||||
|
||||
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
|
||||
|
||||
return loss
|
||||
|
||||
def _get_itc_loss(self, itc_outputs, batch):
|
||||
itc_mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
|
||||
itc_logits = itc_logits[itc_mask]
|
||||
|
||||
itc_labels = batch["itc_labels"].view(-1)
|
||||
itc_labels = itc_labels[itc_mask]
|
||||
|
||||
itc_loss = self.loss_func(itc_logits, itc_labels)
|
||||
|
||||
return itc_loss
|
||||
|
||||
def _get_stc_loss(self, stc_outputs, batch):
|
||||
inv_attention_mask = 1 - batch["attention_mask"]
|
||||
|
||||
bsz, max_seq_length = inv_attention_mask.shape
|
||||
device = inv_attention_mask.device
|
||||
|
||||
invalid_token_mask = torch.cat(
|
||||
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
|
||||
).bool()
|
||||
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
stc_mask = batch["attention_mask"].view(-1).bool()
|
||||
|
||||
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
|
||||
stc_logits = stc_logits[stc_mask]
|
||||
|
||||
stc_labels = batch["stc_labels"].view(-1)
|
||||
stc_labels = stc_labels[stc_mask]
|
||||
|
||||
stc_loss = self.loss_func(stc_logits, stc_labels)
|
||||
|
||||
return stc_loss
|
||||
|
||||
def _get_el_loss(self, el_outputs, batch, from_key=False):
|
||||
bsz, max_seq_length = batch["attention_mask"].shape
|
||||
device = batch["attention_mask"].device
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
|
||||
box_first_token_mask = torch.cat(
|
||||
[
|
||||
(batch["are_box_first_tokens"] == False),
|
||||
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
|
||||
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
logits = el_outputs.view(-1, max_seq_length + 1)
|
||||
logits = logits[mask]
|
||||
|
||||
if from_key:
|
||||
el_labels = batch["el_labels_from_key"]
|
||||
else:
|
||||
el_labels = batch["el_labels"]
|
||||
labels = el_labels.view(-1)
|
||||
labels = labels[mask]
|
||||
|
||||
loss = self.loss_func(logits, labels)
|
||||
return loss
|
248
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/kvu_model.py
Executable file
248
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/kvu_model.py
Executable file
@ -0,0 +1,248 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class KVUModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.device = 'cuda'
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.finetune_only = cfg.train.finetune_only
|
||||
|
||||
# if cfg.stage == 2:
|
||||
# self.freeze = True
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
self._create_head()
|
||||
|
||||
if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
|
||||
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
||||
|
||||
self._create_head()
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
if self.freeze:
|
||||
for name, param in self.named_parameters():
|
||||
if 'backbone' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_head(self):
|
||||
self.backbone_hidden_size = 768
|
||||
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||||
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||||
self.n_classes = self.model_cfg.n_classes + 1
|
||||
|
||||
# (1) Initial token classification
|
||||
self.itc_layer = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (3) Linking token classification
|
||||
self.relation_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
self.itc_layer.apply(self._init_weight)
|
||||
self.stc_layer.apply(self._init_weight)
|
||||
self.relation_layer.apply(self._init_weight)
|
||||
|
||||
|
||||
def _get_backbones(self, config_type):
|
||||
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
|
||||
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(module):
|
||||
init_std = 0.02
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.normal_(module.weight, 1.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
|
||||
# def forward(self, inputs):
|
||||
# token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda()
|
||||
# itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
|
||||
# stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
# el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
# el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
|
||||
# head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
# "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
# loss = self._get_loss(head_outputs, inputs)
|
||||
# return head_outputs, loss
|
||||
|
||||
|
||||
# def forward_single_doccument(self, lbatches):
|
||||
def forward(self, lbatches):
|
||||
windows = lbatches['windows']
|
||||
token_embeddings_windows = []
|
||||
lvalids = []
|
||||
loverlaps = []
|
||||
|
||||
for i, batch in enumerate(windows):
|
||||
batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')}
|
||||
image = batch["image"]
|
||||
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
||||
|
||||
|
||||
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
||||
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
|
||||
|
||||
|
||||
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
|
||||
|
||||
lvalids.append(batch['len_valid_tokens'])
|
||||
loverlaps.append(batch['len_overlap_tokens'])
|
||||
token_embeddings_windows.append(last_hidden_states_layoutxlm)
|
||||
|
||||
|
||||
token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False)
|
||||
# token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True)
|
||||
|
||||
|
||||
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
|
||||
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
|
||||
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key,
|
||||
'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()}
|
||||
|
||||
|
||||
|
||||
loss = 0.0
|
||||
if any(['labels' in key for key in lbatches.keys()]):
|
||||
labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')}
|
||||
loss = self._get_loss(head_outputs, labels)
|
||||
|
||||
return head_outputs, loss
|
||||
|
||||
def _get_loss(self, head_outputs, batch):
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
||||
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
||||
el_loss = self._get_el_loss(el_outputs, batch)
|
||||
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
|
||||
|
||||
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
|
||||
|
||||
return loss
|
||||
|
||||
def _get_itc_loss(self, itc_outputs, batch):
|
||||
itc_mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
|
||||
itc_logits = itc_logits[itc_mask]
|
||||
|
||||
itc_labels = batch["itc_labels"].view(-1)
|
||||
itc_labels = itc_labels[itc_mask]
|
||||
|
||||
itc_loss = self.loss_func(itc_logits, itc_labels)
|
||||
|
||||
return itc_loss
|
||||
|
||||
def _get_stc_loss(self, stc_outputs, batch):
|
||||
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
|
||||
|
||||
bsz, max_seq_length = inv_attention_mask.shape
|
||||
device = inv_attention_mask.device
|
||||
|
||||
invalid_token_mask = torch.cat(
|
||||
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
|
||||
).bool()
|
||||
|
||||
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
|
||||
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
|
||||
stc_logits = stc_logits[stc_mask]
|
||||
|
||||
stc_labels = batch["stc_labels"].view(-1)
|
||||
stc_labels = stc_labels[stc_mask]
|
||||
|
||||
stc_loss = self.loss_func(stc_logits, stc_labels)
|
||||
|
||||
return stc_loss
|
||||
|
||||
def _get_el_loss(self, el_outputs, batch, from_key=False):
|
||||
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
|
||||
|
||||
device = batch["attention_mask_layoutxlm"].device
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
|
||||
box_first_token_mask = torch.cat(
|
||||
[
|
||||
(batch["are_box_first_tokens"] == False),
|
||||
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
|
||||
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
logits = el_outputs.view(-1, max_seq_length + 1)
|
||||
logits = logits[mask]
|
||||
|
||||
if from_key:
|
||||
el_labels = batch["el_labels_from_key"]
|
||||
else:
|
||||
el_labels = batch["el_labels"]
|
||||
labels = el_labels.view(-1)
|
||||
labels = labels[mask]
|
||||
|
||||
loss = self.loss_func(logits, labels)
|
||||
return loss
|
48
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/relation_extractor.py
Executable file
48
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/relation_extractor.py
Executable file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RelationExtractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_relations,
|
||||
backbone_hidden_size,
|
||||
head_hidden_size,
|
||||
head_p_dropout=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_relations = n_relations
|
||||
self.backbone_hidden_size = backbone_hidden_size
|
||||
self.head_hidden_size = head_hidden_size
|
||||
self.head_p_dropout = head_p_dropout
|
||||
|
||||
self.drop = nn.Dropout(head_p_dropout)
|
||||
self.q_net = nn.Linear(
|
||||
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
||||
)
|
||||
|
||||
self.k_net = nn.Linear(
|
||||
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
||||
)
|
||||
|
||||
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
|
||||
nn.init.normal_(self.dummy_node)
|
||||
|
||||
def forward(self, h_q, h_k):
|
||||
h_q = self.q_net(self.drop(h_q))
|
||||
|
||||
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
|
||||
h_k = torch.cat([h_k, dummy_vec], axis=0)
|
||||
h_k = self.k_net(self.drop(h_k))
|
||||
|
||||
head_q = h_q.view(
|
||||
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
|
||||
)
|
||||
head_k = h_k.view(
|
||||
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
|
||||
)
|
||||
|
||||
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
|
||||
|
||||
return relation_score
|
133
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitignore
vendored
Executable file
133
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitignore
vendored
Executable file
@ -0,0 +1,133 @@
|
||||
externals/sdsv_dewarp
|
||||
test/
|
||||
.vscode/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
results/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
6
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitmodules
vendored
Executable file
6
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitmodules
vendored
Executable file
@ -0,0 +1,6 @@
|
||||
[submodule "sdsvtr"]
|
||||
path = externals/sdsvtr
|
||||
url = https://github.com/mrlasdt/sdsvtr.git
|
||||
[submodule "sdsvtd"]
|
||||
path = externals/sdsvtd
|
||||
url = https://github.com/mrlasdt/sdsvtd.git
|
47
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/README.md
Executable file
47
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/README.md
Executable file
@ -0,0 +1,47 @@
|
||||
# OCR Engine
|
||||
|
||||
OCR Engine is a Python package that combines text detection and recognition models from [mmdet](https://github.com/open-mmlab/mmdetection) and [mmocr](https://github.com/open-mmlab/mmocr) to perform Optical Character Recognition (OCR) on various inputs. The package currently supports three types of input: a single image, a recursive directory, or a csv file.
|
||||
|
||||
## Installation
|
||||
|
||||
To install OCR Engine, clone the repository and install the required packages:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:mrlasdt/ocr-engine.git
|
||||
cd ocr-engine
|
||||
pip install -r requirements.txt
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
To use OCR Engine, simply run the `ocr_engine.py` script with the desired input type and input path. For example, to perform OCR on a single image:
|
||||
|
||||
```css
|
||||
python ocr_engine.py --input_type image --input_path /path/to/image.jpg
|
||||
```
|
||||
|
||||
To perform OCR on a recursive directory:
|
||||
|
||||
```css
|
||||
python ocr_engine.py --input_type directory --input_path /path/to/directory/
|
||||
|
||||
```
|
||||
|
||||
To perform OCR on a csv file:
|
||||
|
||||
|
||||
```
|
||||
python ocr_engine.py --input_type csv --input_path /path/to/file.csv
|
||||
```
|
||||
|
||||
OCR Engine will automatically detect and recognize text in the input and output the results in a CSV file named `ocr_results.csv`.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you would like to contribute to OCR Engine, please fork the repository and submit a pull request. We welcome contributions of all types, including bug fixes, new features, and documentation improvements.
|
||||
|
||||
## License
|
||||
|
||||
OCR Engine is released under the [MIT License](https://opensource.org/licenses/MIT). See the LICENSE file for more information.
|
11
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/__init__.py
Executable file
11
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/__init__.py
Executable file
@ -0,0 +1,11 @@
|
||||
# # Define package-level variables
|
||||
# __version__ = '0.0'
|
||||
|
||||
# Import modules
|
||||
from .src.ocr import OcrEngine
|
||||
# from .src.word_formation import words_to_lines
|
||||
from .src.word_formation import words_to_lines_tesseract as words_to_lines
|
||||
from .src.utils import ImageReader, read_ocr_result_from_txt
|
||||
from .src.dto import Word, Line, Page, Document, Box
|
||||
# Expose package contents
|
||||
__all__ = ["OcrEngine", "Box", "Word", "Line", "Page", "Document", "words_to_lines", "ImageReader", "read_ocr_result_from_txt"]
|
82
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/requirements.txt
Executable file
82
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/requirements.txt
Executable file
@ -0,0 +1,82 @@
|
||||
addict==2.4.0
|
||||
asttokens==2.2.1
|
||||
autopep8==1.6.0
|
||||
backcall==0.2.0
|
||||
backports.functools-lru-cache==1.6.4
|
||||
brotlipy==0.7.0
|
||||
certifi==2022.12.7
|
||||
cffi==1.15.1
|
||||
charset-normalizer==2.0.4
|
||||
click==8.1.3
|
||||
colorama==0.4.6
|
||||
cryptography==39.0.1
|
||||
debugpy==1.5.1
|
||||
decorator==5.1.1
|
||||
docopt==0.6.2
|
||||
entrypoints==0.4
|
||||
executing==1.2.0
|
||||
flit_core==3.6.0
|
||||
idna==3.4
|
||||
importlib-metadata==6.0.0
|
||||
ipykernel==6.15.0
|
||||
ipython==8.11.0
|
||||
jedi==0.18.2
|
||||
jupyter-client==7.0.6
|
||||
jupyter_core==4.12.0
|
||||
Markdown==3.4.1
|
||||
markdown-it-py==2.2.0
|
||||
matplotlib-inline==0.1.6
|
||||
mdurl==0.1.2
|
||||
mkl-fft==1.3.1
|
||||
mkl-random==1.2.2
|
||||
mkl-service==2.4.0
|
||||
mmcv-full==1.7.1
|
||||
model-index==0.1.11
|
||||
nest-asyncio==1.5.6
|
||||
numpy==1.23.5
|
||||
opencv-python==4.7.0.72
|
||||
openmim==0.3.6
|
||||
ordered-set==4.1.0
|
||||
packaging==23.0
|
||||
pandas==1.5.3
|
||||
parso==0.8.3
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.4.0
|
||||
pip==22.3.1
|
||||
pipdeptree==2.5.2
|
||||
prompt-toolkit==3.0.38
|
||||
psutil==5.9.0
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pycodestyle==2.10.0
|
||||
pycparser==2.21
|
||||
Pygments==2.14.0
|
||||
pyOpenSSL==23.0.0
|
||||
PySocks==1.7.1
|
||||
python-dateutil==2.8.2
|
||||
pytz==2022.7.1
|
||||
PyYAML==6.0
|
||||
pyzmq==19.0.2
|
||||
requests==2.28.1
|
||||
rich==13.3.1
|
||||
sdsvtd==0.1.1
|
||||
sdsvtr==0.0.5
|
||||
setuptools==65.6.3
|
||||
Shapely==1.8.4
|
||||
six==1.16.0
|
||||
stack-data==0.6.2
|
||||
tabulate==0.9.0
|
||||
toml==0.10.2
|
||||
torch==1.13.1
|
||||
torchvision==0.14.1
|
||||
tornado==6.1
|
||||
tqdm==4.65.0
|
||||
traitlets==5.9.0
|
||||
typing_extensions==4.4.0
|
||||
urllib3==1.26.14
|
||||
wcwidth==0.2.6
|
||||
wheel==0.38.4
|
||||
yapf==0.32.0
|
||||
yarg==0.1.9
|
||||
zipp==3.15.0
|
143
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/run.py
Executable file
143
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/run.py
Executable file
@ -0,0 +1,143 @@
|
||||
"""
|
||||
see scripts/run_ocr.sh to run
|
||||
"""
|
||||
# from pathlib import Path # add parent path to run debugger
|
||||
# import sys
|
||||
# FILE = Path(__file__).absolute()
|
||||
# sys.path.append(FILE.parents[2].as_posix())
|
||||
|
||||
|
||||
from src.utils import construct_file_path, ImageReader
|
||||
from src.dto import Line
|
||||
from src.ocr import OcrEngine
|
||||
import argparse
|
||||
import tqdm
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Union, Tuple, List
|
||||
current_dir = os.getcwd()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser image
|
||||
parser.add_argument("--image", type=str, required=True,
|
||||
help="path to input image/directory/csv file")
|
||||
parser.add_argument("--save_dir", type=str, required=True,
|
||||
help="path to save directory")
|
||||
parser.add_argument(
|
||||
"--base_dir", type=str, required=False, default=current_dir,
|
||||
help="used when --image and --save_dir are relative paths to a base directory, default to current directory")
|
||||
parser.add_argument(
|
||||
"--export_csv", type=str, required=False, default="",
|
||||
help="used when --image is a directory. If set, a csv file contains image_path, ocr_path and label will be exported to save_dir.")
|
||||
parser.add_argument(
|
||||
"--export_img", type=bool, required=False, default=False, help="whether to save the visualize img")
|
||||
parser.add_argument("--ocr_kwargs", type=str, required=False, default="")
|
||||
opt = parser.parse_args()
|
||||
return opt
|
||||
|
||||
|
||||
def load_engine(opt) -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {}
|
||||
engine = OcrEngine(**kw)
|
||||
print("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
|
||||
def convert_relative_path_to_positive_path(tgt_dir: Path, base_dir: Path) -> Path:
|
||||
return tgt_dir if tgt_dir.is_absolute() else base_dir.joinpath(tgt_dir)
|
||||
|
||||
|
||||
def get_paths_from_opt(opt) -> Tuple[Path, Path]:
|
||||
# BC\ kiem\ tra\ y\ te -> BC kiem tra y te
|
||||
img_path = opt.image.replace("\\ ", " ").strip()
|
||||
save_dir = opt.save_dir.replace("\\ ", " ").strip()
|
||||
base_dir = opt.base_dir.replace("\\ ", " ").strip()
|
||||
input_image = convert_relative_path_to_positive_path(
|
||||
Path(img_path), Path(base_dir))
|
||||
save_dir = convert_relative_path_to_positive_path(
|
||||
Path(save_dir), Path(base_dir))
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir()
|
||||
print("[INFO]: Creating folder ", save_dir)
|
||||
return input_image, save_dir
|
||||
|
||||
|
||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||
save_dir_or_path = Path(save_dir_or_path)
|
||||
if isinstance(img, np.ndarray):
|
||||
if save_dir_or_path.is_dir():
|
||||
raise ValueError(
|
||||
"numpy array input require a save path, not a save dir")
|
||||
page = engine(img)
|
||||
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
|
||||
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
|
||||
page.write_to_file('word', save_path)
|
||||
if export_img:
|
||||
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
|
||||
|
||||
|
||||
def process_dir(
|
||||
dir_path: str, save_dir: str, engine: OcrEngine, export_img: bool, lskip_dir: List[str] = [],
|
||||
ddata: dict = {"img_path": list(),
|
||||
"ocr_path": list(),
|
||||
"label": list()}) -> None:
|
||||
dir_path = Path(dir_path)
|
||||
# save_dir_sub = Path(construct_file_path(save_dir, dir_path, ext=""))
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
for img_path in (pbar := tqdm.tqdm(dir_path.iterdir())):
|
||||
pbar.set_description(f"Processing {dir_path}")
|
||||
if img_path.is_dir() and img_path not in lskip_dir:
|
||||
save_dir_sub = save_dir.joinpath(img_path.stem)
|
||||
process_dir(img_path, str(save_dir_sub), engine, ddata)
|
||||
elif img_path.suffix.lower() in ImageReader.supported_ext:
|
||||
simg_path = str(img_path)
|
||||
try:
|
||||
img = ImageReader.read(
|
||||
simg_path) if img_path.suffix != ".pdf" else ImageReader.read(simg_path)[0]
|
||||
save_path = str(Path(save_dir).joinpath(
|
||||
img_path.stem + ".txt"))
|
||||
process_img(img, save_path, engine, export_img)
|
||||
except Exception as e:
|
||||
print('[ERROR]: ', e, ' at ', simg_path)
|
||||
continue
|
||||
ddata["img_path"].append(simg_path)
|
||||
ddata["ocr_path"].append(save_path)
|
||||
ddata["label"].append(dir_path.stem)
|
||||
# ddata.update({"img_path": img_path, "save_path": save_path, "label": dir_path.stem})
|
||||
return ddata
|
||||
|
||||
|
||||
def process_csv(csv_path: str, engine: OcrEngine) -> None:
|
||||
df = pd.read_csv(csv_path)
|
||||
if not 'image_path' in df.columns or not 'ocr_path' in df.columns:
|
||||
raise AssertionError('Cannot fing image_path in df headers')
|
||||
for row in df.iterrows():
|
||||
process_img(row.image_path, row.ocr_path, engine)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
engine = load_engine(opt)
|
||||
print("[INFO]: OCR engine settings:", engine.settings)
|
||||
img, save_dir = get_paths_from_opt(opt)
|
||||
|
||||
lskip_dir = []
|
||||
if img.is_dir():
|
||||
ddata = process_dir(img, save_dir, engine, opt.export_img)
|
||||
if opt.export_csv:
|
||||
pd.DataFrame.from_dict(ddata).to_csv(
|
||||
Path(save_dir).joinpath(opt.export_csv))
|
||||
elif img.suffix in ImageReader.supported_ext:
|
||||
process_img(str(img), save_dir, engine, opt.export_img)
|
||||
elif img.suffix == '.csv':
|
||||
print("[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used")
|
||||
process_csv(img, engine)
|
||||
else:
|
||||
raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img))
|
23
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/scripts/run_ocr.sh
Executable file
23
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/scripts/run_ocr.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#bash scripts/run_ocr.sh -i /mnt/ssd1T/hungbnt/DocumentClassification/data/OCR040_043 -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr/OCR040_043 -e out.csv -k "{\"device\":\"cuda:1\"}" -x True
|
||||
export PYTHONWARNINGS="ignore"
|
||||
|
||||
while getopts i:o:b:e:x:k: flag
|
||||
do
|
||||
case "${flag}" in
|
||||
i) img=${OPTARG};;
|
||||
o) out_dir=${OPTARG};;
|
||||
b) base_dir=${OPTARG};;
|
||||
e) export_csv=${OPTARG};;
|
||||
x) export_img=${OPTARG};;
|
||||
k) ocr_kwargs=${OPTARG};;
|
||||
esac
|
||||
done
|
||||
echo "run.py --image=\"$img\" --save_dir \"$out_dir\" --base_dir \"$base_dir\" --export_csv \"$export_csv\" --export_img \"$export_img\" --ocr_kwargs \"$ocr_kwargs\""
|
||||
|
||||
python run.py \
|
||||
--image="$img" \
|
||||
--save_dir $out_dir \
|
||||
--export_csv $export_csv\
|
||||
--export_img $export_img\
|
||||
--ocr_kwargs $ocr_kwargs\
|
||||
|
17
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml
Executable file
17
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml
Executable file
@ -0,0 +1,17 @@
|
||||
detector: "/models/Kie_invoice_ap/wild_receipt_finetune_weights_c_lite.pth"
|
||||
rotator_version: "/models/Kie_invoice_ap/best_bbox_mAP_epoch_30_lite.pth"
|
||||
recog_max_seq_len: 25
|
||||
recognizer: "satrn-lite-general-pretrain-20230106"
|
||||
device: "cuda:0"
|
||||
do_extend_bbox: True
|
||||
margin_bbox: [0, 0.03, 0.02, 0.05]
|
||||
batch_mode: False
|
||||
batch_size: 16
|
||||
auto_rotate: True
|
||||
img_size: [] #[1920,1920] #text det default size: 1280x1280 #[] = originla size, TODO: fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size
|
||||
deskew: False #True
|
||||
words_to_lines: {
|
||||
"gradient": 0.6,
|
||||
"max_x_dist": 20,
|
||||
"y_overlap_threshold": 0.5,
|
||||
}
|
453
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/dto.py
Executable file
453
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/dto.py
Executable file
@ -0,0 +1,453 @@
|
||||
import numpy as np
|
||||
from typing import Optional, List
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from .utils import visualize_bbox_and_label
|
||||
|
||||
|
||||
class Box:
|
||||
def __init__(self, x1, y1, x2, y2, conf=-1., label=""):
|
||||
self.x1 = x1
|
||||
self.y1 = y1
|
||||
self.x2 = x2
|
||||
self.y2 = y2
|
||||
self.conf = conf
|
||||
self.label = label
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.bbox)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.bbox)
|
||||
|
||||
def get(self, return_confidence=False) -> list:
|
||||
return self.bbox if not return_confidence else self.xyxyc
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.bbox[key]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.y2 - self.y1
|
||||
|
||||
@property
|
||||
def bbox(self) -> list:
|
||||
return [self.x1, self.y1, self.x2, self.y2]
|
||||
|
||||
@bbox.setter
|
||||
def bbox(self, bbox_: list):
|
||||
self.x1, self.y1, self.x2, self.y2 = bbox_
|
||||
|
||||
@property
|
||||
def xyxyc(self) -> list:
|
||||
return [self.x1, self.y1, self.x2, self.y2, self.conf]
|
||||
|
||||
@staticmethod
|
||||
def normalize_bbox(bbox: list):
|
||||
return [int(b) for b in bbox]
|
||||
|
||||
def to_int(self):
|
||||
self.x1, self.y1, self.x2, self.y2 = self.normalize_bbox([self.x1, self.y1, self.x2, self.y2])
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def clamp_bbox_by_img_wh(bbox: list, width: int, height: int):
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1 = min(max(0, x1), width)
|
||||
x2 = min(max(0, x2), width)
|
||||
y1 = min(max(0, y1), height)
|
||||
y2 = min(max(0, y2), height)
|
||||
return (x1, y1, x2, y2)
|
||||
|
||||
def clamp_by_img_wh(self, width: int, height: int):
|
||||
self.x1, self.y1, self.x2, self.y2 = self.clamp_bbox_by_img_wh(
|
||||
[self.x1, self.y1, self.x2, self.y2], width, height)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def extend_bbox(bbox: list, margin: list): # -> Self (python3.11)
|
||||
margin_l, margin_t, margin_r, margin_b = margin
|
||||
l, t, r, b = bbox # left, top, right, bottom
|
||||
t = t - (b - t) * margin_t
|
||||
b = b + (b - t) * margin_b
|
||||
l = l - (r - l) * margin_l
|
||||
r = r + (r - l) * margin_r
|
||||
return [l, t, r, b]
|
||||
|
||||
def get_extend_bbox(self, margin: list):
|
||||
extended_bbox = self.extend_bbox(self.bbox, margin)
|
||||
return Box(*extended_bbox, label=self.label)
|
||||
|
||||
@staticmethod
|
||||
def bbox_is_valid(bbox: list) -> bool:
|
||||
l, t, r, b = bbox # left, top, right, bottom
|
||||
return True if (b - t) * (r - l) > 0 else False
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.bbox_is_valid(self.bbox)
|
||||
|
||||
@staticmethod
|
||||
def crop_img_by_bbox(img: np.ndarray, bbox: list) -> np.ndarray:
|
||||
l, t, r, b = bbox
|
||||
return img[t:b, l:r]
|
||||
|
||||
def crop_img(self, img: np.ndarray) -> np.ndarray:
|
||||
return self.crop_img_by_bbox(img, self.bbox)
|
||||
|
||||
|
||||
class Word:
|
||||
def __init__(
|
||||
self,
|
||||
image=None,
|
||||
text="",
|
||||
conf_cls=-1.,
|
||||
bndbox: Optional[Box] = None,
|
||||
conf_detect=-1.,
|
||||
kie_label="",
|
||||
):
|
||||
self.type = "word"
|
||||
self.text = text
|
||||
self.image = image
|
||||
self.conf_detect = conf_detect
|
||||
self.conf_cls = conf_cls
|
||||
# [left, top,right,bot] coordinate of top-left and bottom-right point
|
||||
self.boundingbox = bndbox
|
||||
self.word_id = 0 # id of word
|
||||
self.word_group_id = 0 # id of word_group which instance belongs to
|
||||
self.line_id = 0 # id of line which instance belongs to
|
||||
self.paragraph_id = 0 # id of line which instance belongs to
|
||||
self.kie_label = kie_label
|
||||
|
||||
@property
|
||||
def bbox(self):
|
||||
return self.boundingbox.bbox
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.boundingbox.height
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.boundingbox.width
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def invalid_size(self):
|
||||
return (self.boundingbox[2] - self.boundingbox[0]) * (
|
||||
self.boundingbox[3] - self.boundingbox[1]
|
||||
) > 0
|
||||
|
||||
def is_special_word(self):
|
||||
left, top, right, bottom = self.boundingbox
|
||||
width, height = right - left, bottom - top
|
||||
text = self.text
|
||||
|
||||
if text is None:
|
||||
return True
|
||||
|
||||
# if len(text) > 7:
|
||||
# return True
|
||||
if len(text) >= 7:
|
||||
no_digits = sum(c.isdigit() for c in text)
|
||||
return no_digits / len(text) >= 0.3
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Word_group:
|
||||
def __init__(self, list_words_: List[Word] = list(),
|
||||
text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1):
|
||||
self.type = "word_group"
|
||||
self.list_words = list_words_ # dict of word instances
|
||||
self.word_group_id = 0 # word group id
|
||||
self.line_id = 0 # id of line which instance belongs to
|
||||
self.paragraph_id = 0 # id of paragraph which instance belongs to
|
||||
self.text = text
|
||||
self.boundingbox = boundingbox
|
||||
self.kie_label = ""
|
||||
self.conf_cls = conf_cls
|
||||
|
||||
@property
|
||||
def bbox(self):
|
||||
return self.boundingbox
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def add_word(self, word: Word): # add a word instance to the word_group
|
||||
if word.text != "✪":
|
||||
for w in self.list_words:
|
||||
if word.word_id == w.word_id:
|
||||
print("Word id collision")
|
||||
return False
|
||||
word.word_group_id = self.word_group_id #
|
||||
word.line_id = self.line_id
|
||||
word.paragraph_id = self.paragraph_id
|
||||
self.list_words.append(word)
|
||||
self.text += " " + word.text
|
||||
if self.boundingbox == [-1, -1, -1, -1]:
|
||||
self.boundingbox = word.boundingbox
|
||||
else:
|
||||
self.boundingbox = [
|
||||
min(self.boundingbox[0], word.boundingbox[0]),
|
||||
min(self.boundingbox[1], word.boundingbox[1]),
|
||||
max(self.boundingbox[2], word.boundingbox[2]),
|
||||
max(self.boundingbox[3], word.boundingbox[3]),
|
||||
]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_word_group_id(self, new_word_group_id):
|
||||
self.word_group_id = new_word_group_id
|
||||
for i in range(len(self.list_words)):
|
||||
self.list_words[i].word_group_id = new_word_group_id
|
||||
|
||||
def update_kie_label(self):
|
||||
list_kie_label = [word.kie_label for word in self.list_words]
|
||||
dict_kie = dict()
|
||||
for label in list_kie_label:
|
||||
if label not in dict_kie:
|
||||
dict_kie[label] = 1
|
||||
else:
|
||||
dict_kie[label] += 1
|
||||
total = len(list(dict_kie.values()))
|
||||
max_value = max(list(dict_kie.values()))
|
||||
list_keys = list(dict_kie.keys())
|
||||
list_values = list(dict_kie.values())
|
||||
self.kie_label = list_keys[list_values.index(max_value)]
|
||||
|
||||
def update_text(self): # update text after changing positions of words in list word
|
||||
text = ""
|
||||
for word in self.list_words:
|
||||
text += " " + word.text
|
||||
self.text = text
|
||||
|
||||
|
||||
class Line:
|
||||
def __init__(self, list_word_groups: List[Word_group] = [],
|
||||
text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1):
|
||||
self.type = "line"
|
||||
self.list_word_groups = list_word_groups # list of Word_group instances in the line
|
||||
self.line_id = 0 # id of line in the paragraph
|
||||
self.paragraph_id = 0 # id of paragraph which instance belongs to
|
||||
self.text = text
|
||||
self.boundingbox = boundingbox
|
||||
self.conf_cls = conf_cls
|
||||
|
||||
@property
|
||||
def bbox(self):
|
||||
return self.boundingbox
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def add_group(self, word_group: Word_group): # add a word_group instance
|
||||
if word_group.list_words is not None:
|
||||
for wg in self.list_word_groups:
|
||||
if word_group.word_group_id == wg.word_group_id:
|
||||
print("Word_group id collision")
|
||||
return False
|
||||
|
||||
self.list_word_groups.append(word_group)
|
||||
self.text += word_group.text
|
||||
word_group.paragraph_id = self.paragraph_id
|
||||
word_group.line_id = self.line_id
|
||||
|
||||
for i in range(len(word_group.list_words)):
|
||||
word_group.list_words[
|
||||
i
|
||||
].paragraph_id = self.paragraph_id # set paragraph_id for word
|
||||
word_group.list_words[i].line_id = self.line_id # set line_id for word
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_line_id(self, new_line_id):
|
||||
self.line_id = new_line_id
|
||||
for i in range(len(self.list_word_groups)):
|
||||
self.list_word_groups[i].line_id = new_line_id
|
||||
for j in range(len(self.list_word_groups[i].list_words)):
|
||||
self.list_word_groups[i].list_words[j].line_id = new_line_id
|
||||
|
||||
def merge_word(self, word): # word can be a Word instance or a Word_group instance
|
||||
if word.text != "✪":
|
||||
if self.boundingbox == [-1, -1, -1, -1]:
|
||||
self.boundingbox = word.boundingbox
|
||||
else:
|
||||
self.boundingbox = [
|
||||
min(self.boundingbox[0], word.boundingbox[0]),
|
||||
min(self.boundingbox[1], word.boundingbox[1]),
|
||||
max(self.boundingbox[2], word.boundingbox[2]),
|
||||
max(self.boundingbox[3], word.boundingbox[3]),
|
||||
]
|
||||
self.list_word_groups.append(word)
|
||||
self.text += " " + word.text
|
||||
return True
|
||||
return False
|
||||
|
||||
def __cal_ratio(self, top1, bottom1, top2, bottom2):
|
||||
sorted_vals = sorted([top1, bottom1, top2, bottom2])
|
||||
intersection = sorted_vals[2] - sorted_vals[1]
|
||||
min_height = min(bottom1 - top1, bottom2 - top2)
|
||||
if min_height == 0:
|
||||
return -1
|
||||
ratio = intersection / min_height
|
||||
return ratio
|
||||
|
||||
def __cal_ratio_height(self, top1, bottom1, top2, bottom2):
|
||||
|
||||
height1, height2 = top1 - bottom1, top2 - bottom2
|
||||
ratio_height = float(max(height1, height2)) / float(min(height1, height2))
|
||||
return ratio_height
|
||||
|
||||
def in_same_line(self, input_line, thresh=0.7):
|
||||
# calculate iou in vertical direction
|
||||
_, top1, _, bottom1 = self.boundingbox
|
||||
_, top2, _, bottom2 = input_line.boundingbox
|
||||
|
||||
ratio = self.__cal_ratio(top1, bottom1, top2, bottom2)
|
||||
ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2)
|
||||
|
||||
if (
|
||||
(top2 <= top1 <= bottom2) or (top1 <= top2 <= bottom1)
|
||||
and ratio >= thresh
|
||||
and (ratio_height < 2)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Paragraph:
|
||||
def __init__(self, id=0, lines=None):
|
||||
self.list_lines = lines if lines is not None else [] # list of all lines in the paragraph
|
||||
self.paragraph_id = id # index of paragraph in the ist of paragraph
|
||||
self.text = ""
|
||||
self.boundingbox = [-1, -1, -1, -1]
|
||||
|
||||
@property
|
||||
def bbox(self):
|
||||
return self.boundingbox
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def add_line(self, line: Line): # add a line instance
|
||||
if line.list_word_groups is not None:
|
||||
for l in self.list_lines:
|
||||
if line.line_id == l.line_id:
|
||||
print("Line id collision")
|
||||
return False
|
||||
for i in range(len(line.list_word_groups)):
|
||||
line.list_word_groups[
|
||||
i
|
||||
].paragraph_id = (
|
||||
self.paragraph_id
|
||||
) # set paragraph id for every word group in line
|
||||
for j in range(len(line.list_word_groups[i].list_words)):
|
||||
line.list_word_groups[i].list_words[
|
||||
j
|
||||
].paragraph_id = (
|
||||
self.paragraph_id
|
||||
) # set paragraph id for every word in word groups
|
||||
line.paragraph_id = self.paragraph_id # set paragraph id for line
|
||||
self.list_lines.append(line) # add line to paragraph
|
||||
self.text += " " + line.text
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_paragraph_id(
|
||||
self, new_paragraph_id
|
||||
): # update new paragraph_id for all lines, word_groups, words inside paragraph
|
||||
self.paragraph_id = new_paragraph_id
|
||||
for i in range(len(self.list_lines)):
|
||||
self.list_lines[
|
||||
i
|
||||
].paragraph_id = new_paragraph_id # set new paragraph_id for line
|
||||
for j in range(len(self.list_lines[i].list_word_groups)):
|
||||
self.list_lines[i].list_word_groups[
|
||||
j
|
||||
].paragraph_id = new_paragraph_id # set new paragraph_id for word_group
|
||||
for k in range(len(self.list_lines[i].list_word_groups[j].list_words)):
|
||||
self.list_lines[i].list_word_groups[j].list_words[
|
||||
k
|
||||
].paragraph_id = new_paragraph_id # set new paragraph id for word
|
||||
return True
|
||||
|
||||
|
||||
class Page:
|
||||
def __init__(self, llines: List[Line], image: np.ndarray) -> None:
|
||||
self.__llines = llines
|
||||
self.__image = image
|
||||
self.__drawed_image = None
|
||||
|
||||
@property
|
||||
def llines(self):
|
||||
return self.__llines
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
return self.__image
|
||||
|
||||
@property
|
||||
def PIL_image(self):
|
||||
return Image.fromarray(self.__image)
|
||||
|
||||
@property
|
||||
def drawed_image(self):
|
||||
if self.__drawed_image:
|
||||
self.__drawed_image = self
|
||||
|
||||
def visualize_bbox_and_label(self, **kwargs: dict):
|
||||
if self.__drawed_image is not None:
|
||||
return self.__drawed_image
|
||||
bboxes = list()
|
||||
texts = list()
|
||||
for line in self.__llines:
|
||||
for word_group in line.list_word_groups:
|
||||
for word in word_group.list_words:
|
||||
bboxes.append([int(float(b)) for b in word.bbox[:]])
|
||||
texts.append(word.text)
|
||||
img = visualize_bbox_and_label(self.__image, bboxes, texts, **kwargs)
|
||||
self.__drawed_image = img
|
||||
return self.__drawed_image
|
||||
|
||||
def save_img(self, save_path: str, **kwargs: dict) -> None:
|
||||
img = self.visualize_bbox_and_label(**kwargs)
|
||||
cv2.imwrite(save_path, img)
|
||||
|
||||
def write_to_file(self, mode: str, save_path: str) -> None:
|
||||
f = open(save_path, "w+", encoding="utf-8")
|
||||
for line in self.__llines:
|
||||
if mode == 'line':
|
||||
xmin, ymin, xmax, ymax = line.bbox[:]
|
||||
f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, line.text))
|
||||
elif mode == "word":
|
||||
for word_group in line.list_word_groups:
|
||||
for word in word_group.list_words:
|
||||
# xmin, ymin, xmax, ymax = word.bbox[:]
|
||||
xmin, ymin, xmax, ymax = [int(float(b)) for b in word.bbox[:]]
|
||||
f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, word.text))
|
||||
f.close()
|
||||
|
||||
|
||||
class Document:
|
||||
def __init__(self, lpages: List[Page]) -> None:
|
||||
self.lpages = lpages
|
207
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/ocr.py
Executable file
207
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/ocr.py
Executable file
@ -0,0 +1,207 @@
|
||||
from typing import Union, overload, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
import mmcv
|
||||
from sdsvtd import StandaloneYOLOXRunner
|
||||
from sdsvtr import StandaloneSATRNRunner
|
||||
from .utils import ImageReader, chunks, rotate_bbox, Timer
|
||||
# from .utils import jdeskew as deskew
|
||||
# from externals.deskew.sdsv_dewarp import pdeskew as deskew
|
||||
from .utils import deskew, post_process_recog
|
||||
from .dto import Word, Line, Page, Document, Box
|
||||
# from .word_formation import words_to_lines as words_to_lines
|
||||
# from .word_formation import wo rds_to_lines_mmocr as words_to_lines
|
||||
from .word_formation import words_to_lines_tesseract as words_to_lines
|
||||
DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml"
|
||||
|
||||
|
||||
class OcrEngine:
|
||||
def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict):
|
||||
""" Warper of text detection and text recognition
|
||||
:param settings_file: path to default setting file
|
||||
:param kwargs: keyword arguments to overwrite the default settings file
|
||||
"""
|
||||
|
||||
with open(settings_file) as f:
|
||||
# use safe_load instead load
|
||||
self.__settings = yaml.safe_load(f)
|
||||
for k, v in kwargs.items(): # overwrite default settings by keyword arguments
|
||||
if k not in self.__settings:
|
||||
raise ValueError("Invalid setting found in OcrEngine: ", k)
|
||||
self.__settings[k] = v
|
||||
|
||||
if "cuda" in self.__settings["device"]:
|
||||
if not torch.cuda.is_available():
|
||||
print("[WARNING]: CUDA is not available, running with cpu instead")
|
||||
self.__settings["device"] = "cpu"
|
||||
self._detector = StandaloneYOLOXRunner(
|
||||
version=self.__settings["detector"],
|
||||
device=self.__settings["device"],
|
||||
auto_rotate=self.__settings["auto_rotate"],
|
||||
rotator_version=self.__settings["rotator_version"])
|
||||
self._recognizer = StandaloneSATRNRunner(
|
||||
version=self.__settings["recognizer"],
|
||||
return_confident=True, device=self.__settings["device"],
|
||||
max_seq_len_overwrite=self.__settings["recog_max_seq_len"]
|
||||
)
|
||||
# extend the bbox to avoid losing accent mark in vietnames, if using ocr for only english, disable it
|
||||
self._do_extend_bbox = self.__settings["do_extend_bbox"]
|
||||
# left, top, right, bottom"]
|
||||
self._margin_bbox = self.__settings["margin_bbox"]
|
||||
self._batch_mode = self.__settings["batch_mode"]
|
||||
self._batch_size = self.__settings["batch_size"]
|
||||
self._deskew = self.__settings["deskew"]
|
||||
self._img_size = self.__settings["img_size"]
|
||||
self.__version__ = {
|
||||
"detector": self.__settings["detector"],
|
||||
"recognizer": self.__settings["recognizer"],
|
||||
}
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self.__version__
|
||||
|
||||
@property
|
||||
def settings(self):
|
||||
return self.__settings
|
||||
|
||||
# @staticmethod
|
||||
# def xyxyc_to_xyxy_c(xyxyc: np.ndarray) -> Tuple[List[list], list]:
|
||||
# '''
|
||||
# convert sdsvtd yoloX detection output to list of bboxes and list of confidences
|
||||
# @param xyxyc: array of shape (n, 5)
|
||||
# '''
|
||||
# xyxy = xyxyc[:, :4].tolist()
|
||||
# confs = xyxyc[:, 4].tolist()
|
||||
# return xyxy, confs
|
||||
# -> Tuple[np.ndarray, List[Box]]:
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
img_ = img.copy()
|
||||
if self.__settings["img_size"]:
|
||||
img_ = mmcv.imrescale(
|
||||
img, tuple(self.__settings["img_size"]),
|
||||
return_scale=False, interpolation='bilinear', backend='cv2')
|
||||
if self._deskew:
|
||||
with Timer("deskew"):
|
||||
img_, angle = deskew(img_)
|
||||
# for i, bbox in enumerate(bboxes):
|
||||
# rotated_bbox = rotate_bbox(bbox[:], angle, img.shape[:2])
|
||||
# bboxes[i].bbox = rotated_bbox
|
||||
return img_ # , bboxes
|
||||
|
||||
def run_detect(self, img: np.ndarray, return_raw: bool = False) -> Tuple[np.ndarray, Union[List[Box], List[list]]]:
|
||||
'''
|
||||
run text detection and return list of xyxyc if return_confidence is True, otherwise return a list of xyxy
|
||||
'''
|
||||
pred_det = self._detector(img)
|
||||
if self.__settings["auto_rotate"]:
|
||||
img, pred_det = pred_det
|
||||
pred_det = pred_det[0] # only image at a time
|
||||
return (img, pred_det.tolist()) if return_raw else (img, [Box(*xyxyc) for xyxyc in pred_det.tolist()])
|
||||
|
||||
def run_recog(self, imgs: List[np.ndarray]) -> Union[List[str], List[Tuple[str, float]]]:
|
||||
if len(imgs) == 0:
|
||||
return list()
|
||||
pred_rec = self._recognizer(imgs)
|
||||
return [(post_process_recog(word), conf) for word, conf in zip(pred_rec[0], pred_rec[1])]
|
||||
|
||||
def read_img(self, img: str) -> np.ndarray:
|
||||
return ImageReader.read(img)
|
||||
|
||||
def get_cropped_imgs(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> Tuple[List[np.ndarray], List[bool]]:
|
||||
"""
|
||||
img: np image
|
||||
bboxes: list of xyxy
|
||||
"""
|
||||
lcropped_imgs = list()
|
||||
mask = list()
|
||||
for bbox in bboxes:
|
||||
bbox = Box(*bbox) if isinstance(bbox, list) else bbox
|
||||
bbox = bbox.get_extend_bbox(
|
||||
self._margin_bbox) if self._do_extend_bbox else bbox
|
||||
bbox.clamp_by_img_wh(img.shape[1], img.shape[0])
|
||||
bbox.to_int()
|
||||
if not bbox.is_valid():
|
||||
mask.append(False)
|
||||
continue
|
||||
cropped_img = bbox.crop_img(img)
|
||||
lcropped_imgs.append(cropped_img)
|
||||
mask.append(True)
|
||||
return lcropped_imgs, mask
|
||||
|
||||
def read_page(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> List[Line]:
|
||||
if len(bboxes) == 0: # no bbox found
|
||||
return list()
|
||||
with Timer("cropped imgs"):
|
||||
lcropped_imgs, mask = self.get_cropped_imgs(img, bboxes)
|
||||
with Timer("recog"):
|
||||
# batch_mode for efficiency
|
||||
pred_recs = self.run_recog(lcropped_imgs)
|
||||
with Timer("construct words"):
|
||||
lwords = list()
|
||||
for i in range(len(pred_recs)):
|
||||
if not mask[i]:
|
||||
continue
|
||||
text, conf_rec = pred_recs[i][0], pred_recs[i][1]
|
||||
bbox = Box(*bboxes[i]) if isinstance(bboxes[i],
|
||||
list) else bboxes[i]
|
||||
lwords.append(Word(
|
||||
image=img, text=text, conf_cls=conf_rec, bndbox=bbox, conf_detect=bbox.conf))
|
||||
with Timer("words to lines"):
|
||||
return words_to_lines(
|
||||
lwords, **self.__settings["words_to_lines"])[0]
|
||||
|
||||
# https://stackoverflow.com/questions/48127642/incompatible-types-in-assignment-on-union
|
||||
|
||||
@overload
|
||||
def __call__(self, img: Union[str, np.ndarray, Image.Image]) -> Page: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self, img: List[Union[str, np.ndarray, Image.Image]]) -> Document: ...
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Accept an image or list of them, return ocr result as a page or document
|
||||
"""
|
||||
with Timer("read image"):
|
||||
img = ImageReader.read(img)
|
||||
if not self._batch_mode:
|
||||
if isinstance(img, list):
|
||||
if len(img) == 1:
|
||||
img = img[0] # in case input type is a 1 page pdf
|
||||
else:
|
||||
raise AssertionError(
|
||||
"list input can only be used with batch_mode enabled")
|
||||
img = self.preprocess(img)
|
||||
with Timer("detect"):
|
||||
img, bboxes = self.run_detect(img)
|
||||
with Timer("read_page"):
|
||||
llines = self.read_page(img, bboxes)
|
||||
return Page(llines, img)
|
||||
else:
|
||||
lpages = []
|
||||
# chunks to reduce memory footprint
|
||||
for imgs in chunks(img, self._batch_size):
|
||||
# pred_dets = self._detector(imgs)
|
||||
# TEMP: use list comprehension because sdsvtd do not support batch mode of text detection
|
||||
img = self.preprocess(img)
|
||||
img, bboxes = self.run_detect(img)
|
||||
for img_, bboxes_ in zip(imgs, bboxes):
|
||||
llines = self.read_page(img, bboxes_)
|
||||
page = Page(llines, img)
|
||||
lpages.append(page)
|
||||
return Document(lpages)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_path = "/mnt/ssd1T/hungbnt/Cello/data/PH/Sea7/Sea_7_1.jpg"
|
||||
engine = OcrEngine(device="cuda:0", return_confidence=True)
|
||||
# https://stackoverflow.com/questions/66435480/overload-following-optional-argument
|
||||
page = engine(img_path) # type: ignore
|
||||
print(page.__llines)
|
||||
|
346
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/utils.py
Executable file
346
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/utils.py
Executable file
@ -0,0 +1,346 @@
|
||||
from PIL import ImageFont, ImageDraw, Image, ImageOps
|
||||
# import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import time
|
||||
from typing import Generator, Union, List, overload, Tuple, Callable
|
||||
import glob
|
||||
import math
|
||||
from pathlib import Path
|
||||
from pdf2image import convert_from_path
|
||||
from deskew import determine_skew
|
||||
from jdeskew.estimator import get_angle
|
||||
from jdeskew.utility import rotate as jrotate
|
||||
|
||||
|
||||
def post_process_recog(text: str) -> str:
|
||||
text = text.replace("✪", " ")
|
||||
return text
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, func: Callable, *args):
|
||||
self.end_time = time.perf_counter()
|
||||
self.elapsed_time = self.end_time - self.start_time
|
||||
print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds")
|
||||
|
||||
|
||||
def rotate(
|
||||
image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]
|
||||
) -> np.ndarray:
|
||||
old_width, old_height = image.shape[:2]
|
||||
angle_radian = math.radians(angle)
|
||||
width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width)
|
||||
height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height)
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
rot_mat[1, 2] += (width - old_width) / 2
|
||||
rot_mat[0, 2] += (height - old_height) / 2
|
||||
return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background)
|
||||
|
||||
|
||||
# def rotate_bbox(bbox: list, angle: float) -> list:
|
||||
# # Compute the center point of the bounding box
|
||||
# cx = bbox[0] + bbox[2] / 2
|
||||
# cy = bbox[1] + bbox[3] / 2
|
||||
|
||||
# # Define the scale factor for the rotated bounding box
|
||||
# scale = 1.0 # following the deskew and jdeskew function
|
||||
# angle_radian = math.radians(angle)
|
||||
|
||||
# # Obtain the rotation matrix using cv2.getRotationMatrix2D()
|
||||
# M = cv2.getRotationMatrix2D((cx, cy), angle_radian, scale)
|
||||
|
||||
# # Apply the rotation matrix to the four corners of the bounding box
|
||||
# corners = np.array([[bbox[0], bbox[1]],
|
||||
# [bbox[0] + bbox[2], bbox[1]],
|
||||
# [bbox[0] + bbox[2], bbox[1] + bbox[3]],
|
||||
# [bbox[0], bbox[1] + bbox[3]]], dtype=np.float32)
|
||||
# rotated_corners = cv2.transform(np.array([corners]), M)[0]
|
||||
|
||||
# # Compute the bounding box of the rotated corners
|
||||
# x = int(np.min(rotated_corners[:, 0]))
|
||||
# y = int(np.min(rotated_corners[:, 1]))
|
||||
# w = int(np.max(rotated_corners[:, 0]) - np.min(rotated_corners[:, 0]))
|
||||
# h = int(np.max(rotated_corners[:, 1]) - np.min(rotated_corners[:, 1]))
|
||||
# rotated_bbox = [x, y, w, h]
|
||||
|
||||
# return rotated_bbox
|
||||
|
||||
def rotate_bbox(bbox: List[int], angle: float, old_shape: Tuple[int, int]) -> List[int]:
|
||||
# https://medium.com/@pokomaru/image-and-bounding-box-rotation-using-opencv-python-2def6c39453
|
||||
bbox_ = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]]
|
||||
h, w = old_shape
|
||||
cx, cy = (int(w / 2), int(h / 2))
|
||||
|
||||
bbox_tuple = [
|
||||
(bbox_[0], bbox_[1]),
|
||||
(bbox_[2], bbox_[3]),
|
||||
(bbox_[4], bbox_[5]),
|
||||
(bbox_[6], bbox_[7]),
|
||||
] # put x and y coordinates in tuples, we will iterate through the tuples and perform rotation
|
||||
|
||||
rotated_bbox = []
|
||||
|
||||
for i, coord in enumerate(bbox_tuple):
|
||||
M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
||||
cos, sin = abs(M[0, 0]), abs(M[0, 1])
|
||||
newW = int((h * sin) + (w * cos))
|
||||
newH = int((h * cos) + (w * sin))
|
||||
M[0, 2] += (newW / 2) - cx
|
||||
M[1, 2] += (newH / 2) - cy
|
||||
v = [coord[0], coord[1], 1]
|
||||
adjusted_coord = np.dot(M, v)
|
||||
rotated_bbox.insert(i, (adjusted_coord[0], adjusted_coord[1]))
|
||||
result = [int(x) for t in rotated_bbox for x in t]
|
||||
return [result[i] for i in [0, 1, 2, -1]] # reformat to xyxy
|
||||
|
||||
|
||||
def deskew(image: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
angle = 0.
|
||||
try:
|
||||
angle = determine_skew(grayscale)
|
||||
except Exception:
|
||||
pass
|
||||
rotated = rotate(image, angle, (0, 0, 0)) if angle else image
|
||||
return rotated, angle
|
||||
|
||||
|
||||
def jdeskew(image: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
angle = 0.
|
||||
try:
|
||||
angle = get_angle(image)
|
||||
except Exception:
|
||||
pass
|
||||
# TODO: change resize = True and scale the bounding box
|
||||
rotated = jrotate(image, angle, resize=False) if angle else image
|
||||
return rotated, angle
|
||||
|
||||
|
||||
class ImageReader:
|
||||
"""
|
||||
accept anything, return numpy array image
|
||||
"""
|
||||
supported_ext = [".png", ".jpg", ".jpeg", ".pdf", ".gif"]
|
||||
|
||||
@staticmethod
|
||||
def validate_img_path(img_path: str) -> None:
|
||||
if not os.path.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
if os.path.isdir(img_path):
|
||||
raise IsADirectoryError(img_path)
|
||||
if not Path(img_path).suffix.lower() in ImageReader.supported_ext:
|
||||
raise NotImplementedError("Not supported extension at {}".format(img_path))
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def read(img: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def read(img: str) -> List[np.ndarray]: ... # for pdf or directory
|
||||
|
||||
@staticmethod
|
||||
def read(img):
|
||||
if isinstance(img, list):
|
||||
return ImageReader.from_list(img)
|
||||
elif isinstance(img, str) and os.path.isdir(img):
|
||||
return ImageReader.from_dir(img)
|
||||
elif isinstance(img, str) and img.endswith(".pdf"):
|
||||
return ImageReader.from_pdf(img)
|
||||
else:
|
||||
return ImageReader._read(img)
|
||||
|
||||
@staticmethod
|
||||
def from_dir(dir_path: str) -> List[np.ndarray]:
|
||||
if os.path.isdir(dir_path):
|
||||
image_files = glob.glob(os.path.join(dir_path, "*"))
|
||||
return ImageReader.from_list(image_files)
|
||||
else:
|
||||
raise NotADirectoryError(dir_path)
|
||||
|
||||
@staticmethod
|
||||
def from_str(img_path: str) -> np.ndarray:
|
||||
ImageReader.validate_img_path(img_path)
|
||||
return ImageReader.from_PIL(Image.open(img_path))
|
||||
|
||||
@staticmethod
|
||||
def from_np(img_array: np.ndarray) -> np.ndarray:
|
||||
return img_array
|
||||
|
||||
@staticmethod
|
||||
def from_PIL(img_pil: Image.Image, transpose=True) -> np.ndarray:
|
||||
# if img_pil.is_animated:
|
||||
# raise NotImplementedError("Only static images are supported, animated image found")
|
||||
if transpose:
|
||||
img_pil = ImageOps.exif_transpose(img_pil)
|
||||
if img_pil.mode != "RGB":
|
||||
img_pil = img_pil.convert("RGB")
|
||||
|
||||
return np.array(img_pil)
|
||||
|
||||
@staticmethod
|
||||
def from_list(img_list: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]:
|
||||
limgs = list()
|
||||
for img_path in img_list:
|
||||
try:
|
||||
if isinstance(img_path, str):
|
||||
ImageReader.validate_img_path(img_path)
|
||||
limgs.append(ImageReader._read(img_path))
|
||||
except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e:
|
||||
print("[ERROR]: ", e)
|
||||
print("[INFO]: Skipping image {}".format(img_path))
|
||||
return limgs
|
||||
|
||||
@staticmethod
|
||||
def from_pdf(pdf_path: str, start_page: int = 0, end_page: int = 0) -> List[np.ndarray]:
|
||||
pdf_file = convert_from_path(pdf_path)
|
||||
if end_page is not None:
|
||||
end_page = min(len(pdf_file), end_page + 1)
|
||||
limgs = [np.array(pdf_page) for pdf_page in pdf_file[start_page:end_page]]
|
||||
return limgs
|
||||
|
||||
@staticmethod
|
||||
def _read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray:
|
||||
if isinstance(img, str):
|
||||
return ImageReader.from_str(img)
|
||||
elif isinstance(img, Image.Image):
|
||||
return ImageReader.from_PIL(img)
|
||||
elif isinstance(img, np.ndarray):
|
||||
return ImageReader.from_np(img)
|
||||
else:
|
||||
raise ValueError("Invalid img argument type: ", type(img))
|
||||
|
||||
|
||||
def get_name(file_path, ext: bool = True):
|
||||
file_path_ = os.path.basename(file_path)
|
||||
return file_path_ if ext else os.path.splitext(file_path_)[0]
|
||||
|
||||
|
||||
def construct_file_path(dir, file_path, ext=''):
|
||||
'''
|
||||
args:
|
||||
dir: /path/to/dir
|
||||
file_path /example_path/to/file.txt
|
||||
ext = '.json'
|
||||
return
|
||||
/path/to/dir/file.json
|
||||
'''
|
||||
return os.path.join(
|
||||
dir, get_name(file_path,
|
||||
True)) if ext == '' else os.path.join(
|
||||
dir, get_name(file_path,
|
||||
False)) + ext
|
||||
|
||||
|
||||
def chunks(lst: list, n: int) -> Generator:
|
||||
"""
|
||||
Yield successive n-sized chunks from lst.
|
||||
https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equally-sized-chunks
|
||||
"""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
|
||||
|
||||
def read_ocr_result_from_txt(file_path: str) -> Tuple[list, list]:
|
||||
'''
|
||||
return list of bounding boxes, list of words
|
||||
'''
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
boxes, words = [], []
|
||||
for line in lines:
|
||||
if line == "":
|
||||
continue
|
||||
x1, y1, x2, y2, text = line.split("\t")
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
if text and text != " ":
|
||||
words.append(text)
|
||||
boxes.append((x1, y1, x2, y2))
|
||||
return boxes, words
|
||||
|
||||
|
||||
def get_xyxywh_base_on_format(bbox, format):
|
||||
if format == "xywh":
|
||||
x1, y1, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
elif format == "xyxy":
|
||||
x1, y1, x2, y2 = bbox
|
||||
w, h = x2 - x1, y2 - y1
|
||||
else:
|
||||
raise NotImplementedError("Invalid format {}".format(format))
|
||||
return (x1, y1, x2, y2, w, h)
|
||||
|
||||
|
||||
def get_dynamic_params_for_bbox_of_label(text, x1, y1, w, h, img_h, img_w, font):
|
||||
font_scale_factor = img_h / (img_w + img_h)
|
||||
font_scale = w / (w + h) * font_scale_factor # adjust font scale by width height
|
||||
thickness = int(font_scale_factor) + 1
|
||||
(text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0]
|
||||
text_offset_x = x1
|
||||
text_offset_y = y1 - thickness
|
||||
box_coords = ((text_offset_x, text_offset_y + 1), (text_offset_x + text_width - 2, text_offset_y - text_height - 2))
|
||||
return (font_scale, thickness, text_height, box_coords)
|
||||
|
||||
|
||||
def visualize_bbox_and_label(
|
||||
img, bboxes, texts, bbox_color=(200, 180, 60),
|
||||
text_color=(0, 0, 0),
|
||||
format="xyxy", is_vnese=False, draw_text=True):
|
||||
ori_img_type = type(img)
|
||||
if is_vnese:
|
||||
img = Image.fromarray(img) if ori_img_type is np.ndarray else img
|
||||
draw = ImageDraw.Draw(img)
|
||||
img_w, img_h = img.size
|
||||
font_pil_str = "fonts/arial.ttf"
|
||||
font_cv2 = cv2.FONT_HERSHEY_SIMPLEX
|
||||
else:
|
||||
img_h, img_w = img.shape[0], img.shape[1]
|
||||
font_cv2 = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for i in range(len(bboxes)):
|
||||
text = texts[i] # text = "{}: {:.0f}%".format(LABELS[classIDs[i]], confidences[i]*100)
|
||||
x1, y1, x2, y2, w, h = get_xyxywh_base_on_format(bboxes[i], format)
|
||||
font_scale, thickness, text_height, box_coords = get_dynamic_params_for_bbox_of_label(
|
||||
text, x1, y1, w, h, img_h, img_w, font=font_cv2)
|
||||
if is_vnese:
|
||||
font_pil = ImageFont.truetype(font_pil_str, size=text_height) # type: ignore
|
||||
fdraw_text = draw.text # type: ignore
|
||||
fdraw_bbox = draw.rectangle # type: ignore
|
||||
# Pil use different coordinate => y = y+thickness = y-thickness + 2*thickness
|
||||
arg_text = ((box_coords[0][0], box_coords[1][1]), text)
|
||||
kwarg_text = {"font": font_pil, "fill": text_color, "width": thickness}
|
||||
arg_rec = ((x1, y1, x2, y2),)
|
||||
kwarg_rec = {"outline": bbox_color, "width": thickness}
|
||||
arg_rec_text = ((box_coords[0], box_coords[1]),)
|
||||
kwarg_rec_text = {"fill": bbox_color, "width": thickness}
|
||||
else:
|
||||
# cv2.rectangle(img, box_coords[0], box_coords[1], color, cv2.FILLED)
|
||||
# cv2.putText(img, text, (text_offset_x, text_offset_y), font, fontScale=font_scale, color=(50, 0,0), thickness=thickness)
|
||||
# cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
|
||||
fdraw_text = cv2.putText
|
||||
fdraw_bbox = cv2.rectangle
|
||||
arg_text = (img, text, box_coords[0])
|
||||
kwarg_text = {"fontFace": font_cv2, "fontScale": font_scale, "color": text_color, "thickness": thickness}
|
||||
arg_rec = (img, (x1, y1), (x2, y2))
|
||||
kwarg_rec = {"color": bbox_color, "thickness": thickness}
|
||||
arg_rec_text = (img, box_coords[0], box_coords[1])
|
||||
kwarg_rec_text = {"color": bbox_color, "thickness": cv2.FILLED}
|
||||
# draw a bounding box rectangle and label on the img
|
||||
fdraw_bbox(*arg_rec, **kwarg_rec) # type: ignore
|
||||
if draw_text:
|
||||
fdraw_bbox(*arg_rec_text, **kwarg_rec_text) # type: ignore
|
||||
fdraw_text(*arg_text, **kwarg_text) # type: ignore # text have to put in front of rec_text
|
||||
return np.array(img) if ori_img_type is np.ndarray and is_vnese else img
|
673
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/word_formation.py
Executable file
673
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/word_formation.py
Executable file
@ -0,0 +1,673 @@
|
||||
from builtins import dict
|
||||
from .dto import Word, Line, Word_group, Box
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple, Union
|
||||
MIN_IOU_HEIGHT = 0.7
|
||||
MIN_WIDTH_LINE_RATIO = 0.05
|
||||
|
||||
|
||||
def resize_to_original(
|
||||
boundingbox, scale
|
||||
): # resize coordinates to match size of original image
|
||||
left, top, right, bottom = boundingbox
|
||||
left *= scale[1]
|
||||
right *= scale[1]
|
||||
top *= scale[0]
|
||||
bottom *= scale[0]
|
||||
return [left, top, right, bottom]
|
||||
|
||||
|
||||
def check_iomin(word: Word, word_group: Word_group):
|
||||
min_height = min(
|
||||
word.boundingbox[3] - word.boundingbox[1],
|
||||
word_group.boundingbox[3] - word_group.boundingbox[1],
|
||||
)
|
||||
intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max(
|
||||
word.boundingbox[1], word_group.boundingbox[1]
|
||||
)
|
||||
if intersect / min_height > 0.7:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def prepare_line(words):
|
||||
lines = []
|
||||
visited = [False] * len(words)
|
||||
for id_word, word in enumerate(words):
|
||||
if word.invalid_size() == 0:
|
||||
continue
|
||||
new_line = True
|
||||
for i in range(len(lines)):
|
||||
if (
|
||||
lines[i].in_same_line(word) and not visited[id_word]
|
||||
): # check if word is in the same line with lines[i]
|
||||
lines[i].merge_word(word)
|
||||
new_line = False
|
||||
visited[id_word] = True
|
||||
|
||||
if new_line == True:
|
||||
new_line = Line()
|
||||
new_line.merge_word(word)
|
||||
lines.append(new_line)
|
||||
|
||||
# print(len(lines))
|
||||
# sort line from top to bottom according top coordinate
|
||||
lines.sort(key=lambda x: x.boundingbox[1])
|
||||
return lines
|
||||
|
||||
|
||||
def __create_word_group(word, word_group_id):
|
||||
new_word_group_ = Word_group()
|
||||
new_word_group_.list_words = list()
|
||||
new_word_group_.word_group_id = word_group_id
|
||||
new_word_group_.add_word(word)
|
||||
|
||||
return new_word_group_
|
||||
|
||||
|
||||
def __sort_line(line):
|
||||
line.list_word_groups.sort(
|
||||
key=lambda x: x.boundingbox[0]
|
||||
) # sort word in lines from left to right
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def __merge_text_for_line(line):
|
||||
line.text = ""
|
||||
for word in line.list_word_groups:
|
||||
line.text += " " + word.text
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def __update_list_word_groups(line, word_group_id, word_id, line_width):
|
||||
|
||||
old_list_word_group = line.list_word_groups
|
||||
list_word_groups = []
|
||||
|
||||
inital_word_group = __create_word_group(
|
||||
old_list_word_group[0], word_group_id)
|
||||
old_list_word_group[0].word_id = word_id
|
||||
list_word_groups.append(inital_word_group)
|
||||
word_group_id += 1
|
||||
word_id += 1
|
||||
|
||||
for word in old_list_word_group[1:]:
|
||||
check_word_group = True
|
||||
word.word_id = word_id
|
||||
word_id += 1
|
||||
|
||||
if (
|
||||
(not list_word_groups[-1].text.endswith(":"))
|
||||
and (
|
||||
(word.boundingbox[0] - list_word_groups[-1].boundingbox[2])
|
||||
/ line_width
|
||||
< MIN_WIDTH_LINE_RATIO
|
||||
)
|
||||
and check_iomin(word, list_word_groups[-1])
|
||||
):
|
||||
list_word_groups[-1].add_word(word)
|
||||
check_word_group = False
|
||||
|
||||
if check_word_group:
|
||||
new_word_group = __create_word_group(word, word_group_id)
|
||||
list_word_groups.append(new_word_group)
|
||||
word_group_id += 1
|
||||
line.list_word_groups = list_word_groups
|
||||
return line, word_group_id, word_id
|
||||
|
||||
|
||||
def construct_word_groups_in_each_line(lines):
|
||||
line_id = 0
|
||||
word_group_id = 0
|
||||
word_id = 0
|
||||
for i in range(len(lines)):
|
||||
if len(lines[i].list_word_groups) == 0:
|
||||
continue
|
||||
|
||||
# left, top ,right, bottom
|
||||
line_width = lines[i].boundingbox[2] - \
|
||||
lines[i].boundingbox[0] # right - left
|
||||
line_width = 1 # TODO: to remove
|
||||
lines[i] = __sort_line(lines[i])
|
||||
|
||||
# update text for lines after sorting
|
||||
lines[i] = __merge_text_for_line(lines[i])
|
||||
|
||||
lines[i], word_group_id, word_id = __update_list_word_groups(
|
||||
lines[i],
|
||||
word_group_id,
|
||||
word_id,
|
||||
line_width)
|
||||
lines[i].update_line_id(line_id)
|
||||
line_id += 1
|
||||
return lines
|
||||
|
||||
|
||||
def words_to_lines(words, check_special_lines=True): # words is list of Word instance
|
||||
# sort word by top
|
||||
words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
||||
# words.sort(key=lambda x: (sum(x.bbox[:])))
|
||||
number_of_word = len(words)
|
||||
# print(number_of_word)
|
||||
# sort list words to list lines, which have not contained word_group yet
|
||||
lines = prepare_line(words)
|
||||
|
||||
# construct word_groups in each line
|
||||
lines = construct_word_groups_in_each_line(lines)
|
||||
return lines, number_of_word
|
||||
|
||||
|
||||
def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8):
|
||||
"""Check if two boxes are on the same line by their y-axis coordinates.
|
||||
|
||||
Two boxes are on the same line if they overlap vertically, and the length
|
||||
of the overlapping line segment is greater than min_y_overlap_ratio * the
|
||||
height of either of the boxes.
|
||||
|
||||
Args:
|
||||
box_a (list), box_b (list): Two bounding boxes to be checked
|
||||
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
|
||||
allowed for boxes in the same line
|
||||
|
||||
Returns:
|
||||
The bool flag indicating if they are on the same line
|
||||
"""
|
||||
a_y_min = np.min(box_a[1::2])
|
||||
b_y_min = np.min(box_b[1::2])
|
||||
a_y_max = np.max(box_a[1::2])
|
||||
b_y_max = np.max(box_b[1::2])
|
||||
|
||||
# Make sure that box a is always the box above another
|
||||
if a_y_min > b_y_min:
|
||||
a_y_min, b_y_min = b_y_min, a_y_min
|
||||
a_y_max, b_y_max = b_y_max, a_y_max
|
||||
|
||||
if b_y_min <= a_y_max:
|
||||
if min_y_overlap_ratio is not None:
|
||||
sorted_y = sorted([b_y_min, b_y_max, a_y_max])
|
||||
overlap = sorted_y[1] - sorted_y[0]
|
||||
min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio
|
||||
min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio
|
||||
return overlap >= min_a_overlap or \
|
||||
overlap >= min_b_overlap
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def merge_bboxes_to_group(bboxes_group, x_sorted_boxes):
|
||||
merged_bboxes = []
|
||||
for box_group in bboxes_group:
|
||||
merged_box = {}
|
||||
merged_box['text'] = ' '.join(
|
||||
[x_sorted_boxes[idx]['text'] for idx in box_group])
|
||||
x_min, y_min = float('inf'), float('inf')
|
||||
x_max, y_max = float('-inf'), float('-inf')
|
||||
for idx in box_group:
|
||||
x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max)
|
||||
x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min)
|
||||
y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max)
|
||||
y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min)
|
||||
merged_box['box'] = [
|
||||
x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max
|
||||
]
|
||||
merged_box['list_words'] = [x_sorted_boxes[idx]['word']
|
||||
for idx in box_group]
|
||||
merged_bboxes.append(merged_box)
|
||||
return merged_bboxes
|
||||
|
||||
|
||||
def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.3):
|
||||
"""Stitch fragmented boxes of words into lines.
|
||||
|
||||
Note: part of its logic is inspired by @Johndirr
|
||||
(https://github.com/faustomorales/keras-ocr/issues/22)
|
||||
|
||||
Args:
|
||||
boxes (list): List of ocr results to be stitched
|
||||
max_x_dist (int): The maximum horizontal distance between the closest
|
||||
edges of neighboring boxes in the same line
|
||||
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
|
||||
allowed for any pairs of neighboring boxes in the same line
|
||||
|
||||
Returns:
|
||||
merged_boxes(List[dict]): List of merged boxes and texts
|
||||
"""
|
||||
|
||||
if len(boxes) <= 1:
|
||||
if len(boxes) == 1:
|
||||
boxes[0]["list_words"] = [boxes[0]["word"]]
|
||||
return boxes
|
||||
|
||||
# merged_groups = []
|
||||
merged_lines = []
|
||||
|
||||
# sort groups based on the x_min coordinate of boxes
|
||||
x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2]))
|
||||
# store indexes of boxes which are already parts of other lines
|
||||
skip_idxs = set()
|
||||
|
||||
i = 0
|
||||
# locate lines of boxes starting from the leftmost one
|
||||
for i in range(len(x_sorted_boxes)):
|
||||
if i in skip_idxs:
|
||||
continue
|
||||
# the rightmost box in the current line
|
||||
rightmost_box_idx = i
|
||||
line = [rightmost_box_idx]
|
||||
for j in range(i + 1, len(x_sorted_boxes)):
|
||||
if j in skip_idxs:
|
||||
continue
|
||||
if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'],
|
||||
x_sorted_boxes[j]['box'], min_y_overlap_ratio):
|
||||
line.append(j)
|
||||
skip_idxs.add(j)
|
||||
rightmost_box_idx = j
|
||||
|
||||
# split line into lines if the distance between two neighboring
|
||||
# sub-lines' is greater than max_x_dist
|
||||
# groups = []
|
||||
# line_idx = 0
|
||||
# groups.append([line[0]])
|
||||
# for k in range(1, len(line)):
|
||||
# curr_box = x_sorted_boxes[line[k]]
|
||||
# prev_box = x_sorted_boxes[line[k - 1]]
|
||||
# dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2])
|
||||
# if dist > max_x_dist:
|
||||
# line_idx += 1
|
||||
# groups.append([])
|
||||
# groups[line_idx].append(line[k])
|
||||
|
||||
# # Get merged boxes
|
||||
merged_line = merge_bboxes_to_group([line], x_sorted_boxes)
|
||||
merged_lines.extend(merged_line)
|
||||
# merged_group = merge_bboxes_to_group(groups,x_sorted_boxes)
|
||||
# merged_groups.extend(merged_group)
|
||||
|
||||
merged_lines = sorted(merged_lines, key=lambda x: np.min(x['box'][1::2]))
|
||||
# merged_groups = sorted(merged_groups, key=lambda x: np.min(x['box'][1::2]))
|
||||
return merged_lines # , merged_groups
|
||||
|
||||
# REFERENCE
|
||||
# https://vigneshgig.medium.com/bounding-box-sorting-algorithm-for-text-detection-and-object-detection-from-left-to-right-and-top-cf2c523c8a85
|
||||
# https://huggingface.co/spaces/tomofi/MMOCR/blame/main/mmocr/utils/box_util.py
|
||||
|
||||
|
||||
def words_to_lines_mmocr(words: List[Word], *args) -> Tuple[List[Line], Optional[int]]:
|
||||
bboxes = [{"box": [w.bbox[0], w.bbox[1], w.bbox[2], w.bbox[1], w.bbox[2], w.bbox[3], w.bbox[0], w.bbox[3]],
|
||||
"text":w.text, "word":w} for w in words]
|
||||
merged_lines = stitch_boxes_into_lines(bboxes)
|
||||
merged_groups = merged_lines # TODO: fix code to return both word group and line
|
||||
lwords_groups = [Word_group(list_words_=merged_box["list_words"],
|
||||
text=merged_box["text"],
|
||||
boundingbox=[merged_box["box"][i] for i in [0, 1, 2, -1]])
|
||||
for merged_box in merged_groups]
|
||||
|
||||
llines = [Line(text=word_group.text, list_word_groups=[word_group], boundingbox=word_group.boundingbox)
|
||||
for word_group in lwords_groups]
|
||||
|
||||
return llines, None # same format with the origin words_to_lines
|
||||
# lines = [Line() for merged]
|
||||
|
||||
|
||||
# def most_overlapping_row(rows, top, bottom, y_shift):
|
||||
# max_overlap = -1
|
||||
# max_overlap_idx = -1
|
||||
# for i, row in enumerate(rows):
|
||||
# row_top, row_bottom = row
|
||||
# overlap = min(top + y_shift, row_top) - max(bottom + y_shift, row_bottom)
|
||||
# if overlap > max_overlap:
|
||||
# max_overlap = overlap
|
||||
# max_overlap_idx = i
|
||||
# return max_overlap_idx
|
||||
def most_overlapping_row(rows, row_words, top, bottom, y_shift, max_row_size, y_overlap_threshold=0.5):
|
||||
max_overlap = -1
|
||||
max_overlap_idx = -1
|
||||
overlapping_rows = []
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
row_top, row_bottom = row
|
||||
overlap = min(top - y_shift[i], row_top) - \
|
||||
max(bottom - y_shift[i], row_bottom)
|
||||
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
max_overlap_idx = i
|
||||
|
||||
# if at least overlap 1 pixel and not (overlap too much and overlap too little)
|
||||
if (row_bottom <= top and row_top >= bottom) and not (top - bottom - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold):
|
||||
overlapping_rows.append(i)
|
||||
|
||||
# Merge overlapping rows if necessary
|
||||
if len(overlapping_rows) > 1:
|
||||
merge_top = max(rows[i][0] for i in overlapping_rows)
|
||||
merge_bottom = min(rows[i][1] for i in overlapping_rows)
|
||||
|
||||
if merge_top - merge_bottom <= max_row_size:
|
||||
# Merge rows
|
||||
merged_row = (merge_top, merge_bottom)
|
||||
merged_words = []
|
||||
# Remove other overlapping rows
|
||||
|
||||
for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2
|
||||
merged_words.extend(row_words[row_idx])
|
||||
del rows[row_idx]
|
||||
del row_words[row_idx]
|
||||
|
||||
rows[overlapping_rows[0]] = merged_row
|
||||
row_words[overlapping_rows[0]].extend(merged_words[::-1])
|
||||
max_overlap_idx = overlapping_rows[0]
|
||||
|
||||
if top - bottom - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold:
|
||||
max_overlap_idx = -1
|
||||
return max_overlap_idx
|
||||
|
||||
|
||||
def stitch_boxes_into_lines_tesseract(words: list[Word],
|
||||
gradient: float, y_overlap_threshold: float) -> Tuple[list[list[Word]],
|
||||
float]:
|
||||
sorted_words = sorted(words, key=lambda x: x.bbox[0])
|
||||
rows = []
|
||||
row_words = []
|
||||
max_row_size = sorted_words[0].height
|
||||
running_y_shift = []
|
||||
for _i, word in enumerate(sorted_words):
|
||||
# if word.bbox[1] > 340 and word.bbox[3] < 450:
|
||||
# print("DEBUG")
|
||||
# if word.text == "Lực":
|
||||
# print("DEBUG")
|
||||
bbox, _text = word.bbox[:], word.text
|
||||
_x1, y1, _x2, y2 = bbox
|
||||
top, bottom = y2, y1
|
||||
max_row_size = max(max_row_size, top - bottom)
|
||||
overlap_row_idx = most_overlapping_row(
|
||||
rows, row_words, top, bottom, running_y_shift, max_row_size, y_overlap_threshold)
|
||||
|
||||
if overlap_row_idx == -1: # No overlapping row found
|
||||
new_row = (top, bottom)
|
||||
rows.append(new_row)
|
||||
row_words.append([word])
|
||||
running_y_shift.append(0)
|
||||
else: # Overlapping row found
|
||||
row_top, row_bottom = rows[overlap_row_idx]
|
||||
new_top = max(row_top, top)
|
||||
new_bottom = min(row_bottom, bottom)
|
||||
rows[overlap_row_idx] = (new_top, new_bottom)
|
||||
row_words[overlap_row_idx].append(word)
|
||||
new_shift = (bottom + top) / 2 - (row_bottom + row_top) / 2
|
||||
running_y_shift[overlap_row_idx] = gradient * \
|
||||
running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift
|
||||
|
||||
# Sort rows and row_texts based on the top y-coordinate
|
||||
sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][0])
|
||||
_sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data)
|
||||
# /_|<- the perpendicular line of the horizontal line and the skew line of the page
|
||||
page_skew_dist = sum(running_y_shift) / len(running_y_shift)
|
||||
return sorted_row_words, page_skew_dist
|
||||
|
||||
|
||||
def construct_word_groups_tesseract(sorted_row_words: list[list[Word]],
|
||||
max_x_dist: int, page_skew_dist: float) -> list[list[list[Word]]]:
|
||||
# approximate page_skew_angle by page_skew_dist
|
||||
corrected_max_x_dist = max_x_dist * abs(np.cos(page_skew_dist / 180 * 3.14))
|
||||
constructed_row_word_groups = []
|
||||
for row_words in sorted_row_words:
|
||||
lword_groups = []
|
||||
line_idx = 0
|
||||
lword_groups.append([row_words[0]])
|
||||
for k in range(1, len(row_words)):
|
||||
curr_box = row_words[k].bbox[:]
|
||||
prev_box = row_words[k - 1].bbox[:]
|
||||
dist = curr_box[0] - prev_box[2]
|
||||
if dist > corrected_max_x_dist:
|
||||
line_idx += 1
|
||||
lword_groups.append([])
|
||||
lword_groups[line_idx].append(row_words[k])
|
||||
constructed_row_word_groups.append(lword_groups)
|
||||
return constructed_row_word_groups
|
||||
|
||||
|
||||
def group_bbox_and_text(lwords: list[Word]) -> tuple[Box, tuple[str, float]]:
|
||||
text = ' '.join([word.text for word in lwords])
|
||||
x_min, y_min = float('inf'), float('inf')
|
||||
x_max, y_max = float('-inf'), float('-inf')
|
||||
conf_det = 0
|
||||
conf_cls = 0
|
||||
for word in lwords:
|
||||
x_max = max(np.max(word.bbox[::2]), x_max)
|
||||
x_min = min(np.min(word.bbox[::2]), x_min)
|
||||
y_max = max(np.max(word.bbox[1::2]), y_max)
|
||||
y_min = min(np.min(word.bbox[1::2]), y_min)
|
||||
conf_det += word.conf_detect
|
||||
conf_cls += word.conf_cls
|
||||
bbox = Box(x_min, y_min, x_max, y_max, conf=conf_det / len(lwords))
|
||||
return bbox, (text, conf_cls / len(lwords))
|
||||
|
||||
|
||||
def words_to_lines_tesseract(words: List[Word],
|
||||
gradient: float, max_x_dist: int, y_overlap_threshold: float) -> Tuple[List[Line],
|
||||
Optional[int]]:
|
||||
sorted_row_words, page_skew_dist = stitch_boxes_into_lines_tesseract(
|
||||
words, gradient, y_overlap_threshold)
|
||||
constructed_row_word_groups = construct_word_groups_tesseract(
|
||||
sorted_row_words, max_x_dist, page_skew_dist)
|
||||
llines = []
|
||||
for row in constructed_row_word_groups:
|
||||
lwords_row = []
|
||||
lword_groups = []
|
||||
for word_group in row:
|
||||
bbox_word_group, text_word_group = group_bbox_and_text(word_group)
|
||||
lwords_row.extend(word_group)
|
||||
lword_groups.append(
|
||||
Word_group(
|
||||
list_words_=word_group, text=text_word_group[0],
|
||||
conf_cls=text_word_group[1],
|
||||
boundingbox=bbox_word_group))
|
||||
bbox_line, text_line = group_bbox_and_text(lwords_row)
|
||||
llines.append(
|
||||
Line(
|
||||
list_word_groups=lword_groups, text=text_line[0],
|
||||
boundingbox=bbox_line, conf_cls=text_line[1]))
|
||||
return llines, None
|
||||
|
||||
|
||||
def near(word_group1: Word_group, word_group2: Word_group):
|
||||
min_height = min(
|
||||
word_group1.boundingbox[3] - word_group1.boundingbox[1],
|
||||
word_group2.boundingbox[3] - word_group2.boundingbox[1],
|
||||
)
|
||||
overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max(
|
||||
word_group1.boundingbox[1], word_group2.boundingbox[1]
|
||||
)
|
||||
|
||||
if overlap > 0:
|
||||
return True
|
||||
if abs(overlap / min_height) < 1.5:
|
||||
print("near enough", abs(overlap / min_height), overlap, min_height)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def calculate_iou_and_near(wg1: Word_group, wg2: Word_group):
|
||||
min_height = min(
|
||||
wg1.boundingbox[3] -
|
||||
wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1]
|
||||
)
|
||||
overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max(
|
||||
wg1.boundingbox[1], wg2.boundingbox[1]
|
||||
)
|
||||
iou = overlap / min_height
|
||||
distance = min(
|
||||
abs(wg1.boundingbox[0] - wg2.boundingbox[2]),
|
||||
abs(wg1.boundingbox[2] - wg2.boundingbox[0]),
|
||||
)
|
||||
if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def construct_word_groups_to_kie_label(list_word_groups: list):
|
||||
kie_dict = dict()
|
||||
for wg in list_word_groups:
|
||||
if wg.kie_label == "other":
|
||||
continue
|
||||
if wg.kie_label not in kie_dict:
|
||||
kie_dict[wg.kie_label] = [wg]
|
||||
else:
|
||||
kie_dict[wg.kie_label].append(wg)
|
||||
|
||||
new_dict = dict()
|
||||
for key, value in kie_dict.items():
|
||||
if len(value) == 1:
|
||||
new_dict[key] = value
|
||||
continue
|
||||
|
||||
value.sort(key=lambda x: x.boundingbox[1])
|
||||
new_dict[key] = value
|
||||
return new_dict
|
||||
|
||||
|
||||
def invoice_construct_word_groups_to_kie_label(list_word_groups: list):
|
||||
kie_dict = dict()
|
||||
|
||||
for wg in list_word_groups:
|
||||
if wg.kie_label == "other":
|
||||
continue
|
||||
if wg.kie_label not in kie_dict:
|
||||
kie_dict[wg.kie_label] = [wg]
|
||||
else:
|
||||
kie_dict[wg.kie_label].append(wg)
|
||||
|
||||
return kie_dict
|
||||
|
||||
|
||||
def postprocess_total_value(kie_dict):
|
||||
if "total_in_words_value" not in kie_dict:
|
||||
return kie_dict
|
||||
|
||||
for k, value in kie_dict.items():
|
||||
if k == "total_in_words_value":
|
||||
continue
|
||||
l = []
|
||||
for v in value:
|
||||
if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]:
|
||||
l.append(v)
|
||||
|
||||
if len(l) != 0:
|
||||
kie_dict[k] = l
|
||||
|
||||
return kie_dict
|
||||
|
||||
|
||||
def postprocess_tax_code_value(kie_dict):
|
||||
if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict:
|
||||
return kie_dict
|
||||
|
||||
kie_dict["buyer_tax_code_value"] = []
|
||||
for v in kie_dict["seller_tax_code_value"]:
|
||||
if "buyer_name_key" in kie_dict and (
|
||||
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
|
||||
or near(v, kie_dict["buyer_name_key"][0])
|
||||
):
|
||||
kie_dict["buyer_tax_code_value"].append(v)
|
||||
continue
|
||||
|
||||
if "buyer_name_value" in kie_dict and (
|
||||
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
|
||||
or near(v, kie_dict["buyer_name_value"][0])
|
||||
):
|
||||
kie_dict["buyer_tax_code_value"].append(v)
|
||||
continue
|
||||
|
||||
if "buyer_address_value" in kie_dict and near(
|
||||
kie_dict["buyer_address_value"][0], v
|
||||
):
|
||||
kie_dict["buyer_tax_code_value"].append(v)
|
||||
return kie_dict
|
||||
|
||||
|
||||
def postprocess_tax_code_key(kie_dict):
|
||||
if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict:
|
||||
return kie_dict
|
||||
kie_dict["buyer_tax_code_key"] = []
|
||||
for v in kie_dict["seller_tax_code_key"]:
|
||||
if "buyer_name_key" in kie_dict and (
|
||||
v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3]
|
||||
or near(v, kie_dict["buyer_name_key"][0])
|
||||
):
|
||||
kie_dict["buyer_tax_code_key"].append(v)
|
||||
continue
|
||||
|
||||
if "buyer_name_value" in kie_dict and (
|
||||
v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3]
|
||||
or near(v, kie_dict["buyer_name_value"][0])
|
||||
):
|
||||
kie_dict["buyer_tax_code_key"].append(v)
|
||||
continue
|
||||
|
||||
if "buyer_address_value" in kie_dict and near(
|
||||
kie_dict["buyer_address_value"][0], v
|
||||
):
|
||||
kie_dict["buyer_tax_code_key"].append(v)
|
||||
|
||||
return kie_dict
|
||||
|
||||
|
||||
def invoice_postprocess(kie_dict: dict):
|
||||
# all keys or values which are below total_in_words_value will be thrown away
|
||||
kie_dict = postprocess_total_value(kie_dict)
|
||||
kie_dict = postprocess_tax_code_value(kie_dict)
|
||||
kie_dict = postprocess_tax_code_key(kie_dict)
|
||||
return kie_dict
|
||||
|
||||
|
||||
def throw_overlapping_words(list_words):
|
||||
new_list = [list_words[0]]
|
||||
for word in list_words:
|
||||
overlap = False
|
||||
area = (word.boundingbox[2] - word.boundingbox[0]) * (
|
||||
word.boundingbox[3] - word.boundingbox[1]
|
||||
)
|
||||
for word2 in new_list:
|
||||
area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * (
|
||||
word2.boundingbox[3] - word2.boundingbox[1]
|
||||
)
|
||||
xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0])
|
||||
xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2])
|
||||
ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1])
|
||||
ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3])
|
||||
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
|
||||
continue
|
||||
|
||||
area_intersect = (xmax_intersect - xmin_intersect) * (
|
||||
ymax_intersect - ymin_intersect
|
||||
)
|
||||
if area_intersect / area > 0.7 or area_intersect / area2 > 0.7:
|
||||
overlap = True
|
||||
if overlap == False:
|
||||
new_list.append(word)
|
||||
return new_list
|
||||
|
||||
|
||||
def check_iou(box1: Word, box2: Box, threshold=0.9):
|
||||
area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * (
|
||||
box1.boundingbox[3] - box1.boundingbox[1]
|
||||
)
|
||||
area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin)
|
||||
xmin_intersect = max(box1.boundingbox[0], box2.xmin)
|
||||
ymin_intersect = max(box1.boundingbox[1], box2.ymin)
|
||||
xmax_intersect = min(box1.boundingbox[2], box2.xmax)
|
||||
ymax_intersect = min(box1.boundingbox[3], box2.ymax)
|
||||
if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect:
|
||||
area_intersect = 0
|
||||
else:
|
||||
area_intersect = (xmax_intersect - xmin_intersect) * (
|
||||
ymax_intersect - ymin_intersect
|
||||
)
|
||||
union = area1 + area2 - area_intersect
|
||||
iou = area_intersect / union
|
||||
if iou > threshold:
|
||||
return True
|
||||
return False
|
230
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py
Executable file
230
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py
Executable file
@ -0,0 +1,230 @@
|
||||
from omegaconf import OmegaConf
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
# from functions import get_colormap, visualize
|
||||
import sys
|
||||
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ???????
|
||||
|
||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||
from model import get_model
|
||||
from utils import load_model_weight
|
||||
|
||||
|
||||
class KVUPredictor:
|
||||
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
||||
cfg_path = configs['cfg']
|
||||
ckpt_path = configs['ckpt']
|
||||
|
||||
self.class_names = class_names
|
||||
self.dummy_idx = dummy_idx
|
||||
self.mode = mode
|
||||
|
||||
print('[INFO] Loading Key-Value Understanding model ...')
|
||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||
print("[INFO] Loaded model")
|
||||
|
||||
if mode == 3:
|
||||
self.max_window_count = cfg.train.max_window_count
|
||||
self.window_size = cfg.train.window_size
|
||||
self.slice_interval = 0
|
||||
self.dummy_idx = dummy_idx * self.max_window_count
|
||||
else:
|
||||
self.slice_interval = cfg.train.slice_interval
|
||||
self.window_size = cfg.train.max_num_words
|
||||
|
||||
|
||||
self.device = 'cuda'
|
||||
|
||||
def _load_model(self, cfg_path, ckpt_path):
|
||||
cfg = OmegaConf.load(cfg_path)
|
||||
cfg.stage = self.mode
|
||||
backbone_type = cfg.model.backbone
|
||||
|
||||
print('[INFO] Checkpoint:', ckpt_path)
|
||||
net = get_model(cfg)
|
||||
load_model_weight(net, ckpt_path)
|
||||
net.to('cuda')
|
||||
net.eval()
|
||||
return net, cfg, backbone_type
|
||||
|
||||
def predict(self, input_sample):
|
||||
if self.mode == 0:
|
||||
if len(input_sample['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
elif self.mode == 1:
|
||||
if len(input_sample['documents']['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
elif self.mode == 2:
|
||||
if len(input_sample['windows'][0]['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
|
||||
for window in input_sample['windows']:
|
||||
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
|
||||
bbox.append(_bbox)
|
||||
lwords.append(_lwords)
|
||||
pr_class_words.append(_pr_class_words)
|
||||
pr_relations.append(_pr_relations)
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
elif self.mode == 3:
|
||||
if len(input_sample["documents"]['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported mode: {self.mode}"
|
||||
)
|
||||
|
||||
def doc_predict(self, input_sample):
|
||||
lwords = input_sample['documents']['words']
|
||||
for idx, window in enumerate(input_sample['windows']):
|
||||
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
|
||||
|
||||
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
||||
input_sample = input_sample['documents']
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = input_sample['attention_mask'].squeeze(0)
|
||||
bbox = input_sample['bbox'].squeeze(0)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
|
||||
def combined_predict(self, input_sample):
|
||||
lwords = input_sample['words']
|
||||
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
|
||||
|
||||
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
||||
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
||||
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
||||
bbox = input_sample['bbox'].squeeze(0)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
def cat_predict(self, input_sample):
|
||||
lwords = input_sample['documents']['words']
|
||||
|
||||
inputs = []
|
||||
for window in input_sample['windows']:
|
||||
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
|
||||
input_sample['windows'] = inputs
|
||||
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
|
||||
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
|
||||
bbox = input_sample['documents']['bbox']
|
||||
|
||||
dummy_idx = input_sample['documents']['bbox'].shape[0]
|
||||
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
|
||||
def get_ground_truth_label(self, ground_truth):
|
||||
# ground_truth = self.preprocessor.load_ground_truth(json_file)
|
||||
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
|
||||
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
|
||||
gt_el_label = ground_truth['el_labels'].squeeze(0)
|
||||
|
||||
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
|
||||
lwords = ground_truth["words"]
|
||||
|
||||
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = ground_truth['attention_mask'].squeeze(0)
|
||||
|
||||
bbox = ground_truth['bbox'].squeeze(0)
|
||||
gt_first_words = parse_initial_words(
|
||||
gt_itc_label, box_first_token_mask, self.class_names
|
||||
)
|
||||
gt_class_words = parse_subsequent_words(
|
||||
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
|
||||
)
|
||||
|
||||
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
|
||||
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
|
||||
gt_relations = gt_relations_from_header | gt_relations_from_key
|
||||
|
||||
return bbox, lwords, gt_class_words, gt_relations
|
601
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py
Executable file
601
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py
Executable file
@ -0,0 +1,601 @@
|
||||
import os
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import imagesize
|
||||
import itertools
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
|
||||
from utils.run_ocr import load_ocr_engine, process_img
|
||||
from lightning_modules.utils import sliding_windows
|
||||
|
||||
|
||||
class KVUProcess:
|
||||
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
||||
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
||||
self.feature_extractor = feature_extractor
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.backbone_type = backbone_type
|
||||
self.class_names = class_names
|
||||
|
||||
self.slice_interval = slice_interval
|
||||
self.window_size = window_size
|
||||
self.run_ocr = run_ocr
|
||||
self.mode = mode
|
||||
|
||||
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
|
||||
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
|
||||
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
|
||||
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
|
||||
|
||||
|
||||
self.class_idx_dic = dict(
|
||||
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
|
||||
)
|
||||
self.ocr_engine = None
|
||||
if self.run_ocr == 1:
|
||||
self.ocr_engine = load_ocr_engine()
|
||||
|
||||
def __call__(self, img_path: str, ocr_path: str) -> list:
|
||||
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
|
||||
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
|
||||
ocr_path = "tmp.txt"
|
||||
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
||||
lwords = post_process_basic_ocr(lwords)
|
||||
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
|
||||
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
|
||||
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
|
||||
|
||||
width, height = imagesize.get(img_path)
|
||||
images = [Image.open(img_path).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
|
||||
|
||||
if self.mode == 0:
|
||||
output = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
elif self.mode == 1:
|
||||
output = {}
|
||||
windows = []
|
||||
for i in range(len(bbox_windows)):
|
||||
_words = word_windows[i]
|
||||
_bboxes = bbox_windows[i]
|
||||
windows.append(
|
||||
self.preprocess(
|
||||
_bboxes, _words,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
)
|
||||
|
||||
output['windows'] = windows
|
||||
elif self.mode == 2:
|
||||
output = {}
|
||||
windows = []
|
||||
output['doduments'] = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=2048)
|
||||
for i in range(len(bbox_windows)):
|
||||
_words = word_windows[i]
|
||||
_bboxes = bbox_windows[i]
|
||||
windows.append(
|
||||
self.preprocess(
|
||||
_bboxes, _words,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
)
|
||||
|
||||
output['windows'] = windows
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported mode: {self.mode }"
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
|
||||
input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
||||
|
||||
attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int)
|
||||
|
||||
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
||||
|
||||
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
||||
|
||||
list_layoutxlm_tokens = []
|
||||
|
||||
list_bbs = []
|
||||
list_words = []
|
||||
lwords = [''] * max_seq_length
|
||||
|
||||
box_to_token_indices = []
|
||||
cum_token_idx = 0
|
||||
|
||||
cls_bbs = [0.0] * 8
|
||||
len_overlap_tokens = 0
|
||||
len_non_overlap_tokens = 0
|
||||
len_valid_tokens = 0
|
||||
|
||||
for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)):
|
||||
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
|
||||
layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word))
|
||||
|
||||
this_box_token_indices = []
|
||||
|
||||
len_valid_tokens += len(layoutxlm_tokens)
|
||||
if word_idx < self.slice_interval:
|
||||
len_non_overlap_tokens += len(layoutxlm_tokens)
|
||||
|
||||
if len(layoutxlm_tokens) == 0:
|
||||
layoutxlm_tokens.append(self.unk_token_id)
|
||||
|
||||
|
||||
if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2:
|
||||
break
|
||||
|
||||
|
||||
list_layoutxlm_tokens += layoutxlm_tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(layoutxlm_tokens))]
|
||||
texts = [word for _ in range(len(layoutxlm_tokens))]
|
||||
|
||||
for _ in layoutxlm_tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
list_words.extend(texts) ###
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_layoutxlm_tokens = (
|
||||
[self.cls_token_id_layoutxlm]
|
||||
+ list_layoutxlm_tokens[: max_seq_length - 2]
|
||||
+ [self.sep_token_id_layoutxlm]
|
||||
)
|
||||
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
|
||||
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP']
|
||||
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
|
||||
|
||||
len_list_layoutxlm_tokens = len(list_layoutxlm_tokens)
|
||||
input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens
|
||||
attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1
|
||||
|
||||
|
||||
bbox[:len_list_layoutxlm_tokens, :] = list_bbs
|
||||
lwords[:len_list_layoutxlm_tokens] = list_words ###
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
|
||||
|
||||
if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"):
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
else:
|
||||
assert False
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
assert len_list_layoutxlm_tokens == len_valid_tokens + 2
|
||||
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
|
||||
|
||||
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
|
||||
|
||||
input_ids_layoutxlm = input_ids_layoutxlm[:ntokens]
|
||||
attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens]
|
||||
bbox = bbox[:ntokens]
|
||||
are_box_first_tokens = are_box_first_tokens[:ntokens]
|
||||
|
||||
|
||||
input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm)
|
||||
attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
|
||||
len_valid_tokens = torch.tensor(len_valid_tokens)
|
||||
len_overlap_tokens = torch.tensor(len_overlap_tokens)
|
||||
return_dict = {
|
||||
"img_path": feature_maps['img_path'],
|
||||
"words": list_words,
|
||||
"len_overlap_tokens": len_overlap_tokens,
|
||||
'len_valid_tokens': len_valid_tokens,
|
||||
"image": feature_maps['image'],
|
||||
"input_ids_layoutxlm": input_ids_layoutxlm,
|
||||
"attention_mask_layoutxlm": attention_mask_layoutxlm,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"bbox": bbox,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
def load_ground_truth(self, json_file):
|
||||
json_obj = read_json(json_file)
|
||||
width = json_obj["meta"]["imageSize"]["width"]
|
||||
height = json_obj["meta"]["imageSize"]["height"]
|
||||
|
||||
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
||||
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
||||
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
||||
|
||||
itc_labels = np.zeros(self.max_seq_length, dtype=int)
|
||||
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
||||
|
||||
# stc_labels stores the index of the previous token.
|
||||
# A stored index of max_seq_length (512) indicates that
|
||||
# this token is the initial token of a word box.
|
||||
stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
|
||||
el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
|
||||
el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length
|
||||
|
||||
|
||||
list_tokens = []
|
||||
list_bbs = []
|
||||
list_words = []
|
||||
box2token_span_map = []
|
||||
lwords = [''] * self.max_seq_length
|
||||
|
||||
box_to_token_indices = []
|
||||
cum_token_idx = 0
|
||||
|
||||
cls_bbs = [0.0] * 8
|
||||
|
||||
for word_idx, word in enumerate(json_obj["words"]):
|
||||
this_box_token_indices = []
|
||||
|
||||
tokens = word["layoutxlm_tokens"]
|
||||
bb = word["boundingBox"]
|
||||
text = word["text"]
|
||||
|
||||
if len(tokens) == 0:
|
||||
tokens.append(self.unk_token_id)
|
||||
|
||||
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
|
||||
break
|
||||
|
||||
box2token_span_map.append(
|
||||
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
|
||||
) # including st_idx
|
||||
list_tokens += tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(tokens))]
|
||||
texts = [text for _ in range(len(tokens))]
|
||||
|
||||
for _ in tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
list_words.extend(texts) ####
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [width, height] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_tokens = (
|
||||
[self.cls_token_id_layoutxlm]
|
||||
+ list_tokens[: self.max_seq_length - 2]
|
||||
+ [self.sep_token_id_layoutxlm]
|
||||
)
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
|
||||
# list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ###
|
||||
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
|
||||
|
||||
|
||||
len_list_tokens = len(list_tokens)
|
||||
input_ids[:len_list_tokens] = list_tokens
|
||||
attention_mask[:len_list_tokens] = 1
|
||||
|
||||
bbox[:len_list_tokens, :] = list_bbs
|
||||
lwords[:len_list_tokens] = list_words
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
|
||||
|
||||
if self.backbone_type in ("layoutlm", "layoutxlm"):
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
else:
|
||||
assert False
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < self.max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
# Label
|
||||
classes_dic = json_obj["parse"]["class"]
|
||||
for class_name in self.class_names:
|
||||
if class_name == "others":
|
||||
continue
|
||||
if class_name not in classes_dic:
|
||||
continue
|
||||
|
||||
for word_list in classes_dic[class_name]:
|
||||
is_first, last_word_idx = True, -1
|
||||
for word_idx in word_list:
|
||||
if word_idx >= len(box_to_token_indices):
|
||||
break
|
||||
box2token_list = box_to_token_indices[word_idx]
|
||||
for converted_word_idx in box2token_list:
|
||||
if converted_word_idx >= self.max_seq_length:
|
||||
break # out of idx
|
||||
|
||||
if is_first:
|
||||
itc_labels[converted_word_idx] = self.class_idx_dic[
|
||||
class_name
|
||||
]
|
||||
is_first, last_word_idx = False, converted_word_idx
|
||||
else:
|
||||
stc_labels[converted_word_idx] = last_word_idx
|
||||
last_word_idx = converted_word_idx
|
||||
|
||||
# Label
|
||||
relations = json_obj["parse"]["relations"]
|
||||
for relation in relations:
|
||||
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
|
||||
box2token_span_map
|
||||
):
|
||||
continue
|
||||
if (
|
||||
box2token_span_map[relation[0]][0] >= self.max_seq_length
|
||||
or box2token_span_map[relation[1]][0] >= self.max_seq_length
|
||||
):
|
||||
continue
|
||||
|
||||
word_from = box2token_span_map[relation[0]][0]
|
||||
word_to = box2token_span_map[relation[1]][0]
|
||||
# el_labels[word_to] = word_from
|
||||
|
||||
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
|
||||
continue
|
||||
|
||||
# if self.second_relations == 1:
|
||||
# if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
# el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
# else:
|
||||
#### 1st relation => ['key, 'value']
|
||||
#### 2st relation => ['header', 'key'or'value']
|
||||
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
||||
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
||||
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
|
||||
|
||||
|
||||
input_ids = torch.from_numpy(input_ids)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
attention_mask = torch.from_numpy(attention_mask)
|
||||
|
||||
itc_labels = torch.from_numpy(itc_labels)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
stc_labels = torch.from_numpy(stc_labels)
|
||||
el_labels = torch.from_numpy(el_labels)
|
||||
el_labels_from_key = torch.from_numpy(el_labels_from_key)
|
||||
|
||||
return_dict = {
|
||||
# "image": feature_maps,
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"words": lwords,
|
||||
"attention_mask": attention_mask,
|
||||
"itc_labels": itc_labels,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"stc_labels": stc_labels,
|
||||
"el_labels": el_labels,
|
||||
"el_labels_from_key": el_labels_from_key
|
||||
}
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
class DocumentKVUProcess(KVUProcess):
|
||||
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
||||
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
|
||||
self.max_window_count = max_window_count
|
||||
self.pad_token_id = self.pad_token_id_layoutxlm
|
||||
self.cls_token_id = self.cls_token_id_layoutxlm
|
||||
self.sep_token_id = self.sep_token_id_layoutxlm
|
||||
self.unk_token_id = self.unk_token_id_layoutxlm
|
||||
self.tokenizer = self.tokenizer_layoutxlm
|
||||
|
||||
def __call__(self, img_path: str, ocr_path: str) -> list:
|
||||
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
|
||||
process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False)
|
||||
ocr_path = "tmp.txt"
|
||||
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
||||
lwords = post_process_basic_ocr(lwords)
|
||||
|
||||
width, height = imagesize.get(img_path)
|
||||
images = [Image.open(img_path).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
output = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path})
|
||||
return output
|
||||
|
||||
def preprocess(self, bounding_boxes, words, feature_maps):
|
||||
n_words = len(words)
|
||||
output_dicts = {'windows': [], 'documents': []}
|
||||
n_empty_windows = 0
|
||||
|
||||
for i in range(self.max_window_count):
|
||||
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
|
||||
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
||||
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
||||
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
||||
|
||||
if n_words == 0:
|
||||
n_empty_windows += 1
|
||||
output_dicts['windows'].append({
|
||||
"image": feature_maps['image'],
|
||||
"input_ids": torch.from_numpy(input_ids),
|
||||
"bbox": torch.from_numpy(bbox),
|
||||
"words": [],
|
||||
"attention_mask": torch.from_numpy(attention_mask),
|
||||
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
|
||||
})
|
||||
continue
|
||||
|
||||
start_word_idx = i * self.window_size
|
||||
stop_word_idx = min(n_words, (i+1)*self.window_size)
|
||||
|
||||
if start_word_idx >= stop_word_idx:
|
||||
n_empty_windows += 1
|
||||
output_dicts['windows'].append(output_dicts['windows'][-1])
|
||||
continue
|
||||
|
||||
list_tokens = []
|
||||
list_bbs = []
|
||||
list_words = []
|
||||
lwords = [''] * self.max_seq_length
|
||||
|
||||
box_to_token_indices = []
|
||||
cum_token_idx = 0
|
||||
|
||||
cls_bbs = [0.0] * 8
|
||||
|
||||
for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])):
|
||||
bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]]
|
||||
tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word))
|
||||
|
||||
this_box_token_indices = []
|
||||
|
||||
if len(tokens) == 0:
|
||||
tokens.append(self.unk_token_id)
|
||||
|
||||
if len(list_tokens) + len(tokens) > self.max_seq_length - 2:
|
||||
break
|
||||
|
||||
list_tokens += tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width']))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height']))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(tokens))]
|
||||
texts = [word for _ in range(len(tokens))]
|
||||
|
||||
for _ in tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
list_words.extend(texts) ###
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [feature_maps['width'], feature_maps['height']] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_tokens = (
|
||||
[self.cls_token_id]
|
||||
+ list_tokens[: self.max_seq_length - 2]
|
||||
+ [self.sep_token_id]
|
||||
)
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs]
|
||||
if len(list_words) < 510:
|
||||
list_words.extend(['</p>' for _ in range(510 - len(list_words))])
|
||||
list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token]
|
||||
|
||||
len_list_tokens = len(list_tokens)
|
||||
input_ids[:len_list_tokens] = list_tokens
|
||||
attention_mask[:len_list_tokens] = 1
|
||||
|
||||
bbox[:len_list_tokens, :] = list_bbs
|
||||
lwords[:len_list_tokens] = list_words ###
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width']
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height']
|
||||
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < self.max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
input_ids = torch.from_numpy(input_ids)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
attention_mask = torch.from_numpy(attention_mask)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
|
||||
return_dict = {
|
||||
"image": feature_maps['image'],
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"words": list_words,
|
||||
"attention_mask": attention_mask,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
}
|
||||
output_dicts["windows"].append(return_dict)
|
||||
|
||||
attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]])
|
||||
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
|
||||
if n_empty_windows > 0:
|
||||
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
|
||||
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
|
||||
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
|
||||
words = []
|
||||
for o in output_dicts['windows']:
|
||||
words.extend(o['words'])
|
||||
|
||||
return_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
"bbox": bbox,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"n_empty_windows": n_empty_windows,
|
||||
"words": words
|
||||
}
|
||||
output_dicts['documents'] = return_dict
|
||||
|
||||
return output_dicts
|
||||
|
||||
|
||||
|
19
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/requirements.txt
Executable file
19
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/requirements.txt
Executable file
@ -0,0 +1,19 @@
|
||||
nptyping==1.4.2
|
||||
numpy==1.20.3
|
||||
pytorch-lightning==1.5.6
|
||||
omegaconf
|
||||
pillow
|
||||
six
|
||||
overrides==4.1.2
|
||||
transformers==4.11.3
|
||||
seqeval==0.0.12
|
||||
imagesize
|
||||
pandas==2.0.1
|
||||
xmltodict
|
||||
dicttoxml
|
||||
|
||||
tensorboard>=2.2.0
|
||||
|
||||
# code-style
|
||||
isort==5.9.3
|
||||
black==21.9b0
|
1
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/run.sh
Executable file
1
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/run.sh
Executable file
@ -0,0 +1 @@
|
||||
python anyKeyValue.py --img_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0
|
106
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/tmp.txt
Executable file
106
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/tmp.txt
Executable file
@ -0,0 +1,106 @@
|
||||
1113 773 1220 825 BEST
|
||||
1243 759 1378 808 DENKI
|
||||
1410 752 1487 799 (S)
|
||||
1430 707 1515 748 TAX
|
||||
1511 745 1598 790 PTE
|
||||
1542 700 1725 740 TNVOICE
|
||||
1618 742 1706 783 LTD
|
||||
1783 725 1920 773 FUNAN
|
||||
1943 723 2054 767 MALL
|
||||
1434 797 1576 843 WORTH
|
||||
1599 785 1760 831 BRIDGE
|
||||
1784 778 1846 822 RD
|
||||
1277 846 1632 897 #02-16/#03-1
|
||||
1655 832 1795 877 FUNAN
|
||||
1817 822 1931 869 MALL
|
||||
1272 897 1518 956 S(179105)
|
||||
1548 890 1655 943 TEL:
|
||||
1686 877 1911 928 69046183
|
||||
1247 1011 1334 1068 GST
|
||||
1358 1006 1447 1059 REG
|
||||
1360 1063 1449 1115 RCB
|
||||
1473 1003 1575 1055 NO.:
|
||||
1474 1059 1555 1110 NO.
|
||||
1595 1042 1868 1096 198202199E
|
||||
1607 985 1944 1040 M2-0053813-7
|
||||
1056 1134 1254 1194 Opening
|
||||
1276 1127 1391 1181 Hrs:
|
||||
1425 1112 1647 1170 10:00:00
|
||||
1672 1102 1735 1161 AN
|
||||
1755 1101 1819 1157 to
|
||||
1846 1090 2067 1147 10:00:00
|
||||
2090 1080 2156 1141 PH
|
||||
1061 1308 1228 1366 Staff:
|
||||
1258 1300 1378 1357 3296
|
||||
1710 1283 1880 1337 Trans:
|
||||
1936 1266 2192 1322 262152554
|
||||
1060 1372 1201 1429 Date:
|
||||
1260 1358 1494 1419 22-03-23
|
||||
1540 1344 1664 1409 9:05
|
||||
1712 1339 1856 1407 Slip:
|
||||
1917 1328 2196 1387 2000130286
|
||||
1124 1487 1439 1545 SALESPERSON
|
||||
1465 1477 1601 1537 CODE.
|
||||
1633 1471 1752 1530 6043
|
||||
1777 1462 2004 1519 HUHAHHAD
|
||||
2032 1451 2177 1509 RAZIH
|
||||
1070 1558 1187 1617 Item
|
||||
1211 1554 1276 1615 No
|
||||
1439 1542 1585 1601 Price
|
||||
1750 1530 1841 1597 Qty
|
||||
1951 1517 2120 1579 Amount
|
||||
1076 1683 1276 1741 ANDROID
|
||||
1304 1673 1477 1733 TABLET
|
||||
1080 1746 1280 1804 2105976
|
||||
1509 1729 1705 1784 SAMSUNG
|
||||
1734 1719 1931 1776 SH-P613
|
||||
1964 1709 2101 1768 128GB
|
||||
1082 1809 1285 1869 SM-P613
|
||||
1316 1802 1454 1860 12838
|
||||
1429 1859 1600 1919 518.00
|
||||
1481 1794 1596 1855 WIFI
|
||||
1622 1790 1656 1850 G
|
||||
1797 1845 1824 1904 1
|
||||
1993 1832 2165 1892 518.00
|
||||
1088 1935 1347 1995 PROMOTION
|
||||
1091 2000 1294 2062 2105664
|
||||
1520 1983 1717 2039 SAMSUNG
|
||||
1743 1963 2106 2030 F-Sam-Redeen
|
||||
1439 2111 1557 2173 0.00
|
||||
1806 2095 1832 2156 1
|
||||
2053 2081 2174 2144 0.00
|
||||
1106 2248 1250 2312 Total
|
||||
1974 2206 2146 2266 518.00
|
||||
1107 2312 1204 2377 UOB
|
||||
1448 2291 1567 2355 CARD
|
||||
1978 2268 2147 2327 518.00
|
||||
1253 2424 1375 2497 GST%
|
||||
1456 2411 1655 2475 Net.Amt
|
||||
1818 2393 1912 2460 GST
|
||||
2023 2387 2192 2445 Amount
|
||||
1106 2494 1231 2560 GST8
|
||||
1486 2472 1661 2537 479.63
|
||||
1770 2458 1916 2523 38.37
|
||||
2027 2448 2203 2511 518.00
|
||||
1553 2601 1699 2666 THANK
|
||||
1721 2592 1821 2661 YOU
|
||||
1436 2678 1616 2749 please
|
||||
1644 2682 1764 2732 come
|
||||
1790 2660 1942 2729 again
|
||||
1191 2862 1391 2931 Those
|
||||
1426 2870 2018 2945 facebook.com
|
||||
1565 2809 1690 2884 join
|
||||
1709 2816 1777 2870 us
|
||||
1799 2811 1868 2865 on
|
||||
1838 2946 2024 3003 com .89
|
||||
1533 3006 2070 3088 ar.com/askbe
|
||||
1300 3326 1659 3446 That's
|
||||
1696 3308 1905 3424 not
|
||||
1937 3289 2131 3408 all!
|
||||
1450 3511 1633 3573 SCAN
|
||||
1392 3589 1489 3645 QR
|
||||
1509 3577 1698 3635 CODE
|
||||
1321 3656 1370 3714 &
|
||||
1517 3638 1768 3699 updates
|
||||
1643 3882 1769 3932 Scan
|
||||
1789 3868 1859 3926 Me
|
127
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/__init__.py
Executable file
127
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/__init__.py
Executable file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
from utils.ema_callbacks import EMA
|
||||
|
||||
|
||||
def _update_config(cfg):
|
||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
|
||||
|
||||
# set per-gpu batch size
|
||||
num_devices = torch.cuda.device_count()
|
||||
print('No. devices:', num_devices)
|
||||
for mode in ["train", "val"]:
|
||||
new_batch_size = cfg[mode].batch_size // num_devices
|
||||
cfg[mode].batch_size = new_batch_size
|
||||
|
||||
def _get_config_from_cli():
|
||||
cfg_cli = OmegaConf.from_cli()
|
||||
cli_keys = list(cfg_cli.keys())
|
||||
for cli_key in cli_keys:
|
||||
if "--" in cli_key:
|
||||
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
|
||||
del cfg_cli[cli_key]
|
||||
|
||||
return cfg_cli
|
||||
|
||||
def get_callbacks(cfg):
|
||||
callback_list = []
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
|
||||
filename='best_model',
|
||||
save_last=True,
|
||||
save_top_k=1,
|
||||
save_weights_only=True,
|
||||
verbose=True,
|
||||
monitor='val_f1', mode='max')
|
||||
checkpoint_callback.FILE_EXTENSION = ".pth"
|
||||
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
|
||||
callback_list.append(checkpoint_callback)
|
||||
if cfg.callbacks.ema.decay != -1:
|
||||
ema_callback = EMA(decay=0.9999)
|
||||
callback_list.append(ema_callback)
|
||||
return callback_list if len(callback_list) > 1 else checkpoint_callback
|
||||
|
||||
def get_plugins(cfg):
|
||||
plugins = []
|
||||
if cfg.train.strategy.type == "ddp":
|
||||
plugins.append(DDPPlugin())
|
||||
|
||||
return plugins
|
||||
|
||||
def get_loggers(cfg):
|
||||
loggers = []
|
||||
|
||||
loggers.append(
|
||||
TensorBoardLogger(
|
||||
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
|
||||
)
|
||||
)
|
||||
|
||||
return loggers
|
||||
|
||||
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
|
||||
for key, val in cfg.items():
|
||||
if isinstance(val, DictConfig):
|
||||
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
|
||||
else:
|
||||
hparam_dict[parent_str + key] = str(val)
|
||||
return hparam_dict
|
||||
|
||||
def get_specific_pl_logger(pl_loggers, logger_type):
|
||||
for pl_logger in pl_loggers:
|
||||
if isinstance(pl_logger, logger_type):
|
||||
return pl_logger
|
||||
return None
|
||||
|
||||
def get_class_names(dataset_root_path):
|
||||
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
|
||||
class_names = (
|
||||
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
|
||||
)
|
||||
return class_names
|
||||
|
||||
def create_exp_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Experiment dir : {}'.format(save_dir))
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
|
||||
def load_checkpoint(ckpt_path, model, key_include):
|
||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
|
||||
for key in list(state_dict.keys()):
|
||||
if f'.{key_include}.' not in key:
|
||||
del state_dict[key]
|
||||
else:
|
||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||
del state_dict[key]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
print(f"Load checkpoint at {ckpt_path}")
|
||||
return model
|
||||
|
||||
def load_model_weight(net, pretrained_model_file):
|
||||
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
|
||||
"state_dict"
|
||||
]
|
||||
new_state_dict = {}
|
||||
for k, v in pretrained_model_state_dict.items():
|
||||
new_k = k
|
||||
if new_k.startswith("net."):
|
||||
new_k = new_k[len("net.") :]
|
||||
new_state_dict[new_k] = v
|
||||
net.load_state_dict(new_state_dict)
|
||||
|
||||
|
346
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/ema_callbacks.py
Executable file
346
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/ema_callbacks.py
Executable file
@ -0,0 +1,346 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
|
||||
|
||||
class EMA(Callback):
|
||||
"""
|
||||
Implements Exponential Moving Averaging (EMA).
|
||||
|
||||
When training a model, this callback will maintain moving averages of the trained parameters.
|
||||
When evaluating, we use the moving averages copy of the trained parameters.
|
||||
When saving, we save an additional set of parameters with the prefix `ema`.
|
||||
|
||||
Args:
|
||||
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
|
||||
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
|
||||
every_n_steps: Apply EMA every N steps.
|
||||
cpu_offload: Offload weights to CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False,
|
||||
):
|
||||
if not (0 <= decay <= 1):
|
||||
raise MisconfigurationException("EMA decay value must be between 0 and 1")
|
||||
self.decay = decay
|
||||
self.validate_original_weights = validate_original_weights
|
||||
self.every_n_steps = every_n_steps
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
device = pl_module.device if not self.cpu_offload else torch.device('cpu')
|
||||
trainer.optimizers = [
|
||||
EMAOptimizer(
|
||||
optim,
|
||||
device=device,
|
||||
decay=self.decay,
|
||||
every_n_steps=self.every_n_steps,
|
||||
current_step=trainer.global_step,
|
||||
)
|
||||
for optim in trainer.optimizers
|
||||
if not isinstance(optim, EMAOptimizer)
|
||||
]
|
||||
|
||||
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
|
||||
return not self.validate_original_weights and self._ema_initialized(trainer)
|
||||
|
||||
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
|
||||
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
|
||||
|
||||
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.switch_main_parameter_weights(saving_ema_model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_ema_model(self, trainer: "pl.Trainer"):
|
||||
"""
|
||||
Saves an EMA copy of the model + EMA optimizer states for resume.
|
||||
"""
|
||||
self.swap_model_weights(trainer, saving_ema_model=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.swap_model_weights(trainer, saving_ema_model=False)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.save_original_optimizer_state = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for optimizer in trainer.optimizers:
|
||||
optimizer.save_original_optimizer_state = False
|
||||
|
||||
def on_load_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
|
||||
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
|
||||
connector = trainer._checkpoint_connector
|
||||
ckpt_path = connector.resume_checkpoint_path
|
||||
|
||||
if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
|
||||
ext = checkpoint_callback.FILE_EXTENSION
|
||||
if ckpt_path.endswith(f'-EMA{ext}'):
|
||||
rank_zero_info(
|
||||
"loading EMA based weights. "
|
||||
"The callback will treat the loaded EMA weights as the main weights"
|
||||
" and create a new EMA copy when training."
|
||||
)
|
||||
return
|
||||
ema_path = ckpt_path.replace(ext, f'-EMA{ext}')
|
||||
if os.path.exists(ema_path):
|
||||
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
|
||||
|
||||
checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
|
||||
del ema_state_dict
|
||||
rank_zero_info("EMA state has been restored.")
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"Unable to find the associated EMA weights when re-loading, "
|
||||
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||
torch._foreach_mul_(ema_model_tuple, decay)
|
||||
torch._foreach_add_(
|
||||
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
|
||||
)
|
||||
|
||||
|
||||
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
|
||||
if pre_sync_stream is not None:
|
||||
pre_sync_stream.synchronize()
|
||||
|
||||
ema_update(ema_model_tuple, current_model_tuple, decay)
|
||||
|
||||
|
||||
class EMAOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
|
||||
Exponential Moving Average of parameters registered in the optimizer.
|
||||
|
||||
EMA parameters are automatically updated after every step of the optimizer
|
||||
with the following formula:
|
||||
|
||||
ema_weight = decay * ema_weight + (1 - decay) * training_weight
|
||||
|
||||
To access EMA parameters, use ``swap_ema_weights()`` context manager to
|
||||
perform a temporary in-place swap of regular parameters with EMA
|
||||
parameters.
|
||||
|
||||
Notes:
|
||||
- EMAOptimizer is not compatible with APEX AMP O2.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): optimizer to wrap
|
||||
device (torch.device): device for EMA parameters
|
||||
decay (float): decay factor
|
||||
|
||||
Returns:
|
||||
returns an instance of torch.optim.Optimizer that computes EMA of
|
||||
parameters
|
||||
|
||||
Example:
|
||||
model = Model().to(device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
|
||||
opt = EMAOptimizer(opt, device, 0.9999)
|
||||
|
||||
for epoch in range(epochs):
|
||||
training_loop(model, opt)
|
||||
|
||||
regular_eval_accuracy = evaluate(model)
|
||||
|
||||
with opt.swap_ema_weights():
|
||||
ema_eval_accuracy = evaluate(model)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
decay: float = 0.9999,
|
||||
every_n_steps: int = 1,
|
||||
current_step: int = 0,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.decay = decay
|
||||
self.device = device
|
||||
self.current_step = current_step
|
||||
self.every_n_steps = every_n_steps
|
||||
self.save_original_optimizer_state = False
|
||||
|
||||
self.first_iteration = True
|
||||
self.rebuild_ema_params = True
|
||||
self.stream = None
|
||||
self.thread = None
|
||||
|
||||
self.ema_params = ()
|
||||
self.in_saving_ema_model_context = False
|
||||
|
||||
def all_parameters(self) -> Iterable[torch.Tensor]:
|
||||
return (param for group in self.param_groups for param in group['params'])
|
||||
|
||||
def step(self, closure=None, **kwargs):
|
||||
self.join()
|
||||
|
||||
if self.first_iteration:
|
||||
if any(p.is_cuda for p in self.all_parameters()):
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
self.first_iteration = False
|
||||
|
||||
if self.rebuild_ema_params:
|
||||
opt_params = list(self.all_parameters())
|
||||
|
||||
self.ema_params += tuple(
|
||||
copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :]
|
||||
)
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
loss = self.optimizer.step(closure)
|
||||
|
||||
if self._should_update_at_step():
|
||||
self.update()
|
||||
self.current_step += 1
|
||||
return loss
|
||||
|
||||
def _should_update_at_step(self) -> bool:
|
||||
return self.current_step % self.every_n_steps == 0
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
if self.stream is not None:
|
||||
self.stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
current_model_state = tuple(
|
||||
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
|
||||
)
|
||||
|
||||
if self.device.type == 'cuda':
|
||||
ema_update(self.ema_params, current_model_state, self.decay)
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.thread = threading.Thread(
|
||||
target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,),
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def swap_tensors(self, tensor1, tensor2):
|
||||
tmp = torch.empty_like(tensor1)
|
||||
tmp.copy_(tensor1)
|
||||
tensor1.copy_(tensor2)
|
||||
tensor2.copy_(tmp)
|
||||
|
||||
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
|
||||
self.join()
|
||||
self.in_saving_ema_model_context = saving_ema_model
|
||||
for param, ema_param in zip(self.all_parameters(), self.ema_params):
|
||||
self.swap_tensors(param.data, ema_param)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_ema_weights(self, enabled: bool = True):
|
||||
r"""
|
||||
A context manager to in-place swap regular parameters with EMA
|
||||
parameters.
|
||||
It swaps back to the original regular parameters on context manager
|
||||
exit.
|
||||
|
||||
Args:
|
||||
enabled (bool): whether the swap should be performed
|
||||
"""
|
||||
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.optimizer, name)
|
||||
|
||||
def join(self):
|
||||
if self.stream is not None:
|
||||
self.stream.synchronize()
|
||||
|
||||
if self.thread is not None:
|
||||
self.thread.join()
|
||||
|
||||
def state_dict(self):
|
||||
self.join()
|
||||
|
||||
if self.save_original_optimizer_state:
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
|
||||
ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters())
|
||||
state_dict = {
|
||||
'opt': self.optimizer.state_dict(),
|
||||
'ema': ema_params,
|
||||
'current_step': self.current_step,
|
||||
'decay': self.decay,
|
||||
'every_n_steps': self.every_n_steps,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.join()
|
||||
|
||||
self.optimizer.load_state_dict(state_dict['opt'])
|
||||
self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
|
||||
self.current_step = state_dict['current_step']
|
||||
self.decay = state_dict['decay']
|
||||
self.every_n_steps = state_dict['every_n_steps']
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self.optimizer.add_param_group(param_group)
|
||||
self.rebuild_ema_params = True
|
459
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py
Executable file
459
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py
Executable file
@ -0,0 +1,459 @@
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import glob
|
||||
import re
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pdf2image import convert_from_path
|
||||
from dicttoxml import dicttoxml
|
||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer
|
||||
from kvu_dictionary import vat_dictionary, ap_dictionary
|
||||
|
||||
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
|
||||
def pdf2image(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pages = convert_from_path(file, 500)
|
||||
for i, page in enumerate(pages):
|
||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
return [
|
||||
float(bbox[0]),
|
||||
float(bbox[1]),
|
||||
float(bbox[2]) - float(bbox[0]),
|
||||
float(bbox[3]) - float(bbox[1]),
|
||||
]
|
||||
|
||||
def write_to_json(file_path, content):
|
||||
with open(file_path, mode='w', encoding='utf8') as f:
|
||||
json.dump(content, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def read_json(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def read_xml(file_path):
|
||||
with open(file_path, 'r') as xml_file:
|
||||
return xml_file.read()
|
||||
|
||||
def write_to_xml(file_path, content):
|
||||
with open(file_path, mode="w", encoding='utf8') as f:
|
||||
f.write(content)
|
||||
|
||||
def write_to_xml_from_dict(file_path, content):
|
||||
xml = dicttoxml(content)
|
||||
xml = content
|
||||
xml_decode = xml.decode()
|
||||
|
||||
with open(file_path, mode="w") as f:
|
||||
f.write(xml_decode)
|
||||
|
||||
|
||||
def load_ocr_result(ocr_path):
|
||||
with open(ocr_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
preds = []
|
||||
for line in lines:
|
||||
preds.append(line.split('\t'))
|
||||
return preds
|
||||
|
||||
def post_process_basic_ocr(lwords: list) -> list:
|
||||
pp_lwords = []
|
||||
for word in lwords:
|
||||
pp_lwords.append(word.replace("✪", " "))
|
||||
return pp_lwords
|
||||
|
||||
def read_ocr_result_from_txt(file_path: str):
|
||||
'''
|
||||
return list of bounding boxes, list of words
|
||||
'''
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
boxes, words = [], []
|
||||
for line in lines:
|
||||
if line == "":
|
||||
continue
|
||||
word_info = line.split("\t")
|
||||
if len(word_info) == 6:
|
||||
x1, y1, x2, y2, text, _ = word_info
|
||||
elif len(word_info) == 5:
|
||||
x1, y1, x2, y2, text = word_info
|
||||
|
||||
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
|
||||
if text and text != " ":
|
||||
words.append(text)
|
||||
boxes.append((x1, y1, x2, y2))
|
||||
return boxes, words
|
||||
|
||||
def get_colormap():
|
||||
return {
|
||||
'others': (0, 0, 255), # others: red
|
||||
'title': (0, 255, 255), # title: yellow
|
||||
'key': (255, 0, 0), # key: blue
|
||||
'value': (0, 255, 0), # value: green
|
||||
'header': (233, 197, 15), # header
|
||||
'group': (0, 128, 128), # group
|
||||
'relation': (0, 0, 255)# (128, 128, 128), # relation
|
||||
}
|
||||
|
||||
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
|
||||
exif = image._getexif()
|
||||
orientation = None
|
||||
if exif is not None:
|
||||
orientation = exif.get(0x0112)
|
||||
|
||||
# Convert the PIL image to OpenCV format
|
||||
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Rotate the image in OpenCV if necessary
|
||||
if orientation == 3:
|
||||
image = cv2.rotate(image, cv2.ROTATE_180)
|
||||
elif orientation == 6:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif orientation == 8:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
image = np.asarray(image)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
|
||||
assert len(image.shape) == 3
|
||||
|
||||
if orientation is not None and orientation == 6:
|
||||
width, height, _ = image.shape
|
||||
else:
|
||||
height, width, _ = image.shape
|
||||
if len(pr_class_words) > 0:
|
||||
id2label = {k: labels[k] for k in range(len(labels))}
|
||||
for lb, groups in enumerate(pr_class_words):
|
||||
if lb == 0:
|
||||
continue
|
||||
for group_id, group in enumerate(groups):
|
||||
for i, word_id in enumerate(group):
|
||||
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
|
||||
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
|
||||
|
||||
if i == 0:
|
||||
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
else:
|
||||
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
|
||||
x_center0, y_center0 = x_center1, y_center1
|
||||
|
||||
if len(pr_relations) > 0:
|
||||
for pair in pr_relations:
|
||||
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
|
||||
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
|
||||
|
||||
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
|
||||
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
|
||||
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
|
||||
outputs = {}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] in (rel_from, rel_to):
|
||||
is_rel[element['class']]['status'] = 1
|
||||
is_rel[element['class']]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
|
||||
return outputs
|
||||
|
||||
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
|
||||
list_keys = list(header_key_pairs.keys())
|
||||
relations = {k: [] for k in list_keys}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] == rel_from and element['group_id'] in list_keys:
|
||||
is_rel[rel_from]['status'] = 1
|
||||
is_rel[rel_from]['value'] = element
|
||||
if element['class'] == rel_to:
|
||||
is_rel[rel_to]['status'] = 1
|
||||
is_rel[rel_to]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
|
||||
return relations
|
||||
|
||||
def get_key2values_relations(key_value_pairs: dict):
|
||||
triple_linkings = {}
|
||||
for value_group_id, key_value_pair in key_value_pairs.items():
|
||||
key_group_id = key_value_pair[0]
|
||||
if key_group_id not in list(triple_linkings.keys()):
|
||||
triple_linkings[key_group_id] = []
|
||||
triple_linkings[key_group_id].append(value_group_id)
|
||||
return triple_linkings
|
||||
|
||||
|
||||
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
|
||||
word_groups = {}
|
||||
id2class = {i: labels[i] for i in range(len(labels))}
|
||||
for class_id, lwgroups_in_class in enumerate(class_words):
|
||||
for ltokens_in_wgroup in lwgroups_in_class:
|
||||
group_id = ltokens_in_wgroup[0]
|
||||
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
|
||||
text_string = get_string(ltokens_to_ltexts)
|
||||
word_groups[group_id] = {
|
||||
'group_id': group_id,
|
||||
'text': text_string,
|
||||
'class': id2class[class_id],
|
||||
'tokens': ltokens_in_wgroup
|
||||
}
|
||||
return word_groups
|
||||
|
||||
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
|
||||
if linking_id not in list(word_groups):
|
||||
for wg_id, _word_group in word_groups.items():
|
||||
if linking_id in _word_group['tokens']:
|
||||
return wg_id
|
||||
return linking_id
|
||||
|
||||
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
outputs = []
|
||||
for pair in lrelations:
|
||||
wg_from = verify_linking_id(word_groups, pair[0])
|
||||
wg_to = verify_linking_id(word_groups, pair[1])
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
|
||||
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
|
||||
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
|
||||
|
||||
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
|
||||
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
|
||||
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
|
||||
|
||||
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif len(list_value_group_ids) == 1:
|
||||
value_group_id = list_value_group_ids[0]
|
||||
single_pairs.append({word_groups[key_group_id]['text']: {
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'id': value_group_id,
|
||||
'class': "value"
|
||||
}})
|
||||
else:
|
||||
item = []
|
||||
for value_group_id in list_value_group_ids:
|
||||
if value_group_id not in header_value.keys():
|
||||
header_name_for_value = "non-header"
|
||||
else:
|
||||
header_group_id = header_value[value_group_id][0]
|
||||
header_name_for_value = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'header': header_name_for_value,
|
||||
'id': value_group_id,
|
||||
'class': 'value'
|
||||
})
|
||||
if key_group_id not in list(header_key.keys()):
|
||||
triplet_pairs.append({
|
||||
word_groups[key_group_id]['text']: item
|
||||
})
|
||||
else:
|
||||
header_group_id = header_key[key_group_id][0]
|
||||
header_name_for_key = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[key_group_id]['text'],
|
||||
'header': header_name_for_key,
|
||||
'id': key_group_id,
|
||||
'class': 'key'
|
||||
})
|
||||
table.append({key_group_id: item})
|
||||
|
||||
if len(table) > 0:
|
||||
table = sorted(table, key=lambda x: list(x.keys())[0])
|
||||
table = [v for item in table for k, v in item.items()]
|
||||
|
||||
outputs = {}
|
||||
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
|
||||
outputs['triplet'] = triplet_pairs
|
||||
outputs['table'] = table
|
||||
|
||||
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
|
||||
write_to_json(file_path, outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
vat_outputs = {}
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
|
||||
# List of items in table
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
||||
if header_name in list(item.keys()):
|
||||
# item[header_name] = value['text']
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': cell['id']
|
||||
})
|
||||
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
table.append(item)
|
||||
|
||||
|
||||
# VAT Information
|
||||
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
# print('='*10, file_path)
|
||||
# print(vat_info)
|
||||
# Combine VAT information and table
|
||||
vat_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
if len(list_potential_value) == 1:
|
||||
vat_outputs[key_name] = list_potential_value[0]['content']
|
||||
else:
|
||||
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
|
||||
for value in list_potential_value:
|
||||
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
|
||||
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
|
||||
else:
|
||||
if len(list_potential_value) == 0: continue
|
||||
if key_name in ("Mã số thuế người bán"):
|
||||
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
|
||||
vat_outputs[key_name] = selected_value['content'].replace(' ', '')
|
||||
else:
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
vat_outputs[key_name] = selected_value['content']
|
||||
|
||||
vat_outputs['table'] = table
|
||||
|
||||
write_to_json(file_path, vat_outputs)
|
||||
|
||||
|
||||
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
# List of items in table
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': cell['id']
|
||||
})
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
table.append(item)
|
||||
|
||||
triplet_pairs = []
|
||||
for single_item in outputs['triplet']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
is_item_valid = 0
|
||||
for key_name, list_value in single_item.items():
|
||||
for value in list_value:
|
||||
if value['header'] == "non-header":
|
||||
continue
|
||||
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
|
||||
if header_name in list(item.keys()):
|
||||
is_item_valid = 1
|
||||
item[header_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
if is_item_valid == 1:
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
item['productname'] = key_name
|
||||
# triplet_pairs.append({key_name: new_item})
|
||||
triplet_pairs.append(item)
|
||||
|
||||
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
ap_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if len(list_potential_value) == 0: continue
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
ap_outputs[key_name] = selected_value['content']
|
||||
|
||||
table = table + triplet_pairs
|
||||
ap_outputs['table'] = table
|
||||
# ap_outputs['triplet'] = triplet_pairs
|
||||
|
||||
write_to_json(file_path, ap_outputs)
|
138
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/kvu_dictionary.py
Executable file
138
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/kvu_dictionary.py
Executable file
@ -0,0 +1,138 @@
|
||||
|
||||
DKVU2XML = {
|
||||
"Ký hiệu mẫu hóa đơn": "form_no",
|
||||
"Ký hiệu hóa đơn": "serial_no",
|
||||
"Số hóa đơn": "invoice_no",
|
||||
"Ngày, tháng, năm lập hóa đơn": "issue_date",
|
||||
"Tên người bán": "seller_name",
|
||||
"Mã số thuế người bán": "seller_tax_code",
|
||||
"Thuế suất": "tax_rate",
|
||||
"Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount",
|
||||
"Mặt hàng": "item",
|
||||
"Đơn vị tính": "unit",
|
||||
"Số lượng": "quantity",
|
||||
"Đơn giá": "unit_price",
|
||||
"Doanh số mua chưa có thuế": "amount"
|
||||
}
|
||||
|
||||
|
||||
def ap_dictionary(header: bool):
|
||||
header_dictionary = {
|
||||
'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'],
|
||||
'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'],
|
||||
'qty': ['quantity', 'invoicequantity']
|
||||
}
|
||||
|
||||
key_dictionary = {
|
||||
'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'],
|
||||
'retailername': ['retailer', 'retailername', 'ownedoperatedby'],
|
||||
'serial_number': ['serialnumber', 'serialno'],
|
||||
'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2']
|
||||
}
|
||||
|
||||
return header_dictionary if header else key_dictionary
|
||||
|
||||
|
||||
def vat_dictionary(header: bool):
|
||||
header_dictionary = {
|
||||
'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'],
|
||||
'Đơn vị tính': ['dvt', 'donvitinh'],
|
||||
'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'],
|
||||
'Đơn giá': ['dongia'],
|
||||
'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'],
|
||||
# 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber']
|
||||
}
|
||||
|
||||
key_dictionary = {
|
||||
'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'],
|
||||
'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'],
|
||||
'Số hóa đơn': ['soinvoiceno', 'invoiceno'],
|
||||
'Ngày, tháng, năm lập hóa đơn': [],
|
||||
'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'],
|
||||
'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'],
|
||||
'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'],
|
||||
'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'],
|
||||
# 'Ghi chú': [],
|
||||
# 'Ngày': ['ngayday', 'ngay', 'day'],
|
||||
# 'Tháng': ['thangmonth', 'thang', 'month'],
|
||||
# 'Năm': ['namyear', 'nam', 'year']
|
||||
}
|
||||
|
||||
# exact_dictionary = {
|
||||
# 'Số hóa đơn': ['sono', 'so'],
|
||||
# 'Mã số thuế người bán': ['mst'],
|
||||
# 'Tên người bán': ['kyboi'],
|
||||
# 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
|
||||
# }
|
||||
|
||||
return header_dictionary if header else key_dictionary
|
||||
|
||||
def manulife_dictionary(type: str):
|
||||
key_dict = {
|
||||
"Document type": ["documenttype", "loaichungtu"],
|
||||
"Document name": ["documentname", "tenchungtu"],
|
||||
"Patient Name": ["patientname", "tenbenhnhan"],
|
||||
"Date of Birth/Year of birth": [
|
||||
"dateofbirth",
|
||||
"yearofbirth",
|
||||
"ngaythangnamsinh",
|
||||
"namsinh",
|
||||
],
|
||||
"Age": ["age", "tuoi"],
|
||||
"Gender": ["gender", "gioitinh"],
|
||||
"Social insurance card No.": ["socialinsurancecardno", "sothebhyt"],
|
||||
"Medical service provider name": ["medicalserviceprovidername", "tencosoyte"],
|
||||
"Department name": ["departmentname", "tenkhoadieutri"],
|
||||
"Diagnosis description": ["diagnosisdescription", "motachandoan"],
|
||||
"Diagnosis code": ["diagnosiscode", "machandoan"],
|
||||
"Admission date": ["admissiondate", "ngaynhapvien"],
|
||||
"Discharge date": ["dischargedate", "ngayxuatvien"],
|
||||
"Treatment method": ["treatmentmethod", "phuongphapdieutri"],
|
||||
"Treatment date": ["treatmentdate", "ngaydieutri", "ngaykham"],
|
||||
"Follow up treatment date": ["followuptreatmentdate", "ngaytaikham"],
|
||||
# "Name of precribed medicine": [],
|
||||
# "Quantity of prescribed medicine": [],
|
||||
# "Dosage for each medicine": []
|
||||
"Medical expense": ["Medical expense", "chiphiyte"],
|
||||
"Invoice No.": ["invoiceno", "sohoadon"],
|
||||
}
|
||||
|
||||
title_dict = {
|
||||
"Chứng từ y tế": [
|
||||
"giayravien",
|
||||
"giaychungnhanphauthuat",
|
||||
"cachthucphauthuat",
|
||||
"phauthuat",
|
||||
"tomtathosobenhan",
|
||||
"donthuoc",
|
||||
"toathuoc",
|
||||
"donbosung",
|
||||
"ketquaconghuongtu"
|
||||
"ketqua",
|
||||
"phieuchidinh",
|
||||
"phieudangkykham",
|
||||
"giayhenkhamlai",
|
||||
"phieukhambenh",
|
||||
"phieukhambenhvaovien",
|
||||
"phieuxetnghiem",
|
||||
"phieuketquaxetnghiem",
|
||||
"phieuchidinhxetnghiem",
|
||||
"ketquasieuam",
|
||||
"phieuchidinhxetnghiem"
|
||||
],
|
||||
"Chứng từ thanh toán": [
|
||||
"hoadon",
|
||||
"hoadongiatrigiatang",
|
||||
"hoadongiatrigiatangchuyendoituhoadondientu",
|
||||
"bangkechiphibaohiem",
|
||||
"bienlaithutien",
|
||||
"bangkechiphidieutrinoitru"
|
||||
],
|
||||
}
|
||||
|
||||
if type == "key":
|
||||
return key_dict
|
||||
elif type == "title":
|
||||
return title_dict
|
||||
else:
|
||||
raise ValueError(f"[ERROR] Dictionary type of {type} is not supported")
|
30
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/run_ocr.py
Executable file
30
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/run_ocr.py
Executable file
@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, List
|
||||
import sys, os
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine"))
|
||||
from src.ocr import OcrEngine
|
||||
|
||||
|
||||
def load_ocr_engine() -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
engine = OcrEngine()
|
||||
print("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||
save_dir_or_path = Path(save_dir_or_path)
|
||||
if isinstance(img, np.ndarray):
|
||||
if save_dir_or_path.is_dir():
|
||||
raise ValueError("numpy array input require a save path, not a save dir")
|
||||
page = engine(img)
|
||||
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
|
||||
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
|
||||
page.write_to_file('word', save_path)
|
||||
if export_img:
|
||||
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
|
||||
|
||||
def read_img(img: Union[str, np.ndarray], engine: OcrEngine):
|
||||
page = engine(img)
|
||||
return ' '.join([f.text for f in page.llines])
|
808
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/utils.py
Executable file
808
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/utils.py
Executable file
@ -0,0 +1,808 @@
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import random
|
||||
import glob
|
||||
import re
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pdf2image import convert_from_path
|
||||
from dicttoxml import dicttoxml
|
||||
from word_preprocess import (
|
||||
vat_standardizer,
|
||||
ap_standardizer,
|
||||
get_string_with_word2line,
|
||||
split_key_value_by_colon,
|
||||
normalize_kvu_output,
|
||||
normalize_kvu_output_for_manulife,
|
||||
manulife_standardizer
|
||||
)
|
||||
from utils.kvu_dictionary import (
|
||||
vat_dictionary,
|
||||
ap_dictionary,
|
||||
manulife_dictionary
|
||||
)
|
||||
|
||||
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# else:
|
||||
# print("DIR already existed.")
|
||||
# print('Save dir : {}'.format(save_dir))
|
||||
|
||||
def convert_pdf2img(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
print(pdf_files)
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pdf2img(file, save_dir, n_pages=-1, return_fname=False)
|
||||
# pages = convert_from_path(file, 500)
|
||||
# for i, page in enumerate(pages):
|
||||
# page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
|
||||
def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False):
|
||||
file_names = []
|
||||
pages = convert_from_path(pdf_path)
|
||||
if n_pages != -1:
|
||||
pages = pages[:n_pages]
|
||||
for i, page in enumerate(pages):
|
||||
_save_path = os.path.join(save_dir, os.path.basename(pdf_path).replace('.pdf', f'_{i}.jpg'))
|
||||
page.save(_save_path, 'JPEG')
|
||||
file_names.append(_save_path)
|
||||
if return_fname:
|
||||
return file_names
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
return [
|
||||
float(bbox[0]),
|
||||
float(bbox[1]),
|
||||
float(bbox[2]) - float(bbox[0]),
|
||||
float(bbox[3]) - float(bbox[1]),
|
||||
]
|
||||
|
||||
def write_to_json(file_path, content):
|
||||
with open(file_path, mode='w', encoding='utf8') as f:
|
||||
json.dump(content, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def read_json(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def read_xml(file_path):
|
||||
with open(file_path, 'r') as xml_file:
|
||||
return xml_file.read()
|
||||
|
||||
def write_to_xml(file_path, content):
|
||||
with open(file_path, mode="w", encoding='utf8') as f:
|
||||
f.write(content)
|
||||
|
||||
def write_to_xml_from_dict(file_path, content):
|
||||
xml = dicttoxml(content)
|
||||
xml = content
|
||||
xml_decode = xml.decode()
|
||||
|
||||
with open(file_path, mode="w") as f:
|
||||
f.write(xml_decode)
|
||||
|
||||
|
||||
def load_ocr_result(ocr_path):
|
||||
with open(ocr_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
preds = []
|
||||
for line in lines:
|
||||
preds.append(line.split('\t'))
|
||||
return preds
|
||||
|
||||
def post_process_basic_ocr(lwords: list) -> list:
|
||||
pp_lwords = []
|
||||
for word in lwords:
|
||||
pp_lwords.append(word.replace("✪", " "))
|
||||
return pp_lwords
|
||||
|
||||
def read_ocr_result_from_txt(file_path: str):
|
||||
'''
|
||||
return list of bounding boxes, list of words
|
||||
'''
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
boxes, words = [], []
|
||||
for line in lines:
|
||||
if line == "":
|
||||
continue
|
||||
word_info = line.split("\t")
|
||||
if len(word_info) == 6:
|
||||
x1, y1, x2, y2, text, _ = word_info
|
||||
elif len(word_info) == 5:
|
||||
x1, y1, x2, y2, text = word_info
|
||||
|
||||
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
|
||||
if text and text != " ":
|
||||
words.append(text)
|
||||
boxes.append((x1, y1, x2, y2))
|
||||
return boxes, words
|
||||
|
||||
def get_colormap():
|
||||
return {
|
||||
'others': (0, 0, 255), # others: red
|
||||
'title': (0, 255, 255), # title: yellow
|
||||
'key': (255, 0, 0), # key: blue
|
||||
'value': (0, 255, 0), # value: green
|
||||
'header': (233, 197, 15), # header
|
||||
'group': (0, 128, 128), # group
|
||||
'relation': (0, 0, 255)# (128, 128, 128), # relation
|
||||
}
|
||||
|
||||
def convert_image(image):
|
||||
exif = image._getexif()
|
||||
orientation = None
|
||||
if exif is not None:
|
||||
orientation = exif.get(0x0112)
|
||||
# Convert the PIL image to OpenCV format
|
||||
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
# Rotate the image in OpenCV if necessary
|
||||
if orientation == 3:
|
||||
image = cv2.rotate(image, cv2.ROTATE_180)
|
||||
elif orientation == 6:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif orientation == 8:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
image = np.asarray(image)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
|
||||
assert len(image.shape) == 3
|
||||
|
||||
return image, orientation
|
||||
|
||||
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
|
||||
image, orientation = convert_image(image)
|
||||
|
||||
# if orientation is not None and orientation == 6:
|
||||
# width, height, _ = image.shape
|
||||
# else:
|
||||
# height, width, _ = image.shape
|
||||
|
||||
if len(pr_class_words) > 0:
|
||||
id2label = {k: labels[k] for k in range(len(labels))}
|
||||
for lb, groups in enumerate(pr_class_words):
|
||||
if lb == 0:
|
||||
continue
|
||||
for group_id, group in enumerate(groups):
|
||||
for i, word_id in enumerate(group):
|
||||
# x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
|
||||
# x0, y0, x1, y1 = revert_box(bbox[word_id], width, height)
|
||||
x0, y0, x1, y1 = bbox[word_id]
|
||||
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
|
||||
|
||||
if i == 0:
|
||||
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
else:
|
||||
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
|
||||
x_center0, y_center0 = x_center1, y_center1
|
||||
|
||||
if len(pr_relations) > 0:
|
||||
for pair in pr_relations:
|
||||
# xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
|
||||
# xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
|
||||
# xyxy0 = revert_box(bbox[pair[0]], width, height)
|
||||
# xyxy1 = revert_box(bbox[pair[1]], width, height)
|
||||
|
||||
xyxy0 = bbox[pair[0]]
|
||||
xyxy1 = bbox[pair[1]]
|
||||
|
||||
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
|
||||
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
|
||||
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
|
||||
|
||||
return image
|
||||
|
||||
def revert_box(box, width, height):
|
||||
return [
|
||||
int((box[0] / 1000) * width),
|
||||
int((box[1] / 1000) * height),
|
||||
int((box[2] / 1000) * width),
|
||||
int((box[3] / 1000) * height)
|
||||
]
|
||||
|
||||
|
||||
def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list:
|
||||
points = [lbbox[i] for i in lword_ids]
|
||||
x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1]
|
||||
x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3]
|
||||
return [x_min, y_min, x_max, y_max]
|
||||
|
||||
|
||||
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
|
||||
outputs = {}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] in (rel_from, rel_to):
|
||||
is_rel[element['class']]['status'] = 1
|
||||
is_rel[element['class']]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
|
||||
return outputs
|
||||
|
||||
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
|
||||
list_keys = list(header_key_pairs.keys())
|
||||
relations = {k: [] for k in list_keys}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] == rel_from and element['group_id'] in list_keys:
|
||||
is_rel[rel_from]['status'] = 1
|
||||
is_rel[rel_from]['value'] = element
|
||||
if element['class'] == rel_to:
|
||||
is_rel[rel_to]['status'] = 1
|
||||
is_rel[rel_to]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
|
||||
return relations
|
||||
|
||||
def get_key2values_relations(key_value_pairs: dict):
|
||||
triple_linkings = {}
|
||||
for value_group_id, key_value_pair in key_value_pairs.items():
|
||||
key_group_id = key_value_pair[0]
|
||||
if key_group_id not in list(triple_linkings.keys()):
|
||||
triple_linkings[key_group_id] = []
|
||||
triple_linkings[key_group_id].append(value_group_id)
|
||||
return triple_linkings
|
||||
|
||||
|
||||
def merged_token_to_wordgroup(class_words: list, lwords: list, lbboxes: list, labels: list) -> dict:
|
||||
word_groups = {}
|
||||
id2class = {i: labels[i] for i in range(len(labels))}
|
||||
for class_id, lwgroups_in_class in enumerate(class_words):
|
||||
for ltokens_in_wgroup in lwgroups_in_class:
|
||||
group_id = ltokens_in_wgroup[0]
|
||||
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
|
||||
ltokens_to_lbboxes = [lbboxes[token] for token in ltokens_in_wgroup]
|
||||
# text_string = get_string(ltokens_to_ltexts)
|
||||
# text_string= get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes)
|
||||
text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes)
|
||||
group_bbox = get_wordgroup_bbox(lbboxes, ltokens_in_wgroup)
|
||||
word_groups[group_id] = {
|
||||
'group_id': group_id,
|
||||
'text': text_string,
|
||||
'class': id2class[class_id],
|
||||
'tokens': ltokens_in_wgroup,
|
||||
'bbox': group_bbox
|
||||
}
|
||||
return word_groups
|
||||
|
||||
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
|
||||
if linking_id not in list(word_groups):
|
||||
for wg_id, _word_group in word_groups.items():
|
||||
if linking_id in _word_group['tokens']:
|
||||
return wg_id
|
||||
return linking_id
|
||||
|
||||
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
outputs = []
|
||||
for pair in lrelations:
|
||||
wg_from = verify_linking_id(word_groups, pair[0])
|
||||
wg_to = verify_linking_id(word_groups, pair[1])
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
def get_single_entity(word_groups: dict, lrelations: list) -> list:
|
||||
single_entity = {'title': [], 'key': [], 'value': [], 'header': []}
|
||||
list_linked_ids = []
|
||||
for pair in lrelations:
|
||||
list_linked_ids.extend(pair)
|
||||
|
||||
for word_group_id, word_group in word_groups.items():
|
||||
if word_group_id not in list_linked_ids:
|
||||
single_entity[word_group['class']].append(word_group)
|
||||
return single_entity
|
||||
|
||||
|
||||
def export_kvu_outputs(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels)
|
||||
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
|
||||
|
||||
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
|
||||
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
|
||||
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
|
||||
single_entity = get_single_entity(word_groups, lrelations)
|
||||
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())):
|
||||
value_group_id = list_value_group_ids[0]
|
||||
|
||||
single_pairs.append({word_groups[key_group_id]['text']: {
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'id': value_group_id,
|
||||
'class': "value",
|
||||
'bbox': word_groups[value_group_id]['bbox'],
|
||||
'key_bbox': word_groups[key_group_id]['bbox']
|
||||
}})
|
||||
else:
|
||||
item = []
|
||||
for value_group_id in list_value_group_ids:
|
||||
if value_group_id not in header_value.keys():
|
||||
header_group_id = -1 # temp
|
||||
header_name_for_value = "non-header"
|
||||
else:
|
||||
header_group_id = header_value[value_group_id][0]
|
||||
header_name_for_value = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'header': header_name_for_value,
|
||||
'id': value_group_id,
|
||||
"key_id": key_group_id,
|
||||
"header_id": header_group_id,
|
||||
'class': 'value',
|
||||
'bbox': word_groups[value_group_id]['bbox'],
|
||||
'key_bbox': word_groups[key_group_id]['bbox'],
|
||||
'header_bbox': word_groups[header_group_id]['bbox'] if header_group_id != -1 else [0, 0, 0, 0],
|
||||
})
|
||||
if key_group_id not in list(header_key.keys()):
|
||||
triplet_pairs.append({
|
||||
word_groups[key_group_id]['text']: item
|
||||
})
|
||||
else:
|
||||
header_group_id = header_key[key_group_id][0]
|
||||
header_name_for_key = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[key_group_id]['text'],
|
||||
'header': header_name_for_key,
|
||||
'id': key_group_id,
|
||||
"key_id": key_group_id,
|
||||
"header_id": header_group_id,
|
||||
'class': 'key',
|
||||
'bbox': word_groups[value_group_id]['bbox'],
|
||||
'key_bbox': word_groups[key_group_id]['bbox'],
|
||||
'header_bbox': word_groups[header_group_id]['bbox'],
|
||||
})
|
||||
table.append({key_group_id: item})
|
||||
|
||||
|
||||
single_entity_dict = {}
|
||||
for class_name, single_items in single_entity.items():
|
||||
single_entity_dict[class_name] = []
|
||||
for single_item in single_items:
|
||||
single_entity_dict[class_name].append({
|
||||
'text': single_item['text'],
|
||||
'id': single_item['group_id'],
|
||||
'class': class_name,
|
||||
'bbox': single_item['bbox']
|
||||
})
|
||||
|
||||
|
||||
|
||||
if len(table) > 0:
|
||||
table = sorted(table, key=lambda x: list(x.keys())[0])
|
||||
table = [v for item in table for k, v in item.items()]
|
||||
|
||||
|
||||
outputs = {}
|
||||
outputs['title'] = single_entity_dict['title']
|
||||
outputs['key'] = single_entity_dict['key']
|
||||
outputs['value'] = single_entity_dict['value']
|
||||
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
|
||||
outputs['triplet'] = triplet_pairs
|
||||
outputs['table'] = table
|
||||
|
||||
|
||||
create_dir(os.path.join(os.path.dirname(file_path), 'kvu_results'))
|
||||
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
|
||||
write_to_json(file_path, outputs)
|
||||
return outputs
|
||||
|
||||
def export_kvu_for_all(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']) -> dict:
|
||||
raw_outputs = export_kvu_outputs(
|
||||
file_path, lwords, lbboxes, class_words, lrelations, labels
|
||||
)
|
||||
outputs = {}
|
||||
# Title
|
||||
outputs["title"] = (
|
||||
raw_outputs["title"][0]["text"] if len(raw_outputs["title"]) > 0 else None
|
||||
)
|
||||
|
||||
# Pairs of key-value
|
||||
for pair in raw_outputs["single"]:
|
||||
for key, values in pair.items():
|
||||
# outputs[key] = values["text"]
|
||||
elements = split_key_value_by_colon(key, values["text"])
|
||||
outputs[elements[0]] = elements[1]
|
||||
|
||||
# Only key fields
|
||||
for key in raw_outputs["key"]:
|
||||
# outputs[key["text"]] = None
|
||||
elements = split_key_value_by_colon(key["text"], None)
|
||||
outputs[elements[0]] = elements[1]
|
||||
|
||||
# Triplet data
|
||||
for triplet in raw_outputs["triplet"]:
|
||||
for key, list_value in triplet.items():
|
||||
outputs[key] = [value["text"] for value in list_value]
|
||||
|
||||
# Table data
|
||||
table = []
|
||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||
if header_list:
|
||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||
print("Header_list:", header_list.keys())
|
||||
|
||||
for row in raw_outputs["table"]:
|
||||
item = {header: None for header in list(header_list.keys())}
|
||||
for cell in row:
|
||||
item[cell["header"]] = cell["text"]
|
||||
table.append(item)
|
||||
outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}]
|
||||
else:
|
||||
outputs["tables"] = []
|
||||
outputs = normalize_kvu_output(outputs)
|
||||
# write_to_json(file_path, outputs)
|
||||
return outputs
|
||||
|
||||
def export_kvu_for_manulife(
|
||||
file_path,
|
||||
lwords,
|
||||
lbboxes,
|
||||
class_words,
|
||||
lrelations,
|
||||
labels=["others", "title", "key", "value", "header"],
|
||||
) -> dict:
|
||||
raw_outputs = export_kvu_outputs(
|
||||
file_path, lwords, lbboxes, class_words, lrelations, labels
|
||||
)
|
||||
outputs = {}
|
||||
# Title
|
||||
title_list = []
|
||||
for title in raw_outputs["title"]:
|
||||
is_match, title_name, score, proceessed_text = manulife_standardizer(title["text"], threshold=0.6, type_dict="title")
|
||||
title_list.append({
|
||||
'documment_type': title_name if is_match else None,
|
||||
'content': title['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': title['id']
|
||||
})
|
||||
|
||||
if len(title_list) > 0:
|
||||
selected_element = max(title_list, key=lambda x: x['lcs_score'])
|
||||
outputs["title"] = selected_element['content'].upper()
|
||||
outputs["class_doc"] = selected_element['documment_type']
|
||||
|
||||
outputs["Loại chứng từ"] = selected_element['documment_type']
|
||||
outputs["Tên chứng từ"] = selected_element['content']
|
||||
else:
|
||||
outputs["title"] = None
|
||||
outputs["class_doc"] = None
|
||||
outputs["Loại chứng từ"] = None
|
||||
outputs["Tên chứng từ"] = None
|
||||
|
||||
# Pairs of key-value
|
||||
for pair in raw_outputs["single"]:
|
||||
for key, values in pair.items():
|
||||
# outputs[key] = values["text"]
|
||||
elements = split_key_value_by_colon(key, values["text"])
|
||||
outputs[elements[0]] = elements[1]
|
||||
|
||||
# Only key fields
|
||||
for key in raw_outputs["key"]:
|
||||
# outputs[key["text"]] = None
|
||||
elements = split_key_value_by_colon(key["text"], None)
|
||||
outputs[elements[0]] = elements[1]
|
||||
|
||||
# Triplet data
|
||||
for triplet in raw_outputs["triplet"]:
|
||||
for key, list_value in triplet.items():
|
||||
outputs[key] = [value["text"] for value in list_value]
|
||||
|
||||
# Table data
|
||||
table = []
|
||||
header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row}
|
||||
if header_list:
|
||||
header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0])))
|
||||
# print("Header_list:", header_list.keys())
|
||||
|
||||
for row in raw_outputs["table"]:
|
||||
item = {header: None for header in list(header_list.keys())}
|
||||
for cell in row:
|
||||
item[cell["header"]] = cell["text"]
|
||||
table.append(item)
|
||||
outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}]
|
||||
else:
|
||||
outputs["tables"] = []
|
||||
outputs = normalize_kvu_output_for_manulife(outputs)
|
||||
# write_to_json(file_path, outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
# For FI-VAT project
|
||||
|
||||
def get_vat_table_information(outputs):
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item]
|
||||
item = {k: [] for k in headers}
|
||||
print(item)
|
||||
for cell in single_item:
|
||||
# header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
||||
# if header_name in list(item.keys()):
|
||||
# item[header_name] = value['text']
|
||||
item[cell['header']].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': cell['header'],
|
||||
'lcs_score': random.uniform(0.75, 1.0),
|
||||
'token_id': cell['id']
|
||||
})
|
||||
|
||||
# for header_name, value in item.items():
|
||||
# if len(value) == 0:
|
||||
# if header_name in ("Số lượng", "Doanh số mua chưa có thuế"):
|
||||
# item[header_name] = '0'
|
||||
# else:
|
||||
# item[header_name] = None
|
||||
# continue
|
||||
# item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
# item = post_process_for_item(item)
|
||||
|
||||
# if item["Mặt hàng"] == None:
|
||||
# continue
|
||||
table.append(item)
|
||||
print(table)
|
||||
return table
|
||||
|
||||
def get_vat_information(outputs):
|
||||
# VAT Information
|
||||
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id'],
|
||||
})
|
||||
|
||||
for triplet in outputs['triplet']:
|
||||
for key, value_list in triplet.items():
|
||||
if len(value_list) == 1:
|
||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': value_list[0]['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value_list[0]['id']
|
||||
})
|
||||
|
||||
for pair in value_list:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': pair['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': pair['id']
|
||||
})
|
||||
|
||||
for table_row in outputs['table']:
|
||||
for pair in table_row:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': pair['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': pair['id']
|
||||
})
|
||||
|
||||
return single_pairs
|
||||
|
||||
|
||||
def post_process_vat_information(single_pairs):
|
||||
vat_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
if len(list_potential_value) == 1:
|
||||
vat_outputs[key_name] = list_potential_value[0]['content']
|
||||
else:
|
||||
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
|
||||
for value in list_potential_value:
|
||||
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
|
||||
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
|
||||
else:
|
||||
if len(list_potential_value) == 0: continue
|
||||
if key_name in ("Mã số thuế người bán"):
|
||||
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
|
||||
# tax_code_raw = selected_value['content'].replace(' ', '')
|
||||
tax_code_raw = selected_value['content']
|
||||
if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated
|
||||
tax_code_raw = tax_code_raw.split(' ')
|
||||
tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0]
|
||||
vat_outputs[key_name] = tax_code_raw.replace(' ', '')
|
||||
|
||||
else:
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
vat_outputs[key_name] = selected_value['content']
|
||||
return vat_outputs
|
||||
|
||||
|
||||
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
vat_outputs = {}
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
|
||||
# List of items in table
|
||||
table = get_vat_table_information(outputs)
|
||||
# table = outputs["table"]
|
||||
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
vat_outputs[raw_key_name] = value['text']
|
||||
|
||||
# VAT Information
|
||||
# single_pairs = get_vat_information(outputs)
|
||||
# vat_outputs = post_process_vat_information(single_pairs)
|
||||
|
||||
# Combine VAT information and table
|
||||
vat_outputs['table'] = table
|
||||
|
||||
write_to_json(file_path, vat_outputs)
|
||||
print(vat_outputs)
|
||||
return vat_outputs
|
||||
|
||||
|
||||
# For SBT project
|
||||
|
||||
def get_ap_table_information(outputs):
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': cell['id']
|
||||
})
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
table.append(item)
|
||||
return table
|
||||
|
||||
def get_ap_triplet_information(outputs):
|
||||
triplet_pairs = []
|
||||
for single_item in outputs['triplet']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
is_item_valid = 0
|
||||
for key_name, list_value in single_item.items():
|
||||
for value in list_value:
|
||||
if value['header'] == "non-header":
|
||||
continue
|
||||
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
|
||||
if header_name in list(item.keys()):
|
||||
is_item_valid = 1
|
||||
item[header_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
if is_item_valid == 1:
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
item['productname'] = key_name
|
||||
# triplet_pairs.append({key_name: new_item})
|
||||
triplet_pairs.append(item)
|
||||
return triplet_pairs
|
||||
|
||||
|
||||
def get_ap_information(outputs):
|
||||
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
## Get single_pair if it in a table (Product Information)
|
||||
is_product_info = False
|
||||
for table_row in outputs['table']:
|
||||
pair = {"key": None, 'value': None}
|
||||
for cell in table_row:
|
||||
_, _, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=False)
|
||||
if any(txt in proceessed_text for txt in ['product', 'information', 'productinformation']):
|
||||
is_product_info = True
|
||||
if cell['class'] in pair:
|
||||
pair[cell['class']] = cell
|
||||
|
||||
if all(v is not None for k, v in pair.items()) and is_product_info == True:
|
||||
key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False)
|
||||
# print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
'content': pair['value']['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': pair['value']['id']
|
||||
})
|
||||
## end_block
|
||||
|
||||
ap_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if len(list_potential_value) == 0: continue
|
||||
if key_name == "imei_number":
|
||||
# print('list_potential_value', list_potential_value)
|
||||
# ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5]
|
||||
ap_outputs[key_name] = []
|
||||
for v in list_potential_value:
|
||||
imei = v['content'].replace(' ', '')
|
||||
if imei.isdigit() and len(imei) > 5: # imei is number and have more 5 digits
|
||||
ap_outputs[key_name].append(imei)
|
||||
else:
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
ap_outputs[key_name] = selected_value['content']
|
||||
|
||||
return ap_outputs
|
||||
|
||||
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
# List of items in table
|
||||
table = get_ap_table_information(outputs)
|
||||
triplet_pairs = get_ap_triplet_information(outputs)
|
||||
table = table + triplet_pairs
|
||||
|
||||
ap_outputs = get_ap_information(outputs)
|
||||
|
||||
ap_outputs['table'] = table
|
||||
# ap_outputs['triplet'] = triplet_pairs
|
||||
|
||||
write_to_json(file_path, ap_outputs)
|
||||
|
||||
return ap_outputs
|
226
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word2line.py
Executable file
226
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word2line.py
Executable file
@ -0,0 +1,226 @@
|
||||
class Word():
|
||||
def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""):
|
||||
self.type = "word"
|
||||
self.text =text
|
||||
self.image = image
|
||||
self.conf_detect = conf_detect
|
||||
self.conf_cls = conf_cls
|
||||
self.boundingbox = bndbox # [left, top,right,bot] coordinate of top-left and bottom-right point
|
||||
self.word_id = 0 # id of word
|
||||
self.word_group_id = 0 # id of word_group which instance belongs to
|
||||
self.line_id = 0 #id of line which instance belongs to
|
||||
self.paragraph_id = 0 #id of line which instance belongs to
|
||||
self.kie_label = kie_label
|
||||
def invalid_size(self):
|
||||
return (self.boundingbox[2] - self.boundingbox[0]) * (self.boundingbox[3] - self.boundingbox[1]) > 0
|
||||
def is_special_word(self):
|
||||
left, top, right, bottom = self.boundingbox
|
||||
width, height = right - left, bottom - top
|
||||
text = self.text
|
||||
|
||||
if text is None:
|
||||
return True
|
||||
|
||||
# if len(text) > 7:
|
||||
# return True
|
||||
if len(text) >= 7:
|
||||
no_digits = sum(c.isdigit() for c in text)
|
||||
return no_digits / len(text) >= 0.3
|
||||
|
||||
return False
|
||||
|
||||
class Word_group():
|
||||
def __init__(self):
|
||||
self.type = "word_group"
|
||||
self.list_words = [] # dict of word instances
|
||||
self.word_group_id = 0 # word group id
|
||||
self.line_id = 0 #id of line which instance belongs to
|
||||
self.paragraph_id = 0# id of paragraph which instance belongs to
|
||||
self.text =""
|
||||
self.boundingbox = [-1,-1,-1,-1]
|
||||
self.kie_label =""
|
||||
def add_word(self, word:Word): #add a word instance to the word_group
|
||||
if word.text != "✪":
|
||||
for w in self.list_words:
|
||||
if word.word_id == w.word_id:
|
||||
print("Word id collision")
|
||||
return False
|
||||
word.word_group_id = self.word_group_id #
|
||||
word.line_id = self.line_id
|
||||
word.paragraph_id = self.paragraph_id
|
||||
self.list_words.append(word)
|
||||
self.text += ' '+ word.text
|
||||
if self.boundingbox == [-1,-1,-1,-1]:
|
||||
self.boundingbox = word.boundingbox
|
||||
else:
|
||||
self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]),
|
||||
min(self.boundingbox[1], word.boundingbox[1]),
|
||||
max(self.boundingbox[2], word.boundingbox[2]),
|
||||
max(self.boundingbox[3], word.boundingbox[3])]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_word_group_id(self, new_word_group_id):
|
||||
self.word_group_id = new_word_group_id
|
||||
for i in range(len(self.list_words)):
|
||||
self.list_words[i].word_group_id = new_word_group_id
|
||||
|
||||
def update_kie_label(self):
|
||||
list_kie_label = [word.kie_label for word in self.list_words]
|
||||
dict_kie = dict()
|
||||
for label in list_kie_label:
|
||||
if label not in dict_kie:
|
||||
dict_kie[label]=1
|
||||
else:
|
||||
dict_kie[label]+=1
|
||||
total = len(list(dict_kie.values()))
|
||||
max_value = max(list(dict_kie.values()))
|
||||
list_keys = list(dict_kie.keys())
|
||||
list_values = list(dict_kie.values())
|
||||
self.kie_label = list_keys[list_values.index(max_value)]
|
||||
|
||||
class Line():
|
||||
def __init__(self):
|
||||
self.type = "line"
|
||||
self.list_word_groups = [] # list of Word_group instances in the line
|
||||
self.line_id = 0 #id of line in the paragraph
|
||||
self.paragraph_id = 0 # id of paragraph which instance belongs to
|
||||
self.text = ""
|
||||
self.boundingbox = [-1,-1,-1,-1]
|
||||
def add_group(self, word_group:Word_group): # add a word_group instance
|
||||
if word_group.list_words is not None:
|
||||
for wg in self.list_word_groups:
|
||||
if word_group.word_group_id == wg.word_group_id:
|
||||
print("Word_group id collision")
|
||||
return False
|
||||
|
||||
self.list_word_groups.append(word_group)
|
||||
self.text += word_group.text
|
||||
word_group.paragraph_id = self.paragraph_id
|
||||
word_group.line_id = self.line_id
|
||||
|
||||
for i in range(len(word_group.list_words)):
|
||||
word_group.list_words[i].paragraph_id = self.paragraph_id #set paragraph_id for word
|
||||
word_group.list_words[i].line_id = self.line_id #set line_id for word
|
||||
return True
|
||||
return False
|
||||
def update_line_id(self, new_line_id):
|
||||
self.line_id = new_line_id
|
||||
for i in range(len(self.list_word_groups)):
|
||||
self.list_word_groups[i].line_id = new_line_id
|
||||
for j in range(len(self.list_word_groups[i].list_words)):
|
||||
self.list_word_groups[i].list_words[j].line_id = new_line_id
|
||||
|
||||
|
||||
def merge_word(self, word): # word can be a Word instance or a Word_group instance
|
||||
if word.text != "✪":
|
||||
if self.boundingbox == [-1,-1,-1,-1]:
|
||||
self.boundingbox = word.boundingbox
|
||||
else:
|
||||
self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]),
|
||||
min(self.boundingbox[1], word.boundingbox[1]),
|
||||
max(self.boundingbox[2], word.boundingbox[2]),
|
||||
max(self.boundingbox[3], word.boundingbox[3])]
|
||||
self.list_word_groups.append(word)
|
||||
self.text += ' ' + word.text
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def in_same_line(self, input_line, thresh=0.7):
|
||||
# calculate iou in vertical direction
|
||||
left1, top1, right1, bottom1 = self.boundingbox
|
||||
left2, top2, right2, bottom2 = input_line.boundingbox
|
||||
|
||||
sorted_vals = sorted([top1, bottom1, top2, bottom2])
|
||||
intersection = sorted_vals[2] - sorted_vals[1]
|
||||
union = sorted_vals[3]-sorted_vals[0]
|
||||
min_height = min(bottom1-top1, bottom2-top2)
|
||||
if min_height==0:
|
||||
return False
|
||||
ratio = intersection / min_height
|
||||
height1, height2 = top1-bottom1, top2-bottom2
|
||||
ratio_height = float(max(height1, height2))/float(min(height1, height2))
|
||||
# height_diff = (float(top1-bottom1))/(float(top2-bottom2))
|
||||
|
||||
|
||||
if (top1 in range(top2, bottom2) or top2 in range(top1, bottom1)) and ratio >= thresh and (ratio_height<2):
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_iomin(word:Word, word_group:Word_group):
|
||||
min_height = min(word.boundingbox[3]-word.boundingbox[1],word_group.boundingbox[3]-word_group.boundingbox[1])
|
||||
intersect = min(word.boundingbox[3],word_group.boundingbox[3]) - max(word.boundingbox[1],word_group.boundingbox[1])
|
||||
if intersect/min_height > 0.7:
|
||||
return True
|
||||
return False
|
||||
|
||||
def words_to_lines(words, check_special_lines=True): #words is list of Word instance
|
||||
#sort word by top
|
||||
words.sort(key = lambda x: (x.boundingbox[1], x.boundingbox[0]))
|
||||
number_of_word = len(words)
|
||||
#sort list words to list lines, which have not contained word_group yet
|
||||
lines = []
|
||||
for i, word in enumerate(words):
|
||||
if word.invalid_size()==0:
|
||||
continue
|
||||
new_line = True
|
||||
for i in range(len(lines)):
|
||||
if lines[i].in_same_line(word): #check if word is in the same line with lines[i]
|
||||
lines[i].merge_word(word)
|
||||
new_line = False
|
||||
|
||||
if new_line ==True:
|
||||
new_line = Line()
|
||||
new_line.merge_word(word)
|
||||
lines.append(new_line)
|
||||
|
||||
# print(len(lines))
|
||||
#sort line from top to bottom according top coordinate
|
||||
lines.sort(key = lambda x: x.boundingbox[1])
|
||||
|
||||
#construct word_groups in each line
|
||||
line_id = 0
|
||||
word_group_id =0
|
||||
word_id = 0
|
||||
for i in range(len(lines)):
|
||||
if len(lines[i].list_word_groups)==0:
|
||||
continue
|
||||
#left, top ,right, bottom
|
||||
line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left
|
||||
# print("line_width",line_width)
|
||||
lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right
|
||||
|
||||
#update text for lines after sorting
|
||||
lines[i].text =""
|
||||
for word in lines[i].list_word_groups:
|
||||
lines[i].text += " "+word.text
|
||||
|
||||
list_word_groups=[]
|
||||
inital_word_group = Word_group()
|
||||
inital_word_group.word_group_id= word_group_id
|
||||
word_group_id +=1
|
||||
lines[i].list_word_groups[0].word_id=word_id
|
||||
inital_word_group.add_word(lines[i].list_word_groups[0])
|
||||
word_id+=1
|
||||
list_word_groups.append(inital_word_group)
|
||||
for word in lines[i].list_word_groups[1:]: #iterate through each word object in list_word_groups (has not been construted to word_group yet)
|
||||
check_word_group= True
|
||||
#set id for each word in each line
|
||||
word.word_id = word_id
|
||||
word_id+=1
|
||||
if (not list_word_groups[-1].text.endswith(':')) and ((word.boundingbox[0]-list_word_groups[-1].boundingbox[2])/line_width <0.05) and check_iomin(word, list_word_groups[-1]):
|
||||
list_word_groups[-1].add_word(word)
|
||||
check_word_group=False
|
||||
if check_word_group ==True:
|
||||
new_word_group = Word_group()
|
||||
new_word_group.word_group_id= word_group_id
|
||||
word_group_id +=1
|
||||
new_word_group.add_word(word)
|
||||
list_word_groups.append(new_word_group)
|
||||
lines[i].list_word_groups = list_word_groups
|
||||
# set id for lines from top to bottom
|
||||
lines[i].update_line_id(line_id)
|
||||
line_id +=1
|
||||
return lines, number_of_word
|
388
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word_preprocess.py
Executable file
388
cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word_preprocess.py
Executable file
@ -0,0 +1,388 @@
|
||||
import nltk
|
||||
import re
|
||||
import string
|
||||
import copy
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML
|
||||
from word2line import Word, words_to_lines
|
||||
nltk.download('words')
|
||||
words = set(nltk.corpus.words.words())
|
||||
|
||||
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
|
||||
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
|
||||
|
||||
# def clean_text(text):
|
||||
# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text)
|
||||
|
||||
|
||||
def get_string(lwords: list):
|
||||
unique_list = []
|
||||
for item in lwords:
|
||||
if item.isdigit() and len(item) == 1:
|
||||
unique_list.append(item)
|
||||
elif item not in unique_list:
|
||||
unique_list.append(item)
|
||||
return ' '.join(unique_list)
|
||||
|
||||
def remove_english_words(text):
|
||||
_word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words]
|
||||
return ' '.join(_word)
|
||||
|
||||
def remove_punctuation(text):
|
||||
return text.translate(str.maketrans(" ", " ", string.punctuation))
|
||||
|
||||
def remove_accents(input_str, s0, s1):
|
||||
s = ''
|
||||
# print input_str.encode('utf-8')
|
||||
for c in input_str:
|
||||
if c in s1:
|
||||
s += s0[s1.index(c)]
|
||||
else:
|
||||
s += c
|
||||
return s
|
||||
|
||||
def remove_spaces(text):
|
||||
return text.replace(' ', '')
|
||||
|
||||
def preprocessing(text: str):
|
||||
# text = remove_english_words(text) if table else text
|
||||
text = remove_punctuation(text)
|
||||
text = remove_accents(text, s0, s1)
|
||||
text = remove_spaces(text)
|
||||
return text.lower()
|
||||
|
||||
|
||||
def vat_standardize_outputs(vat_outputs: dict) -> dict:
|
||||
outputs = {}
|
||||
for key, value in vat_outputs.items():
|
||||
if key != "table":
|
||||
outputs[DKVU2XML[key]] = value
|
||||
else:
|
||||
list_items = []
|
||||
for item in value:
|
||||
list_items.append({
|
||||
DKVU2XML[item_key]: item_value for item_key, item_value in item.items()
|
||||
})
|
||||
outputs['table'] = list_items
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
def vat_standardizer(text: str, threshold: float, header: bool):
|
||||
dictionary = vat_dictionary(header)
|
||||
processed_text = preprocessing(text)
|
||||
|
||||
for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]:
|
||||
if any([processed_text in txt for txt in candidates]):
|
||||
processed_text = candidates[-1]
|
||||
return "Ngày, tháng, năm lập hóa đơn", 5, processed_text
|
||||
|
||||
_dictionary = copy.deepcopy(dictionary)
|
||||
if not header:
|
||||
exact_dictionary = {
|
||||
'Số hóa đơn': ['sono', 'so'],
|
||||
'Mã số thuế người bán': ['mst'],
|
||||
'Tên người bán': ['kyboi'],
|
||||
'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
|
||||
}
|
||||
for k, v in exact_dictionary.items():
|
||||
_dictionary[k] = dictionary[k] + exact_dictionary[k]
|
||||
|
||||
for k, v in dictionary.items():
|
||||
# if k in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
# continue
|
||||
# Prioritize match completely
|
||||
if k in ('Tên người bán') and processed_text == "kyboi":
|
||||
return k, 8, processed_text
|
||||
|
||||
if any([processed_text == key for key in _dictionary[k]]):
|
||||
return k, 10, processed_text
|
||||
|
||||
scores = {k: 0.0 for k in dictionary}
|
||||
for k, v in dictionary.items():
|
||||
if k in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
continue
|
||||
|
||||
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
|
||||
|
||||
key, score = max(scores.items(), key=lambda x: x[1])
|
||||
return key if score > threshold else text, score, processed_text
|
||||
|
||||
def ap_standardizer(text: str, threshold: float, header: bool):
|
||||
dictionary = ap_dictionary(header)
|
||||
processed_text = preprocessing(text)
|
||||
|
||||
# Prioritize match completely
|
||||
_dictionary = copy.deepcopy(dictionary)
|
||||
if not header:
|
||||
_dictionary['serial_number'] = dictionary['serial_number'] + ['sn']
|
||||
_dictionary['imei_number'] = dictionary['imei_number'] + ['imel', 'imed', 'ime'] # text recog error
|
||||
else:
|
||||
_dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei']
|
||||
_dictionary['qty'] = dictionary['qty'] + ['qty']
|
||||
for k, v in dictionary.items():
|
||||
if any([processed_text == key for key in _dictionary[k]]):
|
||||
return k, 10, processed_text
|
||||
|
||||
scores = {k: 0.0 for k in dictionary}
|
||||
for k, v in dictionary.items():
|
||||
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
|
||||
|
||||
key, score = max(scores.items(), key=lambda x: x[1])
|
||||
return key if score >= threshold else text, score, processed_text
|
||||
|
||||
def manulife_standardizer(text: str, threshold: float, type_dict: str):
|
||||
dictionary = manulife_dictionary(type=type_dict)
|
||||
processed_text = preprocessing(text)
|
||||
|
||||
for key, candidates in dictionary.items():
|
||||
|
||||
if any([txt == processed_text for txt in candidates]):
|
||||
return True, key, 5 * (1 + len(processed_text)), processed_text
|
||||
|
||||
if any([txt in processed_text for txt in candidates]):
|
||||
return True, key, 5, processed_text
|
||||
|
||||
scores = {k: 0.0 for k in dictionary}
|
||||
for k, v in dictionary.items():
|
||||
if len(v) == 0:
|
||||
continue
|
||||
scores[k] = max(
|
||||
[
|
||||
longestCommonSubsequence(processed_text, key) / len(key)
|
||||
for key in dictionary[k]
|
||||
]
|
||||
)
|
||||
key, score = max(scores.items(), key=lambda x: x[1])
|
||||
return score > threshold, key if score > threshold else text, score, processed_text
|
||||
|
||||
|
||||
def convert_format_number(s: str) -> float:
|
||||
s = s.replace(' ', '').replace('O', '0').replace('o', '0')
|
||||
if s.endswith(",00") or s.endswith(".00"):
|
||||
s = s[:-3]
|
||||
if all([delimiter in s for delimiter in [',', '.']]):
|
||||
s = s.replace('.', '').split(',')
|
||||
remain_value = s[1].split('0')[0]
|
||||
return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value))
|
||||
else:
|
||||
s = s.replace(',', '').replace('.', '')
|
||||
return int(s)
|
||||
|
||||
|
||||
def post_process_for_item(item: dict) -> dict:
|
||||
check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế']
|
||||
mis_key = []
|
||||
for key in check_keys:
|
||||
if item[key] in (None, '0'):
|
||||
mis_key.append(key)
|
||||
if len(mis_key) == 1:
|
||||
try:
|
||||
if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0:
|
||||
item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__()
|
||||
elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__()
|
||||
elif mis_key[0] == check_keys[2]:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||
except Exception as e:
|
||||
print("Cannot post process this item with error:", e)
|
||||
return item
|
||||
|
||||
|
||||
def longestCommonSubsequence(text1: str, text2: str) -> int:
|
||||
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
|
||||
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
|
||||
for i, c in enumerate(text1):
|
||||
for j, d in enumerate(text2):
|
||||
dp[i + 1][j + 1] = 1 + \
|
||||
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
|
||||
return dp[-1][-1]
|
||||
|
||||
|
||||
def longest_common_subsequence_with_idx(X, Y):
|
||||
"""
|
||||
This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS.
|
||||
The longest_common_subsequence function takes two strings as input, and returns a tuple with three values:
|
||||
the length of the LCS,
|
||||
the index of the first character of the LCS in the first string,
|
||||
and the index of the last character of the LCS in the first string.
|
||||
"""
|
||||
m, n = len(X), len(Y)
|
||||
L = [[0 for i in range(n + 1)] for j in range(m + 1)]
|
||||
|
||||
# Following steps build L[m+1][n+1] in bottom up fashion. Note
|
||||
# that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
|
||||
right_idx = 0
|
||||
max_lcs = 0
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
if L[i][j] > max_lcs:
|
||||
max_lcs = L[i][j]
|
||||
right_idx = i
|
||||
else:
|
||||
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||
|
||||
# Create a string variable to store the lcs string
|
||||
lcs = L[i][j]
|
||||
# Start from the right-most-bottom-most corner and
|
||||
# one by one store characters in lcs[]
|
||||
i = m
|
||||
j = n
|
||||
# right_idx = 0
|
||||
while i > 0 and j > 0:
|
||||
# If current character in X[] and Y are same, then
|
||||
# current character is part of LCS
|
||||
if X[i - 1] == Y[j - 1]:
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
# If not same, then find the larger of two and
|
||||
# go in the direction of larger value
|
||||
elif L[i - 1][j] > L[i][j - 1]:
|
||||
# right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs
|
||||
i -= 1
|
||||
else:
|
||||
j -= 1
|
||||
return lcs, i, max(i + lcs, right_idx)
|
||||
|
||||
|
||||
def get_string_by_deduplicate_bbox(lwords: list, lbboxes: list):
|
||||
unique_list = []
|
||||
prev_bbox = [-1, -1, -1, -1]
|
||||
for word, bbox in zip(lwords, lbboxes):
|
||||
if bbox != prev_bbox:
|
||||
unique_list.append(word)
|
||||
prev_bbox = bbox
|
||||
return ' '.join(unique_list)
|
||||
|
||||
def get_string_with_word2line(lwords: list, lbboxes: list):
|
||||
list_words = []
|
||||
unique_list = []
|
||||
list_sorted_words = []
|
||||
|
||||
prev_bbox = [-1, -1, -1, -1]
|
||||
for word, bbox in zip(lwords, lbboxes):
|
||||
if bbox != prev_bbox:
|
||||
prev_bbox = bbox
|
||||
list_words.append(Word(image=None, text=word, conf_cls=-1, bndbox=bbox, conf_detect=-1))
|
||||
unique_list.append(word)
|
||||
llines = words_to_lines(list_words)[0]
|
||||
|
||||
for line in llines:
|
||||
for _word_group in line.list_word_groups:
|
||||
for _word in _word_group.list_words:
|
||||
list_sorted_words.append(_word.text)
|
||||
|
||||
string_from_model = ' '.join(unique_list)
|
||||
string_after_word2line = ' '.join(list_sorted_words)
|
||||
|
||||
if string_from_model != string_after_word2line:
|
||||
print("[Warning] Word group from model is different with word2line module")
|
||||
print("Model: ", ' '.join(unique_list))
|
||||
print("Word2line: ", ' '.join(list_sorted_words))
|
||||
|
||||
return string_after_word2line
|
||||
|
||||
def remove_bullet_points_and_punctuation(text):
|
||||
# Remove bullet points (e.g., • or -)
|
||||
text = re.sub(r'^\s*[\•\-\*]\s*', '', text, flags=re.MULTILINE)
|
||||
text = re.sub("^\d+\s*", "", text)
|
||||
text = text.strip()
|
||||
# # Remove end-of-sentence punctuation (e.g., ., !, ?)
|
||||
# text = re.sub(r'[.!?]', '', text)
|
||||
if len(text) > 0 and text[0] in (',', '.', ':', ';', '?', '!'):
|
||||
text = text[1:]
|
||||
if len(text) > 0 and text[-1] in (',', '.', ':', ';', '?', '!'):
|
||||
text = text[:-1]
|
||||
return text.strip()
|
||||
|
||||
def split_key_value_by_colon(key: str, value: str) -> list:
|
||||
text_string = key + " " + value if value is not None else key
|
||||
elements = text_string.split(':')
|
||||
if len(elements) > 1:
|
||||
return elements[0], text_string[len(elements[0]):]
|
||||
return key, value
|
||||
|
||||
|
||||
# def normalize_kvu_output(raw_outputs: dict) -> dict:
|
||||
# outputs = {}
|
||||
# for key, values in raw_outputs.items():
|
||||
# if key == "table":
|
||||
# table = []
|
||||
# for row in values:
|
||||
# item = {}
|
||||
# for k, v in row.items():
|
||||
# k = remove_bullet_points_and_punctuation(k)
|
||||
# if v is not None and len(v) > 0:
|
||||
# v = remove_bullet_points_and_punctuation(v)
|
||||
# item[k] = v
|
||||
# table.append(item)
|
||||
# outputs[key] = table
|
||||
# else:
|
||||
# key = remove_bullet_points_and_punctuation(key)
|
||||
# if isinstance(values, list):
|
||||
# values = [remove_bullet_points_and_punctuation(v) for v in values]
|
||||
# elif values is not None and len(values) > 0:
|
||||
# values = remove_bullet_points_and_punctuation(values)
|
||||
# outputs[key] = values
|
||||
# return outputs
|
||||
def normalize_kvu_output(raw_outputs: dict) -> dict:
|
||||
outputs = {}
|
||||
for key, values in raw_outputs.items():
|
||||
if key == "tables" and len(values) > 0:
|
||||
table_list = []
|
||||
for table in values:
|
||||
headers, data = [], []
|
||||
headers = [remove_bullet_points_and_punctuation(header).upper() for header in table['headers']]
|
||||
for row in table['data']:
|
||||
item = []
|
||||
for k, v in row.items():
|
||||
if v is not None and len(v) > 0:
|
||||
item.append(remove_bullet_points_and_punctuation(v))
|
||||
else:
|
||||
item.append(v)
|
||||
data.append(item)
|
||||
table_list.append({"headers": headers, "data": data})
|
||||
outputs[key] = table_list
|
||||
else:
|
||||
key = remove_bullet_points_and_punctuation(key)
|
||||
if isinstance(values, list):
|
||||
values = [remove_bullet_points_and_punctuation(v) for v in values]
|
||||
elif values is not None and len(values) > 0:
|
||||
values = remove_bullet_points_and_punctuation(values)
|
||||
outputs[key] = values
|
||||
return outputs
|
||||
|
||||
def normalize_kvu_output_for_manulife(raw_outputs: dict) -> dict:
|
||||
outputs = {}
|
||||
for key, values in raw_outputs.items():
|
||||
if key == "tables" and len(values) > 0:
|
||||
table_list = []
|
||||
for table in values:
|
||||
headers, data = [], []
|
||||
headers = [
|
||||
remove_bullet_points_and_punctuation(header).upper()
|
||||
for header in table["headers"]
|
||||
]
|
||||
for row in table["data"]:
|
||||
item = []
|
||||
for k, v in row.items():
|
||||
if v is not None and len(v) > 0:
|
||||
item.append(remove_bullet_points_and_punctuation(v))
|
||||
else:
|
||||
item.append(v)
|
||||
data.append(item)
|
||||
table_list.append({"headers": headers, "data": data})
|
||||
outputs[key] = table_list
|
||||
else:
|
||||
if key not in ("title", "tables", "class_doc"):
|
||||
key = remove_bullet_points_and_punctuation(key).capitalize()
|
||||
if isinstance(values, list):
|
||||
values = [remove_bullet_points_and_punctuation(v) for v in values]
|
||||
elif values is not None and len(values) > 0:
|
||||
values = remove_bullet_points_and_punctuation(values)
|
||||
outputs[key] = values
|
||||
return outputs
|
58
cope2n-ai-fi/api/Kie_Invoice_AP/prediction.py
Executable file
58
cope2n-ai-fi/api/Kie_Invoice_AP/prediction.py
Executable file
@ -0,0 +1,58 @@
|
||||
from sdsvkie import Predictor
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib
|
||||
|
||||
model = Predictor(
|
||||
cfg = "/ai-core/models/Kie_invoice_ap/config.yaml",
|
||||
device = "cuda:0",
|
||||
weights = "/ai-core/models/Kie_invoice_ap/ep21"
|
||||
)
|
||||
|
||||
def predict(image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
out = model(img)
|
||||
output = out["end2end_results"]
|
||||
output_dict = {
|
||||
"document_type": "invoice",
|
||||
"fields": []
|
||||
}
|
||||
for key in output.keys():
|
||||
field = {
|
||||
"label": key if key != "id" else "Receipt Number",
|
||||
"value": output[key]['value'] if output[key]['value'] else "",
|
||||
"box": output[key]['box'],
|
||||
"confidence": output[key]['conf']
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
return output_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg"
|
||||
output = predict(image_url)
|
||||
print(output)
|
77
cope2n-ai-fi/api/Kie_Invoice_AP/prediction_fi.py
Executable file
77
cope2n-ai-fi/api/Kie_Invoice_AP/prediction_fi.py
Executable file
@ -0,0 +1,77 @@
|
||||
import os, sys
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
KIE_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvkie")
|
||||
TD_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtd")
|
||||
TR_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtr")
|
||||
sys.path.append(KIE_PATH)
|
||||
sys.path.append(TD_PATH)
|
||||
sys.path.append(TR_PATH)
|
||||
|
||||
from sdsvkie import Predictor
|
||||
from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib
|
||||
|
||||
model = Predictor(
|
||||
cfg = "/models/Kie_invoice_ap/06062023/config.yaml", # TODO: Better be scalable
|
||||
device = "cuda:0",
|
||||
weights = "/models/Kie_invoice_ap/06062023/best" # TODO: Better be scalable
|
||||
)
|
||||
|
||||
class_names = ['others', 'title', 'key', 'value', 'header']
|
||||
save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test")
|
||||
|
||||
predictor, processor = load_engine(exp_dir="/models/Kie_invoice_fi/key_value_understanding-20230627-164536",
|
||||
class_names=class_names,
|
||||
mode=3)
|
||||
|
||||
def predict_fi(page_numb, image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
# img = cv2.imread(image_url)
|
||||
|
||||
# Phan cua LeHoang
|
||||
out = model(img)
|
||||
output = out["end2end_results"]
|
||||
output_kie = {
|
||||
field_name: field_item['value'] for field_name, field_item in output.items()
|
||||
}
|
||||
# print("Hoangggggggggggggggggggggggggggggggggggggggggggggg")
|
||||
# print(output_kie)
|
||||
|
||||
|
||||
#Phan cua Tuan
|
||||
kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor)
|
||||
# print("TuanNnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn")
|
||||
# print(kvu_result)
|
||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||
return kvu_result, output_kie
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg"
|
||||
output = predict_fi(0, image_url)
|
||||
print(output)
|
146
cope2n-ai-fi/api/Kie_Invoice_AP/prediction_sap.py
Executable file
146
cope2n-ai-fi/api/Kie_Invoice_AP/prediction_sap.py
Executable file
@ -0,0 +1,146 @@
|
||||
from sdsvkie import Predictor
|
||||
import sys, os
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
# sys.path.append("cope2n-ai-fi/Kie_Invoice_AP") # Better be relative
|
||||
from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib
|
||||
import random
|
||||
|
||||
# model = Predictor(
|
||||
# cfg = "/cope2n-ai-fi/models/Kie_invoice_ap/config.yaml",
|
||||
# device = "cuda:0",
|
||||
# weights = "/cope2n-ai-fi/models/Kie_invoice_ap/best"
|
||||
# )
|
||||
|
||||
class_names = ['others', 'title', 'key', 'value', 'header']
|
||||
save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test")
|
||||
|
||||
predictor, processor = load_engine(exp_dir=os.path.join(cur_dir, "AnyKey_Value/experiments/key_value_understanding-20231003-171748"),
|
||||
class_names=class_names,
|
||||
mode=3)
|
||||
|
||||
def predict(page_numb, image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
# img = cv2.imread(image_url)
|
||||
|
||||
# Phan cua LeHoang
|
||||
# out = model(img)
|
||||
# output = out["end2end_results"]
|
||||
|
||||
|
||||
#Phan cua Tuan
|
||||
kvu_result = Predictor_KVU(image_url, save_dir, predictor, processor)
|
||||
output_dict = {
|
||||
"document_type": "invoice",
|
||||
"fields": []
|
||||
}
|
||||
for key in kvu_result.keys():
|
||||
field = {
|
||||
"label": key,
|
||||
"value": kvu_result[key],
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": random.uniform(0.9, 1.0),
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
return output_dict
|
||||
|
||||
# if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None:
|
||||
# output_dict = {
|
||||
# "document_type": "invoice",
|
||||
# "fields": []
|
||||
# }
|
||||
# for key in output.keys():
|
||||
# field = {
|
||||
# "label": key if key != "id" else "Receipt Number",
|
||||
# "value": output[key]['value'] if output[key]['value'] else "",
|
||||
# "box": output[key]['box'],
|
||||
# "confidence": output[key]['conf'],
|
||||
# "page": page_numb
|
||||
# }
|
||||
# output_dict['fields'].append(field)
|
||||
# table = kvu_result['table']
|
||||
# field_table = {
|
||||
# "label": "table",
|
||||
# "value": table,
|
||||
# "box": [0, 0, 0, 0],
|
||||
# "confidence": 0.98,
|
||||
# "page": page_numb
|
||||
# }
|
||||
# output_dict['fields'].append(field_table)
|
||||
# return output_dict
|
||||
|
||||
# else:
|
||||
# output_dict = {
|
||||
# "document_type": "KSU",
|
||||
# "fields": []
|
||||
# }
|
||||
# # for key in output.keys():
|
||||
# # field = {
|
||||
# # "label": key if key != "id" else "Receipt Number",
|
||||
# # "value": output[key]['value'] if output[key]['value'] else "",
|
||||
# # "box": output[key]['box'],
|
||||
# # "confidence": output[key]['conf'],
|
||||
# # "page": page_numb
|
||||
# # }
|
||||
# # output_dict['fields'].append(field)
|
||||
|
||||
# # Serial Number
|
||||
# serial_number = kvu_result['serial_number']
|
||||
# field_serial = {
|
||||
# "label" : "serial_number",
|
||||
# "value": serial_number,
|
||||
# "box": [0, 0, 0, 0],
|
||||
# "confidence": 0.98,
|
||||
# "page": page_numb
|
||||
# }
|
||||
# output_dict['fields'].append(field_serial)
|
||||
|
||||
# # IMEI Number
|
||||
# imei_number = kvu_result['imei_number']
|
||||
# if imei_number == None:
|
||||
# return output_dict
|
||||
# if imei_number != None:
|
||||
# for i in range(len(imei_number)):
|
||||
# field_imei = {
|
||||
# "label": "imei_number_{}".format(i+1),
|
||||
# "value": imei_number[i],
|
||||
# "box": [0, 0, 0, 0],
|
||||
# "confidence": 0.98,
|
||||
# "page": page_numb
|
||||
# }
|
||||
# output_dict['fields'].append(field_imei)
|
||||
|
||||
# return output_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
106
cope2n-ai-fi/api/Kie_Invoice_AP/tmp.txt
Executable file
106
cope2n-ai-fi/api/Kie_Invoice_AP/tmp.txt
Executable file
@ -0,0 +1,106 @@
|
||||
1113 773 1220 825 BEST
|
||||
1243 759 1378 808 DENKI
|
||||
1410 752 1487 799 (S)
|
||||
1430 707 1515 748 TAX
|
||||
1511 745 1598 790 PTE
|
||||
1542 700 1725 740 TNVOICE
|
||||
1618 742 1706 783 LTD
|
||||
1783 725 1920 773 FUNAN
|
||||
1943 723 2054 767 MALL
|
||||
1434 797 1576 843 WORTH
|
||||
1599 785 1760 831 BRIDGE
|
||||
1784 778 1846 822 RD
|
||||
1277 846 1632 897 #02-16/#03-1
|
||||
1655 832 1795 877 FUNAN
|
||||
1817 822 1931 869 MALL
|
||||
1272 897 1518 956 S(179105)
|
||||
1548 890 1655 943 TEL:
|
||||
1686 877 1911 928 69046183
|
||||
1247 1011 1334 1068 GST
|
||||
1358 1006 1447 1059 REG
|
||||
1360 1063 1449 1115 RCB
|
||||
1473 1003 1575 1055 NO.:
|
||||
1474 1059 1555 1110 NO.
|
||||
1595 1042 1868 1096 198202199E
|
||||
1607 985 1944 1040 M2-0053813-7
|
||||
1056 1134 1254 1194 Opening
|
||||
1276 1127 1391 1181 Hrs:
|
||||
1425 1112 1647 1170 10:00:00
|
||||
1672 1102 1735 1161 AN
|
||||
1755 1101 1819 1157 to
|
||||
1846 1090 2067 1147 10:00:00
|
||||
2090 1080 2156 1141 PH
|
||||
1061 1308 1228 1366 Staff:
|
||||
1258 1300 1378 1357 3296
|
||||
1710 1283 1880 1337 Trans:
|
||||
1936 1266 2192 1322 262152554
|
||||
1060 1372 1201 1429 Date:
|
||||
1260 1358 1494 1419 22-03-23
|
||||
1540 1344 1664 1409 9:05
|
||||
1712 1339 1856 1407 Slip:
|
||||
1917 1328 2196 1387 2000130286
|
||||
1124 1487 1439 1545 SALESPERSON
|
||||
1465 1477 1601 1537 CODE.
|
||||
1633 1471 1752 1530 6043
|
||||
1777 1462 2004 1519 HUHAHHAD
|
||||
2032 1451 2177 1509 RAZIH
|
||||
1070 1558 1187 1617 Item
|
||||
1211 1554 1276 1615 No
|
||||
1439 1542 1585 1601 Price
|
||||
1750 1530 1841 1597 Qty
|
||||
1951 1517 2120 1579 Amount
|
||||
1076 1683 1276 1741 ANDROID
|
||||
1304 1673 1477 1733 TABLET
|
||||
1080 1746 1280 1804 2105976
|
||||
1509 1729 1705 1784 SAMSUNG
|
||||
1734 1719 1931 1776 SH-P613
|
||||
1964 1709 2101 1768 128GB
|
||||
1082 1809 1285 1869 SM-P613
|
||||
1316 1802 1454 1860 12838
|
||||
1429 1859 1600 1919 518.00
|
||||
1481 1794 1596 1855 WIFI
|
||||
1622 1790 1656 1850 G
|
||||
1797 1845 1824 1904 1
|
||||
1993 1832 2165 1892 518.00
|
||||
1088 1935 1347 1995 PROMOTION
|
||||
1091 2000 1294 2062 2105664
|
||||
1520 1983 1717 2039 SAMSUNG
|
||||
1743 1963 2106 2030 F-Sam-Redeen
|
||||
1439 2111 1557 2173 0.00
|
||||
1806 2095 1832 2156 1
|
||||
2053 2081 2174 2144 0.00
|
||||
1106 2248 1250 2312 Total
|
||||
1974 2206 2146 2266 518.00
|
||||
1107 2312 1204 2377 UOB
|
||||
1448 2291 1567 2355 CARD
|
||||
1978 2268 2147 2327 518.00
|
||||
1253 2424 1375 2497 GST%
|
||||
1456 2411 1655 2475 Net.Amt
|
||||
1818 2393 1912 2460 GST
|
||||
2023 2387 2192 2445 Amount
|
||||
1106 2494 1231 2560 GST8
|
||||
1486 2472 1661 2537 479.63
|
||||
1770 2458 1916 2523 38.37
|
||||
2027 2448 2203 2511 518.00
|
||||
1553 2601 1699 2666 THANK
|
||||
1721 2592 1821 2661 YOU
|
||||
1436 2678 1616 2749 please
|
||||
1644 2682 1764 2732 come
|
||||
1790 2660 1942 2729 again
|
||||
1191 2862 1391 2931 Those
|
||||
1426 2870 2018 2945 facebook.com
|
||||
1565 2809 1690 2884 join
|
||||
1709 2816 1777 2870 us
|
||||
1799 2811 1868 2865 on
|
||||
1838 2946 2024 3003 com .89
|
||||
1533 3006 2070 3088 ar.com/askbe
|
||||
1300 3326 1659 3446 That's
|
||||
1696 3308 1905 3424 not
|
||||
1937 3289 2131 3408 all!
|
||||
1450 3511 1633 3573 SCAN
|
||||
1392 3589 1489 3645 QR
|
||||
1509 3577 1698 3635 CODE
|
||||
1321 3656 1370 3714 &
|
||||
1517 3638 1768 3699 updates
|
||||
1643 3882 1769 3932 Scan
|
||||
1789 3868 1859 3926 Me
|
BIN
cope2n-ai-fi/api/Kie_Invoice_AP/tmp_image/{image_url}.jpg
Executable file
BIN
cope2n-ai-fi/api/Kie_Invoice_AP/tmp_image/{image_url}.jpg
Executable file
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
155
cope2n-ai-fi/api/OCRBase/prediction.py
Executable file
155
cope2n-ai-fi/api/OCRBase/prediction.py
Executable file
@ -0,0 +1,155 @@
|
||||
from OCRBase.text_recognition import ocr_predict
|
||||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import urllib
|
||||
import numpy as np
|
||||
|
||||
def check_percent_overlap_bbox(boxA, boxB):
|
||||
"""check percent box A in boxB
|
||||
|
||||
Args:
|
||||
boxA (_type_): _description_
|
||||
boxB (_type_): _description_
|
||||
|
||||
Returns:
|
||||
Float: percent overlap bbox
|
||||
"""
|
||||
# determine the (x, y)-coordinates of the intersection rectangle
|
||||
box_shape_1 = [
|
||||
[boxA[0], boxA[1]],
|
||||
[boxA[2], boxA[1]],
|
||||
[boxA[2], boxA[3]],
|
||||
[boxA[0], boxA[3]],
|
||||
]
|
||||
|
||||
# Give dimensions of shape 2
|
||||
box_shape_2 = [
|
||||
[boxB[0], boxB[1]],
|
||||
[boxB[2], boxB[1]],
|
||||
[boxB[2], boxB[3]],
|
||||
[boxB[0], boxB[3]],
|
||||
]
|
||||
# Draw polygon 1 from shape 1
|
||||
# dimensions
|
||||
polygon_1 = Polygon(box_shape_1)
|
||||
|
||||
# Draw polygon 2 from shape 2
|
||||
# dimensions
|
||||
polygon_2 = Polygon(box_shape_2)
|
||||
|
||||
# Calculate the intersection of
|
||||
# bounding boxes
|
||||
intersect = polygon_1.intersection(polygon_2).area / polygon_1.area
|
||||
|
||||
return intersect
|
||||
|
||||
|
||||
def check_box_in_box(boxA, boxB):
|
||||
"""check boxA in boxB
|
||||
|
||||
Args:
|
||||
boxA (_type_): _description_
|
||||
boxB (_type_): _description_
|
||||
|
||||
Returns:
|
||||
Boolean: True if boxA in boxB
|
||||
"""
|
||||
if (
|
||||
boxA[0] >= boxB[0]
|
||||
and boxA[1] >= boxB[1]
|
||||
and boxA[2] <= boxB[2]
|
||||
and boxA[3] <= boxB[3]
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def word_to_line_image_origin(list_words, bbox):
|
||||
"""use for predict image with bbox selected
|
||||
|
||||
Args:
|
||||
list_words (_type_): _description_
|
||||
bbox (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
texts, boundingboxes = [], []
|
||||
for line in list_words:
|
||||
if line.text == "":
|
||||
continue
|
||||
else:
|
||||
# convert to bbox image original
|
||||
boundingbox = line.boundingbox
|
||||
boundingbox = list(boundingbox)
|
||||
boundingbox[0] = boundingbox[0] + bbox[0]
|
||||
boundingbox[1] = boundingbox[1] + bbox[1]
|
||||
boundingbox[2] = boundingbox[2] + bbox[0]
|
||||
boundingbox[3] = boundingbox[3] + bbox[1]
|
||||
texts.append(line.text)
|
||||
boundingboxes.append(boundingbox)
|
||||
return texts, boundingboxes
|
||||
|
||||
|
||||
def word_to_line(list_words):
|
||||
"""use for predict full image
|
||||
|
||||
Args:
|
||||
list_words (_type_): _description_
|
||||
"""
|
||||
texts, boundingboxes = [], []
|
||||
for line in list_words:
|
||||
print(line.text)
|
||||
if line.text == "":
|
||||
continue
|
||||
else:
|
||||
boundingbox = line.boundingbox
|
||||
boundingbox = list(boundingbox)
|
||||
texts.append(line.text)
|
||||
boundingboxes.append(boundingbox)
|
||||
return texts, boundingboxes
|
||||
|
||||
|
||||
def predict(page_numb, image_url):
|
||||
"""predict text from image
|
||||
|
||||
Args:
|
||||
image_path (String): path image to predict
|
||||
list_id (List): List id of bbox selected
|
||||
list_bbox (List): List bbox selected
|
||||
|
||||
Returns:
|
||||
Dict: Dict result of prediction
|
||||
"""
|
||||
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
image = cv2.imdecode(arr, -1)
|
||||
list_lines = ocr_predict(image)
|
||||
texts, boundingboxes = word_to_line(list_lines)
|
||||
result = {}
|
||||
texts_replace = []
|
||||
for text in texts:
|
||||
if "✪" in text:
|
||||
text = text.replace("✪", " ")
|
||||
texts_replace.append(text)
|
||||
else:
|
||||
texts_replace.append(text)
|
||||
result["texts"] = texts_replace
|
||||
result["boundingboxes"] = boundingboxes
|
||||
|
||||
output_dict = {
|
||||
"document_type": "ocr-base",
|
||||
"fields": []
|
||||
}
|
||||
field = {
|
||||
"label": "Text",
|
||||
"value": result["texts"],
|
||||
"box": result["boundingboxes"],
|
||||
"confidence": 0.98,
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
|
||||
return output_dict
|
22
cope2n-ai-fi/api/OCRBase/text_detection.py
Executable file
22
cope2n-ai-fi/api/OCRBase/text_detection.py
Executable file
@ -0,0 +1,22 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from distutils.command.config import config
|
||||
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
config = "model/yolox_s_8x8_300e_cocotext_1280.py"
|
||||
checkpoint = "model/best_bbox_mAP_epoch_294.pth"
|
||||
device = "cpu"
|
||||
|
||||
|
||||
def read_imagefile(file) -> Image.Image:
|
||||
image = Image.open(BytesIO(file))
|
||||
return image
|
||||
|
||||
|
||||
def detection_predict(image):
|
||||
model = init_detector(config, checkpoint, device=device)
|
||||
# test a single image
|
||||
result = inference_detector(model, image)
|
||||
return result
|
39
cope2n-ai-fi/api/OCRBase/text_recognition.py
Executable file
39
cope2n-ai-fi/api/OCRBase/text_recognition.py
Executable file
@ -0,0 +1,39 @@
|
||||
from common.utils.ocr_yolox import OcrEngineForYoloX_Invoice
|
||||
from common.utils.word_formation import Word, words_to_lines
|
||||
|
||||
|
||||
det_ckpt = "/models/sdsvtd/hub/wild_receipt_finetune_weights_c_lite.pth"
|
||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||
|
||||
engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt)
|
||||
|
||||
|
||||
def ocr_predict(img):
|
||||
"""Predict text from image
|
||||
|
||||
Args:
|
||||
image_path (str): _description_
|
||||
|
||||
Returns:
|
||||
list: list of words
|
||||
"""
|
||||
try:
|
||||
lbboxes, lwords = engine.run_image(img)
|
||||
lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)]
|
||||
list_lines, _ = words_to_lines(lWords)
|
||||
return list_lines
|
||||
# return lbboxes, lwords
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
list_lines = []
|
||||
return list_lines
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--image", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
list_lines = ocr_predict(args.image)
|
98
cope2n-ai-fi/api/manulife/predict_manulife.py
Normal file
98
cope2n-ai-fi/api/manulife/predict_manulife.py
Normal file
@ -0,0 +1,98 @@
|
||||
import cv2
|
||||
import urllib
|
||||
import random
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys, os
|
||||
cur_dir = str(Path(__file__).parents[2])
|
||||
sys.path.append(cur_dir)
|
||||
from modules.sdsvkvu import load_engine, process_img
|
||||
from modules.ocr_engine import OcrEngine
|
||||
from configs.manulife import device, ocr_cfg, kvu_cfg
|
||||
|
||||
def load_ocr_engine(opt) -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
engine = OcrEngine(**opt)
|
||||
print("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
|
||||
print("OCR engine configfs: \n", ocr_cfg)
|
||||
print("KVU configfs: \n", kvu_cfg)
|
||||
|
||||
ocr_engine = load_ocr_engine(ocr_cfg)
|
||||
kvu_cfg['ocr_engine'] = ocr_engine
|
||||
option = kvu_cfg['option']
|
||||
kvu_cfg.pop("option") # pop option
|
||||
manulife_engine = load_engine(kvu_cfg)
|
||||
|
||||
|
||||
def manulife_predict(image_url, engine) -> None:
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
|
||||
save_dir = "./tmp_results"
|
||||
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
|
||||
image_path = os.path.join(save_dir, "abc.jpg")
|
||||
cv2.imwrite(image_path, img)
|
||||
|
||||
outputs = process_img(img_path=image_path,
|
||||
save_dir=save_dir,
|
||||
engine=engine,
|
||||
export_all=False,
|
||||
option=option)
|
||||
return outputs
|
||||
|
||||
|
||||
def predict(page_numb, image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
kvu_result = manulife_predict(image_url, engine=manulife_engine)
|
||||
output_dict = {
|
||||
"document_type": kvu_result['title'] if kvu_result['title'] is not None else "unknown",
|
||||
"document_class": kvu_result['class_doc'] if kvu_result['class_doc'] is not None else "unknown",
|
||||
"page_number": page_numb,
|
||||
"fields": []
|
||||
}
|
||||
for key in kvu_result.keys():
|
||||
if key in ("title", "class_doc"):
|
||||
continue
|
||||
field = {
|
||||
"label": key,
|
||||
"value": kvu_result[key],
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": random.uniform(0.9, 1.0),
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
print(output_dict)
|
||||
return output_dict
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
94
cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py
Executable file
94
cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py
Executable file
@ -0,0 +1,94 @@
|
||||
import cv2
|
||||
import urllib
|
||||
import random
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys, os
|
||||
cur_dir = str(Path(__file__).parents[2])
|
||||
sys.path.append(cur_dir)
|
||||
from modules.sdsvkvu import load_engine, process_img
|
||||
from modules.ocr_engine import OcrEngine
|
||||
from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg
|
||||
|
||||
|
||||
def load_ocr_engine(opt) -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
engine = OcrEngine(**opt)
|
||||
print("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
|
||||
print("OCR engine configfs: \n", ocr_cfg)
|
||||
print("KVU configfs: \n", kvu_cfg)
|
||||
|
||||
ocr_engine = load_ocr_engine(ocr_cfg)
|
||||
kvu_cfg['ocr_engine'] = ocr_engine
|
||||
option = kvu_cfg['option']
|
||||
kvu_cfg.pop("option") # pop option
|
||||
sbt_engine = load_engine(kvu_cfg)
|
||||
|
||||
|
||||
def sbt_predict(image_url, engine) -> None:
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
|
||||
save_dir = "./tmp_results"
|
||||
# image_path = os.path.join(save_dir, f"{image_url}.jpg")
|
||||
image_path = os.path.join(save_dir, "abc.jpg")
|
||||
cv2.imwrite(image_path, img)
|
||||
|
||||
outputs = process_img(img_path=image_path,
|
||||
save_dir=save_dir,
|
||||
engine=engine,
|
||||
export_all=False,
|
||||
option=option)
|
||||
return outputs
|
||||
|
||||
def predict(page_numb, image_url):
|
||||
"""
|
||||
module predict function
|
||||
|
||||
Args:
|
||||
image_url (str): image url
|
||||
|
||||
Returns:
|
||||
example output:
|
||||
"data": {
|
||||
"document_type": "invoice",
|
||||
"fields": [
|
||||
{
|
||||
"label": "Invoice Number",
|
||||
"value": "INV-12345",
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": 0.98
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
dict: output of model
|
||||
"""
|
||||
|
||||
sbt_result = sbt_predict(image_url, engine=sbt_engine)
|
||||
output_dict = {
|
||||
"document_type": "invoice",
|
||||
"document_class": " ",
|
||||
"page_number": page_numb,
|
||||
"fields": []
|
||||
}
|
||||
for key in sbt_result.keys():
|
||||
field = {
|
||||
"label": key,
|
||||
"value": sbt_result[key],
|
||||
"box": [0, 0, 0, 0],
|
||||
"confidence": random.uniform(0.9, 1.0),
|
||||
"page": page_numb
|
||||
}
|
||||
output_dict['fields'].append(field)
|
||||
return output_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg"
|
||||
output = predict(0, image_url)
|
||||
print(output)
|
0
cope2n-ai-fi/celery_worker/__init__.py
Executable file
0
cope2n-ai-fi/celery_worker/__init__.py
Executable file
81
cope2n-ai-fi/celery_worker/client_connector.py
Executable file
81
cope2n-ai-fi/celery_worker/client_connector.py
Executable file
@ -0,0 +1,81 @@
|
||||
from celery import Celery
|
||||
import base64
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, True)
|
||||
)
|
||||
|
||||
class CeleryConnector:
|
||||
task_routes = {
|
||||
"process_id_result": {"queue": "id_card_rs"},
|
||||
"process_driver_license_result": {"queue": "driver_license_rs"},
|
||||
"process_invoice_result": {"queue": "invoice_rs"},
|
||||
"process_ocr_with_box_result": {"queue": "ocr_with_box_rs"},
|
||||
"process_template_matching_result": {"queue": "template_matching_rs"},
|
||||
# mock task
|
||||
"process_id": {"queue": "id_card"},
|
||||
"process_driver_license": {"queue": "driver_license"},
|
||||
"process_invoice": {"queue": "invoice"},
|
||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
||||
"process_template_matching": {"queue": "template_matching"},
|
||||
}
|
||||
app = Celery(
|
||||
"postman",
|
||||
broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
# backend="rpc://",
|
||||
)
|
||||
|
||||
def process_id_result(self, args):
|
||||
return self.send_task("process_id_result", args)
|
||||
|
||||
def process_driver_license_result(self, args):
|
||||
return self.send_task("process_driver_license_result", args)
|
||||
|
||||
def process_invoice_result(self, args):
|
||||
return self.send_task("process_invoice_result", args)
|
||||
|
||||
def process_ocr_with_box_result(self, args):
|
||||
return self.send_task("process_ocr_with_box_result", args)
|
||||
|
||||
def process_template_matching_result(self, args):
|
||||
return self.send_task("process_template_matching_result", args)
|
||||
|
||||
def process_id(self, args):
|
||||
return self.send_task("process_id", args)
|
||||
|
||||
def process_driver_license(self, args):
|
||||
return self.send_task("process_driver_license", args)
|
||||
|
||||
def process_invoice(self, args):
|
||||
return self.send_task("process_invoice", args)
|
||||
|
||||
def process_ocr_with_box(self, args):
|
||||
return self.send_task("process_ocr_with_box", args)
|
||||
|
||||
def process_template_matching(self, args):
|
||||
return self.send_task("process_template_matching", args)
|
||||
|
||||
def send_task(self, name=None, args=None):
|
||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
||||
return self.app.send_task(name, args)
|
||||
|
||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
||||
|
||||
|
||||
def main():
|
||||
rq_id = 345
|
||||
file_names = "abc.jpg"
|
||||
list_data = []
|
||||
|
||||
with open("/home/sds/thucpd/aicr-2022/abc.jpg", "rb") as fs:
|
||||
encoded_string = base64.b64encode(fs.read()).decode("utf-8")
|
||||
list_data.append(encoded_string)
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
a = c_connector.process_id(args=(rq_id, list_data, file_names))
|
||||
|
||||
print(a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
56
cope2n-ai-fi/celery_worker/client_connector_fi.py
Executable file
56
cope2n-ai-fi/celery_worker/client_connector_fi.py
Executable file
@ -0,0 +1,56 @@
|
||||
from celery import Celery
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, True)
|
||||
)
|
||||
|
||||
class CeleryConnector:
|
||||
task_routes = {
|
||||
'process_fi_invoice_result': {'queue': 'invoice_fi_rs'},
|
||||
'process_sap_invoice_result': {'queue': 'invoice_sap_rs'},
|
||||
'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'},
|
||||
'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'},
|
||||
# mock task
|
||||
'process_fi_invoice': {'queue': "invoice_fi"},
|
||||
'process_sap_invoice': {'queue': "invoice_sap"},
|
||||
'process_manulife_invoice': {'queue': "invoice_manulife"},
|
||||
'process_sbt_invoice': {'queue': "invoice_sbt"},
|
||||
}
|
||||
app = Celery(
|
||||
"postman",
|
||||
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
)
|
||||
|
||||
# mock task for FI
|
||||
def process_fi_invoice_result(self, args):
|
||||
return self.send_task("process_fi_invoice_result", args)
|
||||
|
||||
def process_fi_invoice(self, args):
|
||||
return self.send_task("process_fi_invoice", args)
|
||||
|
||||
# mock task for SAP
|
||||
def process_sap_invoice_result(self, args):
|
||||
return self.send_task("process_sap_invoice_result", args)
|
||||
|
||||
def process_sap_invoice(self, args):
|
||||
return self.send_task("process_sap_invoice", args)
|
||||
|
||||
# mock task for manulife
|
||||
def process_manulife_invoice_result(self, args):
|
||||
return self.send_task("process_manulife_invoice_result", args)
|
||||
|
||||
def process_manulife_invoice(self, args):
|
||||
return self.send_task("process_manulife_invoice", args)
|
||||
|
||||
# mock task for manulife
|
||||
def process_sbt_invoice_result(self, args):
|
||||
return self.send_task("process_sbt_invoice_result", args)
|
||||
|
||||
def process_sbt_invoice(self, args):
|
||||
return self.send_task("process_sbt_invoice", args)
|
||||
|
||||
def send_task(self, name=None, args=None):
|
||||
if name not in self.task_routes or "queue" not in self.task_routes[name]:
|
||||
return self.app.send_task(name, args)
|
||||
|
||||
return self.app.send_task(name, args, queue=self.task_routes[name]["queue"])
|
220
cope2n-ai-fi/celery_worker/mock_process_tasks.py
Executable file
220
cope2n-ai-fi/celery_worker/mock_process_tasks.py
Executable file
@ -0,0 +1,220 @@
|
||||
from celery_worker.worker import app
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
@app.task(name="process_id")
|
||||
def process_id(rq_id, sub_id, folder_name, list_url, user_id):
|
||||
from common.serve_model import predict
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="id_card")
|
||||
print(result)
|
||||
result = {
|
||||
"status": 200,
|
||||
"content": result,
|
||||
"message": "Success",
|
||||
}
|
||||
c_connector.process_id_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
# if image_croped is not None:
|
||||
# if result["data"] == []:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
# else:
|
||||
# result = {
|
||||
# "status": 200,
|
||||
# "content": result,
|
||||
# "message": "Success",
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result))
|
||||
# return {"rq_id": rq_id}
|
||||
# elif image_croped is None:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_id_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {
|
||||
"status": 404,
|
||||
"content": {},
|
||||
}
|
||||
c_connector.process_id_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_driver_license")
|
||||
def process_driver_license(rq_id, sub_id, folder_name, list_url, user_id):
|
||||
from common.serve_model import predict
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="driving_license")
|
||||
result = {
|
||||
"status": 200,
|
||||
"content": result,
|
||||
"message": "Success",
|
||||
}
|
||||
c_connector.process_driver_license_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
# result, image_croped = predict(str(url), "driving_license")
|
||||
# if image_croped is not None:
|
||||
# if result["data"] == []:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
# else:
|
||||
# result = {
|
||||
# "status": 200,
|
||||
# "content": result,
|
||||
# "message": "Success",
|
||||
# }
|
||||
# path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
||||
# cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_croped)
|
||||
# c_connector.process_driver_license_result((rq_id, result, path_image_croped))
|
||||
# return {"rq_id": rq_id}
|
||||
# elif image_croped is None:
|
||||
# result = {
|
||||
# "status": 404,
|
||||
# "content": {},
|
||||
# }
|
||||
# c_connector.process_driver_license_result((rq_id, result, None))
|
||||
# return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {
|
||||
"status": 404,
|
||||
"content": {},
|
||||
}
|
||||
c_connector.process_driver_license_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_template_matching")
|
||||
def process_template_matching(rq_id, sub_id, folder_name, url, tmp_json, user_id):
|
||||
from TemplateMatching.src.ocr_master import Extractor
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
import urllib
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
extractor = Extractor()
|
||||
try:
|
||||
req = urllib.request.urlopen(url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
imgs = [img]
|
||||
image_aliged = extractor.image_alige(imgs, tmp_json)
|
||||
if image_aliged is None:
|
||||
result = {
|
||||
"status": 401,
|
||||
"content": "Image is not match with template",
|
||||
}
|
||||
c_connector.process_template_matching_result(
|
||||
(rq_id, result, None)
|
||||
)
|
||||
return {"rq_id": rq_id}
|
||||
else:
|
||||
output = extractor.extract_information(
|
||||
image_aliged, tmp_json
|
||||
)
|
||||
path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id)
|
||||
cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_aliged)
|
||||
if output == {}:
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_template_matching_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
else:
|
||||
result = {
|
||||
"document_type": "template_matching",
|
||||
"fields": []
|
||||
}
|
||||
print(output)
|
||||
for field in tmp_json["fields"]:
|
||||
print(field["label"])
|
||||
field_value = {
|
||||
"label": field["label"],
|
||||
"value": output[field["label"]],
|
||||
"box": [float(num) for num in field["box"]],
|
||||
"confidence": 0.98 #TODO confidence
|
||||
}
|
||||
result["fields"].append(field_value)
|
||||
|
||||
print(result)
|
||||
result = {"status": 200, "content": result}
|
||||
c_connector.process_template_matching_result(
|
||||
(rq_id, result, path_image_croped)
|
||||
)
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_template_matching_result((rq_id, result, None))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
# @app.task(name="process_invoice")
|
||||
# def process_invoice(rq_id, url):
|
||||
# from celery_worker.client_connector import CeleryConnector
|
||||
# from Kie_Hoanglv.prediction import predict
|
||||
|
||||
# c_connector = CeleryConnector()
|
||||
# try:
|
||||
# print(url)
|
||||
# result = predict(str(url))
|
||||
# hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# hoadon = {"status": 404, "content": {}}
|
||||
# c_connector.process_invoice_result((rq_id, hoadon))
|
||||
# return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_invoice")
|
||||
def process_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
from common.process_pdf import compile_output
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
c_connector.process_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_ocr_with_box")
|
||||
def process_ocr_with_box(rq_id, list_url):
|
||||
from celery_worker.client_connector import CeleryConnector
|
||||
from common.process_pdf import compile_output_ocr_base
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output_ocr_base(list_url)
|
||||
result = {"status": 200, "content": result, "message": "Success"}
|
||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
result = {"status": 404, "content": {}}
|
||||
c_connector.process_ocr_with_box_result((rq_id, result))
|
||||
return {"rq_id": rq_id}
|
74
cope2n-ai-fi/celery_worker/mock_process_tasks_fi.py
Executable file
74
cope2n-ai-fi/celery_worker/mock_process_tasks_fi.py
Executable file
@ -0,0 +1,74 @@
|
||||
from celery_worker.worker_fi import app
|
||||
|
||||
@app.task(name="process_fi_invoice")
|
||||
def process_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_fi
|
||||
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output_fi(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_fi_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
|
||||
@app.task(name="process_sap_invoice")
|
||||
def process_sap_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output
|
||||
|
||||
print(list_url)
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_sap_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_manulife_invoice")
|
||||
def process_manulife_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_manulife
|
||||
# TODO: simply returning 200 and 404 doesn't make any sense
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output_manulife(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_manulife_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
|
||||
@app.task(name="process_sbt_invoice")
|
||||
def process_sbt_invoice(rq_id, list_url):
|
||||
from celery_worker.client_connector_fi import CeleryConnector
|
||||
from common.process_pdf import compile_output_sbt
|
||||
# TODO: simply returning 200 and 404 doesn't make any sense
|
||||
c_connector = CeleryConnector()
|
||||
try:
|
||||
result = compile_output_sbt(list_url)
|
||||
hoadon = {"status": 200, "content": result, "message": "Success"}
|
||||
print(hoadon)
|
||||
c_connector.process_sbt_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
hoadon = {"status": 404, "content": {}}
|
||||
c_connector.process_sbt_invoice_result((rq_id, hoadon))
|
||||
return {"rq_id": rq_id}
|
40
cope2n-ai-fi/celery_worker/worker.py
Executable file
40
cope2n-ai-fi/celery_worker/worker.py
Executable file
@ -0,0 +1,40 @@
|
||||
from celery import Celery
|
||||
from kombu import Queue, Exchange
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, True)
|
||||
)
|
||||
|
||||
app: Celery = Celery(
|
||||
"postman",
|
||||
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
# backend="rpc://",
|
||||
include=[
|
||||
"celery_worker.mock_process_tasks",
|
||||
],
|
||||
)
|
||||
task_exchange = Exchange("default", type="direct")
|
||||
task_create_missing_queues = False
|
||||
app.conf.update(
|
||||
{
|
||||
"result_expires": 3600,
|
||||
"task_queues": [
|
||||
Queue("id_card"),
|
||||
Queue("driver_license"),
|
||||
Queue("invoice"),
|
||||
Queue("ocr_with_box"),
|
||||
Queue("template_matching"),
|
||||
],
|
||||
"task_routes": {
|
||||
"process_id": {"queue": "id_card"},
|
||||
"process_driver_license": {"queue": "driver_license"},
|
||||
"process_invoice": {"queue": "invoice"},
|
||||
"process_ocr_with_box": {"queue": "ocr_with_box"},
|
||||
"process_template_matching": {"queue": "template_matching"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
argv = ["celery_worker.worker", "--loglevel=INFO", "--pool=solo"] # Window opts
|
||||
app.worker_main(argv)
|
37
cope2n-ai-fi/celery_worker/worker_fi.py
Executable file
37
cope2n-ai-fi/celery_worker/worker_fi.py
Executable file
@ -0,0 +1,37 @@
|
||||
from celery import Celery
|
||||
from kombu import Queue, Exchange
|
||||
import environ
|
||||
env = environ.Env(
|
||||
DEBUG=(bool, True)
|
||||
)
|
||||
|
||||
app: Celery = Celery(
|
||||
"postman",
|
||||
broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"),
|
||||
include=[
|
||||
"celery_worker.mock_process_tasks_fi",
|
||||
],
|
||||
)
|
||||
task_exchange = Exchange("default", type="direct")
|
||||
task_create_missing_queues = False
|
||||
app.conf.update(
|
||||
{
|
||||
"result_expires": 3600,
|
||||
"task_queues": [
|
||||
Queue("invoice_fi"),
|
||||
Queue("invoice_sap"),
|
||||
Queue("invoice_manulife"),
|
||||
Queue("invoice_sbt"),
|
||||
],
|
||||
"task_routes": {
|
||||
'process_fi_invoice': {'queue': "invoice_fi"},
|
||||
'process_fi_invoice_result': {'queue': 'invoice_fi_rs'},
|
||||
'process_sap_invoice': {'queue': "invoice_sap"},
|
||||
'process_sap_invoice_result': {'queue': 'invoice_sap_rs'},
|
||||
'process_manulife_invoice': {'queue': 'invoice_manulife'},
|
||||
'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'},
|
||||
'process_sbt_invoice': {'queue': 'invoice_sbt'},
|
||||
'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'},
|
||||
},
|
||||
}
|
||||
)
|
101
cope2n-ai-fi/common/AnyKey_Value/anyKeyValue.py
Executable file
101
cope2n-ai-fi/common/AnyKey_Value/anyKeyValue.py
Executable file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from datetime import datetime
|
||||
# from omegaconf import OmegaConf
|
||||
import sys
|
||||
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/common/AnyKey_Value') # TODO: ????
|
||||
from predictor import KVUPredictor
|
||||
from preprocess import KVUProcess, DocumentKVUProcess
|
||||
from utils.utils import create_dir, visualize, get_colormap, export_kvu_for_VAT_invoice, export_kvu_outputs
|
||||
|
||||
|
||||
def get_args():
|
||||
args = argparse.ArgumentParser(description='Main file')
|
||||
args.add_argument('--img_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
||||
help='Input image directory')
|
||||
args.add_argument('--save_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str,
|
||||
help='Save directory')
|
||||
args.add_argument('--exp_dir', default='/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str,
|
||||
help='Checkpoint and config of model')
|
||||
args.add_argument('--export_img', default=0, type=int,
|
||||
help='Save visualize on image')
|
||||
args.add_argument('--mode', default=3, type=int,
|
||||
help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'")
|
||||
args.add_argument('--dir_level', default=0, type=int,
|
||||
help='Number of subfolders contains image')
|
||||
|
||||
return args.parse_args()
|
||||
|
||||
|
||||
def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor:
|
||||
configs = {
|
||||
'cfg': glob.glob(f'{exp_dir}/*.yaml')[0],
|
||||
'ckpt': f'{exp_dir}/checkpoints/best_model.pth'
|
||||
}
|
||||
dummy_idx = 512
|
||||
predictor = KVUPredictor(configs, class_names, dummy_idx, mode)
|
||||
|
||||
# processor = KVUProcess(predictor.net.tokenizer_layoutxlm,
|
||||
# predictor.net.feature_extractor, predictor.backbone_type, class_names,
|
||||
# predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode)
|
||||
|
||||
processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor,
|
||||
predictor.backbone_type, class_names,
|
||||
predictor.max_window_count, predictor.slice_interval, predictor.window_size,
|
||||
run_ocr=1, mode=mode)
|
||||
return predictor, processor
|
||||
|
||||
|
||||
def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
||||
fname = os.path.basename(img_path)
|
||||
img_ext = os.path.splitext(img_path)[1]
|
||||
output_ext = ".json"
|
||||
inputs = processor(img_path, ocr_path='')
|
||||
|
||||
bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs)
|
||||
|
||||
slide_window = False if len(bbox) == 1 else True
|
||||
|
||||
if len(bbox) == 0:
|
||||
vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords, pr_class_words, pr_relations, predictor.class_names)
|
||||
else:
|
||||
for i in range(len(bbox)):
|
||||
if not slide_window:
|
||||
save_path = os.path.join(save_dir, 'kvu_results')
|
||||
create_dir(save_path)
|
||||
# export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names)
|
||||
|
||||
return vat_outputs
|
||||
|
||||
|
||||
def Predictor_KVU(img: str, save_dir: str, predictor: KVUPredictor, processor) -> None:
|
||||
|
||||
# req = urllib.request.urlopen(image_url)
|
||||
# arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
# img = cv2.imdecode(arr, -1)
|
||||
curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
|
||||
image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime)
|
||||
cv2.imwrite(image_path, img)
|
||||
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
||||
return vat_outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
class_names = ['others', 'title', 'key', 'value', 'header']
|
||||
predict_mode = {
|
||||
'normal': 0,
|
||||
'full_tokens': 1,
|
||||
'sliding': 2,
|
||||
'document': 3
|
||||
}
|
||||
predictor, processor = load_engine(args.exp_dir, class_names, args.mode)
|
||||
create_dir(args.save_dir)
|
||||
image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg"
|
||||
save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1"
|
||||
vat_outputs = predict_image(image_path, save_dir, predictor, processor)
|
||||
print('[INFO] Done')
|
0
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/__init__.py
Executable file
0
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/__init__.py
Executable file
133
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py
Executable file
133
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py
Executable file
@ -0,0 +1,133 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from overrides import overrides
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.optim import SGD, Adam, AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lightning_modules.schedulers import (
|
||||
cosine_scheduler,
|
||||
linear_scheduler,
|
||||
multistep_scheduler,
|
||||
)
|
||||
from model import get_model
|
||||
from utils import cfg_to_hparams, get_specific_pl_logger
|
||||
|
||||
|
||||
class ClassifierModule(LightningModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.net = get_model(self.cfg)
|
||||
self.ignore_index = -100
|
||||
|
||||
self.time_tracker = None
|
||||
|
||||
self.optimizer_types = {
|
||||
"sgd": SGD,
|
||||
"adam": Adam,
|
||||
"adamw": AdamW,
|
||||
}
|
||||
|
||||
@overrides
|
||||
def setup(self, stage):
|
||||
self.time_tracker = time.time()
|
||||
|
||||
@overrides
|
||||
def configure_optimizers(self):
|
||||
optimizer = self._get_optimizer()
|
||||
scheduler = self._get_lr_scheduler(optimizer)
|
||||
scheduler = {
|
||||
"scheduler": scheduler,
|
||||
"name": "learning_rate",
|
||||
"interval": "step",
|
||||
}
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def _get_lr_scheduler(self, optimizer):
|
||||
cfg_train = self.cfg.train
|
||||
lr_schedule_method = cfg_train.optimizer.lr_schedule.method
|
||||
lr_schedule_params = cfg_train.optimizer.lr_schedule.params
|
||||
|
||||
if lr_schedule_method is None:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
|
||||
elif lr_schedule_method == "step":
|
||||
scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
|
||||
elif lr_schedule_method == "cosine":
|
||||
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
|
||||
total_batch_size = cfg_train.batch_size * self.trainer.world_size
|
||||
max_iter = total_samples / total_batch_size
|
||||
scheduler = cosine_scheduler(
|
||||
optimizer, training_steps=max_iter, **lr_schedule_params
|
||||
)
|
||||
elif lr_schedule_method == "linear":
|
||||
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
|
||||
total_batch_size = cfg_train.batch_size * self.trainer.world_size
|
||||
max_iter = total_samples / total_batch_size
|
||||
scheduler = linear_scheduler(
|
||||
optimizer, training_steps=max_iter, **lr_schedule_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")
|
||||
|
||||
return scheduler
|
||||
|
||||
def _get_optimizer(self):
|
||||
opt_cfg = self.cfg.train.optimizer
|
||||
method = opt_cfg.method.lower()
|
||||
|
||||
if method not in self.optimizer_types:
|
||||
raise ValueError(f"Unknown optimizer method={method}")
|
||||
|
||||
kwargs = dict(opt_cfg.params)
|
||||
kwargs["params"] = self.net.parameters()
|
||||
optimizer = self.optimizer_types[method](**kwargs)
|
||||
|
||||
return optimizer
|
||||
|
||||
@rank_zero_only
|
||||
@overrides
|
||||
def on_fit_end(self):
|
||||
hparam_dict = cfg_to_hparams(self.cfg, {})
|
||||
metric_dict = {"metric/dummy": 0}
|
||||
|
||||
tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger)
|
||||
|
||||
if tb_logger:
|
||||
tb_logger.log_hyperparams(hparam_dict, metric_dict)
|
||||
|
||||
@overrides
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
avg_loss = torch.tensor(0.0).to(self.device)
|
||||
for step_out in training_step_outputs:
|
||||
avg_loss += step_out["loss"]
|
||||
|
||||
log_dict = {"train_loss": avg_loss}
|
||||
self._log_shell(log_dict, prefix="train ")
|
||||
|
||||
def _log_shell(self, log_info, prefix=""):
|
||||
log_info_shell = {}
|
||||
for k, v in log_info.items():
|
||||
new_v = v
|
||||
if type(new_v) is torch.Tensor:
|
||||
new_v = new_v.item()
|
||||
log_info_shell[k] = new_v
|
||||
|
||||
out_str = prefix.upper()
|
||||
if prefix.upper().strip() in ["TRAIN", "VAL"]:
|
||||
out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]"
|
||||
|
||||
if self.training:
|
||||
lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"]
|
||||
log_info_shell["lr"] = lr
|
||||
|
||||
for key, value in log_info_shell.items():
|
||||
out_str += f" || {key}: {round(value, 5)}"
|
||||
out_str += f" || time: {round(time.time() - self.time_tracker, 1)}"
|
||||
out_str += " secs."
|
||||
# self.print(out_str)
|
||||
self.time_tracker = time.time()
|
390
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py
Executable file
390
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py
Executable file
@ -0,0 +1,390 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from overrides import overrides
|
||||
|
||||
from lightning_modules.classifier import ClassifierModule
|
||||
from utils import get_class_names
|
||||
|
||||
|
||||
class KVUClassifierModule(ClassifierModule):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
class_names = get_class_names(self.cfg.dataset_root_path)
|
||||
|
||||
self.window_size = cfg.train.max_num_words
|
||||
self.slice_interval = cfg.train.slice_interval
|
||||
self.eval_kwargs = {
|
||||
"class_names": class_names,
|
||||
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
|
||||
}
|
||||
self.stage = cfg.stage
|
||||
|
||||
@overrides
|
||||
def training_step(self, batch, batch_idx, *args):
|
||||
if self.stage == 1:
|
||||
_, loss = self.net(batch['windows'])
|
||||
elif self.stage == 2:
|
||||
_, loss = self.net(batch)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported stage: {self.stage}"
|
||||
)
|
||||
|
||||
log_dict_input = {"train_loss": loss}
|
||||
self.log_dict(log_dict_input, sync_dist=True)
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
@overrides
|
||||
def validation_step(self, batch, batch_idx, *args):
|
||||
if self.stage == 1:
|
||||
step_out_total = {
|
||||
"loss": 0,
|
||||
"ee":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
},
|
||||
"el":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
},
|
||||
"el_from_key":{
|
||||
"n_batch_gt": 0,
|
||||
"n_batch_pr": 0,
|
||||
"n_batch_correct": 0,
|
||||
}}
|
||||
for window in batch['windows']:
|
||||
head_outputs, loss = self.net(window)
|
||||
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
|
||||
for key in step_out_total:
|
||||
if key == 'loss':
|
||||
step_out_total[key] += step_out[key]
|
||||
else:
|
||||
for subkey in step_out_total[key]:
|
||||
step_out_total[key][subkey] += step_out[key][subkey]
|
||||
return step_out_total
|
||||
|
||||
elif self.stage == 2:
|
||||
head_outputs, loss = self.net(batch)
|
||||
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
|
||||
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
|
||||
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
|
||||
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
|
||||
return step_out
|
||||
|
||||
@torch.no_grad()
|
||||
@overrides
|
||||
def validation_epoch_end(self, validation_step_outputs):
|
||||
scores = do_eval_epoch_end(validation_step_outputs)
|
||||
self.print(
|
||||
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
|
||||
)
|
||||
self.print(
|
||||
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
|
||||
)
|
||||
self.print(
|
||||
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
|
||||
)
|
||||
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
|
||||
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
|
||||
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
|
||||
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
|
||||
'val_f1_el_from_key': scores['el_from_key']['f1'],}
|
||||
return {'log': tensorboard_logs}
|
||||
|
||||
|
||||
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
|
||||
class_names = eval_kwargs["class_names"]
|
||||
dummy_idx = eval_kwargs["dummy_idx"]
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_labels = torch.argmax(itc_outputs, -1)
|
||||
pr_stc_labels = torch.argmax(stc_outputs, -1)
|
||||
pr_el_labels = torch.argmax(el_outputs, -1)
|
||||
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
|
||||
|
||||
(
|
||||
n_batch_gt_classes,
|
||||
n_batch_pr_classes,
|
||||
n_batch_correct_classes,
|
||||
) = eval_ee_spade_batch(
|
||||
pr_itc_labels,
|
||||
batch["itc_labels"],
|
||||
batch["are_box_first_tokens"],
|
||||
pr_stc_labels,
|
||||
batch["stc_labels"],
|
||||
batch["attention_mask_layoutxlm"],
|
||||
class_names,
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
|
||||
pr_el_labels,
|
||||
batch["el_labels"],
|
||||
batch["are_box_first_tokens"],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
|
||||
pr_el_labels_from_key,
|
||||
batch["el_labels_from_key"],
|
||||
batch["are_box_first_tokens"],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
step_out = {
|
||||
"loss": loss,
|
||||
"ee":{
|
||||
"n_batch_gt": n_batch_gt_classes,
|
||||
"n_batch_pr": n_batch_pr_classes,
|
||||
"n_batch_correct": n_batch_correct_classes,
|
||||
},
|
||||
"el":{
|
||||
"n_batch_gt": n_batch_gt_rel,
|
||||
"n_batch_pr": n_batch_pr_rel,
|
||||
"n_batch_correct": n_batch_correct_rel,
|
||||
},
|
||||
"el_from_key":{
|
||||
"n_batch_gt": n_batch_gt_rel_from_key,
|
||||
"n_batch_pr": n_batch_pr_rel_from_key,
|
||||
"n_batch_correct": n_batch_correct_rel_from_key,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return step_out
|
||||
|
||||
|
||||
def eval_ee_spade_batch(
|
||||
pr_itc_labels,
|
||||
gt_itc_labels,
|
||||
are_box_first_tokens,
|
||||
pr_stc_labels,
|
||||
gt_stc_labels,
|
||||
attention_mask,
|
||||
class_names,
|
||||
dummy_idx,
|
||||
):
|
||||
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
|
||||
|
||||
bsz = pr_itc_labels.shape[0]
|
||||
for example_idx in range(bsz):
|
||||
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
|
||||
pr_itc_labels[example_idx],
|
||||
gt_itc_labels[example_idx],
|
||||
are_box_first_tokens[example_idx],
|
||||
pr_stc_labels[example_idx],
|
||||
gt_stc_labels[example_idx],
|
||||
attention_mask[example_idx],
|
||||
class_names,
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_classes += n_gt_classes
|
||||
n_batch_pr_classes += n_pr_classes
|
||||
n_batch_correct_classes += n_correct_classes
|
||||
|
||||
return (
|
||||
n_batch_gt_classes,
|
||||
n_batch_pr_classes,
|
||||
n_batch_correct_classes,
|
||||
)
|
||||
|
||||
|
||||
def eval_ee_spade_example(
|
||||
pr_itc_label,
|
||||
gt_itc_label,
|
||||
box_first_token_mask,
|
||||
pr_stc_label,
|
||||
gt_stc_label,
|
||||
attention_mask,
|
||||
class_names,
|
||||
dummy_idx,
|
||||
):
|
||||
gt_first_words = parse_initial_words(
|
||||
gt_itc_label, box_first_token_mask, class_names
|
||||
)
|
||||
gt_class_words = parse_subsequent_words(
|
||||
gt_stc_label, attention_mask, gt_first_words, dummy_idx
|
||||
)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
||||
)
|
||||
|
||||
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
|
||||
for class_idx in range(len(class_names)):
|
||||
# Evaluate by ID
|
||||
gt_parse = set(gt_class_words[class_idx])
|
||||
pr_parse = set(pr_class_words[class_idx])
|
||||
|
||||
n_gt_classes += len(gt_parse)
|
||||
n_pr_classes += len(pr_parse)
|
||||
n_correct_classes += len(gt_parse & pr_parse)
|
||||
|
||||
return n_gt_classes, n_pr_classes, n_correct_classes
|
||||
|
||||
|
||||
def parse_initial_words(itc_label, box_first_token_mask, class_names):
|
||||
itc_label_np = itc_label.cpu().numpy()
|
||||
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
|
||||
|
||||
outputs = [[] for _ in range(len(class_names))]
|
||||
|
||||
for token_idx, label in enumerate(itc_label_np):
|
||||
if box_first_token_mask_np[token_idx] and label != 0:
|
||||
outputs[label].append(token_idx)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
|
||||
max_connections = 50
|
||||
|
||||
valid_stc_label = stc_label * attention_mask.bool()
|
||||
valid_stc_label = valid_stc_label.cpu().numpy()
|
||||
stc_label_np = stc_label.cpu().numpy()
|
||||
|
||||
valid_token_indices = np.where(
|
||||
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
|
||||
)
|
||||
|
||||
next_token_idx_dict = {}
|
||||
for token_idx in valid_token_indices[0]:
|
||||
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
|
||||
|
||||
outputs = []
|
||||
for init_token_indices in init_words:
|
||||
sub_outputs = []
|
||||
for init_token_idx in init_token_indices:
|
||||
cur_token_indices = [init_token_idx]
|
||||
for _ in range(max_connections):
|
||||
if cur_token_indices[-1] in next_token_idx_dict:
|
||||
if (
|
||||
next_token_idx_dict[cur_token_indices[-1]]
|
||||
not in init_token_indices
|
||||
):
|
||||
cur_token_indices.append(
|
||||
next_token_idx_dict[cur_token_indices[-1]]
|
||||
)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
sub_outputs.append(tuple(cur_token_indices))
|
||||
|
||||
outputs.append(sub_outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def eval_el_spade_batch(
|
||||
pr_el_labels,
|
||||
gt_el_labels,
|
||||
are_box_first_tokens,
|
||||
dummy_idx,
|
||||
):
|
||||
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
|
||||
|
||||
bsz = pr_el_labels.shape[0]
|
||||
for example_idx in range(bsz):
|
||||
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
|
||||
pr_el_labels[example_idx],
|
||||
gt_el_labels[example_idx],
|
||||
are_box_first_tokens[example_idx],
|
||||
dummy_idx,
|
||||
)
|
||||
|
||||
n_batch_gt_rel += n_gt_rel
|
||||
n_batch_pr_rel += n_pr_rel
|
||||
n_batch_correct_rel += n_correct_rel
|
||||
|
||||
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
|
||||
|
||||
|
||||
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
|
||||
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
|
||||
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
||||
|
||||
gt_relations = set(gt_relations)
|
||||
pr_relations = set(pr_relations)
|
||||
|
||||
n_gt_rel = len(gt_relations)
|
||||
n_pr_rel = len(pr_relations)
|
||||
n_correct_rel = len(gt_relations & pr_relations)
|
||||
|
||||
return n_gt_rel, n_pr_rel, n_correct_rel
|
||||
|
||||
|
||||
def parse_relations(el_label, box_first_token_mask, dummy_idx):
|
||||
valid_el_labels = el_label * box_first_token_mask
|
||||
valid_el_labels = valid_el_labels.cpu().numpy()
|
||||
el_label_np = el_label.cpu().numpy()
|
||||
|
||||
max_token = box_first_token_mask.shape[0] - 1
|
||||
|
||||
valid_token_indices = np.where(
|
||||
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
|
||||
)
|
||||
|
||||
link_map_tuples = []
|
||||
for token_idx in valid_token_indices[0]:
|
||||
link_map_tuples.append((el_label_np[token_idx], token_idx))
|
||||
|
||||
return set(link_map_tuples)
|
||||
|
||||
def do_eval_epoch_end(step_outputs):
|
||||
scores = {}
|
||||
for task in ['ee', 'el', 'el_from_key']:
|
||||
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
|
||||
|
||||
for step_out in step_outputs:
|
||||
n_total_gt_classes += step_out[task]["n_batch_gt"]
|
||||
n_total_pr_classes += step_out[task]["n_batch_pr"]
|
||||
n_total_correct_classes += step_out[task]["n_batch_correct"]
|
||||
|
||||
precision = (
|
||||
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
|
||||
)
|
||||
recall = (
|
||||
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
|
||||
)
|
||||
f1 = (
|
||||
0.0
|
||||
if recall * precision == 0
|
||||
else 2.0 * recall * precision / (recall + precision)
|
||||
)
|
||||
|
||||
scores[task] = {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
}
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
|
||||
class_names = get_class_names(dataset_root_path)
|
||||
dummy_idx = max_seq_length
|
||||
|
||||
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
|
||||
|
||||
return eval_kwargs
|
||||
|
||||
|
||||
def get_eval_kwargs_spade_rel(max_seq_length):
|
||||
dummy_idx = max_seq_length
|
||||
|
||||
eval_kwargs = {"dummy_idx": dummy_idx}
|
||||
|
||||
return eval_kwargs
|
53
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/schedulers.py
Executable file
53
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/schedulers.py
Executable file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
BROS
|
||||
Copyright 2022-present NAVER Corp.
|
||||
Apache License v2.0
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
|
||||
def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1):
|
||||
"""linear_scheduler with warmup from huggingface"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return max(
|
||||
0.0,
|
||||
float(training_steps - current_step)
|
||||
/ float(max(1, training_steps - warmup_steps)),
|
||||
)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def cosine_scheduler(
|
||||
optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1
|
||||
):
|
||||
"""Cosine LR scheduler with warmup from huggingface"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return current_step / max(1, warmup_steps)
|
||||
progress = current_step - warmup_steps
|
||||
progress /= max(1, training_steps - warmup_steps)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1):
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
# calculate a warmup ratio
|
||||
return current_step / max(1, warmup_steps)
|
||||
else:
|
||||
# calculate a multistep lr scaling ratio
|
||||
idx = np.searchsorted(milestones, current_step)
|
||||
return gamma ** idx
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
161
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py
Executable file
161
cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py
Executable file
@ -0,0 +1,161 @@
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
|
||||
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
|
||||
element_windows = []
|
||||
|
||||
if len(elements) > window_size:
|
||||
max_step = math.ceil((len(elements) - window_size)/slice_interval)
|
||||
|
||||
for i in range(0, max_step + 1):
|
||||
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
|
||||
if (i*slice_interval+window_size) >= len(elements):
|
||||
_window = copy.deepcopy(elements[i*slice_interval:])
|
||||
else:
|
||||
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
|
||||
element_windows.append(_window)
|
||||
return element_windows
|
||||
else:
|
||||
return [elements]
|
||||
|
||||
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
|
||||
word_windows = []
|
||||
parse_class_windows = []
|
||||
parse_relation_windows = []
|
||||
|
||||
if len(lwords) > window_size:
|
||||
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
|
||||
for i in range(0, max_step+1):
|
||||
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
|
||||
if (i*slice_interval+window_size) >= len(lwords):
|
||||
_word_window = copy.deepcopy(lwords[i*slice_interval:])
|
||||
else:
|
||||
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
|
||||
|
||||
if len(_word_window) < 2:
|
||||
continue
|
||||
|
||||
first_word_id = _word_window[0]['word_id']
|
||||
last_word_id = _word_window[-1]['word_id']
|
||||
|
||||
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
|
||||
# word list
|
||||
for _word in _word_window:
|
||||
_word['word_id'] -= first_word_id
|
||||
|
||||
|
||||
# Entity extraction
|
||||
_class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
|
||||
|
||||
# Entity Linking
|
||||
_relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
|
||||
|
||||
word_windows.append(_word_window)
|
||||
parse_class_windows.append(_class_window)
|
||||
parse_relation_windows.append(_relation_window)
|
||||
|
||||
return word_windows, parse_class_windows, parse_relation_windows
|
||||
else:
|
||||
return [lwords], [parse_class], [parse_relation]
|
||||
|
||||
|
||||
def entity_extraction_by_words(parse_class, first_word_id, last_word_id):
|
||||
_class_window = {k: [] for k in list(parse_class.keys())}
|
||||
for class_name, _parse_class in parse_class.items():
|
||||
for group in _parse_class:
|
||||
tmp = []
|
||||
for idw in group:
|
||||
idw -= first_word_id
|
||||
if 0 <= idw <= (last_word_id - first_word_id):
|
||||
tmp.append(idw)
|
||||
_class_window[class_name].append(tmp)
|
||||
return _class_window
|
||||
|
||||
def entity_linking_by_words(parse_relation, first_word_id, last_word_id):
|
||||
_relation_window = []
|
||||
for pair in parse_relation:
|
||||
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
|
||||
_relation_window.append([idw - first_word_id for idw in pair])
|
||||
return _relation_window
|
||||
|
||||
|
||||
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[0]
|
||||
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
|
||||
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
|
||||
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
|
||||
|
||||
for i in range(1, len(lpatches)):
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[i]
|
||||
|
||||
overlap_gap = copy.deepcopy(loverlaps[i-1])
|
||||
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
|
||||
|
||||
if overlap_gap != 0:
|
||||
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
|
||||
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
|
||||
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||||
|
||||
if average:
|
||||
avg_overlap = (
|
||||
prev_overlap + curr_overlap
|
||||
) / 2.
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens, window], dim=1
|
||||
)
|
||||
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
||||
|
||||
|
||||
|
||||
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[0]
|
||||
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
|
||||
cls_token = lpatches[0][:, :1, ...]
|
||||
sep_token = lpatches[0][:, -1:, ...]
|
||||
|
||||
for i in range(1, len(lpatches)):
|
||||
start_pos = 1
|
||||
end_pos = start_pos + lvalids[i]
|
||||
|
||||
overlap_gap = loverlaps[i-1]
|
||||
window = lpatches[i][:, start_pos:end_pos, ...]
|
||||
|
||||
if overlap_gap != 0:
|
||||
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
|
||||
curr_overlap = window[:, :overlap_gap, ...]
|
||||
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||||
|
||||
if average:
|
||||
avg_overlap = (
|
||||
prev_overlap + curr_overlap
|
||||
) / 2.
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
|
||||
)
|
||||
else:
|
||||
embedding_tokens = torch.cat(
|
||||
[embedding_tokens, window], dim=1
|
||||
)
|
||||
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
||||
|
15
cope2n-ai-fi/common/AnyKey_Value/model/__init__.py
Executable file
15
cope2n-ai-fi/common/AnyKey_Value/model/__init__.py
Executable file
@ -0,0 +1,15 @@
|
||||
|
||||
from model.combined_model import CombinedKVUModel
|
||||
from model.kvu_model import KVUModel
|
||||
from model.document_kvu_model import DocumentKVUModel
|
||||
|
||||
def get_model(cfg):
|
||||
if cfg.stage == 1:
|
||||
model = CombinedKVUModel(cfg=cfg)
|
||||
elif cfg.stage == 2:
|
||||
model = KVUModel(cfg=cfg)
|
||||
elif cfg.stage == 3:
|
||||
model = DocumentKVUModel(cfg=cfg)
|
||||
else:
|
||||
raise Exception('[ERROR] Trainging stage is wrong')
|
||||
return model
|
76
cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py
Executable file
76
cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py
Executable file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from transformers import AutoTokenizer, XLMRobertaModel
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from model.kvu_model import KVUModel
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class CombinedKVUModel(KVUModel):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.finetune_only = cfg.train.finetune_only
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
self._create_head()
|
||||
|
||||
if os.path.exists(self.model_cfg.ckpt_model_file):
|
||||
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
||||
self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer')
|
||||
self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer')
|
||||
self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer')
|
||||
self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key')
|
||||
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
if self.freeze:
|
||||
for name, param in self.named_parameters():
|
||||
if 'backbone' in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'EE':
|
||||
for name, param in self.named_parameters():
|
||||
if 'itc_layer' not in name and 'stc_layer' not in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'EL':
|
||||
for name, param in self.named_parameters():
|
||||
if 'relation_layer' not in name or 'relation_layer_from_key' in name:
|
||||
param.requires_grad = False
|
||||
if self.finetune_only == 'ELK':
|
||||
for name, param in self.named_parameters():
|
||||
if 'relation_layer_from_key' not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
image = batch["image"]
|
||||
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
||||
|
||||
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
||||
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
|
||||
|
||||
last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
|
||||
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
loss = 0.0
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss = self._get_loss(head_outputs, batch)
|
||||
|
||||
return head_outputs, loss
|
||||
|
185
cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py
Executable file
185
cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py
Executable file
@ -0,0 +1,185 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutLMv2Config, LayoutLMv2Model
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from model.kvu_model import KVUModel
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class DocumentKVUModel(KVUModel):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.train_cfg = cfg.train
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
# if 'pth' in self.model_cfg.ckpt_model_file:
|
||||
# self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone)
|
||||
|
||||
self._create_head()
|
||||
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
def _create_head(self):
|
||||
self.backbone_hidden_size = self.backbone_config.hidden_size
|
||||
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||||
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||||
self.n_classes = self.model_cfg.n_classes + 1
|
||||
self.repr_hiddent_size = self.backbone_hidden_size
|
||||
|
||||
# (1) Initial token classification
|
||||
self.itc_layer = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (3) Linking token classification
|
||||
self.relation_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# Classfication Layer for whole document
|
||||
# (1) Initial token classification
|
||||
self.itc_layer_document = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.repr_hiddent_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer_document = RelationExtractor(
|
||||
n_relations=1,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
# (3) Linking token classification
|
||||
self.relation_layer_document = RelationExtractor(
|
||||
n_relations=1,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key_document = RelationExtractor(
|
||||
n_relations=1,
|
||||
backbone_hidden_size=self.repr_hiddent_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
self.itc_layer.apply(self._init_weight)
|
||||
self.stc_layer.apply(self._init_weight)
|
||||
self.relation_layer.apply(self._init_weight)
|
||||
self.relation_layer_from_key.apply(self._init_weight)
|
||||
|
||||
self.itc_layer_document.apply(self._init_weight)
|
||||
self.stc_layer_document.apply(self._init_weight)
|
||||
self.relation_layer_document.apply(self._init_weight)
|
||||
self.relation_layer_from_key_document.apply(self._init_weight)
|
||||
|
||||
|
||||
def _get_backbones(self, config_type):
|
||||
configs = {
|
||||
'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor},
|
||||
}
|
||||
|
||||
self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
if config_type != 'xlm-roberta':
|
||||
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
else:
|
||||
self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False)
|
||||
self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False)
|
||||
self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path)
|
||||
|
||||
|
||||
def forward(self, batches):
|
||||
head_outputs_list = []
|
||||
loss = 0.0
|
||||
for batch in batches["windows"]:
|
||||
image = batch["image"]
|
||||
input_ids = batch["input_ids_layoutxlm"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask = batch["attention_mask_layoutxlm"]
|
||||
|
||||
if self.freeze:
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.model_cfg.backbone == 'layoutxlm':
|
||||
backbone_outputs = self.backbone(
|
||||
image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask
|
||||
)
|
||||
else:
|
||||
backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask)
|
||||
|
||||
last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :]
|
||||
last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0)
|
||||
|
||||
window_repr = last_hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
head_outputs = {"window_repr": window_repr,
|
||||
"itc_outputs": itc_outputs,
|
||||
"stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs,
|
||||
"el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss += self._get_loss(head_outputs, batch)
|
||||
|
||||
head_outputs_list.append(head_outputs)
|
||||
|
||||
batch = batches["documents"]
|
||||
|
||||
document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1)
|
||||
document_repr = document_repr.transpose(0, 1).contiguous()
|
||||
|
||||
itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0)
|
||||
el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0)
|
||||
|
||||
head_outputs = {"itc_outputs": itc_outputs,
|
||||
"stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs,
|
||||
"el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
if any(['labels' in key for key in batch.keys()]):
|
||||
loss += self._get_loss(head_outputs, batch)
|
||||
|
||||
return head_outputs, loss
|
||||
|
248
cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py
Executable file
248
cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py
Executable file
@ -0,0 +1,248 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor
|
||||
from transformers import LayoutXLMTokenizer
|
||||
from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2
|
||||
|
||||
|
||||
from model.relation_extractor import RelationExtractor
|
||||
from utils import load_checkpoint
|
||||
|
||||
|
||||
class KVUModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.device = 'cuda'
|
||||
self.model_cfg = cfg.model
|
||||
self.freeze = cfg.train.freeze
|
||||
self.finetune_only = cfg.train.finetune_only
|
||||
|
||||
# if cfg.stage == 2:
|
||||
# self.freeze = True
|
||||
|
||||
self._get_backbones(self.model_cfg.backbone)
|
||||
self._create_head()
|
||||
|
||||
if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)):
|
||||
self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm')
|
||||
|
||||
self._create_head()
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
if self.freeze:
|
||||
for name, param in self.named_parameters():
|
||||
if 'backbone' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_head(self):
|
||||
self.backbone_hidden_size = 768
|
||||
self.head_hidden_size = self.model_cfg.head_hidden_size
|
||||
self.head_p_dropout = self.model_cfg.head_p_dropout
|
||||
self.n_classes = self.model_cfg.n_classes + 1
|
||||
|
||||
# (1) Initial token classification
|
||||
self.itc_layer = nn.Sequential(
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
||||
nn.Dropout(self.head_p_dropout),
|
||||
nn.Linear(self.backbone_hidden_size, self.n_classes),
|
||||
)
|
||||
# (2) Subsequent token classification
|
||||
self.stc_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (3) Linking token classification
|
||||
self.relation_layer = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
# (4) Linking token classification
|
||||
self.relation_layer_from_key = RelationExtractor(
|
||||
n_relations=1, #1
|
||||
backbone_hidden_size=self.backbone_hidden_size,
|
||||
head_hidden_size=self.head_hidden_size,
|
||||
head_p_dropout=self.head_p_dropout,
|
||||
)
|
||||
|
||||
self.itc_layer.apply(self._init_weight)
|
||||
self.stc_layer.apply(self._init_weight)
|
||||
self.relation_layer.apply(self._init_weight)
|
||||
|
||||
|
||||
def _get_backbones(self, config_type):
|
||||
self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base')
|
||||
self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(module):
|
||||
init_std = 0.02
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.normal_(module.weight, 1.0, init_std)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
|
||||
# def forward(self, inputs):
|
||||
# token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda()
|
||||
# itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
|
||||
# stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
# el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
# el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
|
||||
# head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
# "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key}
|
||||
|
||||
# loss = self._get_loss(head_outputs, inputs)
|
||||
# return head_outputs, loss
|
||||
|
||||
|
||||
# def forward_single_doccument(self, lbatches):
|
||||
def forward(self, lbatches):
|
||||
windows = lbatches['windows']
|
||||
token_embeddings_windows = []
|
||||
lvalids = []
|
||||
loverlaps = []
|
||||
|
||||
for i, batch in enumerate(windows):
|
||||
batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')}
|
||||
image = batch["image"]
|
||||
input_ids_layoutxlm = batch["input_ids_layoutxlm"]
|
||||
bbox = batch["bbox"]
|
||||
attention_mask_layoutxlm = batch["attention_mask_layoutxlm"]
|
||||
|
||||
|
||||
backbone_outputs_layoutxlm = self.backbone_layoutxlm(
|
||||
image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm)
|
||||
|
||||
|
||||
last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :]
|
||||
|
||||
lvalids.append(batch['len_valid_tokens'])
|
||||
loverlaps.append(batch['len_overlap_tokens'])
|
||||
token_embeddings_windows.append(last_hidden_states_layoutxlm)
|
||||
|
||||
|
||||
token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False)
|
||||
# token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True)
|
||||
|
||||
|
||||
token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda()
|
||||
itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous()
|
||||
stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0)
|
||||
el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0)
|
||||
head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs,
|
||||
"el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key,
|
||||
'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()}
|
||||
|
||||
|
||||
|
||||
loss = 0.0
|
||||
if any(['labels' in key for key in lbatches.keys()]):
|
||||
labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')}
|
||||
loss = self._get_loss(head_outputs, labels)
|
||||
|
||||
return head_outputs, loss
|
||||
|
||||
def _get_loss(self, head_outputs, batch):
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
itc_loss = self._get_itc_loss(itc_outputs, batch)
|
||||
stc_loss = self._get_stc_loss(stc_outputs, batch)
|
||||
el_loss = self._get_el_loss(el_outputs, batch)
|
||||
el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True)
|
||||
|
||||
loss = itc_loss + stc_loss + el_loss + el_loss_from_key
|
||||
|
||||
return loss
|
||||
|
||||
def _get_itc_loss(self, itc_outputs, batch):
|
||||
itc_mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1)
|
||||
itc_logits = itc_logits[itc_mask]
|
||||
|
||||
itc_labels = batch["itc_labels"].view(-1)
|
||||
itc_labels = itc_labels[itc_mask]
|
||||
|
||||
itc_loss = self.loss_func(itc_logits, itc_labels)
|
||||
|
||||
return itc_loss
|
||||
|
||||
def _get_stc_loss(self, stc_outputs, batch):
|
||||
inv_attention_mask = 1 - batch["attention_mask_layoutxlm"]
|
||||
|
||||
bsz, max_seq_length = inv_attention_mask.shape
|
||||
device = inv_attention_mask.device
|
||||
|
||||
invalid_token_mask = torch.cat(
|
||||
[inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1
|
||||
).bool()
|
||||
|
||||
stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0)
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool()
|
||||
stc_logits = stc_outputs.view(-1, max_seq_length + 1)
|
||||
stc_logits = stc_logits[stc_mask]
|
||||
|
||||
stc_labels = batch["stc_labels"].view(-1)
|
||||
stc_labels = stc_labels[stc_mask]
|
||||
|
||||
stc_loss = self.loss_func(stc_logits, stc_labels)
|
||||
|
||||
return stc_loss
|
||||
|
||||
def _get_el_loss(self, el_outputs, batch, from_key=False):
|
||||
bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape
|
||||
|
||||
device = batch["attention_mask_layoutxlm"].device
|
||||
|
||||
self_token_mask = (
|
||||
torch.eye(max_seq_length, max_seq_length + 1).to(device).bool()
|
||||
)
|
||||
|
||||
box_first_token_mask = torch.cat(
|
||||
[
|
||||
(batch["are_box_first_tokens"] == False),
|
||||
torch.zeros([bsz, 1], dtype=torch.bool).to(device),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0)
|
||||
el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0)
|
||||
|
||||
mask = batch["are_box_first_tokens"].view(-1)
|
||||
|
||||
logits = el_outputs.view(-1, max_seq_length + 1)
|
||||
logits = logits[mask]
|
||||
|
||||
if from_key:
|
||||
el_labels = batch["el_labels_from_key"]
|
||||
else:
|
||||
el_labels = batch["el_labels"]
|
||||
labels = el_labels.view(-1)
|
||||
labels = labels[mask]
|
||||
|
||||
loss = self.loss_func(logits, labels)
|
||||
return loss
|
48
cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py
Executable file
48
cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py
Executable file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RelationExtractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_relations,
|
||||
backbone_hidden_size,
|
||||
head_hidden_size,
|
||||
head_p_dropout=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_relations = n_relations
|
||||
self.backbone_hidden_size = backbone_hidden_size
|
||||
self.head_hidden_size = head_hidden_size
|
||||
self.head_p_dropout = head_p_dropout
|
||||
|
||||
self.drop = nn.Dropout(head_p_dropout)
|
||||
self.q_net = nn.Linear(
|
||||
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
||||
)
|
||||
|
||||
self.k_net = nn.Linear(
|
||||
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
||||
)
|
||||
|
||||
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
|
||||
nn.init.normal_(self.dummy_node)
|
||||
|
||||
def forward(self, h_q, h_k):
|
||||
h_q = self.q_net(self.drop(h_q))
|
||||
|
||||
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
|
||||
h_k = torch.cat([h_k, dummy_vec], axis=0)
|
||||
h_k = self.k_net(self.drop(h_k))
|
||||
|
||||
head_q = h_q.view(
|
||||
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
|
||||
)
|
||||
head_k = h_k.view(
|
||||
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
|
||||
)
|
||||
|
||||
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
|
||||
|
||||
return relation_score
|
228
cope2n-ai-fi/common/AnyKey_Value/predictor.py
Executable file
228
cope2n-ai-fi/common/AnyKey_Value/predictor.py
Executable file
@ -0,0 +1,228 @@
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
# from functions import get_colormap, visualize
|
||||
import sys
|
||||
sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ??????
|
||||
|
||||
from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations
|
||||
from model import get_model
|
||||
from utils import load_model_weight
|
||||
|
||||
|
||||
class KVUPredictor:
|
||||
def __init__(self, configs, class_names, dummy_idx, mode=0):
|
||||
cfg_path = configs['cfg']
|
||||
ckpt_path = configs['ckpt']
|
||||
|
||||
self.class_names = class_names
|
||||
self.dummy_idx = dummy_idx
|
||||
self.mode = mode
|
||||
|
||||
print('[INFO] Loading Key-Value Understanding model ...')
|
||||
self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path)
|
||||
print("[INFO] Loaded model")
|
||||
|
||||
if mode == 3:
|
||||
self.max_window_count = cfg.train.max_window_count
|
||||
self.window_size = cfg.train.window_size
|
||||
self.slice_interval = 0
|
||||
self.dummy_idx = dummy_idx * self.max_window_count
|
||||
else:
|
||||
self.slice_interval = cfg.train.slice_interval
|
||||
self.window_size = cfg.train.max_num_words
|
||||
|
||||
|
||||
self.device = 'cuda'
|
||||
|
||||
def _load_model(self, cfg_path, ckpt_path):
|
||||
cfg = OmegaConf.load(cfg_path)
|
||||
cfg.stage = self.mode
|
||||
backbone_type = cfg.model.backbone
|
||||
|
||||
print('[INFO] Checkpoint:', ckpt_path)
|
||||
net = get_model(cfg)
|
||||
load_model_weight(net, ckpt_path)
|
||||
net.to('cuda')
|
||||
net.eval()
|
||||
return net, cfg, backbone_type
|
||||
|
||||
def predict(self, input_sample):
|
||||
if self.mode == 0:
|
||||
if len(input_sample['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
elif self.mode == 1:
|
||||
if len(input_sample['documents']['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
elif self.mode == 2:
|
||||
if len(input_sample['windows'][0]['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = [], [], [], []
|
||||
for window in input_sample['windows']:
|
||||
_bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window)
|
||||
bbox.append(_bbox)
|
||||
lwords.append(_lwords)
|
||||
pr_class_words.append(_pr_class_words)
|
||||
pr_relations.append(_pr_relations)
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
elif self.mode == 3:
|
||||
if len(input_sample["documents"]['words']) == 0:
|
||||
return [], [], [], []
|
||||
bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample)
|
||||
return [bbox], [lwords], [pr_class_words], [pr_relations]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported mode: {self.mode}"
|
||||
)
|
||||
|
||||
def doc_predict(self, input_sample):
|
||||
lwords = input_sample['documents']['words']
|
||||
for idx, window in enumerate(input_sample['windows']):
|
||||
input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')}
|
||||
|
||||
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
||||
input_sample = input_sample['documents']
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
||||
bbox = input_sample['bbox'].squeeze(0)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
|
||||
def combined_predict(self, input_sample):
|
||||
lwords = input_sample['words']
|
||||
input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')}
|
||||
|
||||
input_sample = {k: v.to(self.device) for k, v in input_sample.items()}
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()}
|
||||
input_sample = {k: v.detach().cpu() for k, v in input_sample.items()}
|
||||
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0)
|
||||
bbox = input_sample['bbox'].squeeze(0)
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, self.dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
def cat_predict(self, input_sample):
|
||||
lwords = input_sample['documents']['words']
|
||||
|
||||
inputs = []
|
||||
for window in input_sample['windows']:
|
||||
inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')})
|
||||
input_sample['windows'] = inputs
|
||||
|
||||
with torch.no_grad():
|
||||
head_outputs, _ = self.net(input_sample)
|
||||
|
||||
head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')}
|
||||
|
||||
itc_outputs = head_outputs["itc_outputs"]
|
||||
stc_outputs = head_outputs["stc_outputs"]
|
||||
el_outputs = head_outputs["el_outputs"]
|
||||
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
||||
|
||||
pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0)
|
||||
pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0)
|
||||
pr_el_label = torch.argmax(el_outputs, -1).squeeze(0)
|
||||
pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0)
|
||||
|
||||
box_first_token_mask = input_sample['documents']['are_box_first_tokens']
|
||||
attention_mask = input_sample['documents']['attention_mask_layoutxlm']
|
||||
bbox = input_sample['documents']['bbox']
|
||||
|
||||
dummy_idx = input_sample['documents']['bbox'].shape[0]
|
||||
|
||||
|
||||
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names)
|
||||
pr_class_words = parse_subsequent_words(
|
||||
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
||||
)
|
||||
|
||||
pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
||||
pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx)
|
||||
pr_relations = pr_relations_from_header | pr_relations_from_key
|
||||
|
||||
return bbox, lwords, pr_class_words, pr_relations
|
||||
|
||||
|
||||
def get_ground_truth_label(self, ground_truth):
|
||||
# ground_truth = self.preprocessor.load_ground_truth(json_file)
|
||||
gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512]
|
||||
gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512]
|
||||
gt_el_label = ground_truth['el_labels'].squeeze(0)
|
||||
|
||||
gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0)
|
||||
lwords = ground_truth["words"]
|
||||
|
||||
box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0)
|
||||
attention_mask = ground_truth['attention_mask'].squeeze(0)
|
||||
|
||||
bbox = ground_truth['bbox'].squeeze(0)
|
||||
gt_first_words = parse_initial_words(
|
||||
gt_itc_label, box_first_token_mask, self.class_names
|
||||
)
|
||||
gt_class_words = parse_subsequent_words(
|
||||
gt_stc_label, attention_mask, gt_first_words, self.dummy_idx
|
||||
)
|
||||
|
||||
gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx)
|
||||
gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx)
|
||||
gt_relations = gt_relations_from_header | gt_relations_from_key
|
||||
|
||||
return bbox, lwords, gt_class_words, gt_relations
|
456
cope2n-ai-fi/common/AnyKey_Value/preprocess.py
Executable file
456
cope2n-ai-fi/common/AnyKey_Value/preprocess.py
Executable file
@ -0,0 +1,456 @@
|
||||
import os
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import imagesize
|
||||
import itertools
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr
|
||||
from utils.run_ocr import load_ocr_engine, process_img
|
||||
from lightning_modules.utils import sliding_windows
|
||||
|
||||
|
||||
class KVUProcess:
|
||||
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
||||
self.tokenizer_layoutxlm = tokenizer_layoutxlm
|
||||
self.feature_extractor = feature_extractor
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.backbone_type = backbone_type
|
||||
self.class_names = class_names
|
||||
|
||||
self.slice_interval = slice_interval
|
||||
self.window_size = window_size
|
||||
self.run_ocr = run_ocr
|
||||
self.mode = mode
|
||||
|
||||
self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token)
|
||||
self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token)
|
||||
self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token)
|
||||
self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token)
|
||||
|
||||
|
||||
self.class_idx_dic = dict(
|
||||
[(class_name, idx) for idx, class_name in enumerate(self.class_names)]
|
||||
)
|
||||
self.ocr_engine = None
|
||||
if self.run_ocr == 1:
|
||||
self.ocr_engine = load_ocr_engine()
|
||||
|
||||
def __call__(self, img_path: str, ocr_path: str) -> list:
|
||||
if (self.run_ocr == 1) or (not os.path.exists(ocr_path)):
|
||||
ocr_path = "tmp.txt"
|
||||
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
|
||||
|
||||
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
||||
lwords = post_process_basic_ocr(lwords)
|
||||
bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval)
|
||||
word_windows = sliding_windows(lwords, self.window_size, self.slice_interval)
|
||||
assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}"
|
||||
|
||||
width, height = imagesize.get(img_path)
|
||||
images = [Image.open(img_path).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
|
||||
|
||||
if self.mode == 0:
|
||||
output = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
elif self.mode == 1:
|
||||
output = {}
|
||||
windows = []
|
||||
for i in range(len(bbox_windows)):
|
||||
_words = word_windows[i]
|
||||
_bboxes = bbox_windows[i]
|
||||
windows.append(
|
||||
self.preprocess(
|
||||
_bboxes, _words,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
)
|
||||
|
||||
output['windows'] = windows
|
||||
elif self.mode == 2:
|
||||
output = {}
|
||||
windows = []
|
||||
output['doduments'] = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=2048)
|
||||
for i in range(len(bbox_windows)):
|
||||
_words = word_windows[i]
|
||||
_bboxes = bbox_windows[i]
|
||||
windows.append(
|
||||
self.preprocess(
|
||||
_bboxes, _words,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
max_seq_length=self.max_seq_length)
|
||||
)
|
||||
|
||||
output['windows'] = windows
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Not supported mode: {self.mode }"
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
|
||||
list_word_objects = []
|
||||
for bb, text in zip(bounding_boxes, words):
|
||||
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
|
||||
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
|
||||
list_word_objects.append({
|
||||
"layoutxlm_tokens": tokens,
|
||||
"boundingBox": boundingBox,
|
||||
"text": text
|
||||
})
|
||||
|
||||
(
|
||||
bbox,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
are_box_first_tokens,
|
||||
box_to_token_indices,
|
||||
box2token_span_map,
|
||||
lwords,
|
||||
len_valid_tokens,
|
||||
len_non_overlap_tokens,
|
||||
len_list_tokens
|
||||
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
|
||||
|
||||
assert len_list_tokens == len_valid_tokens + 2
|
||||
len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens
|
||||
|
||||
ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2
|
||||
|
||||
input_ids = input_ids[:ntokens]
|
||||
attention_mask = attention_mask[:ntokens]
|
||||
bbox = bbox[:ntokens]
|
||||
are_box_first_tokens = are_box_first_tokens[:ntokens]
|
||||
|
||||
|
||||
input_ids = torch.from_numpy(input_ids)
|
||||
attention_mask = torch.from_numpy(attention_mask)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
|
||||
len_valid_tokens = torch.tensor(len_valid_tokens)
|
||||
len_overlap_tokens = torch.tensor(len_overlap_tokens)
|
||||
return_dict = {
|
||||
"img_path": feature_maps['img_path'],
|
||||
"words": lwords,
|
||||
"len_overlap_tokens": len_overlap_tokens,
|
||||
'len_valid_tokens': len_valid_tokens,
|
||||
"image": feature_maps['image'],
|
||||
"input_ids_layoutxlm": input_ids,
|
||||
"attention_mask_layoutxlm": attention_mask,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"bbox": bbox,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
||||
def parser_words(self, words, max_seq_length, width, height):
|
||||
list_bbs = []
|
||||
list_words = []
|
||||
list_tokens = []
|
||||
cls_bbs = [0.0] * 8
|
||||
box2token_span_map = []
|
||||
box_to_token_indices = []
|
||||
lwords = [''] * max_seq_length
|
||||
|
||||
cum_token_idx = 0
|
||||
len_valid_tokens = 0
|
||||
len_non_overlap_tokens = 0
|
||||
|
||||
|
||||
input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm
|
||||
bbox = np.zeros((max_seq_length, 8), dtype=np.float32)
|
||||
attention_mask = np.zeros(max_seq_length, dtype=int)
|
||||
are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_)
|
||||
|
||||
for word_idx, word in enumerate(words):
|
||||
this_box_token_indices = []
|
||||
|
||||
tokens = word["layoutxlm_tokens"]
|
||||
bb = word["boundingBox"]
|
||||
text = word["text"]
|
||||
|
||||
len_valid_tokens += len(tokens)
|
||||
if word_idx < self.slice_interval:
|
||||
len_non_overlap_tokens += len(tokens)
|
||||
|
||||
if len(tokens) == 0:
|
||||
tokens.append(self.unk_token_id)
|
||||
|
||||
if len(list_tokens) + len(tokens) > max_seq_length - 2:
|
||||
break
|
||||
|
||||
box2token_span_map.append(
|
||||
[len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1]
|
||||
) # including st_idx
|
||||
list_tokens += tokens
|
||||
|
||||
# min, max clipping
|
||||
for coord_idx in range(4):
|
||||
bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width))
|
||||
bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height))
|
||||
|
||||
bb = list(itertools.chain(*bb))
|
||||
bbs = [bb for _ in range(len(tokens))]
|
||||
texts = [text for _ in range(len(tokens))]
|
||||
|
||||
for _ in tokens:
|
||||
cum_token_idx += 1
|
||||
this_box_token_indices.append(cum_token_idx)
|
||||
|
||||
list_bbs.extend(bbs)
|
||||
list_words.extend(texts) ####
|
||||
box_to_token_indices.append(this_box_token_indices)
|
||||
|
||||
sep_bbs = [width, height] * 4
|
||||
|
||||
# For [CLS] and [SEP]
|
||||
list_tokens = (
|
||||
[self.cls_token_id_layoutxlm]
|
||||
+ list_tokens[: max_seq_length - 2]
|
||||
+ [self.sep_token_id_layoutxlm]
|
||||
)
|
||||
if len(list_bbs) == 0:
|
||||
# When len(json_obj["words"]) == 0 (no OCR result)
|
||||
list_bbs = [cls_bbs] + [sep_bbs]
|
||||
else: # len(list_bbs) > 0
|
||||
list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs]
|
||||
# list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ###
|
||||
# if len(list_words) < 510:
|
||||
# list_words.extend(['</p>' for _ in range(510 - len(list_words))])
|
||||
list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token]
|
||||
|
||||
|
||||
len_list_tokens = len(list_tokens)
|
||||
input_ids[:len_list_tokens] = list_tokens
|
||||
attention_mask[:len_list_tokens] = 1
|
||||
|
||||
bbox[:len_list_tokens, :] = list_bbs
|
||||
lwords[:len_list_tokens] = list_words
|
||||
|
||||
# Normalize bbox -> 0 ~ 1
|
||||
bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width
|
||||
bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height
|
||||
|
||||
if self.backbone_type in ("layoutlm", "layoutxlm"):
|
||||
bbox = bbox[:, [0, 1, 4, 5]]
|
||||
bbox = bbox * 1000
|
||||
bbox = bbox.astype(int)
|
||||
else:
|
||||
assert False
|
||||
|
||||
st_indices = [
|
||||
indices[0]
|
||||
for indices in box_to_token_indices
|
||||
if indices[0] < max_seq_length
|
||||
]
|
||||
are_box_first_tokens[st_indices] = True
|
||||
|
||||
return (
|
||||
bbox,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
are_box_first_tokens,
|
||||
box_to_token_indices,
|
||||
box2token_span_map,
|
||||
lwords,
|
||||
len_valid_tokens,
|
||||
len_non_overlap_tokens,
|
||||
len_list_tokens
|
||||
)
|
||||
|
||||
|
||||
def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length):
|
||||
itc_labels = np.zeros(max_seq_length, dtype=int)
|
||||
stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length
|
||||
|
||||
classes_dic = parse_class
|
||||
for class_name in self.class_names:
|
||||
if class_name == "others":
|
||||
continue
|
||||
if class_name not in classes_dic:
|
||||
continue
|
||||
|
||||
for word_list in classes_dic[class_name]:
|
||||
is_first, last_word_idx = True, -1
|
||||
for word_idx in word_list:
|
||||
if word_idx >= len(box_to_token_indices):
|
||||
break
|
||||
box2token_list = box_to_token_indices[word_idx]
|
||||
for converted_word_idx in box2token_list:
|
||||
if converted_word_idx >= max_seq_length:
|
||||
break # out of idx
|
||||
|
||||
if is_first:
|
||||
itc_labels[converted_word_idx] = self.class_idx_dic[
|
||||
class_name
|
||||
]
|
||||
is_first, last_word_idx = False, converted_word_idx
|
||||
else:
|
||||
stc_labels[converted_word_idx] = last_word_idx
|
||||
last_word_idx = converted_word_idx
|
||||
|
||||
return itc_labels, stc_labels
|
||||
|
||||
|
||||
def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length):
|
||||
el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length
|
||||
el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length
|
||||
|
||||
|
||||
relations = parse_relation
|
||||
for relation in relations:
|
||||
if relation[0] >= len(box2token_span_map) or relation[1] >= len(
|
||||
box2token_span_map
|
||||
):
|
||||
continue
|
||||
if (
|
||||
box2token_span_map[relation[0]][0] >= max_seq_length
|
||||
or box2token_span_map[relation[1]][0] >= max_seq_length
|
||||
):
|
||||
continue
|
||||
|
||||
word_from = box2token_span_map[relation[0]][0]
|
||||
word_to = box2token_span_map[relation[1]][0]
|
||||
# el_labels[word_to] = word_from
|
||||
|
||||
if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512:
|
||||
continue
|
||||
|
||||
if itc_labels[word_from] == 2 and itc_labels[word_to] == 3:
|
||||
el_labels_from_key[word_to] = word_from # pair of (key-value)
|
||||
if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)):
|
||||
el_labels[word_to] = word_from # pair of (header, key) or (header-value)
|
||||
return el_labels, el_labels_from_key
|
||||
|
||||
|
||||
class DocumentKVUProcess(KVUProcess):
|
||||
def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0):
|
||||
super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode)
|
||||
self.max_window_count = max_window_count
|
||||
self.pad_token_id = self.pad_token_id_layoutxlm
|
||||
self.cls_token_id = self.cls_token_id_layoutxlm
|
||||
self.sep_token_id = self.sep_token_id_layoutxlm
|
||||
self.unk_token_id = self.unk_token_id_layoutxlm
|
||||
self.tokenizer = self.tokenizer_layoutxlm
|
||||
|
||||
def __call__(self, img_path: str, ocr_path: str) -> list:
|
||||
if (self.run_ocr == 1) and (not os.path.exists(ocr_path)):
|
||||
ocr_path = "tmp.txt"
|
||||
process_img(img_path, ocr_path, self.ocr_engine, export_img=False)
|
||||
|
||||
lbboxes, lwords = read_ocr_result_from_txt(ocr_path)
|
||||
lwords = post_process_basic_ocr(lwords)
|
||||
|
||||
width, height = imagesize.get(img_path)
|
||||
images = [Image.open(img_path).convert("RGB")]
|
||||
image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy())
|
||||
output = self.preprocess(lbboxes, lwords,
|
||||
{'image': image_features, 'width': width, 'height': height, 'img_path': img_path},
|
||||
self.max_seq_length)
|
||||
return output
|
||||
|
||||
def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length):
|
||||
n_words = len(words)
|
||||
output_dicts = {'windows': [], 'documents': []}
|
||||
n_empty_windows = 0
|
||||
|
||||
for i in range(self.max_window_count):
|
||||
input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
|
||||
bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32)
|
||||
attention_mask = np.zeros(self.max_seq_length, dtype=int)
|
||||
are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
|
||||
|
||||
if n_words == 0:
|
||||
n_empty_windows += 1
|
||||
output_dicts['windows'].append({
|
||||
"image": feature_maps['image'],
|
||||
"input_ids_layoutxlm": torch.from_numpy(input_ids),
|
||||
"bbox": torch.from_numpy(bbox),
|
||||
"words": [],
|
||||
"attention_mask_layoutxlm": torch.from_numpy(attention_mask),
|
||||
"are_box_first_tokens": torch.from_numpy(are_box_first_tokens),
|
||||
})
|
||||
continue
|
||||
|
||||
start_word_idx = i * self.window_size
|
||||
stop_word_idx = min(n_words, (i+1)*self.window_size)
|
||||
|
||||
if start_word_idx >= stop_word_idx:
|
||||
n_empty_windows += 1
|
||||
output_dicts['windows'].append(output_dicts['windows'][-1])
|
||||
continue
|
||||
|
||||
list_word_objects = []
|
||||
for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]):
|
||||
boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
|
||||
tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text))
|
||||
list_word_objects.append({
|
||||
"layoutxlm_tokens": tokens,
|
||||
"boundingBox": boundingBox,
|
||||
"text": text
|
||||
})
|
||||
|
||||
(
|
||||
bbox,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
are_box_first_tokens,
|
||||
box_to_token_indices,
|
||||
box2token_span_map,
|
||||
lwords,
|
||||
len_valid_tokens,
|
||||
len_non_overlap_tokens,
|
||||
len_list_layoutxlm_tokens
|
||||
) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"])
|
||||
|
||||
|
||||
input_ids = torch.from_numpy(input_ids)
|
||||
bbox = torch.from_numpy(bbox)
|
||||
attention_mask = torch.from_numpy(attention_mask)
|
||||
are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
|
||||
|
||||
return_dict = {
|
||||
"image": feature_maps['image'],
|
||||
"input_ids_layoutxlm": input_ids,
|
||||
"bbox": bbox,
|
||||
"words": lwords,
|
||||
"attention_mask_layoutxlm": attention_mask,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
}
|
||||
output_dicts["windows"].append(return_dict)
|
||||
|
||||
attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]])
|
||||
are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]])
|
||||
if n_empty_windows > 0:
|
||||
attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int))
|
||||
are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_))
|
||||
bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]])
|
||||
words = []
|
||||
for o in output_dicts['windows']:
|
||||
words.extend(o['words'])
|
||||
|
||||
return_dict = {
|
||||
"attention_mask_layoutxlm": attention_mask,
|
||||
"bbox": bbox,
|
||||
"are_box_first_tokens": are_box_first_tokens,
|
||||
"n_empty_windows": n_empty_windows,
|
||||
"words": words
|
||||
}
|
||||
output_dicts['documents'] = return_dict
|
||||
|
||||
return output_dicts
|
||||
|
||||
|
||||
|
30
cope2n-ai-fi/common/AnyKey_Value/requirements.txt
Executable file
30
cope2n-ai-fi/common/AnyKey_Value/requirements.txt
Executable file
@ -0,0 +1,30 @@
|
||||
nptyping==1.4.2
|
||||
numpy==1.20.3
|
||||
opencv-python-headless==4.5.4.60
|
||||
pytorch-lightning==1.5.6
|
||||
omegaconf
|
||||
# pillow
|
||||
six
|
||||
overrides==4.1.2
|
||||
# transformers==4.11.3
|
||||
seqeval==0.0.12
|
||||
imagesize
|
||||
pandas==2.0.1
|
||||
xmltodict
|
||||
dicttoxml
|
||||
|
||||
tensorboard>=2.2.0
|
||||
|
||||
# code-style
|
||||
isort==5.9.3
|
||||
black==21.9b0
|
||||
|
||||
# # pytorch
|
||||
# --find-links https://download.pytorch.org/whl/torch_stable.html
|
||||
# torch==1.9.1+cu102
|
||||
# torchvision==0.10.1+cu102
|
||||
|
||||
# pytorch
|
||||
# --find-links https://download.pytorch.org/whl/torch_stable.html
|
||||
# torch==1.10.0+cu113
|
||||
# torchvision==0.11.1+cu113
|
1
cope2n-ai-fi/common/AnyKey_Value/run.sh
Executable file
1
cope2n-ai-fi/common/AnyKey_Value/run.sh
Executable file
@ -0,0 +1 @@
|
||||
python anyKeyValue.py --img_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0
|
393
cope2n-ai-fi/common/AnyKey_Value/tmp.txt
Executable file
393
cope2n-ai-fi/common/AnyKey_Value/tmp.txt
Executable file
@ -0,0 +1,393 @@
|
||||
550 14 644 53 HÓA
|
||||
654 20 745 53 ĐƠN
|
||||
755 13 833 53 GIÁ
|
||||
842 20 916 60 TRỊ
|
||||
925 22 1001 53 GIA
|
||||
1012 14 1130 54 TĂNG
|
||||
208 63 363 87 CHUNGHO
|
||||
752 76 815 101 (VAT
|
||||
821 76 940 101 INVOICE)
|
||||
1193 80 1230 109 Ký
|
||||
1235 80 1286 109 hiệu
|
||||
1293 83 1390 108 (Serial):
|
||||
1398 81 1519 105 1C23TCV
|
||||
135 130 330 171 ChungHo
|
||||
339 131 436 165 Vina
|
||||
600 150 662 178 (BẢN
|
||||
667 152 719 177 THỂ
|
||||
678 119 743 147 Ngày
|
||||
724 152 788 177 HIÊN
|
||||
749 119 779 142 01
|
||||
786 119 854 147 tháng
|
||||
794 152 847 176 CỦA
|
||||
853 151 908 176 HÓA
|
||||
860 119 891 142 03
|
||||
897 121 949 142 năm
|
||||
915 156 973 176 ĐƠN
|
||||
956 120 1014 142 2023
|
||||
979 154 1043 179 ĐIỆN
|
||||
1049 151 1093 181 TỬ)
|
||||
1193 118 1228 148 Số
|
||||
1234 125 1303 151 (No.):
|
||||
1309 116 1372 148 829
|
||||
187 179 270 197 HEALTH
|
||||
276 179 383 197 SOLUTION
|
||||
664 187 795 210 (EINVOICE
|
||||
802 188 908 209 DISPLAY
|
||||
915 188 1032 213 VERSION)
|
||||
465 240 507 264 Mã
|
||||
512 241 554 264 của
|
||||
560 241 626 266 CQT:
|
||||
655 239 1186 266 PURIBUZANAMOIC99C9CTROINS
|
||||
90 286 141 308 Đơn
|
||||
147 285 172 312 vị
|
||||
177 285 221 308 bán
|
||||
226 285 285 313 hàng
|
||||
293 286 387 312 (Seller):
|
||||
395 279 495 307 CÔNG
|
||||
504 280 549 307 TY
|
||||
558 280 657 307 TNHH
|
||||
664 279 844 307 CHUNGHO
|
||||
852 279 937 307 VINA
|
||||
945 279 1091 309 HEALTH
|
||||
1100 280 1274 309 SOLUTION
|
||||
90 326 130 348 Mã
|
||||
135 322 165 349 số
|
||||
170 322 223 349 thuế
|
||||
229 327 281 351 (Tax
|
||||
288 327 358 351 code):
|
||||
363 324 383 347 0
|
||||
390 324 406 348 1
|
||||
411 321 605 350 08118215
|
||||
89 366 132 392 Địa
|
||||
138 366 176 389 chỉ
|
||||
182 368 298 393 (Address):
|
||||
303 362 334 389 Số
|
||||
340 367 373 392 11,
|
||||
378 368 401 389 lô
|
||||
407 368 477 392 N04A,
|
||||
484 367 526 389 khu
|
||||
530 367 560 389 đô
|
||||
564 367 595 392 thị
|
||||
599 367 643 390 mới
|
||||
648 367 702 392 Dịch
|
||||
707 368 773 393 Vọng,
|
||||
779 367 863 394 phường
|
||||
867 367 922 392 Dịch
|
||||
927 368 993 394 Vọng,
|
||||
999 368 1053 394 quận
|
||||
1057 363 1103 390 Cầu
|
||||
1108 364 1167 394 Giấy,
|
||||
1173 368 1233 390 thành
|
||||
1238 364 1281 394 phố
|
||||
1285 367 1318 390 Hà
|
||||
1323 368 1371 393 Nội,
|
||||
1377 367 1424 393 Việt
|
||||
1430 368 1483 390 Nam
|
||||
89 406 146 431 Điện
|
||||
152 407 211 432 thoại
|
||||
218 407 279 431 (Tel):
|
||||
286 407 331 428 024
|
||||
339 406 397 428 7300
|
||||
404 406 460 428 0891
|
||||
826 406 878 430 Fax:
|
||||
89 440 121 467 Số
|
||||
127 444 158 467 tài
|
||||
165 445 236 466 khoản
|
||||
245 447 348 470 (Account
|
||||
355 446 413 472 No.):
|
||||
421 445 616 469 700-010-446490
|
||||
622 445 653 471 tại
|
||||
659 445 722 471 Ngân
|
||||
728 445 792 472 Hàng
|
||||
799 444 893 468 Shinhan
|
||||
898 449 911 467 -
|
||||
916 444 960 468 Chi
|
||||
965 445 1036 468 nhánh
|
||||
1043 445 1078 469 Hà
|
||||
1084 444 1129 472 Nội
|
||||
89 486 126 513 Họ
|
||||
132 486 169 510 tên
|
||||
174 486 244 513 người
|
||||
251 493 302 509 mua
|
||||
308 486 367 514 hàng
|
||||
374 488 469 514 (Buyer):
|
||||
90 529 136 550 Tên
|
||||
142 529 188 551 đơn
|
||||
193 529 219 555 vị
|
||||
226 530 342 556 (Company
|
||||
349 531 428 555 name):
|
||||
435 524 518 552 CÔNG
|
||||
525 529 562 552 TY
|
||||
570 530 650 551 TNHH
|
||||
657 529 796 552 SAMSUNG
|
||||
803 529 856 552 SDS
|
||||
862 527 930 556 VIỆT
|
||||
937 529 1004 552 NAM
|
||||
89 570 129 592 Mã
|
||||
136 564 165 593 số
|
||||
170 565 223 592 thuế
|
||||
229 571 281 595 (Tax
|
||||
287 571 358 596 code):
|
||||
365 569 510 592 2300680991
|
||||
88 611 133 638 Địa
|
||||
138 612 175 634 chỉ
|
||||
182 613 297 638 (Address):
|
||||
303 611 339 633 Lô
|
||||
346 612 424 636 CN05,
|
||||
432 611 514 639 đường
|
||||
521 612 583 638 YP6,
|
||||
589 612 645 634 Khu
|
||||
652 611 713 640 công
|
||||
719 612 803 641 nghiệp
|
||||
810 611 862 635 Yên
|
||||
869 612 956 640 Phong,
|
||||
962 612 1000 635 Xã
|
||||
1006 612 1057 635 Yên
|
||||
1064 612 1152 640 Trung,
|
||||
1159 611 1242 640 Huyện
|
||||
1249 611 1300 635 Yên
|
||||
1307 612 1394 640 Phong,
|
||||
1402 612 1463 635 Tỉnh
|
||||
1470 605 1518 636 Bắc
|
||||
89 654 158 678 Ninh,
|
||||
165 654 219 679 Việt
|
||||
225 654 286 676 Nam
|
||||
89 681 122 705 Số
|
||||
127 684 157 705 tài
|
||||
164 682 236 705 khoản
|
||||
244 685 348 709 (Account
|
||||
354 684 413 709 No.):
|
||||
90 724 149 747 Hình
|
||||
157 724 208 747 thức
|
||||
216 724 282 746 thanh
|
||||
289 724 341 747 toán
|
||||
349 726 457 749 (Payment
|
||||
464 726 563 748 method):
|
||||
572 725 663 747 TM/CK
|
||||
98 789 148 812 STT
|
||||
163 770 212 793 Tên
|
||||
218 770 281 797 hàng
|
||||
287 770 341 796 hóa,
|
||||
348 769 405 796 dịch
|
||||
436 770 491 792 Đơn
|
||||
498 770 525 798 vị
|
||||
569 783 603 811 Số
|
||||
610 789 684 817 lượng
|
||||
747 789 802 811 Đơn
|
||||
808 788 850 818 giá
|
||||
979 788 1063 812 Thành
|
||||
1070 782 1119 812 tiền
|
||||
1214 765 1282 793 Thuế
|
||||
1287 764 1344 793 suất
|
||||
1393 764 1452 792 Tiền
|
||||
1459 765 1518 792 thuế
|
||||
94 825 153 850 (No.)
|
||||
206 844 356 870 (Description)
|
||||
266 813 299 836 vụ
|
||||
448 845 515 868 (Unit)
|
||||
454 808 507 830 tính
|
||||
568 825 685 852 (Quantity)
|
||||
733 826 793 848 (Unit
|
||||
798 826 865 851 price)
|
||||
993 824 1103 851 (Amount)
|
||||
1223 807 1287 830 (VAT
|
||||
1262 843 1295 869 %)
|
||||
1290 809 1335 829 rate
|
||||
1378 843 1542 869 (VAT Amount)
|
||||
1417 807 1491 830 GTGT
|
||||
116 891 128 913 1
|
||||
273 890 290 913 2
|
||||
472 890 488 912 3
|
||||
617 889 635 912 4
|
||||
790 889 807 914 5
|
||||
992 889 1103 916 6=4x5
|
||||
1270 889 1287 913 7
|
||||
1399 889 1510 914 8=6X7
|
||||
158 939 200 961 Phí
|
||||
207 939 259 961 thuê
|
||||
266 939 316 966 máy
|
||||
323 938 361 965 lọc
|
||||
159 977 219 998 nước
|
||||
225 977 285 1002 nóng
|
||||
292 976 345 1001 lạnh
|
||||
161 1014 236 1040 Digital
|
||||
244 1014 307 1035 CHP-
|
||||
114 1032 127 1055 1
|
||||
159 1052 267 1073 3800ST1
|
||||
276 1052 312 1076 (từ
|
||||
318 1051 377 1078 ngày
|
||||
453 1032 507 1060 Máy
|
||||
678 1035 697 1059 4
|
||||
795 1031 893 1057 800.000
|
||||
1074 1031 1195 1057 3,200,000
|
||||
1297 1031 1353 1057 10%
|
||||
1452 1031 1551 1058 320,000
|
||||
158 1089 292 1111 01/02/2023
|
||||
301 1083 343 1110 đến
|
||||
351 1084 388 1110 hết
|
||||
159 1125 304 1151 28/02/2023)
|
||||
159 1173 200 1195 Phí
|
||||
207 1173 258 1195 thuê
|
||||
265 1173 316 1200 máy
|
||||
323 1173 361 1199 lọc
|
||||
158 1213 218 1234 nước
|
||||
225 1212 285 1238 nóng
|
||||
292 1211 344 1237 lạnh
|
||||
161 1249 236 1274 Digital
|
||||
243 1248 307 1270 CHP-
|
||||
112 1267 129 1291 2
|
||||
160 1287 267 1309 3800ST1
|
||||
276 1287 312 1312 (từ
|
||||
318 1287 377 1313 ngày
|
||||
454 1267 506 1293 Máy
|
||||
664 1266 697 1292 20
|
||||
793 1265 896 1294 876,800
|
||||
1062 1265 1198 1297 17,536,000
|
||||
1430 1266 1552 1295 1,753,600
|
||||
159 1323 292 1346 29/01/2023
|
||||
301 1319 343 1345 đến
|
||||
351 1319 388 1345 hết
|
||||
160 1360 304 1387 28/02/2023)
|
||||
160 1409 200 1431 Phí
|
||||
207 1409 258 1431 thuê
|
||||
265 1409 316 1436 máy
|
||||
323 1408 361 1435 lọc
|
||||
158 1447 218 1468 nước
|
||||
225 1446 285 1473 nóng
|
||||
292 1446 344 1472 lạnh
|
||||
113 1503 128 1527 3
|
||||
160 1484 236 1511 Digital
|
||||
243 1484 307 1505 CHP-
|
||||
452 1503 506 1531 Máy
|
||||
795 1502 897 1529 544,000
|
||||
1074 1502 1197 1529 2,176,000
|
||||
1450 1502 1550 1528 217,600
|
||||
160 1522 267 1544 3800ST1
|
||||
276 1522 312 1546 (từ
|
||||
318 1522 377 1549 ngày
|
||||
162 1559 292 1581 10/02/2023
|
||||
301 1555 343 1581 đến
|
||||
351 1555 388 1581 hết
|
||||
159 1596 304 1623 28/02/2023)
|
||||
160 1645 200 1667 Phí
|
||||
207 1644 259 1667 thuê
|
||||
265 1645 316 1672 máy
|
||||
324 1645 362 1671 lọc
|
||||
159 1683 219 1706 nước
|
||||
226 1683 286 1709 nóng
|
||||
293 1683 345 1708 lạnh
|
||||
160 1720 237 1746 Digital
|
||||
245 1720 307 1742 CHP-
|
||||
112 1738 129 1760 4
|
||||
159 1758 268 1780 3800ST1
|
||||
276 1758 313 1783 (từ
|
||||
319 1758 377 1785 ngày
|
||||
453 1737 508 1767 Máy
|
||||
677 1737 699 1764 4
|
||||
795 1737 895 1764 256,000
|
||||
1077 1737 1197 1764 1,024,000
|
||||
1297 1737 1354 1762 10%
|
||||
1453 1737 1552 1764 102,400
|
||||
158 1795 293 1817 20/02/2023
|
||||
301 1791 344 1817 đến
|
||||
350 1791 388 1817 hết
|
||||
158 1831 304 1859 28/02/2023)
|
||||
93 2004 134 2027 Giá
|
||||
139 2006 169 2031 trị
|
||||
173 2007 221 2026 theo
|
||||
226 2006 276 2026 mức
|
||||
281 2002 331 2026 thuế
|
||||
337 2005 412 2026 GTGT
|
||||
676 1986 747 2008 Thành
|
||||
752 1982 795 2008 tiền
|
||||
800 1986 859 2008 trước
|
||||
865 1983 915 2008 thuế
|
||||
920 1986 992 2007 GTGT
|
||||
1025 1983 1074 2008 Tiền
|
||||
1079 1983 1127 2008 thuế
|
||||
1132 1986 1201 2008 GTGT
|
||||
1234 1985 1303 2008 Thành
|
||||
1308 1982 1352 2008 tiền
|
||||
1357 1985 1406 2012 gồm
|
||||
1411 1982 1460 2008 thuế
|
||||
1466 1985 1538 2008 GTGT
|
||||
718 2023 813 2047 (Amount
|
||||
819 2023 889 2048 before
|
||||
895 2022 952 2048 VAT)
|
||||
1035 2022 1096 2046 (VAT
|
||||
1100 2023 1195 2048 Amount)
|
||||
1248 2023 1345 2047 (Amount
|
||||
1351 2023 1459 2049 including
|
||||
1465 2022 1524 2048 VAT)
|
||||
92 2071 132 2093 Giá
|
||||
137 2072 162 2096 trị
|
||||
166 2072 202 2092 với
|
||||
207 2069 253 2093 thuế
|
||||
259 2072 328 2092 GTGT
|
||||
334 2071 371 2092 0%
|
||||
378 2072 472 2094 (Amount
|
||||
478 2074 502 2093 at
|
||||
507 2071 545 2093 0%
|
||||
552 2070 610 2097 VAT)
|
||||
93 2119 132 2141 Giá
|
||||
137 2120 162 2145 trị
|
||||
167 2120 202 2141 với
|
||||
207 2117 253 2141 thuế
|
||||
259 2121 328 2140 GTGT
|
||||
335 2120 371 2140 5%
|
||||
378 2121 471 2141 (Amount
|
||||
477 2122 502 2141 at
|
||||
507 2119 546 2141 5%
|
||||
552 2119 609 2145 VAT)
|
||||
94 2167 132 2189 Giá
|
||||
137 2168 162 2192 trị
|
||||
166 2167 202 2189 với
|
||||
207 2164 253 2189 thuế
|
||||
259 2169 328 2188 GTGT
|
||||
335 2168 372 2188 8%
|
||||
379 2168 472 2189 (Amount
|
||||
478 2171 502 2189 at
|
||||
507 2168 545 2189 8%
|
||||
552 2167 609 2193 VAT)
|
||||
92 2216 132 2237 Giá
|
||||
137 2217 162 2241 trị
|
||||
167 2215 202 2236 với
|
||||
207 2212 253 2237 thuế
|
||||
259 2216 329 2236 GTGT
|
||||
337 2216 384 2236 10%
|
||||
391 2217 486 2237 (Amount
|
||||
491 2219 516 2236 at
|
||||
523 2216 572 2237 10%
|
||||
579 2214 636 2240 VAT)
|
||||
891 2213 1005 2239 23,936,000
|
||||
1111 2213 1215 2239 2,393,600
|
||||
1435 2213 1552 2240 26,329,600
|
||||
94 2263 170 2285 TỔNG
|
||||
176 2263 252 2288 CỘNG
|
||||
260 2264 361 2287 (GRAND
|
||||
368 2263 461 2288 TOTAL)
|
||||
874 2262 1007 2288 23,936,000
|
||||
1097 2262 1215 2288 2,393,600
|
||||
1416 2261 1549 2288 26,329,600
|
||||
87 2307 119 2333 Số
|
||||
125 2307 171 2332 tiền
|
||||
178 2304 223 2332 viết
|
||||
230 2308 289 2337 bằng
|
||||
295 2308 340 2332 chữ
|
||||
347 2311 443 2332 (Amount
|
||||
449 2310 473 2332 in
|
||||
479 2309 564 2336 words):
|
||||
571 2308 616 2331 Hai
|
||||
621 2310 684 2331 mươi
|
||||
690 2308 731 2331 sáu
|
||||
736 2308 793 2336 triệu
|
||||
797 2308 828 2331 ba
|
||||
833 2309 888 2331 trăm
|
||||
894 2308 932 2331 hai
|
||||
938 2309 1001 2331 mươi
|
||||
1007 2308 1059 2331 chín
|
||||
1065 2308 1133 2336 nghìn
|
||||
1139 2307 1180 2331 sáu
|
||||
1186 2309 1241 2331 trăm
|
||||
1247 2305 1308 2336 đồng
|
127
cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py
Executable file
127
cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py
Executable file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
from utils.ema_callbacks import EMA
|
||||
|
||||
|
||||
def _update_config(cfg):
|
||||
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
||||
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
|
||||
|
||||
# set per-gpu batch size
|
||||
num_devices = torch.cuda.device_count()
|
||||
print('No. devices:', num_devices)
|
||||
for mode in ["train", "val"]:
|
||||
new_batch_size = cfg[mode].batch_size // num_devices
|
||||
cfg[mode].batch_size = new_batch_size
|
||||
|
||||
def _get_config_from_cli():
|
||||
cfg_cli = OmegaConf.from_cli()
|
||||
cli_keys = list(cfg_cli.keys())
|
||||
for cli_key in cli_keys:
|
||||
if "--" in cli_key:
|
||||
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
|
||||
del cfg_cli[cli_key]
|
||||
|
||||
return cfg_cli
|
||||
|
||||
def get_callbacks(cfg):
|
||||
callback_list = []
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
|
||||
filename='best_model',
|
||||
save_last=True,
|
||||
save_top_k=1,
|
||||
save_weights_only=True,
|
||||
verbose=True,
|
||||
monitor='val_f1', mode='max')
|
||||
checkpoint_callback.FILE_EXTENSION = ".pth"
|
||||
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
|
||||
callback_list.append(checkpoint_callback)
|
||||
if cfg.callbacks.ema.decay != -1:
|
||||
ema_callback = EMA(decay=0.9999)
|
||||
callback_list.append(ema_callback)
|
||||
return callback_list if len(callback_list) > 1 else checkpoint_callback
|
||||
|
||||
def get_plugins(cfg):
|
||||
plugins = []
|
||||
if cfg.train.strategy.type == "ddp":
|
||||
plugins.append(DDPPlugin())
|
||||
|
||||
return plugins
|
||||
|
||||
def get_loggers(cfg):
|
||||
loggers = []
|
||||
|
||||
loggers.append(
|
||||
TensorBoardLogger(
|
||||
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
|
||||
)
|
||||
)
|
||||
|
||||
return loggers
|
||||
|
||||
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
|
||||
for key, val in cfg.items():
|
||||
if isinstance(val, DictConfig):
|
||||
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
|
||||
else:
|
||||
hparam_dict[parent_str + key] = str(val)
|
||||
return hparam_dict
|
||||
|
||||
def get_specific_pl_logger(pl_loggers, logger_type):
|
||||
for pl_logger in pl_loggers:
|
||||
if isinstance(pl_logger, logger_type):
|
||||
return pl_logger
|
||||
return None
|
||||
|
||||
def get_class_names(dataset_root_path):
|
||||
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
|
||||
class_names = (
|
||||
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
|
||||
)
|
||||
return class_names
|
||||
|
||||
def create_exp_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Experiment dir : {}'.format(save_dir))
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
|
||||
def load_checkpoint(ckpt_path, model, key_include):
|
||||
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
||||
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
|
||||
for key in list(state_dict.keys()):
|
||||
if f'.{key_include}.' not in key:
|
||||
del state_dict[key]
|
||||
else:
|
||||
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
||||
del state_dict[key]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
print(f"Load checkpoint at {ckpt_path}")
|
||||
return model
|
||||
|
||||
def load_model_weight(net, pretrained_model_file):
|
||||
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
|
||||
"state_dict"
|
||||
]
|
||||
new_state_dict = {}
|
||||
for k, v in pretrained_model_state_dict.items():
|
||||
new_k = k
|
||||
if new_k.startswith("net."):
|
||||
new_k = new_k[len("net.") :]
|
||||
new_state_dict[new_k] = v
|
||||
net.load_state_dict(new_state_dict)
|
||||
|
||||
|
346
cope2n-ai-fi/common/AnyKey_Value/utils/ema_callbacks.py
Executable file
346
cope2n-ai-fi/common/AnyKey_Value/utils/ema_callbacks.py
Executable file
@ -0,0 +1,346 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
|
||||
|
||||
class EMA(Callback):
|
||||
"""
|
||||
Implements Exponential Moving Averaging (EMA).
|
||||
|
||||
When training a model, this callback will maintain moving averages of the trained parameters.
|
||||
When evaluating, we use the moving averages copy of the trained parameters.
|
||||
When saving, we save an additional set of parameters with the prefix `ema`.
|
||||
|
||||
Args:
|
||||
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
|
||||
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
|
||||
every_n_steps: Apply EMA every N steps.
|
||||
cpu_offload: Offload weights to CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False,
|
||||
):
|
||||
if not (0 <= decay <= 1):
|
||||
raise MisconfigurationException("EMA decay value must be between 0 and 1")
|
||||
self.decay = decay
|
||||
self.validate_original_weights = validate_original_weights
|
||||
self.every_n_steps = every_n_steps
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
device = pl_module.device if not self.cpu_offload else torch.device('cpu')
|
||||
trainer.optimizers = [
|
||||
EMAOptimizer(
|
||||
optim,
|
||||
device=device,
|
||||
decay=self.decay,
|
||||
every_n_steps=self.every_n_steps,
|
||||
current_step=trainer.global_step,
|
||||
)
|
||||
for optim in trainer.optimizers
|
||||
if not isinstance(optim, EMAOptimizer)
|
||||
]
|
||||
|
||||
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._should_validate_ema_weights(trainer):
|
||||
self.swap_model_weights(trainer)
|
||||
|
||||
def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
|
||||
return not self.validate_original_weights and self._ema_initialized(trainer)
|
||||
|
||||
def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
|
||||
return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
|
||||
|
||||
def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.switch_main_parameter_weights(saving_ema_model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_ema_model(self, trainer: "pl.Trainer"):
|
||||
"""
|
||||
Saves an EMA copy of the model + EMA optimizer states for resume.
|
||||
"""
|
||||
self.swap_model_weights(trainer, saving_ema_model=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.swap_model_weights(trainer, saving_ema_model=False)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_original_optimizer_state(self, trainer: "pl.Trainer"):
|
||||
for optimizer in trainer.optimizers:
|
||||
assert isinstance(optimizer, EMAOptimizer)
|
||||
optimizer.save_original_optimizer_state = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for optimizer in trainer.optimizers:
|
||||
optimizer.save_original_optimizer_state = False
|
||||
|
||||
def on_load_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
|
||||
# use the connector as NeMo calls the connector directly in the exp_manager when restoring.
|
||||
connector = trainer._checkpoint_connector
|
||||
ckpt_path = connector.resume_checkpoint_path
|
||||
|
||||
if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
|
||||
ext = checkpoint_callback.FILE_EXTENSION
|
||||
if ckpt_path.endswith(f'-EMA{ext}'):
|
||||
rank_zero_info(
|
||||
"loading EMA based weights. "
|
||||
"The callback will treat the loaded EMA weights as the main weights"
|
||||
" and create a new EMA copy when training."
|
||||
)
|
||||
return
|
||||
ema_path = ckpt_path.replace(ext, f'-EMA{ext}')
|
||||
if os.path.exists(ema_path):
|
||||
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
|
||||
|
||||
checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
|
||||
del ema_state_dict
|
||||
rank_zero_info("EMA state has been restored.")
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
"Unable to find the associated EMA weights when re-loading, "
|
||||
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||
torch._foreach_mul_(ema_model_tuple, decay)
|
||||
torch._foreach_add_(
|
||||
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
|
||||
)
|
||||
|
||||
|
||||
def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
|
||||
if pre_sync_stream is not None:
|
||||
pre_sync_stream.synchronize()
|
||||
|
||||
ema_update(ema_model_tuple, current_model_tuple, decay)
|
||||
|
||||
|
||||
class EMAOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
|
||||
Exponential Moving Average of parameters registered in the optimizer.
|
||||
|
||||
EMA parameters are automatically updated after every step of the optimizer
|
||||
with the following formula:
|
||||
|
||||
ema_weight = decay * ema_weight + (1 - decay) * training_weight
|
||||
|
||||
To access EMA parameters, use ``swap_ema_weights()`` context manager to
|
||||
perform a temporary in-place swap of regular parameters with EMA
|
||||
parameters.
|
||||
|
||||
Notes:
|
||||
- EMAOptimizer is not compatible with APEX AMP O2.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): optimizer to wrap
|
||||
device (torch.device): device for EMA parameters
|
||||
decay (float): decay factor
|
||||
|
||||
Returns:
|
||||
returns an instance of torch.optim.Optimizer that computes EMA of
|
||||
parameters
|
||||
|
||||
Example:
|
||||
model = Model().to(device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
|
||||
opt = EMAOptimizer(opt, device, 0.9999)
|
||||
|
||||
for epoch in range(epochs):
|
||||
training_loop(model, opt)
|
||||
|
||||
regular_eval_accuracy = evaluate(model)
|
||||
|
||||
with opt.swap_ema_weights():
|
||||
ema_eval_accuracy = evaluate(model)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
decay: float = 0.9999,
|
||||
every_n_steps: int = 1,
|
||||
current_step: int = 0,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.decay = decay
|
||||
self.device = device
|
||||
self.current_step = current_step
|
||||
self.every_n_steps = every_n_steps
|
||||
self.save_original_optimizer_state = False
|
||||
|
||||
self.first_iteration = True
|
||||
self.rebuild_ema_params = True
|
||||
self.stream = None
|
||||
self.thread = None
|
||||
|
||||
self.ema_params = ()
|
||||
self.in_saving_ema_model_context = False
|
||||
|
||||
def all_parameters(self) -> Iterable[torch.Tensor]:
|
||||
return (param for group in self.param_groups for param in group['params'])
|
||||
|
||||
def step(self, closure=None, **kwargs):
|
||||
self.join()
|
||||
|
||||
if self.first_iteration:
|
||||
if any(p.is_cuda for p in self.all_parameters()):
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
self.first_iteration = False
|
||||
|
||||
if self.rebuild_ema_params:
|
||||
opt_params = list(self.all_parameters())
|
||||
|
||||
self.ema_params += tuple(
|
||||
copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :]
|
||||
)
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
loss = self.optimizer.step(closure)
|
||||
|
||||
if self._should_update_at_step():
|
||||
self.update()
|
||||
self.current_step += 1
|
||||
return loss
|
||||
|
||||
def _should_update_at_step(self) -> bool:
|
||||
return self.current_step % self.every_n_steps == 0
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
if self.stream is not None:
|
||||
self.stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
current_model_state = tuple(
|
||||
param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
|
||||
)
|
||||
|
||||
if self.device.type == 'cuda':
|
||||
ema_update(self.ema_params, current_model_state, self.decay)
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.thread = threading.Thread(
|
||||
target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,),
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def swap_tensors(self, tensor1, tensor2):
|
||||
tmp = torch.empty_like(tensor1)
|
||||
tmp.copy_(tensor1)
|
||||
tensor1.copy_(tensor2)
|
||||
tensor2.copy_(tmp)
|
||||
|
||||
def switch_main_parameter_weights(self, saving_ema_model: bool = False):
|
||||
self.join()
|
||||
self.in_saving_ema_model_context = saving_ema_model
|
||||
for param, ema_param in zip(self.all_parameters(), self.ema_params):
|
||||
self.swap_tensors(param.data, ema_param)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_ema_weights(self, enabled: bool = True):
|
||||
r"""
|
||||
A context manager to in-place swap regular parameters with EMA
|
||||
parameters.
|
||||
It swaps back to the original regular parameters on context manager
|
||||
exit.
|
||||
|
||||
Args:
|
||||
enabled (bool): whether the swap should be performed
|
||||
"""
|
||||
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if enabled:
|
||||
self.switch_main_parameter_weights()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.optimizer, name)
|
||||
|
||||
def join(self):
|
||||
if self.stream is not None:
|
||||
self.stream.synchronize()
|
||||
|
||||
if self.thread is not None:
|
||||
self.thread.join()
|
||||
|
||||
def state_dict(self):
|
||||
self.join()
|
||||
|
||||
if self.save_original_optimizer_state:
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
# if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
|
||||
ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters())
|
||||
state_dict = {
|
||||
'opt': self.optimizer.state_dict(),
|
||||
'ema': ema_params,
|
||||
'current_step': self.current_step,
|
||||
'decay': self.decay,
|
||||
'every_n_steps': self.every_n_steps,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.join()
|
||||
|
||||
self.optimizer.load_state_dict(state_dict['opt'])
|
||||
self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
|
||||
self.current_step = state_dict['current_step']
|
||||
self.decay = state_dict['decay']
|
||||
self.every_n_steps = state_dict['every_n_steps']
|
||||
self.rebuild_ema_params = False
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self.optimizer.add_param_group(param_group)
|
||||
self.rebuild_ema_params = True
|
68
cope2n-ai-fi/common/AnyKey_Value/utils/kvu_dictionary.py
Executable file
68
cope2n-ai-fi/common/AnyKey_Value/utils/kvu_dictionary.py
Executable file
@ -0,0 +1,68 @@
|
||||
|
||||
DKVU2XML = {
|
||||
"Ký hiệu mẫu hóa đơn": "form_no",
|
||||
"Ký hiệu hóa đơn": "serial_no",
|
||||
"Số hóa đơn": "invoice_no",
|
||||
"Ngày, tháng, năm lập hóa đơn": "issue_date",
|
||||
"Tên người bán": "seller_name",
|
||||
"Mã số thuế người bán": "seller_tax_code",
|
||||
"Thuế suất": "tax_rate",
|
||||
"Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount",
|
||||
"Mặt hàng": "item",
|
||||
"Đơn vị tính": "unit",
|
||||
"Số lượng": "quantity",
|
||||
"Đơn giá": "unit_price",
|
||||
"Doanh số mua chưa có thuế": "amount"
|
||||
}
|
||||
|
||||
|
||||
def ap_dictionary(header: bool):
|
||||
header_dictionary = {
|
||||
'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'],
|
||||
'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'],
|
||||
'qty': ['quantity', 'invoicequantity']
|
||||
}
|
||||
|
||||
key_dictionary = {
|
||||
'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'],
|
||||
'retailername': ['retailer', 'retailername', 'ownedoperatedby'],
|
||||
'serial_number': ['serialnumber', 'serialno'],
|
||||
'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2']
|
||||
}
|
||||
|
||||
return header_dictionary if header else key_dictionary
|
||||
|
||||
|
||||
def vat_dictionary(header: bool):
|
||||
header_dictionary = {
|
||||
'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'],
|
||||
'Đơn vị tính': ['dvt', 'donvitinh'],
|
||||
'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'],
|
||||
'Đơn giá': ['dongia'],
|
||||
'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'],
|
||||
# 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber']
|
||||
}
|
||||
|
||||
key_dictionary = {
|
||||
'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'],
|
||||
'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'],
|
||||
'Số hóa đơn': ['soinvoiceno', 'invoiceno'],
|
||||
'Ngày, tháng, năm lập hóa đơn': [],
|
||||
'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'],
|
||||
'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'],
|
||||
'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'],
|
||||
'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'],
|
||||
# 'Ghi chú': [],
|
||||
# 'Ngày': ['ngayday', 'ngay', 'day'],
|
||||
# 'Tháng': ['thangmonth', 'thang', 'month'],
|
||||
# 'Năm': ['namyear', 'nam', 'year']
|
||||
}
|
||||
|
||||
# exact_dictionary = {
|
||||
# 'Số hóa đơn': ['sono', 'so'],
|
||||
# 'Mã số thuế người bán': ['mst'],
|
||||
# 'Tên người bán': ['kyboi'],
|
||||
# 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
|
||||
# }
|
||||
|
||||
return header_dictionary if header else key_dictionary
|
33
cope2n-ai-fi/common/AnyKey_Value/utils/run_ocr.py
Executable file
33
cope2n-ai-fi/common/AnyKey_Value/utils/run_ocr.py
Executable file
@ -0,0 +1,33 @@
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, List
|
||||
import sys
|
||||
# sys.path.append('/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/ocr-engine')
|
||||
# from src.ocr import OcrEngine
|
||||
sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ??????
|
||||
import serve_model
|
||||
|
||||
|
||||
# def load_ocr_engine() -> OcrEngine:
|
||||
def load_ocr_engine() -> OcrEngine:
|
||||
print("[INFO] Loading engine...")
|
||||
# engine = OcrEngine()
|
||||
engine = serve_model.engine
|
||||
print("[INFO] Engine loaded")
|
||||
return engine
|
||||
|
||||
def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None:
|
||||
save_dir_or_path = Path(save_dir_or_path)
|
||||
if isinstance(img, np.ndarray):
|
||||
if save_dir_or_path.is_dir():
|
||||
raise ValueError("numpy array input require a save path, not a save dir")
|
||||
page = engine(img)
|
||||
save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt")
|
||||
) if save_dir_or_path.is_dir() else str(save_dir_or_path)
|
||||
page.write_to_file('word', save_path)
|
||||
if export_img:
|
||||
page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, )
|
||||
|
||||
def read_img(img: Union[str, np.ndarray], engine: OcrEngine):
|
||||
page = engine(img)
|
||||
return ' '.join([f.text for f in page.llines])
|
101
cope2n-ai-fi/common/AnyKey_Value/utils/split_docs.py
Normal file
101
cope2n-ai-fi/common/AnyKey_Value/utils/split_docs.py
Normal file
@ -0,0 +1,101 @@
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
def longestCommonSubsequence(text1: str, text2: str) -> int:
|
||||
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
|
||||
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
|
||||
for i, c in enumerate(text1):
|
||||
for j, d in enumerate(text2):
|
||||
dp[i + 1][j + 1] = 1 + \
|
||||
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
|
||||
return dp[-1][-1]
|
||||
|
||||
def write_to_json(file_path, content):
|
||||
with open(file_path, mode="w", encoding="utf8") as f:
|
||||
json.dump(content, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def read_json(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def check_label_exists(array, target_label):
|
||||
for obj in array:
|
||||
if obj["label"] == target_label:
|
||||
return True # Label exists in the array
|
||||
return False # Label does not exist in the array
|
||||
|
||||
def merged_kvu_outputs(loutputs: list) -> dict:
|
||||
compiled = []
|
||||
for output_model in loutputs:
|
||||
for field in output_model:
|
||||
if field['value'] != "" and not check_label_exists(compiled, field['label']):
|
||||
element = {
|
||||
'label': field['label'],
|
||||
'value': field['value'],
|
||||
}
|
||||
compiled.append(element)
|
||||
elif field['label'] == 'table' and check_label_exists(compiled, "table"):
|
||||
for index, obj in enumerate(compiled):
|
||||
if obj['label'] == 'table' and len(field['value']) > 0:
|
||||
compiled[index]['value'].append(field['value'])
|
||||
return compiled
|
||||
|
||||
|
||||
def split_docs(doc_data: list, threshold: float=0.6) -> list:
|
||||
num_pages = len(doc_data)
|
||||
outputs = []
|
||||
kvu_content = []
|
||||
doc_data = sorted(doc_data, key=lambda x: int(x['page_number']))
|
||||
for data in doc_data:
|
||||
page_id = int(data['page_number'])
|
||||
doc_type = data['document_type']
|
||||
doc_class = data['document_class']
|
||||
fields = data['fields']
|
||||
if page_id == 0:
|
||||
prev_title = doc_type
|
||||
start_page_id = page_id
|
||||
prev_class = doc_class
|
||||
curr_title = doc_type if doc_type != "unknown" else prev_title
|
||||
curr_class = doc_class if doc_class != "unknown" else "other"
|
||||
kvu_content.append(fields)
|
||||
similarity_score = longestCommonSubsequence(curr_title, prev_title) / len(prev_title)
|
||||
if similarity_score < threshold:
|
||||
end_page_id = page_id - 1
|
||||
outputs.append({
|
||||
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
|
||||
"start_page": start_page_id,
|
||||
"end_page": end_page_id,
|
||||
"content": merged_kvu_outputs(kvu_content[:-1])
|
||||
})
|
||||
prev_title = curr_title
|
||||
prev_class = curr_class
|
||||
start_page_id = page_id
|
||||
kvu_content = kvu_content[-1:]
|
||||
if page_id == num_pages - 1: # end_page
|
||||
outputs.append({
|
||||
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
|
||||
"start_page": start_page_id,
|
||||
"end_page": page_id,
|
||||
"content": merged_kvu_outputs(kvu_content)
|
||||
})
|
||||
elif page_id == num_pages - 1: # end_page
|
||||
outputs.append({
|
||||
"doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title,
|
||||
"start_page": start_page_id,
|
||||
"end_page": page_id,
|
||||
"content": merged_kvu_outputs(kvu_content)
|
||||
})
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
threshold = 0.9
|
||||
json_path = "/home/sds/tuanlv/02-KVU/02-KVU_test/visualize/manulife_v2/json_outputs/HS_YCBT_No_IP_HMTD.json"
|
||||
doc_data = read_json(json_path)
|
||||
|
||||
outputs = split_docs(doc_data, threshold)
|
||||
|
||||
write_to_json(os.path.join(os.path.dirname(json_path), "splited_doc.json"), outputs)
|
548
cope2n-ai-fi/common/AnyKey_Value/utils/utils.py
Executable file
548
cope2n-ai-fi/common/AnyKey_Value/utils/utils.py
Executable file
@ -0,0 +1,548 @@
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import glob
|
||||
import re
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pdf2image import convert_from_path
|
||||
from dicttoxml import dicttoxml
|
||||
from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary
|
||||
|
||||
|
||||
|
||||
def create_dir(save_dir=''):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
else:
|
||||
print("DIR already existed.")
|
||||
print('Save dir : {}'.format(save_dir))
|
||||
|
||||
def pdf2image(pdf_dir, save_dir):
|
||||
pdf_files = glob.glob(f'{pdf_dir}/*.pdf')
|
||||
print('No. pdf files:', len(pdf_files))
|
||||
|
||||
for file in tqdm(pdf_files):
|
||||
pages = convert_from_path(file, 500)
|
||||
for i, page in enumerate(pages):
|
||||
page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG')
|
||||
print('Done!!!')
|
||||
|
||||
def xyxy2xywh(bbox):
|
||||
return [
|
||||
float(bbox[0]),
|
||||
float(bbox[1]),
|
||||
float(bbox[2]) - float(bbox[0]),
|
||||
float(bbox[3]) - float(bbox[1]),
|
||||
]
|
||||
|
||||
def write_to_json(file_path, content):
|
||||
with open(file_path, mode='w', encoding='utf8') as f:
|
||||
json.dump(content, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def read_json(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def read_xml(file_path):
|
||||
with open(file_path, 'r') as xml_file:
|
||||
return xml_file.read()
|
||||
|
||||
def write_to_xml(file_path, content):
|
||||
with open(file_path, mode="w", encoding='utf8') as f:
|
||||
f.write(content)
|
||||
|
||||
def write_to_xml_from_dict(file_path, content):
|
||||
xml = dicttoxml(content)
|
||||
xml = content
|
||||
xml_decode = xml.decode()
|
||||
|
||||
with open(file_path, mode="w") as f:
|
||||
f.write(xml_decode)
|
||||
|
||||
|
||||
def load_ocr_result(ocr_path):
|
||||
with open(ocr_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
preds = []
|
||||
for line in lines:
|
||||
preds.append(line.split('\t'))
|
||||
return preds
|
||||
|
||||
def post_process_basic_ocr(lwords: list) -> list:
|
||||
pp_lwords = []
|
||||
for word in lwords:
|
||||
pp_lwords.append(word.replace("✪", " "))
|
||||
return pp_lwords
|
||||
|
||||
def read_ocr_result_from_txt(file_path: str):
|
||||
'''
|
||||
return list of bounding boxes, list of words
|
||||
'''
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
boxes, words = [], []
|
||||
for line in lines:
|
||||
if line == "":
|
||||
continue
|
||||
word_info = line.split("\t")
|
||||
if len(word_info) == 6:
|
||||
x1, y1, x2, y2, text, _ = word_info
|
||||
elif len(word_info) == 5:
|
||||
x1, y1, x2, y2, text = word_info
|
||||
|
||||
x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2))
|
||||
if text and text != " ":
|
||||
words.append(text)
|
||||
boxes.append((x1, y1, x2, y2))
|
||||
return boxes, words
|
||||
|
||||
def get_colormap():
|
||||
return {
|
||||
'others': (0, 0, 255), # others: red
|
||||
'title': (0, 255, 255), # title: yellow
|
||||
'key': (255, 0, 0), # key: blue
|
||||
'value': (0, 255, 0), # value: green
|
||||
'header': (233, 197, 15), # header
|
||||
'group': (0, 128, 128), # group
|
||||
'relation': (0, 0, 255)# (128, 128, 128), # relation
|
||||
}
|
||||
|
||||
|
||||
def convert_image(image):
|
||||
exif = image._getexif()
|
||||
orientation = None
|
||||
if exif is not None:
|
||||
orientation = exif.get(0x0112)
|
||||
|
||||
# Convert the PIL image to OpenCV format
|
||||
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Rotate the image in OpenCV if necessary
|
||||
if orientation == 3:
|
||||
image = cv2.rotate(image, cv2.ROTATE_180)
|
||||
elif orientation == 6:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif orientation == 8:
|
||||
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
image = np.asarray(image)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
|
||||
assert len(image.shape) == 3
|
||||
|
||||
return image, orientation
|
||||
|
||||
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
|
||||
image, orientation = convert_image(image)
|
||||
|
||||
if orientation is not None and orientation == 6:
|
||||
width, height, _ = image.shape
|
||||
else:
|
||||
height, width, _ = image.shape
|
||||
|
||||
if len(pr_class_words) > 0:
|
||||
id2label = {k: labels[k] for k in range(len(labels))}
|
||||
for lb, groups in enumerate(pr_class_words):
|
||||
if lb == 0:
|
||||
continue
|
||||
for group_id, group in enumerate(groups):
|
||||
for i, word_id in enumerate(group):
|
||||
x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000)
|
||||
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
|
||||
|
||||
if i == 0:
|
||||
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
else:
|
||||
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
|
||||
x_center0, y_center0 = x_center1, y_center1
|
||||
|
||||
if len(pr_relations) > 0:
|
||||
for pair in pr_relations:
|
||||
xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000)
|
||||
xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000)
|
||||
|
||||
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
|
||||
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
|
||||
|
||||
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
|
||||
outputs = {}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] in (rel_from, rel_to):
|
||||
is_rel[element['class']]['status'] = 1
|
||||
is_rel[element['class']]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
|
||||
return outputs
|
||||
|
||||
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
|
||||
list_keys = list(header_key_pairs.keys())
|
||||
relations = {k: [] for k in list_keys}
|
||||
for pair in json:
|
||||
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||||
for element in pair:
|
||||
if element['class'] == rel_from and element['group_id'] in list_keys:
|
||||
is_rel[rel_from]['status'] = 1
|
||||
is_rel[rel_from]['value'] = element
|
||||
if element['class'] == rel_to:
|
||||
is_rel[rel_to]['status'] = 1
|
||||
is_rel[rel_to]['value'] = element
|
||||
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||||
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
|
||||
return relations
|
||||
|
||||
def get_key2values_relations(key_value_pairs: dict):
|
||||
triple_linkings = {}
|
||||
for value_group_id, key_value_pair in key_value_pairs.items():
|
||||
key_group_id = key_value_pair[0]
|
||||
if key_group_id not in list(triple_linkings.keys()):
|
||||
triple_linkings[key_group_id] = []
|
||||
triple_linkings[key_group_id].append(value_group_id)
|
||||
return triple_linkings
|
||||
|
||||
|
||||
def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict:
|
||||
word_groups = {}
|
||||
id2class = {i: labels[i] for i in range(len(labels))}
|
||||
for class_id, lwgroups_in_class in enumerate(class_words):
|
||||
for ltokens_in_wgroup in lwgroups_in_class:
|
||||
group_id = ltokens_in_wgroup[0]
|
||||
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
|
||||
text_string = get_string(ltokens_to_ltexts)
|
||||
word_groups[group_id] = {
|
||||
'group_id': group_id,
|
||||
'text': text_string,
|
||||
'class': id2class[class_id],
|
||||
'tokens': ltokens_in_wgroup
|
||||
}
|
||||
return word_groups
|
||||
|
||||
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
|
||||
if linking_id not in list(word_groups):
|
||||
for wg_id, _word_group in word_groups.items():
|
||||
if linking_id in _word_group['tokens']:
|
||||
return wg_id
|
||||
return linking_id
|
||||
|
||||
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||||
outputs = []
|
||||
for pair in lrelations:
|
||||
wg_from = verify_linking_id(word_groups, pair[0])
|
||||
wg_to = verify_linking_id(word_groups, pair[1])
|
||||
try:
|
||||
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||||
except Exception as e:
|
||||
print('Not valid pair:', wg_from, wg_to)
|
||||
return outputs
|
||||
|
||||
|
||||
def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
word_groups = merged_token_to_wordgroup(class_words, lwords, labels)
|
||||
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
|
||||
|
||||
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
|
||||
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
|
||||
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
|
||||
|
||||
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||||
|
||||
triplet_pairs = []
|
||||
single_pairs = []
|
||||
table = []
|
||||
# print('key2values_relations', key2values_relations)
|
||||
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||||
if len(list_value_group_ids) == 0: continue
|
||||
elif len(list_value_group_ids) == 1:
|
||||
value_group_id = list_value_group_ids[0]
|
||||
single_pairs.append({word_groups[key_group_id]['text']: {
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'id': value_group_id,
|
||||
'class': "value"
|
||||
}})
|
||||
else:
|
||||
item = []
|
||||
for value_group_id in list_value_group_ids:
|
||||
if value_group_id not in header_value.keys():
|
||||
header_name_for_value = "non-header"
|
||||
else:
|
||||
header_group_id = header_value[value_group_id][0]
|
||||
header_name_for_value = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[value_group_id]['text'],
|
||||
'header': header_name_for_value,
|
||||
'id': value_group_id,
|
||||
'class': 'value'
|
||||
})
|
||||
if key_group_id not in list(header_key.keys()):
|
||||
triplet_pairs.append({
|
||||
word_groups[key_group_id]['text']: item
|
||||
})
|
||||
else:
|
||||
header_group_id = header_key[key_group_id][0]
|
||||
header_name_for_key = word_groups[header_group_id]['text']
|
||||
item.append({
|
||||
'text': word_groups[key_group_id]['text'],
|
||||
'header': header_name_for_key,
|
||||
'id': key_group_id,
|
||||
'class': 'key'
|
||||
})
|
||||
table.append({key_group_id: item})
|
||||
|
||||
if len(table) > 0:
|
||||
table = sorted(table, key=lambda x: list(x.keys())[0])
|
||||
table = [v for item in table for k, v in item.items()]
|
||||
|
||||
outputs = {}
|
||||
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
|
||||
outputs['triplet'] = triplet_pairs
|
||||
outputs['table'] = table
|
||||
|
||||
file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path))
|
||||
write_to_json(file_path, outputs)
|
||||
return outputs
|
||||
|
||||
# For FI-VAT project
|
||||
|
||||
def get_vat_table_information(outputs):
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
item = {k: [] for k in list(vat_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True)
|
||||
if header_name in list(item.keys()):
|
||||
# item[header_name] = value['text']
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': cell['id']
|
||||
})
|
||||
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
if header_name in ("Số lượng", "Doanh số mua chưa có thuế"):
|
||||
item[header_name] = '0'
|
||||
else:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
item = post_process_for_item(item)
|
||||
|
||||
if item["Mặt hàng"] == None:
|
||||
continue
|
||||
table.append(item)
|
||||
return table
|
||||
|
||||
def get_vat_information(outputs):
|
||||
# VAT Information
|
||||
single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for raw_key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id'],
|
||||
})
|
||||
|
||||
for triplet in outputs['triplet']:
|
||||
for key, value_list in triplet.items():
|
||||
if len(value_list) == 1:
|
||||
key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': value_list[0]['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value_list[0]['id']
|
||||
})
|
||||
|
||||
for pair in value_list:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': pair['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': pair['id']
|
||||
})
|
||||
|
||||
for table_row in outputs['table']:
|
||||
for pair in table_row:
|
||||
key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False)
|
||||
# print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs.keys()):
|
||||
single_pairs[key_name].append({
|
||||
'content': pair['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': pair['id']
|
||||
})
|
||||
|
||||
return single_pairs
|
||||
|
||||
|
||||
def post_process_vat_information(single_pairs):
|
||||
vat_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if key_name in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
if len(list_potential_value) == 1:
|
||||
vat_outputs[key_name] = list_potential_value[0]['content']
|
||||
else:
|
||||
date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'}
|
||||
for value in list_potential_value:
|
||||
date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content'])
|
||||
vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}"
|
||||
else:
|
||||
if len(list_potential_value) == 0: continue
|
||||
if key_name in ("Mã số thuế người bán"):
|
||||
selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code
|
||||
# tax_code_raw = selected_value['content'].replace(' ', '')
|
||||
tax_code_raw = selected_value['content']
|
||||
if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated
|
||||
tax_code_raw = tax_code_raw.split(' ')
|
||||
tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0]
|
||||
vat_outputs[key_name] = tax_code_raw.replace(' ', '')
|
||||
|
||||
else:
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
vat_outputs[key_name] = selected_value['content']
|
||||
return vat_outputs
|
||||
|
||||
|
||||
def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
vat_outputs = {}
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
|
||||
# List of items in table
|
||||
table = get_vat_table_information(outputs)
|
||||
|
||||
# VAT Information
|
||||
single_pairs = get_vat_information(outputs)
|
||||
vat_outputs = post_process_vat_information(single_pairs)
|
||||
|
||||
# Combine VAT information and table
|
||||
vat_outputs['table'] = table
|
||||
|
||||
write_to_json(file_path, vat_outputs)
|
||||
return vat_outputs
|
||||
|
||||
|
||||
# For SBT project
|
||||
|
||||
def get_ap_table_information(outputs):
|
||||
table = []
|
||||
for single_item in outputs['table']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
for cell in single_item:
|
||||
header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}")
|
||||
if header_name in list(item.keys()):
|
||||
item[header_name].append({
|
||||
'content': cell['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': cell['id']
|
||||
})
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
table.append(item)
|
||||
return table
|
||||
|
||||
def get_ap_triplet_information(outputs):
|
||||
triplet_pairs = []
|
||||
for single_item in outputs['triplet']:
|
||||
item = {k: [] for k in list(ap_dictionary(header=True).keys())}
|
||||
is_item_valid = 0
|
||||
for key_name, list_value in single_item.items():
|
||||
for value in list_value:
|
||||
if value['header'] == "non-header":
|
||||
continue
|
||||
header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True)
|
||||
if header_name in list(item.keys()):
|
||||
is_item_valid = 1
|
||||
item[header_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
if is_item_valid == 1:
|
||||
for header_name, value in item.items():
|
||||
if len(value) == 0:
|
||||
item[header_name] = None
|
||||
continue
|
||||
item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score
|
||||
|
||||
item['productname'] = key_name
|
||||
# triplet_pairs.append({key_name: new_item})
|
||||
triplet_pairs.append(item)
|
||||
return triplet_pairs
|
||||
|
||||
|
||||
def get_ap_information(outputs):
|
||||
single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())}
|
||||
for pair in outputs['single']:
|
||||
for key_name, value in pair.items():
|
||||
key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False)
|
||||
# print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}")
|
||||
|
||||
if key_name in list(single_pairs):
|
||||
single_pairs[key_name].append({
|
||||
'content': value['text'],
|
||||
'processed_key_name': proceessed_text,
|
||||
'lcs_score': score,
|
||||
'token_id': value['id']
|
||||
})
|
||||
|
||||
ap_outputs = {k: None for k in list(single_pairs)}
|
||||
for key_name, list_potential_value in single_pairs.items():
|
||||
if len(list_potential_value) == 0: continue
|
||||
selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score
|
||||
ap_outputs[key_name] = selected_value['content']
|
||||
|
||||
return ap_outputs
|
||||
|
||||
def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||||
outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels)
|
||||
# List of items in table
|
||||
table = get_ap_table_information(outputs)
|
||||
triplet_pairs = get_ap_triplet_information(outputs)
|
||||
table = table + triplet_pairs
|
||||
|
||||
ap_outputs = get_ap_information(outputs)
|
||||
|
||||
ap_outputs['table'] = table
|
||||
# ap_outputs['triplet'] = triplet_pairs
|
||||
|
||||
write_to_json(file_path, ap_outputs)
|
224
cope2n-ai-fi/common/AnyKey_Value/word_preprocess.py
Executable file
224
cope2n-ai-fi/common/AnyKey_Value/word_preprocess.py
Executable file
@ -0,0 +1,224 @@
|
||||
import nltk
|
||||
import re
|
||||
import string
|
||||
import copy
|
||||
from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML
|
||||
nltk.download('words')
|
||||
words = set(nltk.corpus.words.words())
|
||||
|
||||
s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ'
|
||||
s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy'
|
||||
|
||||
# def clean_text(text):
|
||||
# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text)
|
||||
|
||||
|
||||
def get_string(lwords: list):
|
||||
unique_list = []
|
||||
for item in lwords:
|
||||
if item.isdigit() and len(item) == 1:
|
||||
unique_list.append(item)
|
||||
elif item not in unique_list:
|
||||
unique_list.append(item)
|
||||
return ' '.join(unique_list)
|
||||
|
||||
def remove_english_words(text):
|
||||
_word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words]
|
||||
return ' '.join(_word)
|
||||
|
||||
def remove_punctuation(text):
|
||||
return text.translate(str.maketrans(" ", " ", string.punctuation))
|
||||
|
||||
def remove_accents(input_str, s0, s1):
|
||||
s = ''
|
||||
# print input_str.encode('utf-8')
|
||||
for c in input_str:
|
||||
if c in s1:
|
||||
s += s0[s1.index(c)]
|
||||
else:
|
||||
s += c
|
||||
return s
|
||||
|
||||
def remove_spaces(text):
|
||||
return text.replace(' ', '')
|
||||
|
||||
def preprocessing(text: str):
|
||||
# text = remove_english_words(text) if table else text
|
||||
text = remove_punctuation(text)
|
||||
text = remove_accents(text, s0, s1)
|
||||
text = remove_spaces(text)
|
||||
return text.lower()
|
||||
|
||||
|
||||
def vat_standardize_outputs(vat_outputs: dict) -> dict:
|
||||
outputs = {}
|
||||
for key, value in vat_outputs.items():
|
||||
if key != "table":
|
||||
outputs[DKVU2XML[key]] = value
|
||||
else:
|
||||
list_items = []
|
||||
for item in value:
|
||||
list_items.append({
|
||||
DKVU2XML[item_key]: item_value for item_key, item_value in item.items()
|
||||
})
|
||||
outputs['table'] = list_items
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
def vat_standardizer(text: str, threshold: float, header: bool):
|
||||
dictionary = vat_dictionary(header)
|
||||
processed_text = preprocessing(text)
|
||||
|
||||
for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]:
|
||||
if any([processed_text in txt for txt in candidates]):
|
||||
processed_text = candidates[-1]
|
||||
return "Ngày, tháng, năm lập hóa đơn", 5, processed_text
|
||||
|
||||
_dictionary = copy.deepcopy(dictionary)
|
||||
if not header:
|
||||
exact_dictionary = {
|
||||
'Số hóa đơn': ['sono', 'so'],
|
||||
'Mã số thuế người bán': ['mst'],
|
||||
'Tên người bán': ['kyboi'],
|
||||
'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate']
|
||||
}
|
||||
for k, v in exact_dictionary.items():
|
||||
_dictionary[k] = dictionary[k] + exact_dictionary[k]
|
||||
|
||||
for k, v in dictionary.items():
|
||||
# if k in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
# continue
|
||||
# Prioritize match completely
|
||||
if k in ('Tên người bán') and processed_text == "kyboi":
|
||||
return k, 8, processed_text
|
||||
|
||||
if any([processed_text == key for key in _dictionary[k]]):
|
||||
return k, 10, processed_text
|
||||
|
||||
scores = {k: 0.0 for k in dictionary}
|
||||
for k, v in dictionary.items():
|
||||
if k in ("Ngày, tháng, năm lập hóa đơn"):
|
||||
continue
|
||||
|
||||
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
|
||||
|
||||
key, score = max(scores.items(), key=lambda x: x[1])
|
||||
return key if score > threshold else text, score, processed_text
|
||||
|
||||
def ap_standardizer(text: str, threshold: float, header: bool):
|
||||
dictionary = ap_dictionary(header)
|
||||
processed_text = preprocessing(text)
|
||||
|
||||
# Prioritize match completely
|
||||
_dictionary = copy.deepcopy(dictionary)
|
||||
if not header:
|
||||
_dictionary['serial_number'] = dictionary['serial_number'] + ['sn']
|
||||
_dictionary['imei_number'] = dictionary['imei_number'] + ['imel']
|
||||
else:
|
||||
_dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei']
|
||||
_dictionary['qty'] = dictionary['qty'] + ['qty']
|
||||
|
||||
for k, v in dictionary.items():
|
||||
if any([processed_text == key for key in _dictionary[k]]):
|
||||
return k, 10, processed_text
|
||||
|
||||
scores = {k: 0.0 for k in dictionary}
|
||||
for k, v in dictionary.items():
|
||||
scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]])
|
||||
|
||||
key, score = max(scores.items(), key=lambda x: x[1])
|
||||
return key if score >= threshold else text, score, processed_text
|
||||
|
||||
|
||||
def convert_format_number(s: str) -> float:
|
||||
s = s.replace(' ', '').replace('O', '0').replace('o', '0')
|
||||
if s.endswith(",00") or s.endswith(".00"):
|
||||
s = s[:-3]
|
||||
if all([delimiter in s for delimiter in [',', '.']]):
|
||||
s = s.replace('.', '').split(',')
|
||||
remain_value = s[1].split('0')[0]
|
||||
return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value))
|
||||
else:
|
||||
s = s.replace(',', '').replace('.', '')
|
||||
return int(s)
|
||||
|
||||
|
||||
def post_process_for_item(item: dict) -> dict:
|
||||
check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế']
|
||||
mis_key = []
|
||||
for key in check_keys:
|
||||
if item[key] in (None, '0'):
|
||||
mis_key.append(key)
|
||||
if len(mis_key) == 1:
|
||||
try:
|
||||
if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0:
|
||||
item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__()
|
||||
elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__()
|
||||
elif mis_key[0] == check_keys[2]:
|
||||
item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__()
|
||||
except Exception as e:
|
||||
print("Cannot post process this item with error:", e)
|
||||
return item
|
||||
|
||||
|
||||
def longestCommonSubsequence(text1: str, text2: str) -> int:
|
||||
# https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis
|
||||
dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)]
|
||||
for i, c in enumerate(text1):
|
||||
for j, d in enumerate(text2):
|
||||
dp[i + 1][j + 1] = 1 + \
|
||||
dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j])
|
||||
return dp[-1][-1]
|
||||
|
||||
|
||||
def longest_common_subsequence_with_idx(X, Y):
|
||||
"""
|
||||
This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS.
|
||||
The longest_common_subsequence function takes two strings as input, and returns a tuple with three values:
|
||||
the length of the LCS,
|
||||
the index of the first character of the LCS in the first string,
|
||||
and the index of the last character of the LCS in the first string.
|
||||
"""
|
||||
m, n = len(X), len(Y)
|
||||
L = [[0 for i in range(n + 1)] for j in range(m + 1)]
|
||||
|
||||
# Following steps build L[m+1][n+1] in bottom up fashion. Note
|
||||
# that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
|
||||
right_idx = 0
|
||||
max_lcs = 0
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
if L[i][j] > max_lcs:
|
||||
max_lcs = L[i][j]
|
||||
right_idx = i
|
||||
else:
|
||||
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||
|
||||
# Create a string variable to store the lcs string
|
||||
lcs = L[i][j]
|
||||
# Start from the right-most-bottom-most corner and
|
||||
# one by one store characters in lcs[]
|
||||
i = m
|
||||
j = n
|
||||
# right_idx = 0
|
||||
while i > 0 and j > 0:
|
||||
# If current character in X[] and Y are same, then
|
||||
# current character is part of LCS
|
||||
if X[i - 1] == Y[j - 1]:
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
# If not same, then find the larger of two and
|
||||
# go in the direction of larger value
|
||||
elif L[i - 1][j] > L[i][j - 1]:
|
||||
# right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs
|
||||
i -= 1
|
||||
else:
|
||||
j -= 1
|
||||
return lcs, i, max(i + lcs, right_idx)
|
172
cope2n-ai-fi/common/crop_location.py
Executable file
172
cope2n-ai-fi/common/crop_location.py
Executable file
@ -0,0 +1,172 @@
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib
|
||||
|
||||
def get_center(box):
|
||||
xmin, ymin, xmax, ymax = box
|
||||
x_center = int((xmin + xmax) / 2)
|
||||
y_center = int((ymin + ymax) / 2)
|
||||
return [x_center, y_center]
|
||||
|
||||
|
||||
def cal_euclidean_dist(p1, p2):
|
||||
return np.linalg.norm(p1 - p2)
|
||||
|
||||
|
||||
def bbox_to_four_poinst(bbox):
|
||||
"""convert one bouding box to 4 corner poinst
|
||||
|
||||
Args:
|
||||
bbox (_type_): _description_
|
||||
"""
|
||||
xmin, ymin, xmax, ymax = bbox
|
||||
|
||||
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
|
||||
return poinst
|
||||
|
||||
|
||||
def find_closest_point(src_point, point_list):
|
||||
"""
|
||||
|
||||
Args:
|
||||
point (list): point format xy
|
||||
point_list (list[list]): list of point xy
|
||||
"""
|
||||
|
||||
point_list = np.array(point_list)
|
||||
dist_list = np.array(
|
||||
cal_euclidean_dist(src_point, target_point) for target_point in point_list
|
||||
)
|
||||
|
||||
index_closest_point = np.argmin(dist_list)
|
||||
return index_closest_point
|
||||
|
||||
|
||||
def crop_align_card(img_src, corner_box_list):
|
||||
"""Dewarp image based on four courners
|
||||
|
||||
Args:
|
||||
corner_list (list): four points of corners
|
||||
"""
|
||||
img = img_src.copy()
|
||||
if isinstance(corner_box_list[0], list):
|
||||
poinst = [get_center(box) for box in corner_box_list]
|
||||
else:
|
||||
# print(corner_box_list)
|
||||
xmin, ymin, xmax, ymax = corner_box_list
|
||||
poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
|
||||
|
||||
return dewarp(img, poinst)
|
||||
|
||||
|
||||
def dewarp(image, poinst):
|
||||
if isinstance(poinst, list):
|
||||
poinst = np.array(poinst, dtype="float32")
|
||||
(tl, tr, br, bl) = poinst
|
||||
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
||||
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
||||
maxWidth = max(int(widthA), int(widthB))
|
||||
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
||||
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
||||
maxHeight = max(int(heightA), int(heightB))
|
||||
|
||||
dst = np.array(
|
||||
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
M = cv2.getPerspectiveTransform(poinst, dst)
|
||||
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
||||
return warped
|
||||
|
||||
|
||||
class MdetPredictor:
|
||||
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
|
||||
self.model = init_detector(config, checkpoint, device=device)
|
||||
self.class_names = self.model.CLASSES
|
||||
|
||||
def infer(self, image, threshold=0.2):
|
||||
bbox_result = inference_detector(self.model, image)
|
||||
|
||||
bboxes = np.vstack(bbox_result)
|
||||
labels = [
|
||||
np.full(bbox.shape[0], i, dtype=np.int32)
|
||||
for i, bbox in enumerate(bbox_result)
|
||||
]
|
||||
labels = np.concatenate(labels)
|
||||
|
||||
res_bboxes = []
|
||||
res_labels = []
|
||||
for idx, box in enumerate(bboxes):
|
||||
score = box[-1]
|
||||
if score >= threshold:
|
||||
label = labels[idx]
|
||||
res_bboxes.append(box.tolist()[:4])
|
||||
res_labels.append(self.class_names[label])
|
||||
|
||||
return res_bboxes, res_labels
|
||||
|
||||
|
||||
class ImageTransformer:
|
||||
def __init__(self, config: str, checkpoint: str, device: str = "cpu"):
|
||||
self.corner_detect_model = MdetPredictor(config, checkpoint, device)
|
||||
|
||||
def __call__(self, image, threshold=0.2):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image (np.ndarray): BGR image
|
||||
"""
|
||||
corner_result = self.corner_detect_model.infer(image)
|
||||
corners_dict = self.__extract_corners(corner_result)
|
||||
card_image = self.__crop_image_based_on_corners(image, corners_dict)
|
||||
|
||||
return card_image
|
||||
|
||||
def __extract_corners(self, corner_result):
|
||||
|
||||
bboxes, labels = corner_result
|
||||
# convert bbox to int
|
||||
bboxes = [[int(x) for x in box] for box in bboxes]
|
||||
output = {k: bboxes[labels.index(k)] for k in labels}
|
||||
# print(output)
|
||||
return output
|
||||
|
||||
def __crop_image_based_on_corners(self, image, corners_dict):
|
||||
"""
|
||||
|
||||
Args:
|
||||
corners_dict (_type_): _description_
|
||||
"""
|
||||
if "card" in corners_dict.keys():
|
||||
if len(corners_dict.keys()) == 5:
|
||||
points = [
|
||||
corners_dict["top_left"],
|
||||
corners_dict["top_right"],
|
||||
corners_dict["bottom_right"],
|
||||
corners_dict["bottom_left"],
|
||||
]
|
||||
else:
|
||||
points = corners_dict["card"]
|
||||
card_image = crop_align_card(image, points)
|
||||
else:
|
||||
card_image = None
|
||||
|
||||
return card_image
|
||||
|
||||
|
||||
def crop_location(image_url):
|
||||
transform_module = ImageTransformer(
|
||||
config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py",
|
||||
checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth",
|
||||
device="cuda:0",
|
||||
)
|
||||
req = urllib.request.urlopen(image_url)
|
||||
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, -1)
|
||||
card_image = transform_module(img)
|
||||
if card_image is not None:
|
||||
return card_image
|
||||
else:
|
||||
return img
|
622
cope2n-ai-fi/common/dates_gplx.json
Executable file
622
cope2n-ai-fi/common/dates_gplx.json
Executable file
@ -0,0 +1,622 @@
|
||||
{
|
||||
"20221027_154840.json": {
|
||||
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
|
||||
"pred": "ngày date 05 tháng month 04 năm/year 2016"
|
||||
},
|
||||
"7ba0f6b2f2ff34a16dee18.json": {
|
||||
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date 01 tháng Amount 04 năm year 2022"
|
||||
},
|
||||
"20221027_155646.json": {
|
||||
"label": "ngày /date 01 tháng /month 04 năm/year✪2022",
|
||||
"pred": "ngày date or tháng month 04 năm/year2022"
|
||||
},
|
||||
"799c679b63d6a588fcc717.json": {
|
||||
"label": "ngày /date 30 tháng /month 07 năm/year 2015",
|
||||
"pred": "ngày lone 30 thể 2 work 07 ndmyvar 2015"
|
||||
},
|
||||
"150094748_1728953773944481_6269983404281027305_n.json": {
|
||||
"label": "ngày/da e 15✪tháng # #th 04 năm /year 2020",
|
||||
"pred": "ngày/da - 15tháng % with 04 năm (year 2020"
|
||||
},
|
||||
"20221027_155711.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày date 16 tháng month 06 năm 'year 2020"
|
||||
},
|
||||
"20221027_155638.json": {
|
||||
"label": "ngày /date 05 tháng /month 04 năm/year✪2016",
|
||||
"pred": "ngày /date 05 tháng month 04 năm/year2016"
|
||||
},
|
||||
"20221027_155754.json": {
|
||||
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
|
||||
"pred": "ngày/date 19 tháng month 03 năm/year 2018"
|
||||
},
|
||||
"201417393_4052045311588809_501501345369021923_n.json": {
|
||||
"label": "ngày /date 27 tháng /month 07 năm/year 2020",
|
||||
"pred": "ngày date 27 tháng month 07 năm/year 2020"
|
||||
},
|
||||
"c50de8e0e2ad24f37dbc28.json": {
|
||||
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày date 24 tháng month 05 năm/year 2016"
|
||||
},
|
||||
"178033599_360318769019204_7688552975060615249_n.json": {
|
||||
"label": "ngày /date 13 tháng month 08 năm /year 2020",
|
||||
"pred": "ngày date 73 tháng month 08 năm year 2020"
|
||||
},
|
||||
"20221027_154755.json": {
|
||||
"label": "ngày /date 22 tháng /month 04 năm/year 2022",
|
||||
"pred": "ngày date 22 tháng month 04 năm/year 2022"
|
||||
},
|
||||
"20221027_154701.json": {
|
||||
"label": "ngày/date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": "ngày/date 27 tháng month 10 năm/year 2014"
|
||||
},
|
||||
"20221027_154617.json": {
|
||||
"label": "ngày/date 10 tháng /month 03 năm/year 2022",
|
||||
"pred": "ngày/date 10 than Wmonth 03 năm/year 2022"
|
||||
},
|
||||
"20221027_155429.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
|
||||
"pred": "ngày date 29 tháng month 10 năm/year 2020"
|
||||
},
|
||||
"38949066_1826046370807755_4672633932229902336_n.json": {
|
||||
"label": "ngày /date 03 tháng /month 07 năm /year 2017",
|
||||
"pred": "ngày date 03 tháng month 0Z năm year 2017"
|
||||
},
|
||||
"174353602_3093780914182491_6316781286680210887_n.json": {
|
||||
"label": "ngày /date 09 tháng /month 09 năm/year 2019",
|
||||
"pred": "ngày date 09 tháng month 09 năm/voce 2019"
|
||||
},
|
||||
"135575662_779633739578177_65454671165731184_n.json": {
|
||||
"label": "ngày /date 02 tháng /month 10 năm /year 2016",
|
||||
"pred": "ngày 'date 07 tháng month 10 năm Sear 2016"
|
||||
},
|
||||
"198291731_4210067705720978_7154894655460708366_n.json": {
|
||||
"label": "ngày /date 05 tháng month 05 năm/year 2014",
|
||||
"pred": "ngày date 05 tháng month 05 ndmivar 2014"
|
||||
},
|
||||
"20221027_155325.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
|
||||
},
|
||||
"20221027_155526.json": {
|
||||
"label": "ngày/date 14 tháng /month 01 năm/year✪2019",
|
||||
"pred": "ngày/date 14 tháng month 01 năm/year2019"
|
||||
},
|
||||
"20221027_155759.json": {
|
||||
"label": "ngày/date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày/date 24 tháng month 05 năm/year 2016"
|
||||
},
|
||||
"f40789388d754b2b126416.json": {
|
||||
"label": "ngày /date 10 tháng/month 09 năm /year 2020",
|
||||
"pred": "ngày /date 10thing month 09 năm year 2020"
|
||||
},
|
||||
"20221027_155453.json": {
|
||||
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": "ngày date 27 tháng month 10 năm/year 2014"
|
||||
},
|
||||
"88bb0caa03e7c5b99cf64.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày dg h 6AR ag M ath 06 năm hear 2020"
|
||||
},
|
||||
"20221027_154408.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 28 tháng month 05 năm year 2019"
|
||||
},
|
||||
"200064855_1976213642526776_4676665588314498194_n.json": {
|
||||
"label": "ngày /date 04 tháng /month 01 năm /year 2018",
|
||||
"pred": "ngày date 04 tháng /month 01 năm year 2018"
|
||||
},
|
||||
"61c259385775912bc86410.json": {
|
||||
"label": "ngày /date 20 tháng/month 09 năm/year 2017",
|
||||
"pred": "ngày /de l 20 th and 4 onth 09 năm/year 2017"
|
||||
},
|
||||
"20221027_155630.json": {
|
||||
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
|
||||
"pred": "ngày date 10 tháng month 09 năm year 2020"
|
||||
},
|
||||
"20221027_155342.json": {
|
||||
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
|
||||
"pred": "ngày/date 19 tháng month 03 năm/year 2018"
|
||||
},
|
||||
"165824171_1804443136392377_8891768953420682785_n.json": {
|
||||
"label": "ngày /date 03 tháng month 04 năm /year 2018",
|
||||
"pred": "ngày date 03 tháng month 04 năm year 2018"
|
||||
},
|
||||
"107005164_1003706713393058_3039921363490738151_n.json": {
|
||||
"label": "ngày /date 1 8 tháng month 11 năm /year 2019",
|
||||
"pred": "ngày 'date I 8 tháng month 11 năm vear 2019"
|
||||
},
|
||||
"20221027_155658.json": {
|
||||
"label": "ngày/date 14 tháng /month 01 năm/year 2019",
|
||||
"pred": "ngày/date 14 tháng month 01 năm/year 2019"
|
||||
},
|
||||
"20221027_154654.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày date 08 tháng month 12 năm/year 2015"
|
||||
},
|
||||
"f2badabfd1f217ac4ee327.json": {
|
||||
"label": "ngày /date 19 tháng /month 03 năm/year 2018",
|
||||
"pred": "ngày /date 19 their almonth 03 năm/year 2018"
|
||||
},
|
||||
"20221027_155501.json": {
|
||||
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
|
||||
"pred": "ngày/date 1 l tháng month 06 năm/year. 2014"
|
||||
},
|
||||
"74179950_806804509749947_7322741604127604736_n.json": {
|
||||
"label": "ngày /date 11 tháng /month 06 năm /year 2019",
|
||||
"pred": "ngày date 11 tháng month 06 năm \\/gar 2019"
|
||||
},
|
||||
"197892118_1197184050744529_3186157591590303981_n.json": {
|
||||
"label": "ngày /date 15 tháng /month 05 năm year 2015",
|
||||
"pred": "ngày date IS tháng month 05 năm year 2015"
|
||||
},
|
||||
"8265ab8ba0c666983fd722.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm /year 2015",
|
||||
"pred": "ngày /date 08 tháng month 12 năm year 2015"
|
||||
},
|
||||
"185421881_360318465685901_6968676669094190049_n.json": {
|
||||
"label": "ngày /date 13 tháng month 08 năm /year 2020",
|
||||
"pred": "ngày date 73 tháng month 08 năm year 2020"
|
||||
},
|
||||
"20221027_155142.json": {
|
||||
"label": "ngày/date 30 tháng/month 07 năm/year 2015",
|
||||
"pred": "ngày/date 30 tháng/month 07 năm/year 2015"
|
||||
},
|
||||
"39441751_326131871285503_8401816317220356096_n.json": {
|
||||
"label": "ngày /date 03 tháng /month 08 năm/year✪2017",
|
||||
"pred": "ngày 'date 03 tháng month 08 năm✪year✪2017"
|
||||
},
|
||||
"20221027_154912.json": {
|
||||
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
|
||||
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
|
||||
},
|
||||
"199235658_1994072707411372_221206969024509405_n.json": {
|
||||
"label": "ngày /date 05 tháng month 06 năm year 2020",
|
||||
"pred": "ngày date 05 tháng month 06 năm year 2020"
|
||||
},
|
||||
"168434217_1698404580364474_5439600436729489777_n.json": {
|
||||
"label": "ngày /date 05 tháng /month 11 năm/year 2020",
|
||||
"pred": "ngày date 05 tháng month 11 ndmhear 2020"
|
||||
},
|
||||
"20221027_154805.json": {
|
||||
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
|
||||
"pred": "ngày date 14 tháng month 01 năm/year 2019"
|
||||
},
|
||||
"e043761b7256b408ed4712.json": {
|
||||
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": "ngày date 27 tháng month 10 năm/year 2014"
|
||||
},
|
||||
"20221027_155401.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
|
||||
"pred": "ngày date 29 tháng /uponth 10 năm/year2013"
|
||||
},
|
||||
"20221027_155316.json": {
|
||||
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
|
||||
"pred": "ngày/date 20 tháng month 09 năm/year 2017"
|
||||
},
|
||||
"193926583_1626674417674644_309549447428666454_n.json": {
|
||||
"label": "ngày /date 29 tháng /month 07 năm/year 2020",
|
||||
"pred": "ngày (date 29 tháng /month 07 năm✪year 2020"
|
||||
},
|
||||
"20221027_154823.json": {
|
||||
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date or tháng month 04 năm year 2022"
|
||||
},
|
||||
"138942425_242694630705291_5683978028617807264_n.json": {
|
||||
"label": "ngày /date 30 tháng/month 01 năm/year✪2013",
|
||||
"pred": "ngày date 30 tháng/month 01 năm/year2013"
|
||||
},
|
||||
"20221027_155334.json": {
|
||||
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày date 24 tháng month 05 năm/year 2016"
|
||||
},
|
||||
"187421917_1668044206721318_779901369147309116_n.json": {
|
||||
"label": "ngày date 04 # # 05 # 2022",
|
||||
"pred": "ngày date 4 them Please 05 advise 2021"
|
||||
},
|
||||
"20221027_154716.json": {
|
||||
"label": "ngày /date 11 tháng /month 06 năm/year 2014",
|
||||
"pred": "ngày date 11 tháng month 06 năm/year 2014"
|
||||
},
|
||||
"5b1116c61d8bdbd5829a23.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
|
||||
"pred": "ngày date 29 tháng /month 10 năm/year 2020"
|
||||
},
|
||||
"20221027_154850.json": {
|
||||
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
|
||||
"pred": "ngày date 10 tháng month 09 năm year 2020"
|
||||
},
|
||||
"1489864e88034e5d17122.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày date 08 hàng month 12 năm/year 2015"
|
||||
},
|
||||
"199990418_1443812262638998_8173300652488821384_n.json": {
|
||||
"label": "ngày/date 10✪tháng /month 03 năm/year 2021",
|
||||
"pred": "ngày/date 10.50m Noun th 03 ndm/year 2021"
|
||||
},
|
||||
"139073668_833869180522062_7998364448555134241_n.json": {
|
||||
"label": "ngày/date 26 tháng /month 07 năm/year 2017",
|
||||
"pred": "ngày/date 26 tháng /month 07 năm/year2 2017"
|
||||
},
|
||||
"20221027_154528.json": {
|
||||
"label": "ngày date 19 tháng /month 03 năm/year 2018",
|
||||
"pred": "ngày date 19 tháng month 03 năm/year 2018"
|
||||
},
|
||||
"20221027_154423.json": {
|
||||
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
|
||||
"pred": "ngày/date 20 tháng month 09 năm/yew 2017"
|
||||
},
|
||||
"20221027_154722.json": {
|
||||
"label": "ngày /date 11 tháng /month 06 năm/year 2014",
|
||||
"pred": "ngày date 1 I tháng month 06 năm/year 2014"
|
||||
},
|
||||
"144003628_4199201026774245_6202264670231940239_n.json": {
|
||||
"label": "ngày date 07 tháng month 03 ######## ### #",
|
||||
"pred": "ngày date0 thôn Normmm 05 adortear 7"
|
||||
},
|
||||
"20221027_154434.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
|
||||
},
|
||||
"20221027_154630.json": {
|
||||
"label": "ngày/date 29 tháng /month 10 năm/year 2020",
|
||||
"pred": "ngày/date 29 tháng month 10 năm/year 2020"
|
||||
},
|
||||
"20221027_155552.json": {
|
||||
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
|
||||
"pred": "ngày date 10 tháng month 09 năm 'year 2020"
|
||||
},
|
||||
"131114177_1027132027767399_411142190418396877_n.json": {
|
||||
"label": "ngày /date 17 tháng /month 01 năm/year✪2018",
|
||||
"pred": "ngày date 17 tháng /month 01 năm/year2018"
|
||||
},
|
||||
"164703297_455738728964651_5260332814645460915_n.json": {
|
||||
"label": "ngày date 03 tháng month 11 năm/year 2020",
|
||||
"pred": "ngày date 03 tháng month I I năm/year 2020"
|
||||
},
|
||||
"20221027_155926.json": {
|
||||
"label": "ngày/date 20 tháng /month 09 năm /year 2017",
|
||||
"pred": "ngày/date 20 tháng month 09 năm year 2017"
|
||||
},
|
||||
"20221027_154832.json": {
|
||||
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
|
||||
"pred": "ngày date 05 tháng month 04 năm/year 2016 6"
|
||||
},
|
||||
"dd09b9b6b2fb74a52dea24.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "123) d 2010 Way Yount 06 năm year 2020"
|
||||
},
|
||||
"20221027_154646.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày date 08 2 tháng month 12 năm/year2 2015"
|
||||
},
|
||||
"180534342_1213803569050037_4381710158357942629_n.json": {
|
||||
"label": "ngày /date 20 tháng /month 10 năm/year 21",
|
||||
"pred": "ngày date 20 tháng month 10 năm/year 2"
|
||||
},
|
||||
"20221027_155443.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày date 08 tháng month 12 năm/year 2015"
|
||||
},
|
||||
"25a9717c7b31bd6fe42029.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 tháng month 07 năm hear 2019"
|
||||
},
|
||||
"a0a8281f2652e00cb9435.json": {
|
||||
"label": "ngày/date 10 tháng /month 03 năm/year 2022",
|
||||
"pred": "ngày/date 10 that abroath 03 năm/year 2022"
|
||||
},
|
||||
"48793dfd37b0f1eea8a130.json": {
|
||||
"label": "tháng/month 09 năm/year\n ngày/date 20 tháng/month\n 163",
|
||||
"pred": "ngày/date 20 chăn knowin năm/ya\n 163"
|
||||
},
|
||||
"20221027_154730.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày date to tháng month 06 năm year 2020"
|
||||
},
|
||||
"174102242_893537741194123_1381062036549019974_n.json": {
|
||||
"label": "ngày /date 11 tháng /month 11 năm/year 2019",
|
||||
"pred": "ngày date 17 tháng Thuonth 11 năm/year 2019"
|
||||
},
|
||||
"20221027_154541.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
|
||||
"pred": "ngày date 29 tháng month 10 năm/year2013"
|
||||
},
|
||||
"20221027_155939.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 28 tháng month 05 năm 'year 2019"
|
||||
},
|
||||
"20221027_154452.json": {
|
||||
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày date 24 tháng month 05 năm/year 2016"
|
||||
},
|
||||
"104353445_990772771353119_6131582365614146594_n.json": {
|
||||
"label": "ngày date 18 tháng /month 03 năm /year 2019",
|
||||
"pred": "ngày date If tháng month 03 năm year 2019"
|
||||
},
|
||||
"20221027_155418.json": {
|
||||
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
|
||||
"pred": "ngày/date 10than dmonth 03 năm/year 2022"
|
||||
},
|
||||
"20221027_155916.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 tháng month 07 năm 'year 2019"
|
||||
},
|
||||
"195887607_545276056640128_7265052621888807786_n.json": {
|
||||
"label": "ngày /date 12 tháng /month 01 năm /year 2017",
|
||||
"pred": "ngày /dote 12 tháng month 01 năm hear 201 7"
|
||||
},
|
||||
"168303942_358282189193092_4968412916165104911_n.json": {
|
||||
"label": "ngày/date 19 tháng month 03 năm year 2015",
|
||||
"pred": "ngày/date 19 tháng month 03 năm your 1019"
|
||||
},
|
||||
"6a70d15bd51613484a0713.json": {
|
||||
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date n háu Z with 04 năm year 2022"
|
||||
},
|
||||
"20221027_155302.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 28 tháng month 05 năm 'year 2019"
|
||||
},
|
||||
"20221027_154815.json": {
|
||||
"label": "ngày /date 01 tháng /month 04 năm/year 2022",
|
||||
"pred": "ngày date or tháng month 04 năm/year 2022"
|
||||
},
|
||||
"c87b81298e64483a11757.json": {
|
||||
"label": "ngày /date 19 tháng /month 03 năm/year 201",
|
||||
"pred": "ngày date 19 thớ almonth 03 năm/year 201"
|
||||
},
|
||||
"20221027_155620.json": {
|
||||
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
|
||||
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
|
||||
},
|
||||
"20221027_155511.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày date to tháng month 06 năm 'year 2020"
|
||||
},
|
||||
"745c5d6b52269478cd379.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 thân g worth 07 năm hear 2019"
|
||||
},
|
||||
"9329511f5552930cca4315.json": {
|
||||
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
|
||||
"pred": "ngày/date 11 tháng month 06 năm/year 2014"
|
||||
},
|
||||
"158882925_262850065433814_5526034984745996835_n.json": {
|
||||
"label": "ngày /date 16 tháng month 12 năm/year 2015",
|
||||
"pred": "ngày date 16 tháng month 12 năm/year 2015"
|
||||
},
|
||||
"140687263_3755683421155059_7637736837539526203_n.json": {
|
||||
"label": "năm/year 2017\n ngày /date # # # # 08 năm/year",
|
||||
"pred": "năm 2017\n ngày 'date 422 ins 1.1 ma 08 năm"
|
||||
},
|
||||
"20221027_155717.json": {
|
||||
"label": "ngày/date 11 tháng /month 06 năm/year 2014",
|
||||
"pred": "ngày/date 11 tháng month 06 năm/year 2014"
|
||||
},
|
||||
"148919455_2877898732481724_2579276238538203411_n.json": {
|
||||
"label": "ngày /date 05 tháng /month 09 năm/year✪2018",
|
||||
"pred": "ngày date 05 thán g /month 09 năm/year2018"
|
||||
},
|
||||
"c8b0dc9cd8d11e8f47c014.json": {
|
||||
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
|
||||
"pred": "ngày /date 0.5 tháng g/month 04 năm/year 2016"
|
||||
},
|
||||
"175913976_2827333254262221_2873818403698028020_n.json": {
|
||||
"label": "ngày date 05 tháng month 01 năm year 2018",
|
||||
"pred": "ngày dan us thông month 01 năm year 2018"
|
||||
},
|
||||
"20221027_155029.json": {
|
||||
"label": "ngày/date 30 tháng/month\n 07 năm/year 2015",
|
||||
"pred": "ngày/date\n năm/year 2015"
|
||||
},
|
||||
"196165776_1160925321042008_58817602967276351_n.json": {
|
||||
"label": "ngày date 01 tháng month 09 năm year 2020",
|
||||
"pred": "ngày date or tháng month 09 năm year 2020"
|
||||
},
|
||||
"162820484_451115505943597_8326385834717580925_n.json": {
|
||||
"label": "ngày /date 27 tháng /month 08 thường 2014",
|
||||
"pred": "ngày/ date 27 tháng furonth 08 năm/year 2014"
|
||||
},
|
||||
"41594446_1316256461838413_661251515624718336_n.json": {
|
||||
"label": "ngày /date 21 tháng /month 06 năm/year✪2017",
|
||||
"pred": "ngày /date 27 tháng month 06 năm/year201 7"
|
||||
},
|
||||
"20221027_155728.json": {
|
||||
"label": "ngày /date 08 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày date 08 tháng month 12 năm/year 2015"
|
||||
},
|
||||
"142090201_2826170774286852_1233962294093312865_n.json": {
|
||||
"label": "ngày /date 29 tháng month 07 năm /year 2019",
|
||||
"pred": "ngày date 29 tháng month 07 năm year 2019"
|
||||
},
|
||||
"0dd2fe8bf5c633986ad726.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year✪2013",
|
||||
"pred": "march date 29 tháng month 10✪năm✪volur✪2013"
|
||||
},
|
||||
"190919701_1913643422140446_6855763478065892825_n.json": {
|
||||
"label": "ngày date 24 tháng /month 07 năm/year 2017",
|
||||
"pred": "ngày /date 24 tháng month 07 năm/year 2017"
|
||||
},
|
||||
"147585615_2757487791230682_5515346433540820516_n.json": {
|
||||
"label": "ngày /date # 9✪tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 2 Danate month 05 năm year 2019"
|
||||
},
|
||||
"5fcae09ee4d3228d7bc211.json": {
|
||||
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
|
||||
"pred": "ngày 'date 14 tháng month 0.1 năm/year 2019"
|
||||
},
|
||||
"20221027_154417.json": {
|
||||
"label": "ngày/date 20 tháng /month 09 năm/year 2017",
|
||||
"pred": "ngày/date 20 thân ghi onth 09 năm/year 2017"
|
||||
},
|
||||
"20221027_154802.json": {
|
||||
"label": "ngày /date 14 tháng /month 01 năm/year 2019",
|
||||
"pred": "ngày date 14 tháng month 01 năm/year 2019"
|
||||
},
|
||||
"20221027_154749.json": {
|
||||
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date 22 tháng month 04 năm year 2022"
|
||||
},
|
||||
"20221027_154900.json": {
|
||||
"label": "ngày /date 10 tháng /month 09 năm /year 2020",
|
||||
"pred": "ngày date 10 tháng month 09 năm Year 2020"
|
||||
},
|
||||
"4b2b35453e08f856a11925.json": {
|
||||
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
|
||||
"pred": "ngày/date 1 in 03 năm/year 2022"
|
||||
},
|
||||
"20221027_155734.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
|
||||
"pred": "ngày date 29 tháng month 10 năm/year 2020"
|
||||
},
|
||||
"198876248_2931797967062357_4287721016641237281_n.json": {
|
||||
"label": "ngày /date 29 áng /month 10 năm /year 2019",
|
||||
"pred": "ngày Warm 204 đón Quanth 10 năm vear 2019"
|
||||
},
|
||||
"20221027_154613.json": {
|
||||
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
|
||||
"pred": "ngày/date 10than Imonth 03 năm/year 2022"
|
||||
},
|
||||
"20221027_154600.json": {
|
||||
"label": "ngày/date 29 tháng /month 10 năm/year✪2013",
|
||||
"pred": "ngày/date 29 tháng month 10 năm/year2013"
|
||||
},
|
||||
"191389634_910736173104716_4923402486196996972_n.json": {
|
||||
"label": "ngày /date 24 tháng /month 02 năm/year 2021",
|
||||
"pred": "ngày date 24 tháng month 02 năm/year 2021"
|
||||
},
|
||||
"20221027_155723.json": {
|
||||
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": "ngày /date 27 tháng month 10 năm/year 2014"
|
||||
},
|
||||
"184606042_1586323798376373_2179113485447088932_n.json": {
|
||||
"label": "ngày /date 29 tháng /month 07 năm/year 2020",
|
||||
"pred": "ngày (date 29 tháng /month 07 năm✪year 2020"
|
||||
},
|
||||
"6fb460d06a9dacc3f58c31.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 28 tháng month 05 năm year 2019"
|
||||
},
|
||||
"7e810ce502a8c4f69db93.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm /year 2020",
|
||||
"pred": "ngày de it 29 thing month 10 năm Year 2020"
|
||||
},
|
||||
"20221027_154400.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019",
|
||||
"pred": "ngày date 28 tháng month 05 năm year 2019"
|
||||
},
|
||||
"139579296_107198731349013_7325819456715999953_n.json": {
|
||||
"label": "ngày /date 10tháng month 05 năm /year 2018",
|
||||
"pred": "ngày date 1 month 05 năm hear 2018"
|
||||
},
|
||||
"20221027_155748.json": {
|
||||
"label": "ngày/date 29 tháng/month 10 năm/year✪2013",
|
||||
"pred": "ngày/date 29 tháng/month 10 năm/year2013"
|
||||
},
|
||||
"164359233_2788848161366629_6843431895499380423_n.json": {
|
||||
"label": "ngày /date 25 tháng /month 12 năm/year 2015",
|
||||
"pred": "ngày dute 25 tháng month 12 năm/year 2015"
|
||||
},
|
||||
"962650ff5eb298ecc1a36.json": {
|
||||
"label": "ngày /date 29 tháng/month 10 năm/year 2013",
|
||||
"pred": "10 năm/year 2013\n ngày the 29thanginonth\n TL"
|
||||
},
|
||||
"20221027_155600.json": {
|
||||
"label": "ngày/date 30 tháng /month 07 năm/year 2015",
|
||||
"pred": "ngày/date 30 tháng month 07 năm/year 2015"
|
||||
},
|
||||
"951970367f7bb925e06a1.json": {
|
||||
"label": "ngày /date 28 tháng /month 05 năm /year 2019 -",
|
||||
"pred": "ngày late 28 hàng worth 05 năm hear 2019"
|
||||
},
|
||||
"20221027_154627.json": {
|
||||
"label": "ngày /date 29 tháng /month 10 năm/year 2020",
|
||||
"pred": "ngày date 29 tháng month 10 năm/year. 2020"
|
||||
},
|
||||
"20221027_155535.json": {
|
||||
"label": "ngày /date 01 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date 01 tháng month 04 năm 'year 2022"
|
||||
},
|
||||
"106402928_1000018507095212_5438034148254460378_n.json": {
|
||||
"label": "ngày date 04 tháng month 09 năm /year 2019",
|
||||
"pred": "ngày dan 04 tháng month 09 năm Year 2019"
|
||||
},
|
||||
"20221027_154458.json": {
|
||||
"label": "ngày /date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày date 24 tháng month 03 năm/year 2016"
|
||||
},
|
||||
"190841312_3057702594458479_8551202571498845435_n.json": {
|
||||
"label": "ngày /date 14 tháng month 12 năm /year 2018",
|
||||
"pred": "ngày date 14 tháng month 12 năm year 2018"
|
||||
},
|
||||
"20221027_154907.json": {
|
||||
"label": "ngày/date 30 tháng /month 07 năm/year✪2015",
|
||||
"pred": "ngày/date 30 tháng month 07 năm/year2015"
|
||||
},
|
||||
"136098726_3413968628702123_4090292519699106839_n.json": {
|
||||
"label": "ngày /date 20tháng /month 11 năm /year 2020",
|
||||
"pred": "ngày date 70tháng month 11 năm year 2020"
|
||||
},
|
||||
"20221027_154440.json": {
|
||||
"label": "ngày /date 09 tháng /month 07 năm /year 2019",
|
||||
"pred": "ngày date 09 tháng mc each 07 năm year 2019"
|
||||
},
|
||||
"07966c1b6256a408fd478.json": {
|
||||
"label": "ngày/date 24 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày/d te 24 thán day anth 05 năm/year 2016"
|
||||
},
|
||||
"20221027_155740.json": {
|
||||
"label": "ngày/date 10 tháng/month 03 năm/year 2022",
|
||||
"pred": "ngày/date 10than almonth 03 năm/year 2022"
|
||||
},
|
||||
"130727988_1377512982599930_6481917606912865462_n.json": {
|
||||
"label": "ngày/date 30 tháng /month 05 năm/year 2016",
|
||||
"pred": "ngày/date 30 tháng month 05 năm/year 2016"
|
||||
},
|
||||
"194993073_539716147195165_8525378287933192246_n.json": {
|
||||
"label": "ngày /date 12 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày date 12 tháng /month 06 năm vear 2020"
|
||||
},
|
||||
"20221027_154521.json": {
|
||||
"label": "ngày/date 19 tháng /month 03 năm/year 2018",
|
||||
"pred": "ngày/dec 19 tháng month 03 năm/year 2018"
|
||||
},
|
||||
"174247511_900878714088455_7516565117828455890_n.json": {
|
||||
"label": "ngày/date 1 0✪tháng /month 05 năm/year 2018",
|
||||
"pred": "ngày/date 4 Other month 05 năm/year 2018"
|
||||
},
|
||||
"20221027_155519.json": {
|
||||
"label": "ngày /date 22 tháng /month 04 năm /year 2022",
|
||||
"pred": "ngày date 22 tháng month 04 năm year 2022"
|
||||
},
|
||||
"20221027_154706.json": {
|
||||
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": "ngày date 27 tháng month 10 năm/year. 2014"
|
||||
},
|
||||
"a88ae66aed272b79723621.json": {
|
||||
"label": "ngày /date 27 tháng /month 10 năm/year 2014",
|
||||
"pred": ""
|
||||
},
|
||||
"4378298-95583fbb1edb703a6c5bbc1744246058-1-1.json": {
|
||||
"label": "ngày/d 24 tháng /month 1 năm/year 2015",
|
||||
"pred": "ngày/d / 24 than chuonth 1 1 năm/year 2015"
|
||||
},
|
||||
"20221027_155704.json": {
|
||||
"label": "ngày /date 22 tháng /month 04 năm/year 2022",
|
||||
"pred": "ngày date 22 tháng month 04 năm hear 2022"
|
||||
},
|
||||
"179642128_1945301335636182_5557211235870766646_n.json": {
|
||||
"label": "ngày/date #✪thá# # onth 10 năm/year 201",
|
||||
"pred": "ngowdan (Ký - touth 10 năm/year 201"
|
||||
},
|
||||
"20221027_154734.json": {
|
||||
"label": "ngày /date 16 tháng /month 06 năm /year 2020",
|
||||
"pred": "ngày date to tháng month 06 năm year 2020"
|
||||
},
|
||||
"20221027_155542.json": {
|
||||
"label": "ngày /date 05 tháng /month 04 năm/year 2016",
|
||||
"pred": "ngày /date 05 tháng month 04 năm/year 2016"
|
||||
}
|
||||
}
|
196
cope2n-ai-fi/common/json2xml.py
Executable file
196
cope2n-ai-fi/common/json2xml.py
Executable file
@ -0,0 +1,196 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime
|
||||
ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#")
|
||||
|
||||
|
||||
xml_template3 = """
|
||||
<HDon>
|
||||
<DLHDon>
|
||||
<TTChung>
|
||||
<PBan>None</PBan>
|
||||
<THDon>None</THDon>
|
||||
<KHMSHDon>None</KHMSHDon>
|
||||
<KHHDon>None</KHHDon>
|
||||
<SHDon>None</SHDon>
|
||||
<NLap>None</NLap>
|
||||
<DVTTe>None</DVTTe>
|
||||
<TGia>None</TGia>
|
||||
<HTTToan>None</HTTToan>
|
||||
<MSTTCGP>None</MSTTCGP>
|
||||
</TTChung>
|
||||
<NDHDon>
|
||||
<NBan>
|
||||
<Ten>None</Ten>
|
||||
<MST>None</MST>
|
||||
<DChi>None</DChi>
|
||||
<SDThoai>None</SDThoai>
|
||||
</NBan>
|
||||
<NMua>
|
||||
<Ten>None</Ten>
|
||||
<MST>None</MST>
|
||||
<DChi>None</DChi>
|
||||
<SDThoai>None</SDThoai>
|
||||
<HVTNMHang>None</HVTNMHang>
|
||||
</NMua>
|
||||
<DSHHDVu>
|
||||
<HHDVu>
|
||||
<TChat>None</TChat>
|
||||
<STT>None</STT>
|
||||
<THHDVu>None</THHDVu>
|
||||
<DVTinh>None</DVTinh>
|
||||
<SLuong>None</SLuong>
|
||||
<DGia>None</DGia>
|
||||
<TLCKhau>None</TLCKhau>
|
||||
<STCKhau>None</STCKhau>
|
||||
<ThTien>None</ThTien>
|
||||
<TSuat>None</TSuat>
|
||||
</HHDVu>
|
||||
</DSHHDVu>
|
||||
<TToan>
|
||||
<THTTLTSuat>
|
||||
<LTSuat>
|
||||
<TSuat>None</TSuat>
|
||||
<ThTien>None</ThTien>
|
||||
<TThue>None</TThue>
|
||||
</LTSuat>
|
||||
</THTTLTSuat>
|
||||
<TgTCThue>None</TgTCThue>
|
||||
<TgTThue>None</TgTThue>
|
||||
<TTCKTMai>None</TTCKTMai>
|
||||
<TgTTTBSo>None</TgTTTBSo>
|
||||
<TgTTTBChu>None</TgTTTBChu>
|
||||
</TToan>
|
||||
</NDHDon>
|
||||
</DLHDon>
|
||||
<MCCQT Id="None">None</MCCQT>
|
||||
<DLQRCode>None</DLQRCode>
|
||||
</HDon>
|
||||
"""
|
||||
|
||||
def replace_xml_values(xml_str, replacement_dict):
|
||||
""" replace xml values
|
||||
|
||||
Args:
|
||||
xml_str (_type_): _description_
|
||||
replacement_dict (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
try:
|
||||
root = ET.fromstring(xml_str)
|
||||
for key, value in replacement_dict.items():
|
||||
if not value:
|
||||
continue
|
||||
if key == "TToan":
|
||||
ttoan_element = root.find(".//TToan")
|
||||
tsuat_element = ttoan_element.find(".//TgTThue")
|
||||
tthue_element = ttoan_element.find(".//TgTTTBSo")
|
||||
tthuebchu_element = ttoan_element.find(".//TgTTTBChu")
|
||||
if value["TgTThue"]:
|
||||
tsuat_element.text = value["TgTThue"]
|
||||
if value["TgTTTBSo"]:
|
||||
tthue_element.text = value["TgTTTBSo"]
|
||||
if value["TgTTTBChu"]:
|
||||
tthuebchu_element.text = value["TgTTTBChu"]
|
||||
elif key == "NMua":
|
||||
nmua_element = root.find(".//NMua")
|
||||
for key_ in ["DChi", "SDThoai", "MST", "Ten", "HVTNMHang"]:
|
||||
if value.get(key_, None):
|
||||
nmua_element_key = nmua_element.find(f".//{key_}")
|
||||
nmua_element_key.text = value[key_]
|
||||
elif key == "NBan":
|
||||
nban_element = root.find(".//NBan")
|
||||
for key_ in ["DChi", "SDThoai", "MST", "Ten"]:
|
||||
if value.get(key_, None):
|
||||
nban_element_key = nban_element.find(f".//{key_}")
|
||||
nban_element_key.text = value[key_]
|
||||
elif key == "HHDVu":
|
||||
dshhdvu_element = root.find(".//DSHHDVu")
|
||||
hhdvu_template = root.find(".//HHDVu")
|
||||
if hhdvu_template is not None and dshhdvu_element is not None:
|
||||
dshhdvu_element.remove(hhdvu_template) # Remove the template
|
||||
for hhdvu_data in value:
|
||||
hhdvu_element = ET.SubElement(dshhdvu_element, "HHDVu")
|
||||
for h_key, h_value in hhdvu_data.items():
|
||||
h_element = ET.SubElement(hhdvu_element, h_key)
|
||||
h_element.text = h_value if h_value is not None else "None"
|
||||
elif key == "NLap":
|
||||
nlap_element = root.find(".//NLap")
|
||||
if nlap_element is not None:
|
||||
# Convert the date to yyyy-mm-dd format
|
||||
try:
|
||||
date_obj = datetime.strptime(value, "%d/%m/%Y")
|
||||
formatted_date = date_obj.strftime("%Y-%m-%d")
|
||||
nlap_element.text = formatted_date
|
||||
except ValueError:
|
||||
print(f"Invalid date format for {key}: {value}")
|
||||
nlap_element.text = value
|
||||
else:
|
||||
element = root.find(f".//{key}")
|
||||
if element is not None:
|
||||
element.text = value
|
||||
ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#")
|
||||
return ET.tostring(root, encoding="unicode")
|
||||
except ET.ParseError as e:
|
||||
print(f"Error parsing XML: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_key_names(original_dict):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
original_dict (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
key_mapping = {
|
||||
"table": "HHDVu",
|
||||
"Mặt hàng": "THHDVu",
|
||||
"Đơn vị tính": "DVTinh",
|
||||
"Số lượng": "SLuong",
|
||||
"Đơn giá": "DGia",
|
||||
"Doanh số mua chưa có thuế": "ThTien",
|
||||
"buyer_address_value": "NMua.DChi",
|
||||
'buyer_company_name_value': 'NMua.Ten',
|
||||
'buyer_personal_name_value': 'NMua.HVTNMHang',
|
||||
'buyer_tax_code_value': 'NMua.MST',
|
||||
'buyer_tel_value': 'NMua.SDThoai',
|
||||
'seller_address_value': 'NBan.DChi',
|
||||
'seller_company_name_value': 'NBan.Ten',
|
||||
'seller_tax_code_value': 'NBan.MST',
|
||||
'seller_tel_value': 'NBan.SDThoai',
|
||||
'date_value': 'NLap',
|
||||
'form_value': 'KHMSHDon',
|
||||
'no_value': 'SHDon',
|
||||
'serial_value': 'KHHDon',
|
||||
'tax_amount_value': 'TToan.TgTThue',
|
||||
'total_in_words_value': 'TToan.TgTTTBChu',
|
||||
'total_value': 'TToan.TgTTTBSo'
|
||||
}
|
||||
|
||||
converted_dict = {}
|
||||
for key, value in original_dict.items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
if "." in new_key:
|
||||
parts = new_key.split(".")
|
||||
current_dict = converted_dict
|
||||
for i, part in enumerate(parts):
|
||||
if i == len(parts) - 1:
|
||||
current_dict[part] = value
|
||||
else:
|
||||
current_dict.setdefault(part, {})
|
||||
current_dict = current_dict[part]
|
||||
else:
|
||||
if key == "table":
|
||||
lconverted_table_values = []
|
||||
for table_value in value:
|
||||
converted_table_value = convert_key_names(table_value)
|
||||
lconverted_table_values.append(converted_table_value)
|
||||
converted_dict[new_key] = lconverted_table_values
|
||||
else:
|
||||
converted_dict[new_key] = value
|
||||
|
||||
return converted_dict
|
37
cope2n-ai-fi/common/ocr.py
Executable file
37
cope2n-ai-fi/common/ocr.py
Executable file
@ -0,0 +1,37 @@
|
||||
from common.utils.ocr_yolox import OcrEngineForYoloX_ID_Driving
|
||||
from common.utils.word_formation import Word, words_to_lines
|
||||
|
||||
det_ckpt = "yolox-s-general-text-pretrain-20221226"
|
||||
cls_ckpt = "satrn-lite-general-pretrain-20230106"
|
||||
|
||||
engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt)
|
||||
|
||||
|
||||
def ocr_predict(image):
|
||||
"""Predict text from image
|
||||
|
||||
Args:
|
||||
image_path (str): _description_
|
||||
|
||||
Returns:
|
||||
list: list of words
|
||||
"""
|
||||
try:
|
||||
lbboxes, lwords = engine.run_image(image)
|
||||
lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)]
|
||||
list_lines, _ = words_to_lines(lWords)
|
||||
return list_lines
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
list_lines = []
|
||||
return list_lines
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--image", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
list_lines = ocr_predict(args.image)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user