diff --git a/.gitignore b/.gitignore index 8004166..1ae0e35 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ backup/ *.sqlite3 *.log __pycache__ -migrations/ \ No newline at end of file +migrations/ +test/ +._git/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..c1f5013 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cope2n-ai-fi/modules/sdsvkvu"] + path = cope2n-ai-fi/modules/sdsvkvu + url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu diff --git a/cope2n-ai-fi/._gitmodules b/cope2n-ai-fi/._gitmodules new file mode 100644 index 0000000..9483200 --- /dev/null +++ b/cope2n-ai-fi/._gitmodules @@ -0,0 +1,6 @@ +[submodule "modules/sdsvkvu"] + path = modules/sdsvkvu + url = https://code.sdsdev.co.kr/tuanlv/sdsvkvu.git +[submodule "modules/ocr_engine"] + path = modules/ocr_engine + url = https://code.sdsdev.co.kr/tuanlv/IDP-BasicOCR.git diff --git a/cope2n-ai-fi/.dockerignore b/cope2n-ai-fi/.dockerignore new file mode 100755 index 0000000..70a2b49 --- /dev/null +++ b/cope2n-ai-fi/.dockerignore @@ -0,0 +1,7 @@ +.github +.git +.vscode +__pycache__ +DataBase/image_temp/ +DataBase/json_temp/ +DataBase/template.db \ No newline at end of file diff --git a/cope2n-ai-fi/.gitignore b/cope2n-ai-fi/.gitignore new file mode 100755 index 0000000..f2c8e28 --- /dev/null +++ b/cope2n-ai-fi/.gitignore @@ -0,0 +1,21 @@ +.vscode +__pycache__ +DataBase/image_temp/ +DataBase/json_temp/ +DataBase/template.db +sdsvtd/ +sdsvtr/ +sdsvkie/ +detectron2/ +output/ +data/ +models/ +server/ +image_logs/ +experiments/ +weights/ +packages/ +tmp_results/ +.env +.zip +.json \ No newline at end of file diff --git a/cope2n-ai-fi/Dockerfile b/cope2n-ai-fi/Dockerfile new file mode 100755 index 0000000..135713e --- /dev/null +++ b/cope2n-ai-fi/Dockerfile @@ -0,0 +1,43 @@ +# FROM thucpd2408/env-cope2n:v1 +FROM thucpd2408/env-deskew + +COPY ./packages/cudnn-linux*.tar.xz /tmp/cudnn-linux*.tar.xz + +RUN tar -xvf /tmp/cudnn-linux*.tar.xz -C /tmp/ \ + && cp /tmp/cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include \ + && cp -P /tmp/cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 \ + && chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn* \ + && rm -rf /tmp/cudnn-*-archive + +RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6 +# RUN pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 + +WORKDIR /workspace + + +COPY ./modules/ocr_engine/externals/ /workspace/cope2n-ai-fi/modules/ocr_engine/externals/ +COPY ./modules/ocr_engine/requirements.txt /workspace/cope2n-ai-fi/modules/ocr_engine/requirements.txt +COPY ./modules/sdsvkie/ /workspace/cope2n-ai-fi/modules/sdsvkie/ +COPY ./modules/sdsvkvu/ /workspace/cope2n-ai-fi/modules/sdsvkvu/ +COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt + +RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtd && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtr && pip3 install -v -e . + +# RUN cd /workspace/cope2n-ai-fi/modules/ocr_engine/ && pip3 install -r requirements.txt +RUN cd /workspace/cope2n-ai-fi/modules/sdsvkie && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi/modules/sdsvkvu && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt + +RUN rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \ + rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \ + rm -f /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11 && \ + ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublasLt.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublasLt.so.11 && \ + ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcublas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libcublas.so.11 && \ + ln -s /usr/local/cuda-11.8/targets/x86_64-linux/lib/libnvblas.so.11 /usr/local/lib/python3.10/dist-packages/nvidia/cublas/lib/libnvblas.so.11 + +ENV PYTHONPATH="." + +CMD [ "sh", "run.sh"] +# CMD ["tail -f > /dev/null"] \ No newline at end of file diff --git a/cope2n-ai-fi/Dockerfile-dev b/cope2n-ai-fi/Dockerfile-dev new file mode 100755 index 0000000..f58c519 --- /dev/null +++ b/cope2n-ai-fi/Dockerfile-dev @@ -0,0 +1,21 @@ +FROM thucpd2408/env-cope2n:v1 + +RUN apt-get update && apt-get install -y gcc g++ ffmpeg libsm6 libxext6 + +WORKDIR /workspace + +COPY ./requirements.txt /workspace/cope2n-ai-fi/requirements.txt +COPY ./sdsvkie/ /workspace/cope2n-ai-fi/sdsvkie/ +COPY ./sdsvtd /workspace/cope2n-ai-fi/sdsvtd/ +COPY ./sdsvtr/ /workspace/cope2n-ai-fi/sdsvtr/ +COPY ./models/ /models + +RUN cd /workspace/cope2n-ai-fi && pip3 install -r requirements.txt +RUN cd /workspace/cope2n-ai-fi/sdsvkie && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi/sdsvtd && pip3 install -v -e . +RUN cd /workspace/cope2n-ai-fi/sdsvtr && pip3 install -v -e . + +ENV PYTHONPATH="." + +CMD [ "sh", "run.sh"] +# CMD ["tail -f > /dev/null"] \ No newline at end of file diff --git a/cope2n-ai-fi/Dockerfile_fwd b/cope2n-ai-fi/Dockerfile_fwd new file mode 100644 index 0000000..e93fc0c --- /dev/null +++ b/cope2n-ai-fi/Dockerfile_fwd @@ -0,0 +1,8 @@ +FROM hisiter/fwd_env:1.0.0 +RUN apt-get update && apt-get install -y \ + libssl-dev \ + libxml2-dev \ + libxslt-dev \ + && rm -rf /var/lib/apt/lists/* +RUN pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +RUN pip install paddleocr>=2.0.1 \ No newline at end of file diff --git a/cope2n-ai-fi/LICENSE b/cope2n-ai-fi/LICENSE new file mode 100755 index 0000000..f288702 --- /dev/null +++ b/cope2n-ai-fi/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/cope2n-ai-fi/NOTE.md b/cope2n-ai-fi/NOTE.md new file mode 100755 index 0000000..9d8e22d --- /dev/null +++ b/cope2n-ai-fi/NOTE.md @@ -0,0 +1,7 @@ +#### Environment +``` +docker run -itd --privileged --name=TannedCungnoCope2n-ai-fi \ + -v /mnt/hdd2T/dxtan/TannedCung/OCR/cope2n-ai-fi:/workspace \ + tannedcung/mmocr:latest \ + tail -f > /dev/null +``` \ No newline at end of file diff --git a/cope2n-ai-fi/README.md b/cope2n-ai-fi/README.md new file mode 100755 index 0000000..72c209b --- /dev/null +++ b/cope2n-ai-fi/README.md @@ -0,0 +1,36 @@ +# AI-core + +## Add your files + +- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files +- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: + +```bash +cd existing_repo +git remote add origin http://code.sdsrv.vn/c-ope2n/ai-core.git +git branch -M main +git push -uf origin main +``` + +## Develop + +Assume you are at root folder with struct: + +```bash +. +├── cope2n-ai-fi +├── cope2n-api +├── cope2n-fe +├── .env +└── docker-compose-dev.yml +``` + +Run: `docker-compose -f docker-compose-dev.yml up --build -d` to bring the project alive + +## Environment Variables + +Variable | Default | Usage | Note +------------- | ------- | ----- | ---- +CELERY_BROKER | $250 | | | +SAP_KIE_MODEL | $80 | | | +FI_KIE_MODEL | $420 | | | diff --git a/cope2n-ai-fi/TODO.md b/cope2n-ai-fi/TODO.md new file mode 100644 index 0000000..9a1c625 --- /dev/null +++ b/cope2n-ai-fi/TODO.md @@ -0,0 +1,26 @@ +## Bring abs path to relative + +- [x] save_dir `Kie_Invoice_AP/prediction_sap.py:18` +- [x] detector `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor) +- [x] rotator_version `Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml` [_fixed_](#refactor) +- [x] cfg `Kie_Invoice_AP/prediction_fi.py` +- [x] weight `Kie_Invoice_AP/prediction_fi.py` +- [x] save_dir `Kie_Invoice_AP/prediction_fi.py:18` + +## Bring abs path to .env + +- [x] CELERY_BROKER: +- [ ] SAP_KIE_MODEL: `Kie_Invoice_AP/prediction_sap.py:20` [_NEED_REFACTOR_](#refactor) +- [ ] FI_KIE_MODEL: `Kie_Invoice_AP/prediction_fi.py:20` [_NEED_REFACTOR_](#refactor) + +## Possible logic confict + +### Refactor + +- [ ] Each model should be loaded in a docker container and serve as a service +- [ ] Some files (weights, ...) should be mounted in container in a format for endurability +- [ ] `Kie_Invoice_AP/prediction_fi.py` and `Kie_Invoice_AP/prediction_fi.py` should be merged into a single file as it shared resources with different logic +- [ ] `Kie_Invoice_AP/prediction.py` seems to be the base function, this should act as a proxy which import all other `predict_{anything else}` functions +- [ ] There should be a unique folder to keep all models with different versions then mount as /models in container. Currently, `fi` is loading from `/models/Kie_invoice_fi` while `sap` is loading from `Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20231003-171748`. Another model weight is at `sdsvtd/hub` for unknown reason +- [ ] Env variables should have its description in README +- [ ] diff --git a/cope2n-ai-fi/api/Kie_AHung/prediction.py b/cope2n-ai-fi/api/Kie_AHung/prediction.py new file mode 100755 index 0000000..4755df2 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_AHung/prediction.py @@ -0,0 +1,149 @@ +from PIL import Image +import cv2 + +import numpy as np +from transformers import ( + LayoutXLMTokenizer, + LayoutLMv2FeatureExtractor, + LayoutXLMProcessor, + LayoutLMv2ForTokenClassification, +) + +from common.utils.word_formation import * + +from common.utils.global_variables import * +from common.utils.process_label import * +import ssl + +ssl._create_default_https_context = ssl._create_unverified_context +os.environ["CURL_CA_BUNDLE"] = "" +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +# config +IGNORE_KIE_LABEL = "others" +KIE_LABELS = [ + "Number", + "Name", + "Birthday", + "Home Town", + "Address", + "Sex", + "Nationality", + "Expiry Date", + "Nation", + "Religion", + "Date Range", + "Issued By", + IGNORE_KIE_LABEL, + "Rank" +] +DEVICE = "cuda:0" + +# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code + +# tokenizer = LayoutXLMTokenizer.from_pretrained( +# "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH +# ) + +# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) +# processor = LayoutXLMProcessor(feature_extractor, tokenizer) + +model = LayoutLMv2ForTokenClassification.from_pretrained( + "Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True +).to( + DEVICE +) # TODO FIX this hard code + + +def load_ocr_labels(list_lines): + words, boxes, labels = [], [], [] + for line in list_lines: + for word_group in line.list_word_groups: + for word in word_group.list_words: + xmin, ymin, xmax, ymax = ( + word.boundingbox[0], + word.boundingbox[1], + word.boundingbox[2], + word.boundingbox[3], + ) + text = word.text + label = "seller_name_value" # TODO ??? fix this + x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax) + if text != " ": + words.append(text) + boxes.append([x1, y1, x2, y2]) + labels.append(label) + return words, boxes, labels + + +def _normalize_box(box, width, height): + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def infer_driving_license(image_crop, list_lines, max_n_words, processor): + # Load inputs + # image = Image.open(image_path) + image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + batch_words, batch_boxes, _ = load_ocr_labels(list_lines) + batch_preds, batch_true_boxes = [], [] + list_words = [] + for i in range(0, len(batch_words), max_n_words): + words = batch_words[i : i + max_n_words] + boxes = batch_boxes[i : i + max_n_words] + boxes_norm = [ + _normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes + ] + + # Preprocess + dummy_word_labels = [0] * len(words) + encoding = processor( + image, + text=words, + boxes=boxes_norm, + word_labels=dummy_word_labels, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=512, + ) + + # Run model + for k, v in encoding.items(): + encoding[k] = v.to(DEVICE) + outputs = model(**encoding) + predictions = outputs.logits.argmax(-1).squeeze().tolist() + + # Postprocess + is_subword = ( + (encoding["labels"] == -100).detach().cpu().numpy()[0] + ) # remove padding + true_predictions = [ + pred for idx, pred in enumerate(predictions) if not is_subword[idx] + ] + true_boxes = ( + boxes # TODO check assumption that layourlm do not change box order + ) + + for i, word in enumerate(words): + bndbox = [int(j) for j in true_boxes[i]] + list_words.append( + Word( + text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]] + ) + ) + + batch_preds.extend(true_predictions) + batch_true_boxes.extend(true_boxes) + + batch_preds = np.array(batch_preds) + batch_true_boxes = np.array(batch_true_boxes) + return batch_words, batch_preds, batch_true_boxes, list_words diff --git a/cope2n-ai-fi/api/Kie_AHung_ID/prediction.py b/cope2n-ai-fi/api/Kie_AHung_ID/prediction.py new file mode 100755 index 0000000..ec1b22c --- /dev/null +++ b/cope2n-ai-fi/api/Kie_AHung_ID/prediction.py @@ -0,0 +1,145 @@ +from PIL import Image +import numpy as np +import cv2 + +from transformers import LayoutLMv2ForTokenClassification + +from common.utils.word_formation import * + +from common.utils.global_variables import * +from common.utils.process_label import * +import ssl + +ssl._create_default_https_context = ssl._create_unverified_context +os.environ["CURL_CA_BUNDLE"] = "" +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +# config +IGNORE_KIE_LABEL = "others" +KIE_LABELS = [ + "Number", + "Name", + "Birthday", + "Home Town", + "Address", + "Sex", + "Nationality", + "Expiry Date", + "Nation", + "Religion", + "Date Range", + "Issued By", + IGNORE_KIE_LABEL +] +DEVICE = "cuda:0" + +# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code + +# tokenizer = LayoutXLMTokenizer.from_pretrained( +# "Kie_AHung_ID/model/pretrained/layoutxlm-base/tokenizer", +# model_max_length=MAX_SEQ_LENGTH, +# ) + +# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) +# processor = LayoutXLMProcessor(feature_extractor, tokenizer) + +model = LayoutLMv2ForTokenClassification.from_pretrained( + "Kie_AHung_ID/model/ID_CARD_145_train_300_val_0.02_char_0.06_word", + num_labels=len(KIE_LABELS), + local_files_only=True, +).to( + DEVICE +) # TODO FIX this hard code + + +def load_ocr_labels(list_lines): + words, boxes, labels = [], [], [] + for line in list_lines: + for word_group in line.list_word_groups: + for word in word_group.list_words: + xmin, ymin, xmax, ymax = ( + word.boundingbox[0], + word.boundingbox[1], + word.boundingbox[2], + word.boundingbox[3], + ) + text = word.text + label = "seller_name_value" # TODO ??? fix this + x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax) + if text != " ": + words.append(text) + boxes.append([x1, y1, x2, y2]) + labels.append(label) + return words, boxes, labels + + +def _normalize_box(box, width, height): + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def infer_id_card(image_crop, list_lines, max_n_words, processor): + # Load inputs + image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + batch_words, batch_boxes, _ = load_ocr_labels(list_lines) + batch_preds, batch_true_boxes = [], [] + list_words = [] + for i in range(0, len(batch_words), max_n_words): + words = batch_words[i : i + max_n_words] + boxes = batch_boxes[i : i + max_n_words] + boxes_norm = [ + _normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes + ] + + # Preprocess + dummy_word_labels = [0] * len(words) + encoding = processor( + image, + text=words, + boxes=boxes_norm, + word_labels=dummy_word_labels, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=512, + ) + + # Run model + for k, v in encoding.items(): + encoding[k] = v.to(DEVICE) + outputs = model(**encoding) + predictions = outputs.logits.argmax(-1).squeeze().tolist() + + # Postprocess + is_subword = ( + (encoding["labels"] == -100).detach().cpu().numpy()[0] + ) # remove padding + true_predictions = [ + pred for idx, pred in enumerate(predictions) if not is_subword[idx] + ] + true_boxes = ( + boxes # TODO check assumption that layourlm do not change box order + ) + + for i, word in enumerate(words): + bndbox = [int(j) for j in true_boxes[i]] + list_words.append( + Word( + text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]] + ) + ) + + batch_preds.extend(true_predictions) + batch_true_boxes.extend(true_boxes) + + batch_preds = np.array(batch_preds) + batch_true_boxes = np.array(batch_true_boxes) + return batch_words, batch_preds, batch_true_boxes, list_words diff --git a/cope2n-ai-fi/api/Kie_Hoanglv/prediction.py b/cope2n-ai-fi/api/Kie_Hoanglv/prediction.py new file mode 100755 index 0000000..7fbe734 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Hoanglv/prediction.py @@ -0,0 +1,57 @@ +from sdsvkie import Predictor +import cv2 +import numpy as np +import urllib +from common import serve_model +from common import ocr + +model = Predictor( + cfg = "./models/kie_invoice/config.yaml", + device = "cuda:0", + weights = "./models/models/kie_invoice/last", + proccessor = serve_model.processor, + ocr_engine = ocr.engine + ) + +def predict(page_numb, image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + image = cv2.imdecode(arr, -1) + out = model(image) + output = out["end2end_results"] + output_dict = { + "document_type": "invoice", + "fields": [] + } + for key in output.keys(): + field = { + "label": key, + "value": output[key]['value'] if output[key]['value'] else "", + "box": output[key]['box'], + "confidence": output[key]['conf'], + "page": page_numb + } + output_dict['fields'].append(field) + return output_dict \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py b/cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py new file mode 100755 index 0000000..e6ef934 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Hoanglv/serve_model.py @@ -0,0 +1,83 @@ +from common.utils_invoice.run_ocr import ocr_predict +import os +from Kie_Hoanglv.prediction2 import KIEInvoiceInfer +from configs.config_invoice.layoutxlm_base_invoice import * + +from PIL import Image +import requests +from io import BytesIO +import numpy as np +import cv2 + +model = KIEInvoiceInfer( + weight_dir=TRAINED_DIR, + tokenizer_dir=TOKENIZER_DIR, + max_seq_len=MAX_SEQ_LENGTH, + classes=KIE_LABELS, + device=DEVICE, + outdir_visualize=VISUALIZE_DIR, +) + +def format_result(result): + """ + return: + [ + { + key: 'name', + value: 'Nguyen Hoang Hiep', + true_box: [ + 373, + 113, + 700, + 420 + ] + }, + { + key: 'name', + value: 'Nguyen Hoang Hiep 1', + true_box: [ + 10, + 10, + 20, + 20, + ] + }, + ] + """ + new_result = [] + for i, item in enumerate(result[0]): + new_result.append( + { + "key": item, + "value": result[0][item], + "true_box": result[1][i], + } + ) + return new_result + +def predict(image_url): + if not os.path.exists(PRED_DIR): + os.makedirs(PRED_DIR, exist_ok=True) + + if not os.path.exists(VISUALIZE_DIR): + os.makedirs(VISUALIZE_DIR, exist_ok=True) + + + response = requests.get(image_url) + image = Image.open(BytesIO(response.content)) + + cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + bboxes, texts = ocr_predict(cv_image) + + texts_replaced = [] + for text in texts: + if "✪" in text: + text_replaced = text.replace("✪", " ") + texts_replaced.append(text_replaced) + else: + texts_replaced.append(text) + inputs = model.prepare_kie_inputs(image, ocr_info=[bboxes, texts_replaced]) + result = model(inputs) + result = format_result(result) + return result \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/anyKeyValue.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/anyKeyValue.py new file mode 100755 index 0000000..dd21d2f --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/anyKeyValue.py @@ -0,0 +1,137 @@ +import os +import glob +import cv2 +import argparse +from tqdm import tqdm +import urllib +import numpy as np +import imagesize +# from omegaconf import OmegaConf +import sys +cur_dir = os.path.dirname(__file__) +sys.path.append(cur_dir) +# sys.path.append('/cope2n-ai-fi/Kie_Invoice_AP/AnyKey_Value/') +from predictor import KVUPredictor +from preprocess import KVUProcess, DocumentKVUProcess +from utils.utils import create_dir, visualize, get_colormap, export_kvu_outputs, export_kvu_for_manulife + + +def get_args(): + args = argparse.ArgumentParser(description='Main file') + args.add_argument('--img_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str, + help='Input image directory') + args.add_argument('--save_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str, + help='Save directory') + # args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str, + # help='Checkpoint and config of model') + args.add_argument('--exp_dir', default='/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str, + help='Checkpoint and config of model') + args.add_argument('--export_img', default=0, type=int, + help='Save visualize on image') + args.add_argument('--mode', default=3, type=int, + help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'") + args.add_argument('--dir_level', default=0, type=int, + help='Number of subfolders contains image') + + return args.parse_args() + + +def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor: + configs = { + 'cfg': glob.glob(f'{exp_dir}/*.yaml')[0], + 'ckpt': f'{exp_dir}/checkpoints/best_model.pth' + } + dummy_idx = 512 + predictor = KVUPredictor(configs, class_names, dummy_idx, mode) + + # processor = KVUProcess(predictor.net.tokenizer_layoutxlm, + # predictor.net.feature_extractor, predictor.backbone_type, class_names, + # predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode) + + processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor, + predictor.backbone_type, class_names, + predictor.max_window_count, predictor.slice_interval, predictor.window_size, + run_ocr=1, mode=mode) + return predictor, processor + +def revert_box(box, width, height): + return [ + int((box[0] / 1000) * width), + int((box[1] / 1000) * height), + int((box[2] / 1000) * width), + int((box[3] / 1000) * height) + ] + +def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None: + fname = os.path.basename(img_path) + img_ext = img_path.split('.')[-1] + inputs = processor(img_path, ocr_path='') + width, height = imagesize.get(img_path) + + bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs) + # slide_window = False if len(bbox) == 1 else True + + if len(bbox) == 0: + bbox, lwords, pr_class_words, pr_relations = [bbox], [lwords], [pr_class_words], [pr_relations] + + for i in range(len(bbox)): + bbox[i] = [revert_box(bb, width, height) for bb in bbox[i]] + # vat_outputs_invoice = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names) + # vat_outputs_receipt = export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names) + # vat_outputs_invoice = export_kvu_for_all(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names) + vat_outputs_invoice = export_kvu_for_manulife(os.path.join(save_dir, fname.replace(img_ext, '.json')), lwords[i], bbox[i], pr_class_words[i], pr_relations[i], predictor.class_names) + + print(vat_outputs_invoice) + return vat_outputs_invoice + + +def load_groundtruth(img_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor: KVUProcess, export_img: int) -> None: + fname = os.path.basename(img_path) + img_ext = img_path.split('.')[-1] + inputs = processor.load_ground_truth(os.path.join(json_dir, fname.replace(f".{img_ext}", ".json"))) + bbox, lwords, pr_class_words, pr_relations = predictor.get_ground_truth_label(inputs) + + export_kvu_outputs(os.path.join(save_dir, fname.replace(f'.{img_ext}', '.json')), lwords, pr_class_words, pr_relations, predictor.class_names) + + if export_img == 1: + save_path = os.path.join(save_dir, 'kvu_results') + create_dir(save_path) + color_map = get_colormap() + image = cv2.imread(img_path) + image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness=1) + cv2.imwrite(os.path.join(save_path, fname), image) + +def show_groundtruth(dir_path: str, json_dir: str, save_dir: str, predictor: KVUPredictor, processor, export_img: int) -> None: + list_images = [] + for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png']: + list_images += glob.glob(os.path.join(dir_path, f'*.{ext}')) + print('No. images:', len(list_images)) + for img_path in tqdm(list_images): + load_groundtruth(img_path, json_dir, save_dir, predictor, processor, export_img) + +def Predictor_KVU(image_url: str, save_dir: str, predictor: KVUPredictor, processor) -> None: + + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + image_path = "./Kie_Invoice_AP/tmp_image/{image_url}.jpg" + cv2.imwrite(image_path, img) + vat_outputs = predict_image(image_path, save_dir, predictor, processor) + return vat_outputs + + +if __name__ == "__main__": + args = get_args() + class_names = ['others', 'title', 'key', 'value', 'header'] + predict_mode = { + 'normal': 0, + 'full_tokens': 1, + 'sliding': 2, + 'document': 3 + } + predictor, processor = load_engine(args.exp_dir, class_names, args.mode) + create_dir(args.save_dir) + image_path = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg" + save_dir = "/home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test" + predict_image(image_path, save_dir, predictor, processor) + print('[INFO] Done') \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier.py new file mode 100755 index 0000000..45b668f --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier.py @@ -0,0 +1,133 @@ +import time + +import torch +import torch.utils.data +from overrides import overrides +from pytorch_lightning import LightningModule +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.utilities.distributed import rank_zero_only +from torch.optim import SGD, Adam, AdamW +from torch.optim.lr_scheduler import LambdaLR + +from lightning_modules.schedulers import ( + cosine_scheduler, + linear_scheduler, + multistep_scheduler, +) +from model import get_model +from utils import cfg_to_hparams, get_specific_pl_logger + + +class ClassifierModule(LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.net = get_model(self.cfg) + self.ignore_index = -100 + + self.time_tracker = None + + self.optimizer_types = { + "sgd": SGD, + "adam": Adam, + "adamw": AdamW, + } + + @overrides + def setup(self, stage): + self.time_tracker = time.time() + + @overrides + def configure_optimizers(self): + optimizer = self._get_optimizer() + scheduler = self._get_lr_scheduler(optimizer) + scheduler = { + "scheduler": scheduler, + "name": "learning_rate", + "interval": "step", + } + return [optimizer], [scheduler] + + def _get_lr_scheduler(self, optimizer): + cfg_train = self.cfg.train + lr_schedule_method = cfg_train.optimizer.lr_schedule.method + lr_schedule_params = cfg_train.optimizer.lr_schedule.params + + if lr_schedule_method is None: + scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1) + elif lr_schedule_method == "step": + scheduler = multistep_scheduler(optimizer, **lr_schedule_params) + elif lr_schedule_method == "cosine": + total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch + total_batch_size = cfg_train.batch_size * self.trainer.world_size + max_iter = total_samples / total_batch_size + scheduler = cosine_scheduler( + optimizer, training_steps=max_iter, **lr_schedule_params + ) + elif lr_schedule_method == "linear": + total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch + total_batch_size = cfg_train.batch_size * self.trainer.world_size + max_iter = total_samples / total_batch_size + scheduler = linear_scheduler( + optimizer, training_steps=max_iter, **lr_schedule_params + ) + else: + raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}") + + return scheduler + + def _get_optimizer(self): + opt_cfg = self.cfg.train.optimizer + method = opt_cfg.method.lower() + + if method not in self.optimizer_types: + raise ValueError(f"Unknown optimizer method={method}") + + kwargs = dict(opt_cfg.params) + kwargs["params"] = self.net.parameters() + optimizer = self.optimizer_types[method](**kwargs) + + return optimizer + + @rank_zero_only + @overrides + def on_fit_end(self): + hparam_dict = cfg_to_hparams(self.cfg, {}) + metric_dict = {"metric/dummy": 0} + + tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger) + + if tb_logger: + tb_logger.log_hyperparams(hparam_dict, metric_dict) + + @overrides + def training_epoch_end(self, training_step_outputs): + avg_loss = torch.tensor(0.0).to(self.device) + for step_out in training_step_outputs: + avg_loss += step_out["loss"] + + log_dict = {"train_loss": avg_loss} + self._log_shell(log_dict, prefix="train ") + + def _log_shell(self, log_info, prefix=""): + log_info_shell = {} + for k, v in log_info.items(): + new_v = v + if type(new_v) is torch.Tensor: + new_v = new_v.item() + log_info_shell[k] = new_v + + out_str = prefix.upper() + if prefix.upper().strip() in ["TRAIN", "VAL"]: + out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]" + + if self.training: + lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"] + log_info_shell["lr"] = lr + + for key, value in log_info_shell.items(): + out_str += f" || {key}: {round(value, 5)}" + out_str += f" || time: {round(time.time() - self.time_tracker, 1)}" + out_str += " secs." + self.print(out_str) + self.time_tracker = time.time() diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier_module.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier_module.py new file mode 100755 index 0000000..5786cc5 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/classifier_module.py @@ -0,0 +1,390 @@ +import numpy as np +import torch +import torch.utils.data +from overrides import overrides + +from lightning_modules.classifier import ClassifierModule +from utils import get_class_names + + +class KVUClassifierModule(ClassifierModule): + def __init__(self, cfg): + super().__init__(cfg) + + class_names = get_class_names(self.cfg.dataset_root_path) + + self.window_size = cfg.train.max_num_words + self.slice_interval = cfg.train.slice_interval + self.eval_kwargs = { + "class_names": class_names, + "dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step + } + self.stage = cfg.stage + + @overrides + def training_step(self, batch, batch_idx, *args): + if self.stage == 1: + _, loss = self.net(batch['windows']) + elif self.stage == 2: + _, loss = self.net(batch) + else: + raise ValueError( + f"Not supported stage: {self.stage}" + ) + + log_dict_input = {"train_loss": loss} + self.log_dict(log_dict_input, sync_dist=True) + return loss + + @torch.no_grad() + @overrides + def validation_step(self, batch, batch_idx, *args): + if self.stage == 1: + step_out_total = { + "loss": 0, + "ee":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }, + "el":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }, + "el_from_key":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }} + for window in batch['windows']: + head_outputs, loss = self.net(window) + step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs) + for key in step_out_total: + if key == 'loss': + step_out_total[key] += step_out[key] + else: + for subkey in step_out_total[key]: + step_out_total[key][subkey] += step_out[key][subkey] + return step_out_total + + elif self.stage == 2: + head_outputs, loss = self.net(batch) + # self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1] + # step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs) + self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1] + step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs) + return step_out + + @torch.no_grad() + @overrides + def validation_epoch_end(self, validation_step_outputs): + scores = do_eval_epoch_end(validation_step_outputs) + self.print( + f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}" + ) + self.print( + f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}" + ) + self.print( + f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}" + ) + self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.) + tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'], + 'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'], + 'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \ + 'val_f1_el_from_key': scores['el_from_key']['f1'],} + return {'log': tensorboard_logs} + + +def do_eval_step(batch, head_outputs, loss, eval_kwargs): + class_names = eval_kwargs["class_names"] + dummy_idx = eval_kwargs["dummy_idx"] + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_labels = torch.argmax(itc_outputs, -1) + pr_stc_labels = torch.argmax(stc_outputs, -1) + pr_el_labels = torch.argmax(el_outputs, -1) + pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1) + + ( + n_batch_gt_classes, + n_batch_pr_classes, + n_batch_correct_classes, + ) = eval_ee_spade_batch( + pr_itc_labels, + batch["itc_labels"], + batch["are_box_first_tokens"], + pr_stc_labels, + batch["stc_labels"], + batch["attention_mask_layoutxlm"], + class_names, + dummy_idx, + ) + + n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch( + pr_el_labels, + batch["el_labels"], + batch["are_box_first_tokens"], + dummy_idx, + ) + + n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch( + pr_el_labels_from_key, + batch["el_labels_from_key"], + batch["are_box_first_tokens"], + dummy_idx, + ) + + step_out = { + "loss": loss, + "ee":{ + "n_batch_gt": n_batch_gt_classes, + "n_batch_pr": n_batch_pr_classes, + "n_batch_correct": n_batch_correct_classes, + }, + "el":{ + "n_batch_gt": n_batch_gt_rel, + "n_batch_pr": n_batch_pr_rel, + "n_batch_correct": n_batch_correct_rel, + }, + "el_from_key":{ + "n_batch_gt": n_batch_gt_rel_from_key, + "n_batch_pr": n_batch_pr_rel_from_key, + "n_batch_correct": n_batch_correct_rel_from_key, + } + + } + + return step_out + + +def eval_ee_spade_batch( + pr_itc_labels, + gt_itc_labels, + are_box_first_tokens, + pr_stc_labels, + gt_stc_labels, + attention_mask, + class_names, + dummy_idx, +): + n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0 + + bsz = pr_itc_labels.shape[0] + for example_idx in range(bsz): + n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example( + pr_itc_labels[example_idx], + gt_itc_labels[example_idx], + are_box_first_tokens[example_idx], + pr_stc_labels[example_idx], + gt_stc_labels[example_idx], + attention_mask[example_idx], + class_names, + dummy_idx, + ) + + n_batch_gt_classes += n_gt_classes + n_batch_pr_classes += n_pr_classes + n_batch_correct_classes += n_correct_classes + + return ( + n_batch_gt_classes, + n_batch_pr_classes, + n_batch_correct_classes, + ) + + +def eval_ee_spade_example( + pr_itc_label, + gt_itc_label, + box_first_token_mask, + pr_stc_label, + gt_stc_label, + attention_mask, + class_names, + dummy_idx, +): + gt_first_words = parse_initial_words( + gt_itc_label, box_first_token_mask, class_names + ) + gt_class_words = parse_subsequent_words( + gt_stc_label, attention_mask, gt_first_words, dummy_idx + ) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, dummy_idx + ) + + n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0 + for class_idx in range(len(class_names)): + # Evaluate by ID + gt_parse = set(gt_class_words[class_idx]) + pr_parse = set(pr_class_words[class_idx]) + + n_gt_classes += len(gt_parse) + n_pr_classes += len(pr_parse) + n_correct_classes += len(gt_parse & pr_parse) + + return n_gt_classes, n_pr_classes, n_correct_classes + + +def parse_initial_words(itc_label, box_first_token_mask, class_names): + itc_label_np = itc_label.cpu().numpy() + box_first_token_mask_np = box_first_token_mask.cpu().numpy() + + outputs = [[] for _ in range(len(class_names))] + + for token_idx, label in enumerate(itc_label_np): + if box_first_token_mask_np[token_idx] and label != 0: + outputs[label].append(token_idx) + + return outputs + + +def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx): + max_connections = 50 + + valid_stc_label = stc_label * attention_mask.bool() + valid_stc_label = valid_stc_label.cpu().numpy() + stc_label_np = stc_label.cpu().numpy() + + valid_token_indices = np.where( + (valid_stc_label != dummy_idx) * (valid_stc_label != 0) + ) + + next_token_idx_dict = {} + for token_idx in valid_token_indices[0]: + next_token_idx_dict[stc_label_np[token_idx]] = token_idx + + outputs = [] + for init_token_indices in init_words: + sub_outputs = [] + for init_token_idx in init_token_indices: + cur_token_indices = [init_token_idx] + for _ in range(max_connections): + if cur_token_indices[-1] in next_token_idx_dict: + if ( + next_token_idx_dict[cur_token_indices[-1]] + not in init_token_indices + ): + cur_token_indices.append( + next_token_idx_dict[cur_token_indices[-1]] + ) + else: + break + else: + break + sub_outputs.append(tuple(cur_token_indices)) + + outputs.append(sub_outputs) + + return outputs + + +def eval_el_spade_batch( + pr_el_labels, + gt_el_labels, + are_box_first_tokens, + dummy_idx, +): + n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0 + + bsz = pr_el_labels.shape[0] + for example_idx in range(bsz): + n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example( + pr_el_labels[example_idx], + gt_el_labels[example_idx], + are_box_first_tokens[example_idx], + dummy_idx, + ) + + n_batch_gt_rel += n_gt_rel + n_batch_pr_rel += n_pr_rel + n_batch_correct_rel += n_correct_rel + + return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel + + +def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx): + gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx) + pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) + + gt_relations = set(gt_relations) + pr_relations = set(pr_relations) + + n_gt_rel = len(gt_relations) + n_pr_rel = len(pr_relations) + n_correct_rel = len(gt_relations & pr_relations) + + return n_gt_rel, n_pr_rel, n_correct_rel + + +def parse_relations(el_label, box_first_token_mask, dummy_idx): + valid_el_labels = el_label * box_first_token_mask + valid_el_labels = valid_el_labels.cpu().numpy() + el_label_np = el_label.cpu().numpy() + + max_token = box_first_token_mask.shape[0] - 1 + + valid_token_indices = np.where( + ((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ### + ) + + link_map_tuples = [] + for token_idx in valid_token_indices[0]: + link_map_tuples.append((el_label_np[token_idx], token_idx)) + + return set(link_map_tuples) + +def do_eval_epoch_end(step_outputs): + scores = {} + for task in ['ee', 'el', 'el_from_key']: + n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0 + + for step_out in step_outputs: + n_total_gt_classes += step_out[task]["n_batch_gt"] + n_total_pr_classes += step_out[task]["n_batch_pr"] + n_total_correct_classes += step_out[task]["n_batch_correct"] + + precision = ( + 0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes + ) + recall = ( + 0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes + ) + f1 = ( + 0.0 + if recall * precision == 0 + else 2.0 * recall * precision / (recall + precision) + ) + + scores[task] = { + "precision": precision, + "recall": recall, + "f1": f1, + } + + return scores + + +def get_eval_kwargs_spade(dataset_root_path, max_seq_length): + class_names = get_class_names(dataset_root_path) + dummy_idx = max_seq_length + + eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx} + + return eval_kwargs + + +def get_eval_kwargs_spade_rel(max_seq_length): + dummy_idx = max_seq_length + + eval_kwargs = {"dummy_idx": dummy_idx} + + return eval_kwargs \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/data_module.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/data_module.py new file mode 100755 index 0000000..1b9a255 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/data_module.py @@ -0,0 +1,135 @@ +import time + +import torch +import pytorch_lightning as pl +from overrides import overrides +from torch.utils.data.dataloader import DataLoader + +from lightning_modules.data_modules.kvu_dataset import KVUDataset, KVUEmbeddingDataset +from lightning_modules.utils import _get_number_samples + +class KVUDataModule(pl.LightningDataModule): + def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor): + super().__init__() + self.cfg = cfg + self.train_loader = None + self.val_loader = None + self.tokenizer_layoutxlm = tokenizer_layoutxlm + self.feature_extractor = feature_extractor + self.collate_fn = None + + self.backbone_type = self.cfg.model.backbone + self.num_samples_per_epoch = _get_number_samples(cfg.dataset_root_path) + + + @overrides + def setup(self, stage=None): + self.train_loader = self._get_train_loader() + self.val_loader = self._get_val_loaders() + + @overrides + def train_dataloader(self): + return self.train_loader + + @overrides + def val_dataloader(self): + return self.val_loader + + def _get_train_loader(self): + start_time = time.time() + + if self.cfg.stage == 1: + dataset = KVUDataset( + self.cfg, + self.tokenizer_layoutxlm, + self.feature_extractor, + mode="train", + ) + elif self.cfg.stage == 2: + # dataset = KVUEmbeddingDataset( + # self.cfg, + # mode="train", + # ) + dataset = KVUDataset( + self.cfg, + self.tokenizer_layoutxlm, + self.feature_extractor, + mode="train", + ) + else: + raise ValueError( + f"Not supported stage: {self.cfg.stage}" + ) + + print('No. training samples:', len(dataset)) + + data_loader = DataLoader( + dataset, + batch_size=self.cfg.train.batch_size, + shuffle=True, + num_workers=self.cfg.train.num_workers, + pin_memory=True, + ) + + elapsed_time = time.time() - start_time + print(f"Elapsed time for loading training data: {elapsed_time}") + + return data_loader + + def _get_val_loaders(self): + + if self.cfg.stage == 1: + dataset = KVUDataset( + self.cfg, + self.tokenizer_layoutxlm, + self.feature_extractor, + mode="val", + ) + elif self.cfg.stage == 2: + # dataset = KVUEmbeddingDataset( + # self.cfg, + # mode="val", + # ) + dataset = KVUDataset( + self.cfg, + self.tokenizer_layoutxlm, + self.feature_extractor, + mode="val", + ) + else: + raise ValueError( + f"Not supported stage: {self.cfg.stage}" + ) + + print('No. validation samples:', len(dataset)) + + data_loader = DataLoader( + dataset, + batch_size=self.cfg.val.batch_size, + shuffle=False, + num_workers=self.cfg.val.num_workers, + pin_memory=True, + drop_last=False, + ) + + return data_loader + + @overrides + def transfer_batch_to_device(self, batch, device, dataloader_idx): + if isinstance(batch, list): + for sub_batch in batch: + for k in sub_batch.keys(): + if isinstance(sub_batch[k], torch.Tensor): + sub_batch[k] = sub_batch[k].to(device) + else: + for k in batch.keys(): + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + return batch + + + + + + + diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/kvu_dataset.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/kvu_dataset.py new file mode 100755 index 0000000..e15ec64 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/data_modules/kvu_dataset.py @@ -0,0 +1,728 @@ +import os +import json +import omegaconf +import itertools +import copy + +import numpy as np +from PIL import Image + +import torch +from torch.utils.data.dataset import Dataset + +from utils import get_class_names +from lightning_modules.utils import sliding_windows_by_words + +class KVUDataset(Dataset): + def __init__( + self, + cfg, + tokenizer_layoutxlm, + feature_extractor, + mode=None, + ): + super(KVUDataset, self).__init__() + + self.dataset_root_path = cfg.dataset_root_path + if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig): + self.dataset_root_path = [self.dataset_root_path] + + self.backbone_type = cfg.model.backbone + self.max_seq_length = cfg.train.max_seq_length + self.window_size = cfg.train.max_num_words + self.slice_interval = cfg.train.slice_interval + + self.tokenizer_layoutxlm = tokenizer_layoutxlm + self.feature_extractor = feature_extractor + + self.stage = cfg.stage + self.mode = mode + + self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._pad_token) + self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._cls_token) + self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._sep_token) + self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token) + + self.examples = self._load_examples() + + self.class_names = get_class_names(self.dataset_root_path) + self.class_idx_dic = dict( + [(class_name, idx) for idx, class_name in enumerate(self.class_names)] + ) + + def _load_examples(self): + examples = [] + for dataset_dir in self.dataset_root_path: + with open( + os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"), + "r", + encoding="utf-8", + ) as fp: + for line in fp.readlines(): + preprocessed_file = os.path.join(dataset_dir, line.strip()) + examples.append( + json.load(open(preprocessed_file, "r", encoding="utf-8")) + ) + + return examples + + def __len__(self): + return len(self.examples) + + + def __getitem__(self, index): + json_obj = self.examples[index] + + width = json_obj["meta"]["imageSize"]["width"] + height = json_obj["meta"]["imageSize"]["height"] + img_path = json_obj["meta"]["image_path"] + + images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + + word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words( + json_obj["words"], + json_obj['parse']['class'], + json_obj['parse']['relations'], + self.window_size, self.slice_interval) + outputs = {} + if self.stage == 1: + if self.mode == 'train': + i = np.random.randint(0, len(word_windows), 1)[0] + outputs['windows'] = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i], + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + else: + outputs['windows'] = [] + for i in range(len(word_windows)): + single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i], + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + outputs['windows'].append(single_window) + + elif self.stage == 2: + outputs['documents'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'], + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=2048) + + windows = [] + for i in range(len(word_windows)): + _words = word_windows[i] + _parse_class = parse_class_windows[i] + _parse_relation = parse_relation_windows[i] + windows.append( + self.preprocess(_words, _parse_class, _parse_relation, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + self.max_seq_length) + ) + outputs['windows'] = windows + return outputs + + def preprocess(self, words, parse_class, parse_relation, feature_maps, max_seq_length): + input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm + + attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int) + + bbox = np.zeros((max_seq_length, 8), dtype=np.float32) + + are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) + + itc_labels = np.zeros(max_seq_length, dtype=int) + stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length + el_labels = np.ones((max_seq_length,), dtype=int) * max_seq_length + el_labels_from_key = np.ones((max_seq_length,), dtype=int) * max_seq_length + list_layoutxlm_tokens = [] + + list_bbs = [] + box2token_span_map = [] + + + box_to_token_indices = [] + cum_token_idx = 0 + + cls_bbs = [0.0] * 8 + len_overlap_tokens = 0 + len_non_overlap_tokens = 0 + len_valid_tokens = 0 + + for word_idx, word in enumerate(words): + this_box_token_indices = [] + + layoutxlm_tokens = word["layoutxlm_tokens"] + bb = word["boundingBox"] + len_valid_tokens += len(layoutxlm_tokens) + if word_idx < self.slice_interval: + len_non_overlap_tokens += len(layoutxlm_tokens) + # print(word_idx, layoutxlm_tokens, non_overlap_tokens) + + if len(layoutxlm_tokens) == 0: + layoutxlm_tokens.append(self.unk_token_id_layoutxlm) + + if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2: + break + + box2token_span_map.append( + [len(list_layoutxlm_tokens) + 1, len(list_layoutxlm_tokens) + len(layoutxlm_tokens) + 1] + ) # including st_idx + list_layoutxlm_tokens += layoutxlm_tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width'])) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height'])) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(layoutxlm_tokens))] + + for _ in layoutxlm_tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [feature_maps['width'], feature_maps['height']] * 4 + + # For [CLS] and [SEP] + list_layoutxlm_tokens = ( + [self.cls_token_id_layoutxlm] + + list_layoutxlm_tokens[: max_seq_length - 2] + + [self.sep_token_id_layoutxlm] + ) + + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] + + len_list_layoutxlm_tokens = len(list_layoutxlm_tokens) + input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens + attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1 + + bbox[:len_list_layoutxlm_tokens, :] = list_bbs + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width'] + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height'] + + if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < max_seq_length + ] + are_box_first_tokens[st_indices] = True + + # Label for entity extraction + classes_dic = parse_class + for class_name in self.class_names: + if class_name == "others": + continue + if class_name not in classes_dic: + continue + + for word_list in classes_dic[class_name]: + is_first, last_word_idx = True, -1 + for word_idx in word_list: + if word_idx >= len(box_to_token_indices): + break + box2token_list = box_to_token_indices[word_idx] + for converted_word_idx in box2token_list: + if converted_word_idx >= max_seq_length: + break # out of idx + + if is_first: + itc_labels[converted_word_idx] = self.class_idx_dic[ + class_name + ] + is_first, last_word_idx = False, converted_word_idx + else: + stc_labels[converted_word_idx] = last_word_idx + last_word_idx = converted_word_idx + + + # Label for entity linking + relations = parse_relation + for relation in relations: + if relation[0] >= len(box2token_span_map) or relation[1] >= len( + box2token_span_map + ): + continue + if ( + box2token_span_map[relation[0]][0] >= max_seq_length + or box2token_span_map[relation[1]][0] >= max_seq_length + ): + continue + + word_from = box2token_span_map[relation[0]][0] + word_to = box2token_span_map[relation[1]][0] + # el_labels[word_to] = word_from + + + #### 1st relation => ['key, 'value'] + #### 2st relation => ['header', 'key'or'value'] + if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: + el_labels_from_key[word_to] = word_from # pair of (key-value) + if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + el_labels[word_to] = word_from # pair of (header, key) or (header-value) + + assert len_list_layoutxlm_tokens == len_valid_tokens + 2 + len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens + # overlap_tokens = max_seq_length - non_overlap_tokens + ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 + + # ntokens = max_seq_length + + input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm[:ntokens]) + + attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm[:ntokens]) + + bbox = torch.from_numpy(bbox[:ntokens]) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens[:ntokens]) + + itc_labels = itc_labels[:ntokens] + stc_labels = stc_labels[:ntokens] + el_labels = el_labels[:ntokens] + el_labels_from_key = el_labels_from_key[:ntokens] + + itc_labels = np.where(itc_labels != max_seq_length, itc_labels, ntokens) + stc_labels = np.where(stc_labels != max_seq_length, stc_labels, ntokens) + el_labels = np.where(el_labels != max_seq_length, el_labels, ntokens) + el_labels_from_key = np.where(el_labels_from_key != max_seq_length, el_labels_from_key, ntokens) + + itc_labels = torch.from_numpy(itc_labels) + stc_labels = torch.from_numpy(stc_labels) + el_labels = torch.from_numpy(el_labels) + el_labels_from_key = torch.from_numpy(el_labels_from_key) + + + return_dict = { + "img_path": feature_maps['img_path'], + "len_overlap_tokens": len_overlap_tokens, + 'len_valid_tokens': len_valid_tokens, + "image": feature_maps['image'], + "input_ids_layoutxlm": input_ids_layoutxlm, + "attention_mask_layoutxlm": attention_mask_layoutxlm, + "bbox": bbox, + "itc_labels": itc_labels, + "are_box_first_tokens": are_box_first_tokens, + "stc_labels": stc_labels, + "el_labels": el_labels, + "el_labels_from_key": el_labels_from_key, + } + return return_dict + + + +class KVUPredefinedDataset(KVUDataset): + def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None): + super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode) + self.max_windows = cfg.train.max_windows + + def __getitem__(self, index): + json_obj = self.examples[index] + + width = json_obj["meta"]["imageSize"]["width"] + height = json_obj["meta"]["imageSize"]["height"] + img_path = json_obj["meta"]["image_path"] + + images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + + + word_windows, parse_class_windows, parse_relation_windows = sliding_windows_by_words( + json_obj["words"], + json_obj['parse']['class'], + json_obj['parse']['relations'], + self.window_size, self.slice_interval) + + word_windows = word_windows[: self.max_windows] if len(word_windows) >= self.max_windows else word_windows + [[]] * (self.max_windows - len(word_windows)) + parse_class_windows = parse_class_windows[: self.max_windows] if len(parse_class_windows) >= self.max_windows else parse_class_windows + [[]] * (self.max_windows - len(parse_class_windows)) + parse_relation_windows = parse_relation_windows[: self.max_windows] if len(parse_relation_windows) >= self.max_windows else parse_relation_windows + [[]] * (self.max_windows - len(parse_relation_windows)) + + + outputs = {} + # outputs['labels'] = self.preprocess(json_obj["words"], json_obj['parse']['class'], json_obj['parse']['relations'], + # {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + # max_seq_length=self.max_seq_length*self.max_windows) + + outputs['windows'] = [] + for i in range(len(self.max_windows)): + single_window = self.preprocess(word_windows[i], parse_class_windows[i], parse_relation_windows[i], + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + outputs['windows'].append(single_window) + + outputs['windows'] = torch.cat(outputs["windows"], dim=0) + + return outputs + + +class KVUEmbeddingDataset(Dataset): + def __init__( + self, + cfg, + mode=None, + ): + super(KVUEmbeddingDataset, self).__init__() + + self.dataset_root_path = cfg.dataset_root_path + if not isinstance(self.dataset_root_path, omegaconf.listconfig.ListConfig): + self.dataset_root_path = [self.dataset_root_path] + + self.stage = cfg.stage + self.mode = mode + + self.examples = self._load_examples() + + def _load_examples(self): + examples = [] + for dataset_dir in self.dataset_root_path: + with open( + os.path.join(dataset_dir, f"preprocessed_files_{self.mode}.txt"), + "r", + encoding="utf-8", + ) as fp: + for line in fp.readlines(): + preprocessed_file = os.path.join(dataset_dir, line.strip()) + examples.append( + preprocessed_file.replace('preprocessed', 'embedding_matrix').replace('.json', '.npz') + ) + return examples + + def __len__(self): + return len(self.examples) + + def __getitem__(self, index): + input_embeddings = np.load(self.examples[index]) + + return_dict = { + "embeddings": torch.from_numpy(input_embeddings["embeddings"]).type(torch.HalfTensor), + "attention_mask_layoutxlm": torch.from_numpy(input_embeddings["attention_mask_layoutxlm"]), + "are_box_first_tokens": torch.from_numpy(input_embeddings["are_box_first_tokens"]), + "bbox": torch.from_numpy(input_embeddings["bbox"]), + "itc_labels": torch.from_numpy(input_embeddings["itc_labels"]), + "stc_labels": torch.from_numpy(input_embeddings["stc_labels"]), + "el_labels": torch.from_numpy(input_embeddings["el_labels"]), + "el_labels_from_key": torch.from_numpy(input_embeddings["el_labels_from_key"]), + } + return return_dict + + +class DocumentKVUDataset(KVUDataset): + def __init__(self, cfg, tokenizer_layoutxlm, feature_extractor, mode=None): + super().__init__(cfg, tokenizer_layoutxlm, feature_extractor, mode) + self.self.max_window_count = cfg.train.max_window_count + + def __getitem__(self, idx): + json_obj = self.examples[idx] + + width = json_obj["meta"]["imageSize"]["width"] + height = json_obj["meta"]["imageSize"]["height"] + + images = [Image.open(json_obj["meta"]["image_path"]).convert("RGB")] + feature_maps = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + + n_words = len(json_obj['words']) + output_dicts = {'windows': [], 'documents': []} + box_to_token_indices_document = [] + box2token_span_map_document = [] + n_empty_windows = 0 + + for i in range(self.max_window_count): + input_ids = np.ones(self.max_seq_length, dtype=int) * 1 + bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(self.max_seq_length, dtype=int) + + itc_labels = np.zeros(self.max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) + + # stc_labels stores the index of the previous token. + # A stored index of max_seq_length (512) indicates that + # this token is the initial token of a word box. + stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length + el_labels = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length + el_labels_from_key = np.ones((self.max_seq_length,), dtype=int) * self.max_seq_length + + start_word_idx = i * self.window_size + stop_word_idx = min(n_words, (i+1)*self.window_size) + + if start_word_idx >= stop_word_idx: + n_empty_windows += 1 + output_dicts.append(output_dicts[-1]) + + box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices) + for i_box in range(len(box_to_token_indices_to_mod)): + for j in range(len(box_to_token_indices_to_mod[i_box])): + box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length + for element in box_to_token_indices_to_mod: + box_to_token_indices_document.append(element) + + box2token_span_map_to_mod = copy.deepcopy(box2token_span_map) + for i_box in range(len(box2token_span_map_to_mod)): + for j in range(len(box2token_span_map_to_mod[i_box])): + box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length + for element in box2token_span_map_to_mod: + box2token_span_map_document.append(element) + + continue + + list_tokens = [] + list_bbs = [] + box2token_span_map = [] + + box_to_token_indices = [] + cum_token_idx = 0 + + cls_bbs = [0.0] * 8 + + # Parse words + for word_idx, word in enumerate(json_obj["words"][start_word_idx:stop_word_idx]): + this_box_token_indices = [] + tokens = word["tokens"] + bb = word["boundingBox"] + if len(tokens) == 0: + tokens.append(self.unk_token_id) + + if len(list_tokens) + len(tokens) > self.max_seq_length - 2: + break ### be able to apply sliding window here + + box2token_span_map.append( + [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] + ) # including st_idx + list_tokens += tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(tokens))] + + for _ in tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [width, height] * 4 + + # For [CLS] and [SEP] + list_tokens = ( + [self.cls_token_id] + + list_tokens[: self.max_seq_length - 2] + + [self.sep_token_id] + ) + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs] + + len_list_tokens = len(list_tokens) + input_ids[:len_list_tokens] = list_tokens + attention_mask[:len_list_tokens] = 1 + + bbox[:len_list_tokens, :] = list_bbs + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height + + if self.backbone_type in ('layoutlm', 'layoutxlm', 'xlm-roberta'): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < self.max_seq_length + ] + are_box_first_tokens[st_indices] = True + + # Parse word groups + classes_dic = json_obj["parse"]["class"] + for class_name in self.class_names: + if class_name == "others": + continue + if class_name not in classes_dic: + continue + + for word_list in classes_dic[class_name]: + word_list = [w for w in word_list if w >= start_word_idx and w < stop_word_idx] + if len(word_list) == 0: + continue # no more word left + word_list = [w - start_word_idx for w in word_list] + + is_first, last_word_idx = True, -1 + for word_idx in word_list: + if word_idx >= len(box_to_token_indices): + break + box2token_list = box_to_token_indices[word_idx] + for converted_word_idx in box2token_list: + if converted_word_idx >= self.max_seq_length: + break # out of idx + + if is_first: + itc_labels[converted_word_idx] = self.class_idx_dic[ + class_name + ] + is_first, last_word_idx = False, converted_word_idx + else: + stc_labels[converted_word_idx] = last_word_idx + last_word_idx = converted_word_idx + + # Parse relation + relations = json_obj["parse"]["relations"] + for relation in relations: + relation = [r for r in relation if r >= start_word_idx and r < stop_word_idx] + if len(relation) != 2: + continue # relation popped due to window inconsistent + relation[0] -= start_word_idx + relation[1] -= start_word_idx + + word_from = box2token_span_map[relation[0]][0] + word_to = box2token_span_map[relation[1]][0] + + #### 1st relation => ['key, 'value'] + #### 2st relation => ['header', 'key'or'value'] + if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: + el_labels_from_key[word_to] = word_from # pair of (key-value) + if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + el_labels[word_to] = word_from # pair of (header, key) or (header-value) + + input_ids = torch.from_numpy(input_ids) + bbox = torch.from_numpy(bbox) + attention_mask = torch.from_numpy(attention_mask) + + itc_labels = torch.from_numpy(itc_labels) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + stc_labels = torch.from_numpy(stc_labels) + el_labels = torch.from_numpy(el_labels) + el_labels_from_key = torch.from_numpy(el_labels_from_key) + + box_to_token_indices_to_mod = copy.deepcopy(box_to_token_indices) + for i_box in range(len(box_to_token_indices_to_mod)): + for j in range(len(box_to_token_indices_to_mod[i_box])): + box_to_token_indices_to_mod[i_box][j] += i * self.max_seq_length + for element in box_to_token_indices_to_mod: + box_to_token_indices_document.append(element) + + box2token_span_map_to_mod = copy.deepcopy(box2token_span_map) + for i_box in range(len(box2token_span_map_to_mod)): + for j in range(len(box2token_span_map_to_mod[i_box])): + box2token_span_map_to_mod[i_box][j] += i * self.max_seq_length + for element in box2token_span_map_to_mod: + box2token_span_map_document.append(element) + + return_dict = { + "image": feature_maps, + "input_ids": input_ids, + "bbox": bbox, + "attention_mask": attention_mask, + "itc_labels": itc_labels, + "are_box_first_tokens": are_box_first_tokens, + "stc_labels": stc_labels, + "el_labels": el_labels, + "el_labels_from_key": el_labels_from_key, + } + + output_dicts["windows"].append(return_dict) + + + # Parse whole document labels + attention_mask = torch.cat([o['attention_mask'] for o in output_dicts]) + are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts]) + if n_empty_windows > 0: + attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int)) + are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)) + bbox = torch.cat([o['bbox'] for o in output_dicts]) + + self.max_seq_length_document = self.max_seq_length * self.max_window_count + itc_labels = np.zeros(self.max_seq_length_document, dtype=int) + stc_labels = np.ones(self.max_seq_length_document, dtype=np.int64) * self.max_seq_length_document + el_labels = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document + el_labels_from_key = np.ones((self.max_seq_length_document,), dtype=int) * self.max_seq_length_document + + # Parse word groups + classes_dic = json_obj["parse"]["class"] + for class_name in self.class_names: + if class_name == "others": + continue + if class_name not in classes_dic: + continue + + word_lists = classes_dic[class_name] + + for word_list in word_lists: + is_first, last_word_idx = True, -1 + for word_idx in word_list: + if word_idx >= len(box_to_token_indices_document): + break + box2token_list = box_to_token_indices_document[word_idx] + for converted_word_idx in box2token_list: + if converted_word_idx >= self.max_seq_length_document: + break # out of idx + + if is_first: + itc_labels[converted_word_idx] = self.class_idx_dic[ + class_name + ] + is_first, last_word_idx = False, converted_word_idx + else: + stc_labels[converted_word_idx] = last_word_idx + last_word_idx = converted_word_idx + + # Parse relation + relations = json_obj["parse"]["relations"] + + for relation in relations: + if relation[0] >= len(box2token_span_map_document) or relation[1] >= len( + box2token_span_map_document + ): + continue + if ( + box2token_span_map_document[relation[0]][0] >= self.max_seq_length_document + or box2token_span_map_document[relation[1]][0] >= self.max_seq_length_document + ): + continue + + word_from = box2token_span_map_document[relation[0]][0] + word_to = box2token_span_map_document[relation[1]][0] + + if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: + el_labels_from_key[word_to] = word_from # pair of (key-value) + if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + el_labels[word_to] = word_from # pair of (header, key) or (header-value) + + itc_labels = torch.from_numpy(itc_labels) + stc_labels = torch.from_numpy(stc_labels) + el_labels = torch.from_numpy(el_labels) + el_labels_from_key = torch.from_numpy(el_labels_from_key) + + return_dict = { + "img_path": json_obj["meta"]["image_path"], + "attention_mask": attention_mask, + "bbox": bbox, + "itc_labels": itc_labels, + "are_box_first_tokens": are_box_first_tokens, + "stc_labels": stc_labels, + "el_labels": el_labels, + "el_labels_from_key": el_labels_from_key, + "n_empty_windows": n_empty_windows + } + + output_dicts['documents'] = return_dict + return output_dicts \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/schedulers.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/schedulers.py new file mode 100755 index 0000000..b49abc2 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/schedulers.py @@ -0,0 +1,53 @@ +""" +BROS +Copyright 2022-present NAVER Corp. +Apache License v2.0 +""" + +import math + +import numpy as np +from torch.optim.lr_scheduler import LambdaLR + + +def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1): + """linear_scheduler with warmup from huggingface""" + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, + float(training_steps - current_step) + / float(max(1, training_steps - warmup_steps)), + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def cosine_scheduler( + optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1 +): + """Cosine LR scheduler with warmup from huggingface""" + + def lr_lambda(current_step): + if current_step < warmup_steps: + return current_step / max(1, warmup_steps) + progress = current_step - warmup_steps + progress /= max(1, training_steps - warmup_steps) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1): + def lr_lambda(current_step): + if current_step < warmup_steps: + # calculate a warmup ratio + return current_step / max(1, warmup_steps) + else: + # calculate a multistep lr scaling ratio + idx = np.searchsorted(milestones, current_step) + return gamma ** idx + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/utils.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/utils.py new file mode 100755 index 0000000..92cdd5e --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/lightning_modules/utils.py @@ -0,0 +1,218 @@ +import os +import copy +import numpy as np +import torch +import torch.nn as nn +import math + + +def _get_number_samples(dataset_root_path): + n_samples = 0 + for dataset_dir in dataset_root_path: + with open( + os.path.join(dataset_dir, f"preprocessed_files_train.txt"), + "r", + encoding="utf-8", + ) as fp: + for line in fp.readlines(): + n_samples += 1 + return n_samples + + +def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list: + element_windows = [] + + if len(elements) > window_size: + max_step = math.ceil((len(elements) - window_size)/slice_interval) + + for i in range(0, max_step + 1): + # element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))])) + if (i*slice_interval+window_size) >= len(elements): + _window = copy.deepcopy(elements[i*slice_interval:]) + else: + _window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size]) + element_windows.append(_window) + return element_windows + else: + return [elements] + +def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list: + word_windows = [] + parse_class_windows = [] + parse_relation_windows = [] + + if len(lwords) > window_size: + max_step = math.ceil((len(lwords) - window_size)/slice_interval) + for i in range(0, max_step+1): + # _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))]) + if (i*slice_interval+window_size) >= len(lwords): + _word_window = copy.deepcopy(lwords[i*slice_interval:]) + else: + _word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size]) + + if len(_word_window) < 2: + continue + + first_word_id = _word_window[0]['word_id'] + last_word_id = _word_window[-1]['word_id'] + + # assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords)) + # word list + for _word in _word_window: + _word['word_id'] -= first_word_id + + + # Entity extraction + _class_window = {k: [] for k in list(parse_class.keys())} + for class_name, _parse_class in parse_class.items(): + for group in _parse_class: + tmp = [] + for idw in group: + idw -= first_word_id + if 0 <= idw <= (last_word_id - first_word_id): + tmp.append(idw) + _class_window[class_name].append(tmp) + + # Entity Linking + _relation_window = [] + for pair in parse_relation: + if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]): + _relation_window.append([idw - first_word_id for idw in pair]) + + word_windows.append(_word_window) + parse_class_windows.append(_class_window) + parse_relation_windows.append(_relation_window) + + return word_windows, parse_class_windows, parse_relation_windows + else: + return [lwords], [parse_class], [parse_relation] + +def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: + start_pos = 1 + end_pos = start_pos + lvalids[0] + embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) + cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) + sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) + + for i in range(1, len(lpatches)): + start_pos = 1 + end_pos = start_pos + lvalids[i] + + overlap_gap = copy.deepcopy(loverlaps[i-1]) + window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...]) + + if overlap_gap != 0: + prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) + curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) + assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + + if average: + avg_overlap = ( + prev_overlap + curr_overlap + ) / 2. + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens, window], dim=1 + ) + return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) + + + +def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: + start_pos = 1 + end_pos = start_pos + lvalids[0] + embedding_tokens = lpatches[0][:, start_pos:end_pos, ...] + cls_token = lpatches[0][:, :1, ...] + sep_token = lpatches[0][:, -1:, ...] + + for i in range(1, len(lpatches)): + start_pos = 1 + end_pos = start_pos + lvalids[i] + + overlap_gap = loverlaps[i-1] + window = lpatches[i][:, start_pos:end_pos, ...] + + if overlap_gap != 0: + prev_overlap = embedding_tokens[:, -overlap_gap:, ...] + curr_overlap = window[:, :overlap_gap, ...] + assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + + if average: + avg_overlap = ( + prev_overlap + curr_overlap + ) / 2. + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens, window], dim=1 + ) + return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) + + + +# def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: +# # start_pos = 1 +# # end_pos = start_pos + lvalids[0] +# # embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) +# embedding_tokens = np.zeros((1, 2046, 768)) +# cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) +# sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) + +# token_apperance_count = np.zeros((2046,)) +# start_idx = 1 +# for loverlap, lvalid in zip(loverlaps, lvalids): +# token_apperance_count[start_idx:start_idx+lvalid] += 1 +# start_idx = start_idx + lvalid - loverlap + +# embedding_matrix_spos = 0 +# for i in range(0, len(lpatches)): +# embedding_matrix_epos = embedding_matrix_spos + int(lvalids[i]) + +# # assert embedding_matrix_epos - embedding_matrix_spos == end_pos - 1, (embedding_matrix_spos, embedding_matrix_epos, lvalid[i], end_pos) + +# overlap_gap = copy.deepcopy(loverlaps[i]).cpu().numpy() +# window = copy.deepcopy(lpatches[i][:, 1:int(lvalids[i])+1, ...]).cpu().numpy() + +# # embedding_tokens[:,embedding_matrix_spos:embedding_matrix_epos:] = window * token_apperance_count[embedding_matrix_spos:embedding_matrix_epos] +# for i in range(len(window)): +# window[:,i,:] *= token_apperance_count[embedding_matrix_spos + i] +# embedding_tokens[: ,int(embedding_matrix_spos): int(embedding_matrix_epos), :] += window +# embedding_matrix_spos -= overlap_gap + + +# embedding_tokens = torch.tensor(embedding_tokens).type(torch.HalfTensor).cuda() + +# # if overlap_gap != 0: +# # prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) +# # curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) +# # assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + +# # if average: +# # avg_overlap = ( +# # prev_overlap + curr_overlap +# # ) / 2. +# # embedding_tokens = torch.cat( +# # [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 +# # ) +# # else: +# # embedding_tokens = torch.cat( +# # [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 +# # ) +# # else: +# # embedding_tokens = torch.cat( +# # [embedding_tokens, window], dim=1 +# # ) +# return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/__init__.py new file mode 100755 index 0000000..859a3f7 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/__init__.py @@ -0,0 +1,15 @@ + +from model.combined_model import CombinedKVUModel +from model.kvu_model import KVUModel +from model.document_kvu_model import DocumentKVUModel + +def get_model(cfg): + if cfg.stage == 1: + model = CombinedKVUModel(cfg=cfg) + elif cfg.stage == 2: + model = KVUModel(cfg=cfg) + elif cfg.stage == 3: + model = DocumentKVUModel(cfg=cfg) + else: + AssertionError('[ERROR] Trainging stage is wrong') + return model diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/combined_model.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/combined_model.py new file mode 100755 index 0000000..86b3cd9 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/combined_model.py @@ -0,0 +1,224 @@ +import os +import torch +from torch import nn +from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor +from transformers import LayoutXLMTokenizer +from transformers import AutoTokenizer, XLMRobertaModel + + +from model.relation_extractor import RelationExtractor +from utils import load_checkpoint + + +class CombinedKVUModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + + self._get_backbones(self.model_cfg.backbone) + + self._create_head() + + if os.path.exists(self.model_cfg.ckpt_model_file): + self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer') + self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer') + self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer') + self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key') + + + self.loss_func = nn.CrossEntropyLoss() + + if self.freeze: + for name, param in self.named_parameters(): + if 'backbone' in name: + param.requires_grad = False + if self.finetune_only == 'EE': + for name, param in self.named_parameters(): + if 'itc_layer' not in name and 'stc_layer' not in name: + param.requires_grad = False + if self.finetune_only == 'EL': + for name, param in self.named_parameters(): + if 'relation_layer' not in name or 'relation_layer_from_key' in name: + param.requires_grad = False + if self.finetune_only == 'ELK': + for name, param in self.named_parameters(): + if 'relation_layer_from_key' not in name: + param.requires_grad = False + + def _create_head(self): + self.backbone_hidden_size = 768 + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + self.n_classes = self.model_cfg.n_classes + 1 + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + + + def _get_backbones(self, config_type): + + self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base') + self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) + self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base') + + @staticmethod + def _init_weight(module): + init_std = 0.02 + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.LayerNorm): + nn.init.normal_(module.weight, 1.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + + def forward(self, batch): + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) + + last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} + + loss = 0.0 + if any(['labels' in key for key in batch.keys()]): + loss = self._get_loss(head_outputs, batch) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch) + el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) + + loss = itc_loss + stc_loss + el_loss + el_loss_from_key + + return loss + + def _get_itc_loss(self, itc_outputs, batch): + itc_mask = batch["are_box_first_tokens"].view(-1) + + itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) + itc_logits = itc_logits[itc_mask] + + itc_labels = batch["itc_labels"].view(-1) + itc_labels = itc_labels[itc_mask] + + itc_loss = self.loss_func(itc_logits, itc_labels) + + return itc_loss + + def _get_stc_loss(self, stc_outputs, batch): + inv_attention_mask = 1 - batch["attention_mask_layoutxlm"] + + bsz, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + + invalid_token_mask = torch.cat( + [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 + ).bool() + stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool() + + stc_logits = stc_outputs.view(-1, max_seq_length + 1) + stc_logits = stc_logits[stc_mask] + + stc_labels = batch["stc_labels"].view(-1) + stc_labels = stc_labels[stc_mask] + + stc_loss = self.loss_func(stc_logits, stc_labels) + + return stc_loss + + def _get_el_loss(self, el_outputs, batch, from_key=False): + bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape + device = batch["attention_mask_layoutxlm"].device + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + + box_first_token_mask = torch.cat( + [ + (batch["are_box_first_tokens"] == False), + torch.zeros([bsz, 1], dtype=torch.bool).to(device), + ], + axis=1, + ) + el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) + el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + mask = batch["are_box_first_tokens"].view(-1) + + logits = el_outputs.view(-1, max_seq_length + 1) + logits = logits[mask] + + if from_key: + el_labels = batch["el_labels_from_key"] + else: + el_labels = batch["el_labels"] + labels = el_labels.view(-1) + labels = labels[mask] + + loss = self.loss_func(logits, labels) + return loss diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/document_kvu_model.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/document_kvu_model.py new file mode 100755 index 0000000..c97be90 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/document_kvu_model.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor +from transformers import LayoutLMv2Config, LayoutLMv2Model +from transformers import LayoutXLMTokenizer +from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel + + +from model.relation_extractor import RelationExtractor +from utils import load_checkpoint + + +class DocumentKVUModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.train_cfg = cfg.train + + self._get_backbones(self.model_cfg.backbone) + # if 'pth' in self.model_cfg.ckpt_model_file: + # self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone, 'backbone') + + self._create_head() + + self.loss_func = nn.CrossEntropyLoss() + + def _create_head(self): + self.backbone_hidden_size = self.backbone_config.hidden_size + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + self.n_classes = self.model_cfg.n_classes + 1 + self.relations = self.model_cfg.n_relations + # self.repr_hiddent_size = self.backbone_hidden_size + self.n_classes + (self.train_cfg.max_seq_length + 1) * 3 + self.repr_hiddent_size = self.backbone_hidden_size + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=self.relations, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=self.relations, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # Classfication Layer for whole document + # (1) Initial token classification + self.itc_layer_document = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (3) Linking token classification + self.relation_layer_document = RelationExtractor( + n_relations=self.relations, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (4) Linking token classification + self.relation_layer_from_key_document = RelationExtractor( + n_relations=self.relations, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + self.relation_layer_from_key.apply(self._init_weight) + + self.itc_layer_document.apply(self._init_weight) + self.stc_layer_document.apply(self._init_weight) + self.relation_layer_document.apply(self._init_weight) + self.relation_layer_from_key_document.apply(self._init_weight) + + + def _get_backbones(self, config_type): + configs = { + 'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, + 'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor}, + 'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, + } + + self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path) + if config_type != 'xlm-roberta': + self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path) + else: + self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False) + self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False) + self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path) + + @staticmethod + def _init_weight(module): + init_std = 0.02 + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.LayerNorm): + nn.init.normal_(module.weight, 1.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + + def forward(self, batches): + head_outputs_list = [] + loss = 0.0 + for batch in batches["windows"]: + image = batch["image"] + input_ids = batch["input_ids"] + bbox = batch["bbox"] + attention_mask = batch["attention_mask"] + + if self.freeze: + for param in self.backbone.parameters(): + param.requires_grad = False + + if self.model_cfg.backbone == 'layoutxlm': + backbone_outputs = self.backbone( + image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask + ) + else: + backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) + + last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + + window_repr = last_hidden_states.transpose(0, 1).contiguous() + + head_outputs = {"window_repr": window_repr, + "itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + head_outputs_list.append(head_outputs) + + batch = batches["documents"] + + document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) + document_repr = document_repr.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) + el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0) + + head_outputs = {"itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch) + el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) + + loss = itc_loss + stc_loss + el_loss + el_loss_from_key + + return loss + + def _get_itc_loss(self, itc_outputs, batch): + itc_mask = batch["are_box_first_tokens"].view(-1) + + itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) + itc_logits = itc_logits[itc_mask] + + itc_labels = batch["itc_labels"].view(-1) + itc_labels = itc_labels[itc_mask] + + itc_loss = self.loss_func(itc_logits, itc_labels) + + return itc_loss + + def _get_stc_loss(self, stc_outputs, batch): + inv_attention_mask = 1 - batch["attention_mask"] + + bsz, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + + invalid_token_mask = torch.cat( + [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 + ).bool() + stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + stc_mask = batch["attention_mask"].view(-1).bool() + + stc_logits = stc_outputs.view(-1, max_seq_length + 1) + stc_logits = stc_logits[stc_mask] + + stc_labels = batch["stc_labels"].view(-1) + stc_labels = stc_labels[stc_mask] + + stc_loss = self.loss_func(stc_logits, stc_labels) + + return stc_loss + + def _get_el_loss(self, el_outputs, batch, from_key=False): + bsz, max_seq_length = batch["attention_mask"].shape + device = batch["attention_mask"].device + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + + box_first_token_mask = torch.cat( + [ + (batch["are_box_first_tokens"] == False), + torch.zeros([bsz, 1], dtype=torch.bool).to(device), + ], + axis=1, + ) + el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) + el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + mask = batch["are_box_first_tokens"].view(-1) + + logits = el_outputs.view(-1, max_seq_length + 1) + logits = logits[mask] + + if from_key: + el_labels = batch["el_labels_from_key"] + else: + el_labels = batch["el_labels"] + labels = el_labels.view(-1) + labels = labels[mask] + + loss = self.loss_func(logits, labels) + return loss diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/kvu_model.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/kvu_model.py new file mode 100755 index 0000000..d500370 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/kvu_model.py @@ -0,0 +1,248 @@ +import os +import torch +from torch import nn +from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor +from transformers import LayoutXLMTokenizer +from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2 + + +from model.relation_extractor import RelationExtractor +from utils import load_checkpoint + + +class KVUModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.device = 'cuda' + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + + # if cfg.stage == 2: + # self.freeze = True + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)): + self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + + self._create_head() + self.loss_func = nn.CrossEntropyLoss() + + if self.freeze: + for name, param in self.named_parameters(): + if 'backbone' in name: + param.requires_grad = False + + def _create_head(self): + self.backbone_hidden_size = 768 + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + self.n_classes = self.model_cfg.n_classes + 1 + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + + + def _get_backbones(self, config_type): + self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base') + self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) + self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base') + + @staticmethod + def _init_weight(module): + init_std = 0.02 + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.LayerNorm): + nn.init.normal_(module.weight, 1.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + + + # def forward(self, inputs): + # token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda() + # itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() + # stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) + # el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) + # el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0) + # head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + # "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} + + # loss = self._get_loss(head_outputs, inputs) + # return head_outputs, loss + + + # def forward_single_doccument(self, lbatches): + def forward(self, lbatches): + windows = lbatches['windows'] + token_embeddings_windows = [] + lvalids = [] + loverlaps = [] + + for i, batch in enumerate(windows): + batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')} + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) + + + last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] + + lvalids.append(batch['len_valid_tokens']) + loverlaps.append(batch['len_overlap_tokens']) + token_embeddings_windows.append(last_hidden_states_layoutxlm) + + + token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False) + # token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True) + + + token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda() + itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0) + head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key, + 'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()} + + + + loss = 0.0 + if any(['labels' in key for key in lbatches.keys()]): + labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')} + loss = self._get_loss(head_outputs, labels) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch) + el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) + + loss = itc_loss + stc_loss + el_loss + el_loss_from_key + + return loss + + def _get_itc_loss(self, itc_outputs, batch): + itc_mask = batch["are_box_first_tokens"].view(-1) + + itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) + itc_logits = itc_logits[itc_mask] + + itc_labels = batch["itc_labels"].view(-1) + itc_labels = itc_labels[itc_mask] + + itc_loss = self.loss_func(itc_logits, itc_labels) + + return itc_loss + + def _get_stc_loss(self, stc_outputs, batch): + inv_attention_mask = 1 - batch["attention_mask_layoutxlm"] + + bsz, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + + invalid_token_mask = torch.cat( + [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 + ).bool() + + stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool() + stc_logits = stc_outputs.view(-1, max_seq_length + 1) + stc_logits = stc_logits[stc_mask] + + stc_labels = batch["stc_labels"].view(-1) + stc_labels = stc_labels[stc_mask] + + stc_loss = self.loss_func(stc_logits, stc_labels) + + return stc_loss + + def _get_el_loss(self, el_outputs, batch, from_key=False): + bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape + + device = batch["attention_mask_layoutxlm"].device + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + + box_first_token_mask = torch.cat( + [ + (batch["are_box_first_tokens"] == False), + torch.zeros([bsz, 1], dtype=torch.bool).to(device), + ], + axis=1, + ) + el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) + el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + mask = batch["are_box_first_tokens"].view(-1) + + logits = el_outputs.view(-1, max_seq_length + 1) + logits = logits[mask] + + if from_key: + el_labels = batch["el_labels_from_key"] + else: + el_labels = batch["el_labels"] + labels = el_labels.view(-1) + labels = labels[mask] + + loss = self.loss_func(logits, labels) + return loss diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/relation_extractor.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/relation_extractor.py new file mode 100755 index 0000000..40a169e --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/model/relation_extractor.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + + +class RelationExtractor(nn.Module): + def __init__( + self, + n_relations, + backbone_hidden_size, + head_hidden_size, + head_p_dropout=0.1, + ): + super().__init__() + + self.n_relations = n_relations + self.backbone_hidden_size = backbone_hidden_size + self.head_hidden_size = head_hidden_size + self.head_p_dropout = head_p_dropout + + self.drop = nn.Dropout(head_p_dropout) + self.q_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.k_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size)) + nn.init.normal_(self.dummy_node) + + def forward(self, h_q, h_k): + h_q = self.q_net(self.drop(h_q)) + + dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1) + h_k = torch.cat([h_k, dummy_vec], axis=0) + h_k = self.k_net(self.drop(h_k)) + + head_q = h_q.view( + h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size + ) + head_k = h_k.view( + h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size + ) + + relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k)) + + return relation_score diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitignore b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitignore new file mode 100755 index 0000000..bf9f45b --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitignore @@ -0,0 +1,133 @@ +externals/sdsv_dewarp +test/ +.vscode/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*/__pycache__/ +*.py[cod] +*$py.class +results/ +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitmodules b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitmodules new file mode 100755 index 0000000..4e60ad1 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/.gitmodules @@ -0,0 +1,6 @@ +[submodule "sdsvtr"] + path = externals/sdsvtr + url = https://github.com/mrlasdt/sdsvtr.git +[submodule "sdsvtd"] + path = externals/sdsvtd + url = https://github.com/mrlasdt/sdsvtd.git diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/README.md b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/README.md new file mode 100755 index 0000000..ca1b349 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/README.md @@ -0,0 +1,47 @@ +# OCR Engine + +OCR Engine is a Python package that combines text detection and recognition models from [mmdet](https://github.com/open-mmlab/mmdetection) and [mmocr](https://github.com/open-mmlab/mmocr) to perform Optical Character Recognition (OCR) on various inputs. The package currently supports three types of input: a single image, a recursive directory, or a csv file. + +## Installation + +To install OCR Engine, clone the repository and install the required packages: + +```bash +git clone git@github.com:mrlasdt/ocr-engine.git +cd ocr-engine +pip install -r requirements.txt + +``` + + +## Usage + +To use OCR Engine, simply run the `ocr_engine.py` script with the desired input type and input path. For example, to perform OCR on a single image: + +```css +python ocr_engine.py --input_type image --input_path /path/to/image.jpg +``` + +To perform OCR on a recursive directory: + +```css +python ocr_engine.py --input_type directory --input_path /path/to/directory/ + +``` + +To perform OCR on a csv file: + + +``` +python ocr_engine.py --input_type csv --input_path /path/to/file.csv +``` + +OCR Engine will automatically detect and recognize text in the input and output the results in a CSV file named `ocr_results.csv`. + +## Contributing + +If you would like to contribute to OCR Engine, please fork the repository and submit a pull request. We welcome contributions of all types, including bug fixes, new features, and documentation improvements. + +## License + +OCR Engine is released under the [MIT License](https://opensource.org/licenses/MIT). See the LICENSE file for more information. diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/__init__.py new file mode 100755 index 0000000..aabc310 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/__init__.py @@ -0,0 +1,11 @@ +# # Define package-level variables +# __version__ = '0.0' + +# Import modules +from .src.ocr import OcrEngine +# from .src.word_formation import words_to_lines +from .src.word_formation import words_to_lines_tesseract as words_to_lines +from .src.utils import ImageReader, read_ocr_result_from_txt +from .src.dto import Word, Line, Page, Document, Box +# Expose package contents +__all__ = ["OcrEngine", "Box", "Word", "Line", "Page", "Document", "words_to_lines", "ImageReader", "read_ocr_result_from_txt"] diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/requirements.txt b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/requirements.txt new file mode 100755 index 0000000..b247e52 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/requirements.txt @@ -0,0 +1,82 @@ +addict==2.4.0 +asttokens==2.2.1 +autopep8==1.6.0 +backcall==0.2.0 +backports.functools-lru-cache==1.6.4 +brotlipy==0.7.0 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==2.0.4 +click==8.1.3 +colorama==0.4.6 +cryptography==39.0.1 +debugpy==1.5.1 +decorator==5.1.1 +docopt==0.6.2 +entrypoints==0.4 +executing==1.2.0 +flit_core==3.6.0 +idna==3.4 +importlib-metadata==6.0.0 +ipykernel==6.15.0 +ipython==8.11.0 +jedi==0.18.2 +jupyter-client==7.0.6 +jupyter_core==4.12.0 +Markdown==3.4.1 +markdown-it-py==2.2.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mkl-fft==1.3.1 +mkl-random==1.2.2 +mkl-service==2.4.0 +mmcv-full==1.7.1 +model-index==0.1.11 +nest-asyncio==1.5.6 +numpy==1.23.5 +opencv-python==4.7.0.72 +openmim==0.3.6 +ordered-set==4.1.0 +packaging==23.0 +pandas==1.5.3 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.4.0 +pip==22.3.1 +pipdeptree==2.5.2 +prompt-toolkit==3.0.38 +psutil==5.9.0 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycodestyle==2.10.0 +pycparser==2.21 +Pygments==2.14.0 +pyOpenSSL==23.0.0 +PySocks==1.7.1 +python-dateutil==2.8.2 +pytz==2022.7.1 +PyYAML==6.0 +pyzmq==19.0.2 +requests==2.28.1 +rich==13.3.1 +sdsvtd==0.1.1 +sdsvtr==0.0.5 +setuptools==65.6.3 +Shapely==1.8.4 +six==1.16.0 +stack-data==0.6.2 +tabulate==0.9.0 +toml==0.10.2 +torch==1.13.1 +torchvision==0.14.1 +tornado==6.1 +tqdm==4.65.0 +traitlets==5.9.0 +typing_extensions==4.4.0 +urllib3==1.26.14 +wcwidth==0.2.6 +wheel==0.38.4 +yapf==0.32.0 +yarg==0.1.9 +zipp==3.15.0 diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/run.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/run.py new file mode 100755 index 0000000..5837bf1 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/run.py @@ -0,0 +1,143 @@ +""" +see scripts/run_ocr.sh to run +""" +# from pathlib import Path # add parent path to run debugger +# import sys +# FILE = Path(__file__).absolute() +# sys.path.append(FILE.parents[2].as_posix()) + + +from src.utils import construct_file_path, ImageReader +from src.dto import Line +from src.ocr import OcrEngine +import argparse +import tqdm +import pandas as pd +from pathlib import Path +import json +import os +import numpy as np +from typing import Union, Tuple, List +current_dir = os.getcwd() + + +def get_args(): + parser = argparse.ArgumentParser() + # parser image + parser.add_argument("--image", type=str, required=True, + help="path to input image/directory/csv file") + parser.add_argument("--save_dir", type=str, required=True, + help="path to save directory") + parser.add_argument( + "--base_dir", type=str, required=False, default=current_dir, + help="used when --image and --save_dir are relative paths to a base directory, default to current directory") + parser.add_argument( + "--export_csv", type=str, required=False, default="", + help="used when --image is a directory. If set, a csv file contains image_path, ocr_path and label will be exported to save_dir.") + parser.add_argument( + "--export_img", type=bool, required=False, default=False, help="whether to save the visualize img") + parser.add_argument("--ocr_kwargs", type=str, required=False, default="") + opt = parser.parse_args() + return opt + + +def load_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {} + engine = OcrEngine(**kw) + print("[INFO] Engine loaded") + return engine + + +def convert_relative_path_to_positive_path(tgt_dir: Path, base_dir: Path) -> Path: + return tgt_dir if tgt_dir.is_absolute() else base_dir.joinpath(tgt_dir) + + +def get_paths_from_opt(opt) -> Tuple[Path, Path]: + # BC\ kiem\ tra\ y\ te -> BC kiem tra y te + img_path = opt.image.replace("\\ ", " ").strip() + save_dir = opt.save_dir.replace("\\ ", " ").strip() + base_dir = opt.base_dir.replace("\\ ", " ").strip() + input_image = convert_relative_path_to_positive_path( + Path(img_path), Path(base_dir)) + save_dir = convert_relative_path_to_positive_path( + Path(save_dir), Path(base_dir)) + if not save_dir.exists(): + save_dir.mkdir() + print("[INFO]: Creating folder ", save_dir) + return input_image, save_dir + + +def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None: + save_dir_or_path = Path(save_dir_or_path) + if isinstance(img, np.ndarray): + if save_dir_or_path.is_dir(): + raise ValueError( + "numpy array input require a save path, not a save dir") + page = engine(img) + save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt") + ) if save_dir_or_path.is_dir() else str(save_dir_or_path) + page.write_to_file('word', save_path) + if export_img: + page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, ) + + +def process_dir( + dir_path: str, save_dir: str, engine: OcrEngine, export_img: bool, lskip_dir: List[str] = [], + ddata: dict = {"img_path": list(), + "ocr_path": list(), + "label": list()}) -> None: + dir_path = Path(dir_path) + # save_dir_sub = Path(construct_file_path(save_dir, dir_path, ext="")) + save_dir = Path(save_dir) + save_dir.mkdir(exist_ok=True) + for img_path in (pbar := tqdm.tqdm(dir_path.iterdir())): + pbar.set_description(f"Processing {dir_path}") + if img_path.is_dir() and img_path not in lskip_dir: + save_dir_sub = save_dir.joinpath(img_path.stem) + process_dir(img_path, str(save_dir_sub), engine, ddata) + elif img_path.suffix.lower() in ImageReader.supported_ext: + simg_path = str(img_path) + try: + img = ImageReader.read( + simg_path) if img_path.suffix != ".pdf" else ImageReader.read(simg_path)[0] + save_path = str(Path(save_dir).joinpath( + img_path.stem + ".txt")) + process_img(img, save_path, engine, export_img) + except Exception as e: + print('[ERROR]: ', e, ' at ', simg_path) + continue + ddata["img_path"].append(simg_path) + ddata["ocr_path"].append(save_path) + ddata["label"].append(dir_path.stem) + # ddata.update({"img_path": img_path, "save_path": save_path, "label": dir_path.stem}) + return ddata + + +def process_csv(csv_path: str, engine: OcrEngine) -> None: + df = pd.read_csv(csv_path) + if not 'image_path' in df.columns or not 'ocr_path' in df.columns: + raise AssertionError('Cannot fing image_path in df headers') + for row in df.iterrows(): + process_img(row.image_path, row.ocr_path, engine) + + +if __name__ == "__main__": + opt = get_args() + engine = load_engine(opt) + print("[INFO]: OCR engine settings:", engine.settings) + img, save_dir = get_paths_from_opt(opt) + + lskip_dir = [] + if img.is_dir(): + ddata = process_dir(img, save_dir, engine, opt.export_img) + if opt.export_csv: + pd.DataFrame.from_dict(ddata).to_csv( + Path(save_dir).joinpath(opt.export_csv)) + elif img.suffix in ImageReader.supported_ext: + process_img(str(img), save_dir, engine, opt.export_img) + elif img.suffix == '.csv': + print("[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used") + process_csv(img, engine) + else: + raise NotImplementedError('[ERROR]: Unsupported file {}'.format(img)) diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/scripts/run_ocr.sh b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/scripts/run_ocr.sh new file mode 100755 index 0000000..13d5c26 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/scripts/run_ocr.sh @@ -0,0 +1,23 @@ +#bash scripts/run_ocr.sh -i /mnt/ssd1T/hungbnt/DocumentClassification/data/OCR040_043 -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr/OCR040_043 -e out.csv -k "{\"device\":\"cuda:1\"}" -x True +export PYTHONWARNINGS="ignore" + +while getopts i:o:b:e:x:k: flag +do + case "${flag}" in + i) img=${OPTARG};; + o) out_dir=${OPTARG};; + b) base_dir=${OPTARG};; + e) export_csv=${OPTARG};; + x) export_img=${OPTARG};; + k) ocr_kwargs=${OPTARG};; + esac +done +echo "run.py --image=\"$img\" --save_dir \"$out_dir\" --base_dir \"$base_dir\" --export_csv \"$export_csv\" --export_img \"$export_img\" --ocr_kwargs \"$ocr_kwargs\"" + +python run.py \ + --image="$img" \ + --save_dir $out_dir \ + --export_csv $export_csv\ + --export_img $export_img\ + --ocr_kwargs $ocr_kwargs\ + diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml new file mode 100755 index 0000000..3232828 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/settings.yml @@ -0,0 +1,17 @@ +detector: "/models/Kie_invoice_ap/wild_receipt_finetune_weights_c_lite.pth" +rotator_version: "/models/Kie_invoice_ap/best_bbox_mAP_epoch_30_lite.pth" +recog_max_seq_len: 25 +recognizer: "satrn-lite-general-pretrain-20230106" +device: "cuda:0" +do_extend_bbox: True +margin_bbox: [0, 0.03, 0.02, 0.05] +batch_mode: False +batch_size: 16 +auto_rotate: True +img_size: [] #[1920,1920] #text det default size: 1280x1280 #[] = originla size, TODO: fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size +deskew: False #True +words_to_lines: { + "gradient": 0.6, + "max_x_dist": 20, + "y_overlap_threshold": 0.5, +} \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/dto.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/dto.py new file mode 100755 index 0000000..c1c644d --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/dto.py @@ -0,0 +1,453 @@ +import numpy as np +from typing import Optional, List +import cv2 +from PIL import Image +from .utils import visualize_bbox_and_label + + +class Box: + def __init__(self, x1, y1, x2, y2, conf=-1., label=""): + self.x1 = x1 + self.y1 = y1 + self.x2 = x2 + self.y2 = y2 + self.conf = conf + self.label = label + + def __repr__(self) -> str: + return str(self.bbox) + + def __str__(self) -> str: + return str(self.bbox) + + def get(self, return_confidence=False) -> list: + return self.bbox if not return_confidence else self.xyxyc + + def __getitem__(self, key): + return self.bbox[key] + + @property + def width(self): + return self.x2 - self.x1 + + @property + def height(self): + return self.y2 - self.y1 + + @property + def bbox(self) -> list: + return [self.x1, self.y1, self.x2, self.y2] + + @bbox.setter + def bbox(self, bbox_: list): + self.x1, self.y1, self.x2, self.y2 = bbox_ + + @property + def xyxyc(self) -> list: + return [self.x1, self.y1, self.x2, self.y2, self.conf] + + @staticmethod + def normalize_bbox(bbox: list): + return [int(b) for b in bbox] + + def to_int(self): + self.x1, self.y1, self.x2, self.y2 = self.normalize_bbox([self.x1, self.y1, self.x2, self.y2]) + return self + + @staticmethod + def clamp_bbox_by_img_wh(bbox: list, width: int, height: int): + x1, y1, x2, y2 = bbox + x1 = min(max(0, x1), width) + x2 = min(max(0, x2), width) + y1 = min(max(0, y1), height) + y2 = min(max(0, y2), height) + return (x1, y1, x2, y2) + + def clamp_by_img_wh(self, width: int, height: int): + self.x1, self.y1, self.x2, self.y2 = self.clamp_bbox_by_img_wh( + [self.x1, self.y1, self.x2, self.y2], width, height) + return self + + @staticmethod + def extend_bbox(bbox: list, margin: list): # -> Self (python3.11) + margin_l, margin_t, margin_r, margin_b = margin + l, t, r, b = bbox # left, top, right, bottom + t = t - (b - t) * margin_t + b = b + (b - t) * margin_b + l = l - (r - l) * margin_l + r = r + (r - l) * margin_r + return [l, t, r, b] + + def get_extend_bbox(self, margin: list): + extended_bbox = self.extend_bbox(self.bbox, margin) + return Box(*extended_bbox, label=self.label) + + @staticmethod + def bbox_is_valid(bbox: list) -> bool: + l, t, r, b = bbox # left, top, right, bottom + return True if (b - t) * (r - l) > 0 else False + + def is_valid(self) -> bool: + return self.bbox_is_valid(self.bbox) + + @staticmethod + def crop_img_by_bbox(img: np.ndarray, bbox: list) -> np.ndarray: + l, t, r, b = bbox + return img[t:b, l:r] + + def crop_img(self, img: np.ndarray) -> np.ndarray: + return self.crop_img_by_bbox(img, self.bbox) + + +class Word: + def __init__( + self, + image=None, + text="", + conf_cls=-1., + bndbox: Optional[Box] = None, + conf_detect=-1., + kie_label="", + ): + self.type = "word" + self.text = text + self.image = image + self.conf_detect = conf_detect + self.conf_cls = conf_cls + # [left, top,right,bot] coordinate of top-left and bottom-right point + self.boundingbox = bndbox + self.word_id = 0 # id of word + self.word_group_id = 0 # id of word_group which instance belongs to + self.line_id = 0 # id of line which instance belongs to + self.paragraph_id = 0 # id of line which instance belongs to + self.kie_label = kie_label + + @property + def bbox(self): + return self.boundingbox.bbox + + @property + def height(self): + return self.boundingbox.height + + @property + def width(self): + return self.boundingbox.width + + def __repr__(self) -> str: + return self.text + + def __str__(self) -> str: + return self.text + + def invalid_size(self): + return (self.boundingbox[2] - self.boundingbox[0]) * ( + self.boundingbox[3] - self.boundingbox[1] + ) > 0 + + def is_special_word(self): + left, top, right, bottom = self.boundingbox + width, height = right - left, bottom - top + text = self.text + + if text is None: + return True + + # if len(text) > 7: + # return True + if len(text) >= 7: + no_digits = sum(c.isdigit() for c in text) + return no_digits / len(text) >= 0.3 + + return False + + +class Word_group: + def __init__(self, list_words_: List[Word] = list(), + text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1): + self.type = "word_group" + self.list_words = list_words_ # dict of word instances + self.word_group_id = 0 # word group id + self.line_id = 0 # id of line which instance belongs to + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = text + self.boundingbox = boundingbox + self.kie_label = "" + self.conf_cls = conf_cls + + @property + def bbox(self): + return self.boundingbox + + def __repr__(self) -> str: + return self.text + + def __str__(self) -> str: + return self.text + + def add_word(self, word: Word): # add a word instance to the word_group + if word.text != "✪": + for w in self.list_words: + if word.word_id == w.word_id: + print("Word id collision") + return False + word.word_group_id = self.word_group_id # + word.line_id = self.line_id + word.paragraph_id = self.paragraph_id + self.list_words.append(word) + self.text += " " + word.text + if self.boundingbox == [-1, -1, -1, -1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [ + min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3]), + ] + return True + else: + return False + + def update_word_group_id(self, new_word_group_id): + self.word_group_id = new_word_group_id + for i in range(len(self.list_words)): + self.list_words[i].word_group_id = new_word_group_id + + def update_kie_label(self): + list_kie_label = [word.kie_label for word in self.list_words] + dict_kie = dict() + for label in list_kie_label: + if label not in dict_kie: + dict_kie[label] = 1 + else: + dict_kie[label] += 1 + total = len(list(dict_kie.values())) + max_value = max(list(dict_kie.values())) + list_keys = list(dict_kie.keys()) + list_values = list(dict_kie.values()) + self.kie_label = list_keys[list_values.index(max_value)] + + def update_text(self): # update text after changing positions of words in list word + text = "" + for word in self.list_words: + text += " " + word.text + self.text = text + + +class Line: + def __init__(self, list_word_groups: List[Word_group] = [], + text: str = '', boundingbox: Box = Box(-1, -1, -1, -1), conf_cls: float = -1): + self.type = "line" + self.list_word_groups = list_word_groups # list of Word_group instances in the line + self.line_id = 0 # id of line in the paragraph + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = text + self.boundingbox = boundingbox + self.conf_cls = conf_cls + + @property + def bbox(self): + return self.boundingbox + + def __repr__(self) -> str: + return self.text + + def __str__(self) -> str: + return self.text + + def add_group(self, word_group: Word_group): # add a word_group instance + if word_group.list_words is not None: + for wg in self.list_word_groups: + if word_group.word_group_id == wg.word_group_id: + print("Word_group id collision") + return False + + self.list_word_groups.append(word_group) + self.text += word_group.text + word_group.paragraph_id = self.paragraph_id + word_group.line_id = self.line_id + + for i in range(len(word_group.list_words)): + word_group.list_words[ + i + ].paragraph_id = self.paragraph_id # set paragraph_id for word + word_group.list_words[i].line_id = self.line_id # set line_id for word + return True + return False + + def update_line_id(self, new_line_id): + self.line_id = new_line_id + for i in range(len(self.list_word_groups)): + self.list_word_groups[i].line_id = new_line_id + for j in range(len(self.list_word_groups[i].list_words)): + self.list_word_groups[i].list_words[j].line_id = new_line_id + + def merge_word(self, word): # word can be a Word instance or a Word_group instance + if word.text != "✪": + if self.boundingbox == [-1, -1, -1, -1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [ + min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3]), + ] + self.list_word_groups.append(word) + self.text += " " + word.text + return True + return False + + def __cal_ratio(self, top1, bottom1, top2, bottom2): + sorted_vals = sorted([top1, bottom1, top2, bottom2]) + intersection = sorted_vals[2] - sorted_vals[1] + min_height = min(bottom1 - top1, bottom2 - top2) + if min_height == 0: + return -1 + ratio = intersection / min_height + return ratio + + def __cal_ratio_height(self, top1, bottom1, top2, bottom2): + + height1, height2 = top1 - bottom1, top2 - bottom2 + ratio_height = float(max(height1, height2)) / float(min(height1, height2)) + return ratio_height + + def in_same_line(self, input_line, thresh=0.7): + # calculate iou in vertical direction + _, top1, _, bottom1 = self.boundingbox + _, top2, _, bottom2 = input_line.boundingbox + + ratio = self.__cal_ratio(top1, bottom1, top2, bottom2) + ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2) + + if ( + (top2 <= top1 <= bottom2) or (top1 <= top2 <= bottom1) + and ratio >= thresh + and (ratio_height < 2) + ): + return True + return False + + +class Paragraph: + def __init__(self, id=0, lines=None): + self.list_lines = lines if lines is not None else [] # list of all lines in the paragraph + self.paragraph_id = id # index of paragraph in the ist of paragraph + self.text = "" + self.boundingbox = [-1, -1, -1, -1] + + @property + def bbox(self): + return self.boundingbox + + def __repr__(self) -> str: + return self.text + + def __str__(self) -> str: + return self.text + + def add_line(self, line: Line): # add a line instance + if line.list_word_groups is not None: + for l in self.list_lines: + if line.line_id == l.line_id: + print("Line id collision") + return False + for i in range(len(line.list_word_groups)): + line.list_word_groups[ + i + ].paragraph_id = ( + self.paragraph_id + ) # set paragraph id for every word group in line + for j in range(len(line.list_word_groups[i].list_words)): + line.list_word_groups[i].list_words[ + j + ].paragraph_id = ( + self.paragraph_id + ) # set paragraph id for every word in word groups + line.paragraph_id = self.paragraph_id # set paragraph id for line + self.list_lines.append(line) # add line to paragraph + self.text += " " + line.text + return True + else: + return False + + def update_paragraph_id( + self, new_paragraph_id + ): # update new paragraph_id for all lines, word_groups, words inside paragraph + self.paragraph_id = new_paragraph_id + for i in range(len(self.list_lines)): + self.list_lines[ + i + ].paragraph_id = new_paragraph_id # set new paragraph_id for line + for j in range(len(self.list_lines[i].list_word_groups)): + self.list_lines[i].list_word_groups[ + j + ].paragraph_id = new_paragraph_id # set new paragraph_id for word_group + for k in range(len(self.list_lines[i].list_word_groups[j].list_words)): + self.list_lines[i].list_word_groups[j].list_words[ + k + ].paragraph_id = new_paragraph_id # set new paragraph id for word + return True + + +class Page: + def __init__(self, llines: List[Line], image: np.ndarray) -> None: + self.__llines = llines + self.__image = image + self.__drawed_image = None + + @property + def llines(self): + return self.__llines + + @property + def image(self): + return self.__image + + @property + def PIL_image(self): + return Image.fromarray(self.__image) + + @property + def drawed_image(self): + if self.__drawed_image: + self.__drawed_image = self + + def visualize_bbox_and_label(self, **kwargs: dict): + if self.__drawed_image is not None: + return self.__drawed_image + bboxes = list() + texts = list() + for line in self.__llines: + for word_group in line.list_word_groups: + for word in word_group.list_words: + bboxes.append([int(float(b)) for b in word.bbox[:]]) + texts.append(word.text) + img = visualize_bbox_and_label(self.__image, bboxes, texts, **kwargs) + self.__drawed_image = img + return self.__drawed_image + + def save_img(self, save_path: str, **kwargs: dict) -> None: + img = self.visualize_bbox_and_label(**kwargs) + cv2.imwrite(save_path, img) + + def write_to_file(self, mode: str, save_path: str) -> None: + f = open(save_path, "w+", encoding="utf-8") + for line in self.__llines: + if mode == 'line': + xmin, ymin, xmax, ymax = line.bbox[:] + f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, line.text)) + elif mode == "word": + for word_group in line.list_word_groups: + for word in word_group.list_words: + # xmin, ymin, xmax, ymax = word.bbox[:] + xmin, ymin, xmax, ymax = [int(float(b)) for b in word.bbox[:]] + f.write("{}\t{}\t{}\t{}\t{}\n".format(xmin, ymin, xmax, ymax, word.text)) + f.close() + + +class Document: + def __init__(self, lpages: List[Page]) -> None: + self.lpages = lpages diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/ocr.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/ocr.py new file mode 100755 index 0000000..2c30c01 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/ocr.py @@ -0,0 +1,207 @@ +from typing import Union, overload, List, Optional, Tuple +from PIL import Image +import torch +import numpy as np +import yaml +from pathlib import Path +import mmcv +from sdsvtd import StandaloneYOLOXRunner +from sdsvtr import StandaloneSATRNRunner +from .utils import ImageReader, chunks, rotate_bbox, Timer +# from .utils import jdeskew as deskew +# from externals.deskew.sdsv_dewarp import pdeskew as deskew +from .utils import deskew, post_process_recog +from .dto import Word, Line, Page, Document, Box +# from .word_formation import words_to_lines as words_to_lines +# from .word_formation import wo rds_to_lines_mmocr as words_to_lines +from .word_formation import words_to_lines_tesseract as words_to_lines +DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml" + + +class OcrEngine: + def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs: dict): + """ Warper of text detection and text recognition + :param settings_file: path to default setting file + :param kwargs: keyword arguments to overwrite the default settings file + """ + + with open(settings_file) as f: + # use safe_load instead load + self.__settings = yaml.safe_load(f) + for k, v in kwargs.items(): # overwrite default settings by keyword arguments + if k not in self.__settings: + raise ValueError("Invalid setting found in OcrEngine: ", k) + self.__settings[k] = v + + if "cuda" in self.__settings["device"]: + if not torch.cuda.is_available(): + print("[WARNING]: CUDA is not available, running with cpu instead") + self.__settings["device"] = "cpu" + self._detector = StandaloneYOLOXRunner( + version=self.__settings["detector"], + device=self.__settings["device"], + auto_rotate=self.__settings["auto_rotate"], + rotator_version=self.__settings["rotator_version"]) + self._recognizer = StandaloneSATRNRunner( + version=self.__settings["recognizer"], + return_confident=True, device=self.__settings["device"], + max_seq_len_overwrite=self.__settings["recog_max_seq_len"] + ) + # extend the bbox to avoid losing accent mark in vietnames, if using ocr for only english, disable it + self._do_extend_bbox = self.__settings["do_extend_bbox"] + # left, top, right, bottom"] + self._margin_bbox = self.__settings["margin_bbox"] + self._batch_mode = self.__settings["batch_mode"] + self._batch_size = self.__settings["batch_size"] + self._deskew = self.__settings["deskew"] + self._img_size = self.__settings["img_size"] + self.__version__ = { + "detector": self.__settings["detector"], + "recognizer": self.__settings["recognizer"], + } + + @property + def version(self): + return self.__version__ + + @property + def settings(self): + return self.__settings + + # @staticmethod + # def xyxyc_to_xyxy_c(xyxyc: np.ndarray) -> Tuple[List[list], list]: + # ''' + # convert sdsvtd yoloX detection output to list of bboxes and list of confidences + # @param xyxyc: array of shape (n, 5) + # ''' + # xyxy = xyxyc[:, :4].tolist() + # confs = xyxyc[:, 4].tolist() + # return xyxy, confs + # -> Tuple[np.ndarray, List[Box]]: + def preprocess(self, img: np.ndarray) -> np.ndarray: + img_ = img.copy() + if self.__settings["img_size"]: + img_ = mmcv.imrescale( + img, tuple(self.__settings["img_size"]), + return_scale=False, interpolation='bilinear', backend='cv2') + if self._deskew: + with Timer("deskew"): + img_, angle = deskew(img_) + # for i, bbox in enumerate(bboxes): + # rotated_bbox = rotate_bbox(bbox[:], angle, img.shape[:2]) + # bboxes[i].bbox = rotated_bbox + return img_ # , bboxes + + def run_detect(self, img: np.ndarray, return_raw: bool = False) -> Tuple[np.ndarray, Union[List[Box], List[list]]]: + ''' + run text detection and return list of xyxyc if return_confidence is True, otherwise return a list of xyxy + ''' + pred_det = self._detector(img) + if self.__settings["auto_rotate"]: + img, pred_det = pred_det + pred_det = pred_det[0] # only image at a time + return (img, pred_det.tolist()) if return_raw else (img, [Box(*xyxyc) for xyxyc in pred_det.tolist()]) + + def run_recog(self, imgs: List[np.ndarray]) -> Union[List[str], List[Tuple[str, float]]]: + if len(imgs) == 0: + return list() + pred_rec = self._recognizer(imgs) + return [(post_process_recog(word), conf) for word, conf in zip(pred_rec[0], pred_rec[1])] + + def read_img(self, img: str) -> np.ndarray: + return ImageReader.read(img) + + def get_cropped_imgs(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> Tuple[List[np.ndarray], List[bool]]: + """ + img: np image + bboxes: list of xyxy + """ + lcropped_imgs = list() + mask = list() + for bbox in bboxes: + bbox = Box(*bbox) if isinstance(bbox, list) else bbox + bbox = bbox.get_extend_bbox( + self._margin_bbox) if self._do_extend_bbox else bbox + bbox.clamp_by_img_wh(img.shape[1], img.shape[0]) + bbox.to_int() + if not bbox.is_valid(): + mask.append(False) + continue + cropped_img = bbox.crop_img(img) + lcropped_imgs.append(cropped_img) + mask.append(True) + return lcropped_imgs, mask + + def read_page(self, img: np.ndarray, bboxes: List[Union[Box, list]]) -> List[Line]: + if len(bboxes) == 0: # no bbox found + return list() + with Timer("cropped imgs"): + lcropped_imgs, mask = self.get_cropped_imgs(img, bboxes) + with Timer("recog"): + # batch_mode for efficiency + pred_recs = self.run_recog(lcropped_imgs) + with Timer("construct words"): + lwords = list() + for i in range(len(pred_recs)): + if not mask[i]: + continue + text, conf_rec = pred_recs[i][0], pred_recs[i][1] + bbox = Box(*bboxes[i]) if isinstance(bboxes[i], + list) else bboxes[i] + lwords.append(Word( + image=img, text=text, conf_cls=conf_rec, bndbox=bbox, conf_detect=bbox.conf)) + with Timer("words to lines"): + return words_to_lines( + lwords, **self.__settings["words_to_lines"])[0] + + # https://stackoverflow.com/questions/48127642/incompatible-types-in-assignment-on-union + + @overload + def __call__(self, img: Union[str, np.ndarray, Image.Image]) -> Page: ... + + @overload + def __call__( + self, img: List[Union[str, np.ndarray, Image.Image]]) -> Document: ... + + def __call__(self, img): + """ + Accept an image or list of them, return ocr result as a page or document + """ + with Timer("read image"): + img = ImageReader.read(img) + if not self._batch_mode: + if isinstance(img, list): + if len(img) == 1: + img = img[0] # in case input type is a 1 page pdf + else: + raise AssertionError( + "list input can only be used with batch_mode enabled") + img = self.preprocess(img) + with Timer("detect"): + img, bboxes = self.run_detect(img) + with Timer("read_page"): + llines = self.read_page(img, bboxes) + return Page(llines, img) + else: + lpages = [] + # chunks to reduce memory footprint + for imgs in chunks(img, self._batch_size): + # pred_dets = self._detector(imgs) + # TEMP: use list comprehension because sdsvtd do not support batch mode of text detection + img = self.preprocess(img) + img, bboxes = self.run_detect(img) + for img_, bboxes_ in zip(imgs, bboxes): + llines = self.read_page(img, bboxes_) + page = Page(llines, img) + lpages.append(page) + return Document(lpages) + + + +if __name__ == "__main__": + img_path = "/mnt/ssd1T/hungbnt/Cello/data/PH/Sea7/Sea_7_1.jpg" + engine = OcrEngine(device="cuda:0", return_confidence=True) + # https://stackoverflow.com/questions/66435480/overload-following-optional-argument + page = engine(img_path) # type: ignore + print(page.__llines) + diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/utils.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/utils.py new file mode 100755 index 0000000..d66405a --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/utils.py @@ -0,0 +1,346 @@ +from PIL import ImageFont, ImageDraw, Image, ImageOps +# import matplotlib.pyplot as plt +import numpy as np +import cv2 +import os +import time +from typing import Generator, Union, List, overload, Tuple, Callable +import glob +import math +from pathlib import Path +from pdf2image import convert_from_path +from deskew import determine_skew +from jdeskew.estimator import get_angle +from jdeskew.utility import rotate as jrotate + + +def post_process_recog(text: str) -> str: + text = text.replace("✪", " ") + return text + + +class Timer: + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, func: Callable, *args): + self.end_time = time.perf_counter() + self.elapsed_time = self.end_time - self.start_time + print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds") + + +def rotate( + image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]] +) -> np.ndarray: + old_width, old_height = image.shape[:2] + angle_radian = math.radians(angle) + width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width) + height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height) + image_center = tuple(np.array(image.shape[1::-1]) / 2) + rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + rot_mat[1, 2] += (width - old_width) / 2 + rot_mat[0, 2] += (height - old_height) / 2 + return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background) + + +# def rotate_bbox(bbox: list, angle: float) -> list: +# # Compute the center point of the bounding box +# cx = bbox[0] + bbox[2] / 2 +# cy = bbox[1] + bbox[3] / 2 + +# # Define the scale factor for the rotated bounding box +# scale = 1.0 # following the deskew and jdeskew function +# angle_radian = math.radians(angle) + +# # Obtain the rotation matrix using cv2.getRotationMatrix2D() +# M = cv2.getRotationMatrix2D((cx, cy), angle_radian, scale) + +# # Apply the rotation matrix to the four corners of the bounding box +# corners = np.array([[bbox[0], bbox[1]], +# [bbox[0] + bbox[2], bbox[1]], +# [bbox[0] + bbox[2], bbox[1] + bbox[3]], +# [bbox[0], bbox[1] + bbox[3]]], dtype=np.float32) +# rotated_corners = cv2.transform(np.array([corners]), M)[0] + +# # Compute the bounding box of the rotated corners +# x = int(np.min(rotated_corners[:, 0])) +# y = int(np.min(rotated_corners[:, 1])) +# w = int(np.max(rotated_corners[:, 0]) - np.min(rotated_corners[:, 0])) +# h = int(np.max(rotated_corners[:, 1]) - np.min(rotated_corners[:, 1])) +# rotated_bbox = [x, y, w, h] + +# return rotated_bbox + +def rotate_bbox(bbox: List[int], angle: float, old_shape: Tuple[int, int]) -> List[int]: + # https://medium.com/@pokomaru/image-and-bounding-box-rotation-using-opencv-python-2def6c39453 + bbox_ = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]] + h, w = old_shape + cx, cy = (int(w / 2), int(h / 2)) + + bbox_tuple = [ + (bbox_[0], bbox_[1]), + (bbox_[2], bbox_[3]), + (bbox_[4], bbox_[5]), + (bbox_[6], bbox_[7]), + ] # put x and y coordinates in tuples, we will iterate through the tuples and perform rotation + + rotated_bbox = [] + + for i, coord in enumerate(bbox_tuple): + M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + cos, sin = abs(M[0, 0]), abs(M[0, 1]) + newW = int((h * sin) + (w * cos)) + newH = int((h * cos) + (w * sin)) + M[0, 2] += (newW / 2) - cx + M[1, 2] += (newH / 2) - cy + v = [coord[0], coord[1], 1] + adjusted_coord = np.dot(M, v) + rotated_bbox.insert(i, (adjusted_coord[0], adjusted_coord[1])) + result = [int(x) for t in rotated_bbox for x in t] + return [result[i] for i in [0, 1, 2, -1]] # reformat to xyxy + + +def deskew(image: np.ndarray) -> Tuple[np.ndarray, float]: + grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + angle = 0. + try: + angle = determine_skew(grayscale) + except Exception: + pass + rotated = rotate(image, angle, (0, 0, 0)) if angle else image + return rotated, angle + + +def jdeskew(image: np.ndarray) -> Tuple[np.ndarray, float]: + angle = 0. + try: + angle = get_angle(image) + except Exception: + pass + # TODO: change resize = True and scale the bounding box + rotated = jrotate(image, angle, resize=False) if angle else image + return rotated, angle + + +class ImageReader: + """ + accept anything, return numpy array image + """ + supported_ext = [".png", ".jpg", ".jpeg", ".pdf", ".gif"] + + @staticmethod + def validate_img_path(img_path: str) -> None: + if not os.path.exists(img_path): + raise FileNotFoundError(img_path) + if os.path.isdir(img_path): + raise IsADirectoryError(img_path) + if not Path(img_path).suffix.lower() in ImageReader.supported_ext: + raise NotImplementedError("Not supported extension at {}".format(img_path)) + + @overload + @staticmethod + def read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: ... + + @overload + @staticmethod + def read(img: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: ... + + @overload + @staticmethod + def read(img: str) -> List[np.ndarray]: ... # for pdf or directory + + @staticmethod + def read(img): + if isinstance(img, list): + return ImageReader.from_list(img) + elif isinstance(img, str) and os.path.isdir(img): + return ImageReader.from_dir(img) + elif isinstance(img, str) and img.endswith(".pdf"): + return ImageReader.from_pdf(img) + else: + return ImageReader._read(img) + + @staticmethod + def from_dir(dir_path: str) -> List[np.ndarray]: + if os.path.isdir(dir_path): + image_files = glob.glob(os.path.join(dir_path, "*")) + return ImageReader.from_list(image_files) + else: + raise NotADirectoryError(dir_path) + + @staticmethod + def from_str(img_path: str) -> np.ndarray: + ImageReader.validate_img_path(img_path) + return ImageReader.from_PIL(Image.open(img_path)) + + @staticmethod + def from_np(img_array: np.ndarray) -> np.ndarray: + return img_array + + @staticmethod + def from_PIL(img_pil: Image.Image, transpose=True) -> np.ndarray: + # if img_pil.is_animated: + # raise NotImplementedError("Only static images are supported, animated image found") + if transpose: + img_pil = ImageOps.exif_transpose(img_pil) + if img_pil.mode != "RGB": + img_pil = img_pil.convert("RGB") + + return np.array(img_pil) + + @staticmethod + def from_list(img_list: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: + limgs = list() + for img_path in img_list: + try: + if isinstance(img_path, str): + ImageReader.validate_img_path(img_path) + limgs.append(ImageReader._read(img_path)) + except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e: + print("[ERROR]: ", e) + print("[INFO]: Skipping image {}".format(img_path)) + return limgs + + @staticmethod + def from_pdf(pdf_path: str, start_page: int = 0, end_page: int = 0) -> List[np.ndarray]: + pdf_file = convert_from_path(pdf_path) + if end_page is not None: + end_page = min(len(pdf_file), end_page + 1) + limgs = [np.array(pdf_page) for pdf_page in pdf_file[start_page:end_page]] + return limgs + + @staticmethod + def _read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: + if isinstance(img, str): + return ImageReader.from_str(img) + elif isinstance(img, Image.Image): + return ImageReader.from_PIL(img) + elif isinstance(img, np.ndarray): + return ImageReader.from_np(img) + else: + raise ValueError("Invalid img argument type: ", type(img)) + + +def get_name(file_path, ext: bool = True): + file_path_ = os.path.basename(file_path) + return file_path_ if ext else os.path.splitext(file_path_)[0] + + +def construct_file_path(dir, file_path, ext=''): + ''' + args: + dir: /path/to/dir + file_path /example_path/to/file.txt + ext = '.json' + return + /path/to/dir/file.json + ''' + return os.path.join( + dir, get_name(file_path, + True)) if ext == '' else os.path.join( + dir, get_name(file_path, + False)) + ext + + +def chunks(lst: list, n: int) -> Generator: + """ + Yield successive n-sized chunks from lst. + https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equally-sized-chunks + """ + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +def read_ocr_result_from_txt(file_path: str) -> Tuple[list, list]: + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + boxes, words = [], [] + for line in lines: + if line == "": + continue + x1, y1, x2, y2, text = line.split("\t") + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + + +def get_xyxywh_base_on_format(bbox, format): + if format == "xywh": + x1, y1, w, h = bbox[0], bbox[1], bbox[2], bbox[3] + x2, y2 = x1 + w, y1 + h + elif format == "xyxy": + x1, y1, x2, y2 = bbox + w, h = x2 - x1, y2 - y1 + else: + raise NotImplementedError("Invalid format {}".format(format)) + return (x1, y1, x2, y2, w, h) + + +def get_dynamic_params_for_bbox_of_label(text, x1, y1, w, h, img_h, img_w, font): + font_scale_factor = img_h / (img_w + img_h) + font_scale = w / (w + h) * font_scale_factor # adjust font scale by width height + thickness = int(font_scale_factor) + 1 + (text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0] + text_offset_x = x1 + text_offset_y = y1 - thickness + box_coords = ((text_offset_x, text_offset_y + 1), (text_offset_x + text_width - 2, text_offset_y - text_height - 2)) + return (font_scale, thickness, text_height, box_coords) + + +def visualize_bbox_and_label( + img, bboxes, texts, bbox_color=(200, 180, 60), + text_color=(0, 0, 0), + format="xyxy", is_vnese=False, draw_text=True): + ori_img_type = type(img) + if is_vnese: + img = Image.fromarray(img) if ori_img_type is np.ndarray else img + draw = ImageDraw.Draw(img) + img_w, img_h = img.size + font_pil_str = "fonts/arial.ttf" + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + else: + img_h, img_w = img.shape[0], img.shape[1] + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(bboxes)): + text = texts[i] # text = "{}: {:.0f}%".format(LABELS[classIDs[i]], confidences[i]*100) + x1, y1, x2, y2, w, h = get_xyxywh_base_on_format(bboxes[i], format) + font_scale, thickness, text_height, box_coords = get_dynamic_params_for_bbox_of_label( + text, x1, y1, w, h, img_h, img_w, font=font_cv2) + if is_vnese: + font_pil = ImageFont.truetype(font_pil_str, size=text_height) # type: ignore + fdraw_text = draw.text # type: ignore + fdraw_bbox = draw.rectangle # type: ignore + # Pil use different coordinate => y = y+thickness = y-thickness + 2*thickness + arg_text = ((box_coords[0][0], box_coords[1][1]), text) + kwarg_text = {"font": font_pil, "fill": text_color, "width": thickness} + arg_rec = ((x1, y1, x2, y2),) + kwarg_rec = {"outline": bbox_color, "width": thickness} + arg_rec_text = ((box_coords[0], box_coords[1]),) + kwarg_rec_text = {"fill": bbox_color, "width": thickness} + else: + # cv2.rectangle(img, box_coords[0], box_coords[1], color, cv2.FILLED) + # cv2.putText(img, text, (text_offset_x, text_offset_y), font, fontScale=font_scale, color=(50, 0,0), thickness=thickness) + # cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) + fdraw_text = cv2.putText + fdraw_bbox = cv2.rectangle + arg_text = (img, text, box_coords[0]) + kwarg_text = {"fontFace": font_cv2, "fontScale": font_scale, "color": text_color, "thickness": thickness} + arg_rec = (img, (x1, y1), (x2, y2)) + kwarg_rec = {"color": bbox_color, "thickness": thickness} + arg_rec_text = (img, box_coords[0], box_coords[1]) + kwarg_rec_text = {"color": bbox_color, "thickness": cv2.FILLED} + # draw a bounding box rectangle and label on the img + fdraw_bbox(*arg_rec, **kwarg_rec) # type: ignore + if draw_text: + fdraw_bbox(*arg_rec_text, **kwarg_rec_text) # type: ignore + fdraw_text(*arg_text, **kwarg_text) # type: ignore # text have to put in front of rec_text + return np.array(img) if ori_img_type is np.ndarray and is_vnese else img diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/word_formation.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/word_formation.py new file mode 100755 index 0000000..511c783 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/ocr-engine/src/word_formation.py @@ -0,0 +1,673 @@ +from builtins import dict +from .dto import Word, Line, Word_group, Box +import numpy as np +from typing import Optional, List, Tuple, Union +MIN_IOU_HEIGHT = 0.7 +MIN_WIDTH_LINE_RATIO = 0.05 + + +def resize_to_original( + boundingbox, scale +): # resize coordinates to match size of original image + left, top, right, bottom = boundingbox + left *= scale[1] + right *= scale[1] + top *= scale[0] + bottom *= scale[0] + return [left, top, right, bottom] + + +def check_iomin(word: Word, word_group: Word_group): + min_height = min( + word.boundingbox[3] - word.boundingbox[1], + word_group.boundingbox[3] - word_group.boundingbox[1], + ) + intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max( + word.boundingbox[1], word_group.boundingbox[1] + ) + if intersect / min_height > 0.7: + return True + return False + + +def prepare_line(words): + lines = [] + visited = [False] * len(words) + for id_word, word in enumerate(words): + if word.invalid_size() == 0: + continue + new_line = True + for i in range(len(lines)): + if ( + lines[i].in_same_line(word) and not visited[id_word] + ): # check if word is in the same line with lines[i] + lines[i].merge_word(word) + new_line = False + visited[id_word] = True + + if new_line == True: + new_line = Line() + new_line.merge_word(word) + lines.append(new_line) + + # print(len(lines)) + # sort line from top to bottom according top coordinate + lines.sort(key=lambda x: x.boundingbox[1]) + return lines + + +def __create_word_group(word, word_group_id): + new_word_group_ = Word_group() + new_word_group_.list_words = list() + new_word_group_.word_group_id = word_group_id + new_word_group_.add_word(word) + + return new_word_group_ + + +def __sort_line(line): + line.list_word_groups.sort( + key=lambda x: x.boundingbox[0] + ) # sort word in lines from left to right + + return line + + +def __merge_text_for_line(line): + line.text = "" + for word in line.list_word_groups: + line.text += " " + word.text + + return line + + +def __update_list_word_groups(line, word_group_id, word_id, line_width): + + old_list_word_group = line.list_word_groups + list_word_groups = [] + + inital_word_group = __create_word_group( + old_list_word_group[0], word_group_id) + old_list_word_group[0].word_id = word_id + list_word_groups.append(inital_word_group) + word_group_id += 1 + word_id += 1 + + for word in old_list_word_group[1:]: + check_word_group = True + word.word_id = word_id + word_id += 1 + + if ( + (not list_word_groups[-1].text.endswith(":")) + and ( + (word.boundingbox[0] - list_word_groups[-1].boundingbox[2]) + / line_width + < MIN_WIDTH_LINE_RATIO + ) + and check_iomin(word, list_word_groups[-1]) + ): + list_word_groups[-1].add_word(word) + check_word_group = False + + if check_word_group: + new_word_group = __create_word_group(word, word_group_id) + list_word_groups.append(new_word_group) + word_group_id += 1 + line.list_word_groups = list_word_groups + return line, word_group_id, word_id + + +def construct_word_groups_in_each_line(lines): + line_id = 0 + word_group_id = 0 + word_id = 0 + for i in range(len(lines)): + if len(lines[i].list_word_groups) == 0: + continue + + # left, top ,right, bottom + line_width = lines[i].boundingbox[2] - \ + lines[i].boundingbox[0] # right - left + line_width = 1 # TODO: to remove + lines[i] = __sort_line(lines[i]) + + # update text for lines after sorting + lines[i] = __merge_text_for_line(lines[i]) + + lines[i], word_group_id, word_id = __update_list_word_groups( + lines[i], + word_group_id, + word_id, + line_width) + lines[i].update_line_id(line_id) + line_id += 1 + return lines + + +def words_to_lines(words, check_special_lines=True): # words is list of Word instance + # sort word by top + words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0])) + # words.sort(key=lambda x: (sum(x.bbox[:]))) + number_of_word = len(words) + # print(number_of_word) + # sort list words to list lines, which have not contained word_group yet + lines = prepare_line(words) + + # construct word_groups in each line + lines = construct_word_groups_in_each_line(lines) + return lines, number_of_word + + +def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8): + """Check if two boxes are on the same line by their y-axis coordinates. + + Two boxes are on the same line if they overlap vertically, and the length + of the overlapping line segment is greater than min_y_overlap_ratio * the + height of either of the boxes. + + Args: + box_a (list), box_b (list): Two bounding boxes to be checked + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for boxes in the same line + + Returns: + The bool flag indicating if they are on the same line + """ + a_y_min = np.min(box_a[1::2]) + b_y_min = np.min(box_b[1::2]) + a_y_max = np.max(box_a[1::2]) + b_y_max = np.max(box_b[1::2]) + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def merge_bboxes_to_group(bboxes_group, x_sorted_boxes): + merged_bboxes = [] + for box_group in bboxes_group: + merged_box = {} + merged_box['text'] = ' '.join( + [x_sorted_boxes[idx]['text'] for idx in box_group]) + x_min, y_min = float('inf'), float('inf') + x_max, y_max = float('-inf'), float('-inf') + for idx in box_group: + x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max) + x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min) + y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max) + y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min) + merged_box['box'] = [ + x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max + ] + merged_box['list_words'] = [x_sorted_boxes[idx]['word'] + for idx in box_group] + merged_bboxes.append(merged_box) + return merged_bboxes + + +def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.3): + """Stitch fragmented boxes of words into lines. + + Note: part of its logic is inspired by @Johndirr + (https://github.com/faustomorales/keras-ocr/issues/22) + + Args: + boxes (list): List of ocr results to be stitched + max_x_dist (int): The maximum horizontal distance between the closest + edges of neighboring boxes in the same line + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for any pairs of neighboring boxes in the same line + + Returns: + merged_boxes(List[dict]): List of merged boxes and texts + """ + + if len(boxes) <= 1: + if len(boxes) == 1: + boxes[0]["list_words"] = [boxes[0]["word"]] + return boxes + + # merged_groups = [] + merged_lines = [] + + # sort groups based on the x_min coordinate of boxes + x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2])) + # store indexes of boxes which are already parts of other lines + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(x_sorted_boxes)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(x_sorted_boxes)): + if j in skip_idxs: + continue + if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'], + x_sorted_boxes[j]['box'], min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + # split line into lines if the distance between two neighboring + # sub-lines' is greater than max_x_dist + # groups = [] + # line_idx = 0 + # groups.append([line[0]]) + # for k in range(1, len(line)): + # curr_box = x_sorted_boxes[line[k]] + # prev_box = x_sorted_boxes[line[k - 1]] + # dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2]) + # if dist > max_x_dist: + # line_idx += 1 + # groups.append([]) + # groups[line_idx].append(line[k]) + + # # Get merged boxes + merged_line = merge_bboxes_to_group([line], x_sorted_boxes) + merged_lines.extend(merged_line) + # merged_group = merge_bboxes_to_group(groups,x_sorted_boxes) + # merged_groups.extend(merged_group) + + merged_lines = sorted(merged_lines, key=lambda x: np.min(x['box'][1::2])) + # merged_groups = sorted(merged_groups, key=lambda x: np.min(x['box'][1::2])) + return merged_lines # , merged_groups + +# REFERENCE +# https://vigneshgig.medium.com/bounding-box-sorting-algorithm-for-text-detection-and-object-detection-from-left-to-right-and-top-cf2c523c8a85 +# https://huggingface.co/spaces/tomofi/MMOCR/blame/main/mmocr/utils/box_util.py + + +def words_to_lines_mmocr(words: List[Word], *args) -> Tuple[List[Line], Optional[int]]: + bboxes = [{"box": [w.bbox[0], w.bbox[1], w.bbox[2], w.bbox[1], w.bbox[2], w.bbox[3], w.bbox[0], w.bbox[3]], + "text":w.text, "word":w} for w in words] + merged_lines = stitch_boxes_into_lines(bboxes) + merged_groups = merged_lines # TODO: fix code to return both word group and line + lwords_groups = [Word_group(list_words_=merged_box["list_words"], + text=merged_box["text"], + boundingbox=[merged_box["box"][i] for i in [0, 1, 2, -1]]) + for merged_box in merged_groups] + + llines = [Line(text=word_group.text, list_word_groups=[word_group], boundingbox=word_group.boundingbox) + for word_group in lwords_groups] + + return llines, None # same format with the origin words_to_lines + # lines = [Line() for merged] + + +# def most_overlapping_row(rows, top, bottom, y_shift): +# max_overlap = -1 +# max_overlap_idx = -1 +# for i, row in enumerate(rows): +# row_top, row_bottom = row +# overlap = min(top + y_shift, row_top) - max(bottom + y_shift, row_bottom) +# if overlap > max_overlap: +# max_overlap = overlap +# max_overlap_idx = i +# return max_overlap_idx +def most_overlapping_row(rows, row_words, top, bottom, y_shift, max_row_size, y_overlap_threshold=0.5): + max_overlap = -1 + max_overlap_idx = -1 + overlapping_rows = [] + + for i, row in enumerate(rows): + row_top, row_bottom = row + overlap = min(top - y_shift[i], row_top) - \ + max(bottom - y_shift[i], row_bottom) + + if overlap > max_overlap: + max_overlap = overlap + max_overlap_idx = i + + # if at least overlap 1 pixel and not (overlap too much and overlap too little) + if (row_bottom <= top and row_top >= bottom) and not (top - bottom - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold): + overlapping_rows.append(i) + + # Merge overlapping rows if necessary + if len(overlapping_rows) > 1: + merge_top = max(rows[i][0] for i in overlapping_rows) + merge_bottom = min(rows[i][1] for i in overlapping_rows) + + if merge_top - merge_bottom <= max_row_size: + # Merge rows + merged_row = (merge_top, merge_bottom) + merged_words = [] + # Remove other overlapping rows + + for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2 + merged_words.extend(row_words[row_idx]) + del rows[row_idx] + del row_words[row_idx] + + rows[overlapping_rows[0]] = merged_row + row_words[overlapping_rows[0]].extend(merged_words[::-1]) + max_overlap_idx = overlapping_rows[0] + + if top - bottom - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold: + max_overlap_idx = -1 + return max_overlap_idx + + +def stitch_boxes_into_lines_tesseract(words: list[Word], + gradient: float, y_overlap_threshold: float) -> Tuple[list[list[Word]], + float]: + sorted_words = sorted(words, key=lambda x: x.bbox[0]) + rows = [] + row_words = [] + max_row_size = sorted_words[0].height + running_y_shift = [] + for _i, word in enumerate(sorted_words): + # if word.bbox[1] > 340 and word.bbox[3] < 450: + # print("DEBUG") + # if word.text == "Lực": + # print("DEBUG") + bbox, _text = word.bbox[:], word.text + _x1, y1, _x2, y2 = bbox + top, bottom = y2, y1 + max_row_size = max(max_row_size, top - bottom) + overlap_row_idx = most_overlapping_row( + rows, row_words, top, bottom, running_y_shift, max_row_size, y_overlap_threshold) + + if overlap_row_idx == -1: # No overlapping row found + new_row = (top, bottom) + rows.append(new_row) + row_words.append([word]) + running_y_shift.append(0) + else: # Overlapping row found + row_top, row_bottom = rows[overlap_row_idx] + new_top = max(row_top, top) + new_bottom = min(row_bottom, bottom) + rows[overlap_row_idx] = (new_top, new_bottom) + row_words[overlap_row_idx].append(word) + new_shift = (bottom + top) / 2 - (row_bottom + row_top) / 2 + running_y_shift[overlap_row_idx] = gradient * \ + running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift + + # Sort rows and row_texts based on the top y-coordinate + sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][0]) + _sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data) + # /_|<- the perpendicular line of the horizontal line and the skew line of the page + page_skew_dist = sum(running_y_shift) / len(running_y_shift) + return sorted_row_words, page_skew_dist + + +def construct_word_groups_tesseract(sorted_row_words: list[list[Word]], + max_x_dist: int, page_skew_dist: float) -> list[list[list[Word]]]: + # approximate page_skew_angle by page_skew_dist + corrected_max_x_dist = max_x_dist * abs(np.cos(page_skew_dist / 180 * 3.14)) + constructed_row_word_groups = [] + for row_words in sorted_row_words: + lword_groups = [] + line_idx = 0 + lword_groups.append([row_words[0]]) + for k in range(1, len(row_words)): + curr_box = row_words[k].bbox[:] + prev_box = row_words[k - 1].bbox[:] + dist = curr_box[0] - prev_box[2] + if dist > corrected_max_x_dist: + line_idx += 1 + lword_groups.append([]) + lword_groups[line_idx].append(row_words[k]) + constructed_row_word_groups.append(lword_groups) + return constructed_row_word_groups + + +def group_bbox_and_text(lwords: list[Word]) -> tuple[Box, tuple[str, float]]: + text = ' '.join([word.text for word in lwords]) + x_min, y_min = float('inf'), float('inf') + x_max, y_max = float('-inf'), float('-inf') + conf_det = 0 + conf_cls = 0 + for word in lwords: + x_max = max(np.max(word.bbox[::2]), x_max) + x_min = min(np.min(word.bbox[::2]), x_min) + y_max = max(np.max(word.bbox[1::2]), y_max) + y_min = min(np.min(word.bbox[1::2]), y_min) + conf_det += word.conf_detect + conf_cls += word.conf_cls + bbox = Box(x_min, y_min, x_max, y_max, conf=conf_det / len(lwords)) + return bbox, (text, conf_cls / len(lwords)) + + +def words_to_lines_tesseract(words: List[Word], + gradient: float, max_x_dist: int, y_overlap_threshold: float) -> Tuple[List[Line], + Optional[int]]: + sorted_row_words, page_skew_dist = stitch_boxes_into_lines_tesseract( + words, gradient, y_overlap_threshold) + constructed_row_word_groups = construct_word_groups_tesseract( + sorted_row_words, max_x_dist, page_skew_dist) + llines = [] + for row in constructed_row_word_groups: + lwords_row = [] + lword_groups = [] + for word_group in row: + bbox_word_group, text_word_group = group_bbox_and_text(word_group) + lwords_row.extend(word_group) + lword_groups.append( + Word_group( + list_words_=word_group, text=text_word_group[0], + conf_cls=text_word_group[1], + boundingbox=bbox_word_group)) + bbox_line, text_line = group_bbox_and_text(lwords_row) + llines.append( + Line( + list_word_groups=lword_groups, text=text_line[0], + boundingbox=bbox_line, conf_cls=text_line[1])) + return llines, None + + +def near(word_group1: Word_group, word_group2: Word_group): + min_height = min( + word_group1.boundingbox[3] - word_group1.boundingbox[1], + word_group2.boundingbox[3] - word_group2.boundingbox[1], + ) + overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max( + word_group1.boundingbox[1], word_group2.boundingbox[1] + ) + + if overlap > 0: + return True + if abs(overlap / min_height) < 1.5: + print("near enough", abs(overlap / min_height), overlap, min_height) + return True + return False + + +def calculate_iou_and_near(wg1: Word_group, wg2: Word_group): + min_height = min( + wg1.boundingbox[3] - + wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1] + ) + overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max( + wg1.boundingbox[1], wg2.boundingbox[1] + ) + iou = overlap / min_height + distance = min( + abs(wg1.boundingbox[0] - wg2.boundingbox[2]), + abs(wg1.boundingbox[2] - wg2.boundingbox[0]), + ) + if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]): + return True + return False + + +def construct_word_groups_to_kie_label(list_word_groups: list): + kie_dict = dict() + for wg in list_word_groups: + if wg.kie_label == "other": + continue + if wg.kie_label not in kie_dict: + kie_dict[wg.kie_label] = [wg] + else: + kie_dict[wg.kie_label].append(wg) + + new_dict = dict() + for key, value in kie_dict.items(): + if len(value) == 1: + new_dict[key] = value + continue + + value.sort(key=lambda x: x.boundingbox[1]) + new_dict[key] = value + return new_dict + + +def invoice_construct_word_groups_to_kie_label(list_word_groups: list): + kie_dict = dict() + + for wg in list_word_groups: + if wg.kie_label == "other": + continue + if wg.kie_label not in kie_dict: + kie_dict[wg.kie_label] = [wg] + else: + kie_dict[wg.kie_label].append(wg) + + return kie_dict + + +def postprocess_total_value(kie_dict): + if "total_in_words_value" not in kie_dict: + return kie_dict + + for k, value in kie_dict.items(): + if k == "total_in_words_value": + continue + l = [] + for v in value: + if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]: + l.append(v) + + if len(l) != 0: + kie_dict[k] = l + + return kie_dict + + +def postprocess_tax_code_value(kie_dict): + if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict: + return kie_dict + + kie_dict["buyer_tax_code_value"] = [] + for v in kie_dict["seller_tax_code_value"]: + if "buyer_name_key" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_key"][0]) + ): + kie_dict["buyer_tax_code_value"].append(v) + continue + + if "buyer_name_value" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_value"][0]) + ): + kie_dict["buyer_tax_code_value"].append(v) + continue + + if "buyer_address_value" in kie_dict and near( + kie_dict["buyer_address_value"][0], v + ): + kie_dict["buyer_tax_code_value"].append(v) + return kie_dict + + +def postprocess_tax_code_key(kie_dict): + if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict: + return kie_dict + kie_dict["buyer_tax_code_key"] = [] + for v in kie_dict["seller_tax_code_key"]: + if "buyer_name_key" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_key"][0]) + ): + kie_dict["buyer_tax_code_key"].append(v) + continue + + if "buyer_name_value" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_value"][0]) + ): + kie_dict["buyer_tax_code_key"].append(v) + continue + + if "buyer_address_value" in kie_dict and near( + kie_dict["buyer_address_value"][0], v + ): + kie_dict["buyer_tax_code_key"].append(v) + + return kie_dict + + +def invoice_postprocess(kie_dict: dict): + # all keys or values which are below total_in_words_value will be thrown away + kie_dict = postprocess_total_value(kie_dict) + kie_dict = postprocess_tax_code_value(kie_dict) + kie_dict = postprocess_tax_code_key(kie_dict) + return kie_dict + + +def throw_overlapping_words(list_words): + new_list = [list_words[0]] + for word in list_words: + overlap = False + area = (word.boundingbox[2] - word.boundingbox[0]) * ( + word.boundingbox[3] - word.boundingbox[1] + ) + for word2 in new_list: + area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * ( + word2.boundingbox[3] - word2.boundingbox[1] + ) + xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0]) + xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2]) + ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1]) + ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3]) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + continue + + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect - ymin_intersect + ) + if area_intersect / area > 0.7 or area_intersect / area2 > 0.7: + overlap = True + if overlap == False: + new_list.append(word) + return new_list + + +def check_iou(box1: Word, box2: Box, threshold=0.9): + area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * ( + box1.boundingbox[3] - box1.boundingbox[1] + ) + area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) + xmin_intersect = max(box1.boundingbox[0], box2.xmin) + ymin_intersect = max(box1.boundingbox[1], box2.ymin) + xmax_intersect = min(box1.boundingbox[2], box2.xmax) + ymax_intersect = min(box1.boundingbox[3], box2.ymax) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + area_intersect = 0 + else: + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect - ymin_intersect + ) + union = area1 + area2 - area_intersect + iou = area_intersect / union + if iou > threshold: + return True + return False diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py new file mode 100755 index 0000000..3f4ce36 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/predictor.py @@ -0,0 +1,230 @@ +from omegaconf import OmegaConf +import os +import cv2 +import torch +# from functions import get_colormap, visualize +import sys +sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') # TODO: ??????? + +from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations +from model import get_model +from utils import load_model_weight + + +class KVUPredictor: + def __init__(self, configs, class_names, dummy_idx, mode=0): + cfg_path = configs['cfg'] + ckpt_path = configs['ckpt'] + + self.class_names = class_names + self.dummy_idx = dummy_idx + self.mode = mode + + print('[INFO] Loading Key-Value Understanding model ...') + self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path) + print("[INFO] Loaded model") + + if mode == 3: + self.max_window_count = cfg.train.max_window_count + self.window_size = cfg.train.window_size + self.slice_interval = 0 + self.dummy_idx = dummy_idx * self.max_window_count + else: + self.slice_interval = cfg.train.slice_interval + self.window_size = cfg.train.max_num_words + + + self.device = 'cuda' + + def _load_model(self, cfg_path, ckpt_path): + cfg = OmegaConf.load(cfg_path) + cfg.stage = self.mode + backbone_type = cfg.model.backbone + + print('[INFO] Checkpoint:', ckpt_path) + net = get_model(cfg) + load_model_weight(net, ckpt_path) + net.to('cuda') + net.eval() + return net, cfg, backbone_type + + def predict(self, input_sample): + if self.mode == 0: + if len(input_sample['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 1: + if len(input_sample['documents']['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 2: + if len(input_sample['windows'][0]['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = [], [], [], [] + for window in input_sample['windows']: + _bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window) + bbox.append(_bbox) + lwords.append(_lwords) + pr_class_words.append(_pr_class_words) + pr_relations.append(_pr_relations) + return bbox, lwords, pr_class_words, pr_relations + + elif self.mode == 3: + if len(input_sample["documents"]['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + else: + raise ValueError( + f"Not supported mode: {self.mode}" + ) + + def doc_predict(self, input_sample): + lwords = input_sample['documents']['words'] + for idx, window in enumerate(input_sample['windows']): + input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')} + + # input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')} + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + input_sample = input_sample['documents'] + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask'].squeeze(0) + bbox = input_sample['bbox'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + + def combined_predict(self, input_sample): + lwords = input_sample['words'] + input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')} + + input_sample = {k: v.to(self.device) for k, v in input_sample.items()} + + + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} + + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) + bbox = input_sample['bbox'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + def cat_predict(self, input_sample): + lwords = input_sample['documents']['words'] + + inputs = [] + for window in input_sample['windows']: + inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')}) + input_sample['windows'] = inputs + + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')} + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['documents']['are_box_first_tokens'] + attention_mask = input_sample['documents']['attention_mask_layoutxlm'] + bbox = input_sample['documents']['bbox'] + + dummy_idx = input_sample['documents']['bbox'].shape[0] + + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + + def get_ground_truth_label(self, ground_truth): + # ground_truth = self.preprocessor.load_ground_truth(json_file) + gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512] + gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512] + gt_el_label = ground_truth['el_labels'].squeeze(0) + + gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0) + lwords = ground_truth["words"] + + box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0) + attention_mask = ground_truth['attention_mask'].squeeze(0) + + bbox = ground_truth['bbox'].squeeze(0) + gt_first_words = parse_initial_words( + gt_itc_label, box_first_token_mask, self.class_names + ) + gt_class_words = parse_subsequent_words( + gt_stc_label, attention_mask, gt_first_words, self.dummy_idx + ) + + gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx) + gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx) + gt_relations = gt_relations_from_header | gt_relations_from_key + + return bbox, lwords, gt_class_words, gt_relations \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py new file mode 100755 index 0000000..89e3167 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/preprocess.py @@ -0,0 +1,601 @@ +import os +from typing import Any +import numpy as np +import pandas as pd +import imagesize +import itertools +from PIL import Image +import argparse + +import torch + +from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr +from utils.run_ocr import load_ocr_engine, process_img +from lightning_modules.utils import sliding_windows + + +class KVUProcess: + def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): + self.tokenizer_layoutxlm = tokenizer_layoutxlm + self.feature_extractor = feature_extractor + + self.max_seq_length = max_seq_length + self.backbone_type = backbone_type + self.class_names = class_names + + self.slice_interval = slice_interval + self.window_size = window_size + self.run_ocr = run_ocr + self.mode = mode + + self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token) + self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token) + self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token) + self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token) + + + self.class_idx_dic = dict( + [(class_name, idx) for idx, class_name in enumerate(self.class_names)] + ) + self.ocr_engine = None + if self.run_ocr == 1: + self.ocr_engine = load_ocr_engine() + + def __call__(self, img_path: str, ocr_path: str) -> list: + if (self.run_ocr == 1) or (not os.path.exists(ocr_path)): + process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False) + ocr_path = "tmp.txt" + lbboxes, lwords = read_ocr_result_from_txt(ocr_path) + lwords = post_process_basic_ocr(lwords) + bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval) + word_windows = sliding_windows(lwords, self.window_size, self.slice_interval) + assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}" + + width, height = imagesize.get(img_path) + images = [Image.open(img_path).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + + + if self.mode == 0: + output = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + elif self.mode == 1: + output = {} + windows = [] + for i in range(len(bbox_windows)): + _words = word_windows[i] + _bboxes = bbox_windows[i] + windows.append( + self.preprocess( + _bboxes, _words, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + ) + + output['windows'] = windows + elif self.mode == 2: + output = {} + windows = [] + output['doduments'] = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=2048) + for i in range(len(bbox_windows)): + _words = word_windows[i] + _bboxes = bbox_windows[i] + windows.append( + self.preprocess( + _bboxes, _words, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + ) + + output['windows'] = windows + else: + raise ValueError( + f"Not supported mode: {self.mode }" + ) + return output + + + def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): + input_ids_layoutxlm = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm + + attention_mask_layoutxlm = np.zeros(max_seq_length, dtype=int) + + bbox = np.zeros((max_seq_length, 8), dtype=np.float32) + + are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) + + list_layoutxlm_tokens = [] + + list_bbs = [] + list_words = [] + lwords = [''] * max_seq_length + + box_to_token_indices = [] + cum_token_idx = 0 + + cls_bbs = [0.0] * 8 + len_overlap_tokens = 0 + len_non_overlap_tokens = 0 + len_valid_tokens = 0 + + for word_idx, (bounding_box, word) in enumerate(zip(bounding_boxes, words)): + bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]] + layoutxlm_tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(word)) + + this_box_token_indices = [] + + len_valid_tokens += len(layoutxlm_tokens) + if word_idx < self.slice_interval: + len_non_overlap_tokens += len(layoutxlm_tokens) + + if len(layoutxlm_tokens) == 0: + layoutxlm_tokens.append(self.unk_token_id) + + + if len(list_layoutxlm_tokens) + len(layoutxlm_tokens) > max_seq_length - 2: + break + + + list_layoutxlm_tokens += layoutxlm_tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width'])) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height'])) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(layoutxlm_tokens))] + texts = [word for _ in range(len(layoutxlm_tokens))] + + for _ in layoutxlm_tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + list_words.extend(texts) ### + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [feature_maps['width'], feature_maps['height']] * 4 + + # For [CLS] and [SEP] + list_layoutxlm_tokens = ( + [self.cls_token_id_layoutxlm] + + list_layoutxlm_tokens[: max_seq_length - 2] + + [self.sep_token_id_layoutxlm] + ) + + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] + # list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] + list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] + + len_list_layoutxlm_tokens = len(list_layoutxlm_tokens) + input_ids_layoutxlm[:len_list_layoutxlm_tokens] = list_layoutxlm_tokens + attention_mask_layoutxlm[:len_list_layoutxlm_tokens] = 1 + + + bbox[:len_list_layoutxlm_tokens, :] = list_bbs + lwords[:len_list_layoutxlm_tokens] = list_words ### + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width'] + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height'] + + if self.backbone_type in ("layoutlm", "layoutxlm", "xlm-roberta"): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < max_seq_length + ] + are_box_first_tokens[st_indices] = True + + assert len_list_layoutxlm_tokens == len_valid_tokens + 2 + len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens + + ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 + + input_ids_layoutxlm = input_ids_layoutxlm[:ntokens] + attention_mask_layoutxlm = attention_mask_layoutxlm[:ntokens] + bbox = bbox[:ntokens] + are_box_first_tokens = are_box_first_tokens[:ntokens] + + + input_ids_layoutxlm = torch.from_numpy(input_ids_layoutxlm) + attention_mask_layoutxlm = torch.from_numpy(attention_mask_layoutxlm) + bbox = torch.from_numpy(bbox) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + len_valid_tokens = torch.tensor(len_valid_tokens) + len_overlap_tokens = torch.tensor(len_overlap_tokens) + return_dict = { + "img_path": feature_maps['img_path'], + "words": list_words, + "len_overlap_tokens": len_overlap_tokens, + 'len_valid_tokens': len_valid_tokens, + "image": feature_maps['image'], + "input_ids_layoutxlm": input_ids_layoutxlm, + "attention_mask_layoutxlm": attention_mask_layoutxlm, + "are_box_first_tokens": are_box_first_tokens, + "bbox": bbox, + } + return return_dict + + def load_ground_truth(self, json_file): + json_obj = read_json(json_file) + width = json_obj["meta"]["imageSize"]["width"] + height = json_obj["meta"]["imageSize"]["height"] + + input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id_layoutxlm + bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(self.max_seq_length, dtype=int) + + itc_labels = np.zeros(self.max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) + + # stc_labels stores the index of the previous token. + # A stored index of max_seq_length (512) indicates that + # this token is the initial token of a word box. + stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length + el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length + el_labels_from_key = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length + + + list_tokens = [] + list_bbs = [] + list_words = [] + box2token_span_map = [] + lwords = [''] * self.max_seq_length + + box_to_token_indices = [] + cum_token_idx = 0 + + cls_bbs = [0.0] * 8 + + for word_idx, word in enumerate(json_obj["words"]): + this_box_token_indices = [] + + tokens = word["layoutxlm_tokens"] + bb = word["boundingBox"] + text = word["text"] + + if len(tokens) == 0: + tokens.append(self.unk_token_id) + + if len(list_tokens) + len(tokens) > self.max_seq_length - 2: + break + + box2token_span_map.append( + [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] + ) # including st_idx + list_tokens += tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(tokens))] + texts = [text for _ in range(len(tokens))] + + for _ in tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + list_words.extend(texts) #### + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [width, height] * 4 + + # For [CLS] and [SEP] + list_tokens = ( + [self.cls_token_id_layoutxlm] + + list_tokens[: self.max_seq_length - 2] + + [self.sep_token_id_layoutxlm] + ) + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs] + # list_words = ['CLS'] + list_words[: self.max_seq_length - 2] + ['SEP'] ### + list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] + + + len_list_tokens = len(list_tokens) + input_ids[:len_list_tokens] = list_tokens + attention_mask[:len_list_tokens] = 1 + + bbox[:len_list_tokens, :] = list_bbs + lwords[:len_list_tokens] = list_words + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height + + if self.backbone_type in ("layoutlm", "layoutxlm"): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < self.max_seq_length + ] + are_box_first_tokens[st_indices] = True + + # Label + classes_dic = json_obj["parse"]["class"] + for class_name in self.class_names: + if class_name == "others": + continue + if class_name not in classes_dic: + continue + + for word_list in classes_dic[class_name]: + is_first, last_word_idx = True, -1 + for word_idx in word_list: + if word_idx >= len(box_to_token_indices): + break + box2token_list = box_to_token_indices[word_idx] + for converted_word_idx in box2token_list: + if converted_word_idx >= self.max_seq_length: + break # out of idx + + if is_first: + itc_labels[converted_word_idx] = self.class_idx_dic[ + class_name + ] + is_first, last_word_idx = False, converted_word_idx + else: + stc_labels[converted_word_idx] = last_word_idx + last_word_idx = converted_word_idx + + # Label + relations = json_obj["parse"]["relations"] + for relation in relations: + if relation[0] >= len(box2token_span_map) or relation[1] >= len( + box2token_span_map + ): + continue + if ( + box2token_span_map[relation[0]][0] >= self.max_seq_length + or box2token_span_map[relation[1]][0] >= self.max_seq_length + ): + continue + + word_from = box2token_span_map[relation[0]][0] + word_to = box2token_span_map[relation[1]][0] + # el_labels[word_to] = word_from + + if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512: + continue + + # if self.second_relations == 1: + # if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + # el_labels[word_to] = word_from # pair of (header, key) or (header-value) + # else: + #### 1st relation => ['key, 'value'] + #### 2st relation => ['header', 'key'or'value'] + if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: + el_labels_from_key[word_to] = word_from # pair of (key-value) + if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + el_labels[word_to] = word_from # pair of (header, key) or (header-value) + + + + input_ids = torch.from_numpy(input_ids) + bbox = torch.from_numpy(bbox) + attention_mask = torch.from_numpy(attention_mask) + + itc_labels = torch.from_numpy(itc_labels) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + stc_labels = torch.from_numpy(stc_labels) + el_labels = torch.from_numpy(el_labels) + el_labels_from_key = torch.from_numpy(el_labels_from_key) + + return_dict = { + # "image": feature_maps, + "input_ids": input_ids, + "bbox": bbox, + "words": lwords, + "attention_mask": attention_mask, + "itc_labels": itc_labels, + "are_box_first_tokens": are_box_first_tokens, + "stc_labels": stc_labels, + "el_labels": el_labels, + "el_labels_from_key": el_labels_from_key + } + + return return_dict + + +class DocumentKVUProcess(KVUProcess): + def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): + super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode) + self.max_window_count = max_window_count + self.pad_token_id = self.pad_token_id_layoutxlm + self.cls_token_id = self.cls_token_id_layoutxlm + self.sep_token_id = self.sep_token_id_layoutxlm + self.unk_token_id = self.unk_token_id_layoutxlm + self.tokenizer = self.tokenizer_layoutxlm + + def __call__(self, img_path: str, ocr_path: str) -> list: + if (self.run_ocr == 1) and (not os.path.exists(ocr_path)): + process_img(img_path, "tmp.txt", self.ocr_engine, export_img=False) + ocr_path = "tmp.txt" + lbboxes, lwords = read_ocr_result_from_txt(ocr_path) + lwords = post_process_basic_ocr(lwords) + + width, height = imagesize.get(img_path) + images = [Image.open(img_path).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + output = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}) + return output + + def preprocess(self, bounding_boxes, words, feature_maps): + n_words = len(words) + output_dicts = {'windows': [], 'documents': []} + n_empty_windows = 0 + + for i in range(self.max_window_count): + input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id + bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(self.max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) + + if n_words == 0: + n_empty_windows += 1 + output_dicts['windows'].append({ + "image": feature_maps['image'], + "input_ids": torch.from_numpy(input_ids), + "bbox": torch.from_numpy(bbox), + "words": [], + "attention_mask": torch.from_numpy(attention_mask), + "are_box_first_tokens": torch.from_numpy(are_box_first_tokens), + }) + continue + + start_word_idx = i * self.window_size + stop_word_idx = min(n_words, (i+1)*self.window_size) + + if start_word_idx >= stop_word_idx: + n_empty_windows += 1 + output_dicts['windows'].append(output_dicts['windows'][-1]) + continue + + list_tokens = [] + list_bbs = [] + list_words = [] + lwords = [''] * self.max_seq_length + + box_to_token_indices = [] + cum_token_idx = 0 + + cls_bbs = [0.0] * 8 + + for _, (bounding_box, word) in enumerate(zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx])): + bb = [[bounding_box[0], bounding_box[1]], [bounding_box[2], bounding_box[1]], [bounding_box[2], bounding_box[3]], [bounding_box[0], bounding_box[3]]] + tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(word)) + + this_box_token_indices = [] + + if len(tokens) == 0: + tokens.append(self.unk_token_id) + + if len(list_tokens) + len(tokens) > self.max_seq_length - 2: + break + + list_tokens += tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], feature_maps['width'])) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], feature_maps['height'])) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(tokens))] + texts = [word for _ in range(len(tokens))] + + for _ in tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + list_words.extend(texts) ### + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [feature_maps['width'], feature_maps['height']] * 4 + + # For [CLS] and [SEP] + list_tokens = ( + [self.cls_token_id] + + list_tokens[: self.max_seq_length - 2] + + [self.sep_token_id] + ) + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: self.max_seq_length - 2] + [sep_bbs] + if len(list_words) < 510: + list_words.extend(['

' for _ in range(510 - len(list_words))]) + list_words = [self.tokenizer._cls_token] + list_words[: self.max_seq_length - 2] + [self.tokenizer._sep_token] + + len_list_tokens = len(list_tokens) + input_ids[:len_list_tokens] = list_tokens + attention_mask[:len_list_tokens] = 1 + + bbox[:len_list_tokens, :] = list_bbs + lwords[:len_list_tokens] = list_words ### + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / feature_maps['width'] + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / feature_maps['height'] + + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < self.max_seq_length + ] + are_box_first_tokens[st_indices] = True + + input_ids = torch.from_numpy(input_ids) + bbox = torch.from_numpy(bbox) + attention_mask = torch.from_numpy(attention_mask) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + return_dict = { + "image": feature_maps['image'], + "input_ids": input_ids, + "bbox": bbox, + "words": list_words, + "attention_mask": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + } + output_dicts["windows"].append(return_dict) + + attention_mask = torch.cat([o['attention_mask'] for o in output_dicts["windows"]]) + are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]]) + if n_empty_windows > 0: + attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int)) + are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)) + bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]]) + words = [] + for o in output_dicts['windows']: + words.extend(o['words']) + + return_dict = { + "attention_mask": attention_mask, + "bbox": bbox, + "are_box_first_tokens": are_box_first_tokens, + "n_empty_windows": n_empty_windows, + "words": words + } + output_dicts['documents'] = return_dict + + return output_dicts + + + \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/requirements.txt b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/requirements.txt new file mode 100755 index 0000000..639eb12 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/requirements.txt @@ -0,0 +1,19 @@ +nptyping==1.4.2 +numpy==1.20.3 +pytorch-lightning==1.5.6 +omegaconf +pillow +six +overrides==4.1.2 +transformers==4.11.3 +seqeval==0.0.12 +imagesize +pandas==2.0.1 +xmltodict +dicttoxml + +tensorboard>=2.2.0 + +# code-style +isort==5.9.3 +black==21.9b0 \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/run.sh b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/run.sh new file mode 100755 index 0000000..9f64108 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/run.sh @@ -0,0 +1 @@ +python anyKeyValue.py --img_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/thucpd/thucpd/cope2n-ai/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0 \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/tmp.txt b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/tmp.txt new file mode 100755 index 0000000..52cb8ee --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/tmp.txt @@ -0,0 +1,106 @@ +1113 773 1220 825 BEST +1243 759 1378 808 DENKI +1410 752 1487 799 (S) +1430 707 1515 748 TAX +1511 745 1598 790 PTE +1542 700 1725 740 TNVOICE +1618 742 1706 783 LTD +1783 725 1920 773 FUNAN +1943 723 2054 767 MALL +1434 797 1576 843 WORTH +1599 785 1760 831 BRIDGE +1784 778 1846 822 RD +1277 846 1632 897 #02-16/#03-1 +1655 832 1795 877 FUNAN +1817 822 1931 869 MALL +1272 897 1518 956 S(179105) +1548 890 1655 943 TEL: +1686 877 1911 928 69046183 +1247 1011 1334 1068 GST +1358 1006 1447 1059 REG +1360 1063 1449 1115 RCB +1473 1003 1575 1055 NO.: +1474 1059 1555 1110 NO. +1595 1042 1868 1096 198202199E +1607 985 1944 1040 M2-0053813-7 +1056 1134 1254 1194 Opening +1276 1127 1391 1181 Hrs: +1425 1112 1647 1170 10:00:00 +1672 1102 1735 1161 AN +1755 1101 1819 1157 to +1846 1090 2067 1147 10:00:00 +2090 1080 2156 1141 PH +1061 1308 1228 1366 Staff: +1258 1300 1378 1357 3296 +1710 1283 1880 1337 Trans: +1936 1266 2192 1322 262152554 +1060 1372 1201 1429 Date: +1260 1358 1494 1419 22-03-23 +1540 1344 1664 1409 9:05 +1712 1339 1856 1407 Slip: +1917 1328 2196 1387 2000130286 +1124 1487 1439 1545 SALESPERSON +1465 1477 1601 1537 CODE. +1633 1471 1752 1530 6043 +1777 1462 2004 1519 HUHAHHAD +2032 1451 2177 1509 RAZIH +1070 1558 1187 1617 Item +1211 1554 1276 1615 No +1439 1542 1585 1601 Price +1750 1530 1841 1597 Qty +1951 1517 2120 1579 Amount +1076 1683 1276 1741 ANDROID +1304 1673 1477 1733 TABLET +1080 1746 1280 1804 2105976 +1509 1729 1705 1784 SAMSUNG +1734 1719 1931 1776 SH-P613 +1964 1709 2101 1768 128GB +1082 1809 1285 1869 SM-P613 +1316 1802 1454 1860 12838 +1429 1859 1600 1919 518.00 +1481 1794 1596 1855 WIFI +1622 1790 1656 1850 G +1797 1845 1824 1904 1 +1993 1832 2165 1892 518.00 +1088 1935 1347 1995 PROMOTION +1091 2000 1294 2062 2105664 +1520 1983 1717 2039 SAMSUNG +1743 1963 2106 2030 F-Sam-Redeen +1439 2111 1557 2173 0.00 +1806 2095 1832 2156 1 +2053 2081 2174 2144 0.00 +1106 2248 1250 2312 Total +1974 2206 2146 2266 518.00 +1107 2312 1204 2377 UOB +1448 2291 1567 2355 CARD +1978 2268 2147 2327 518.00 +1253 2424 1375 2497 GST% +1456 2411 1655 2475 Net.Amt +1818 2393 1912 2460 GST +2023 2387 2192 2445 Amount +1106 2494 1231 2560 GST8 +1486 2472 1661 2537 479.63 +1770 2458 1916 2523 38.37 +2027 2448 2203 2511 518.00 +1553 2601 1699 2666 THANK +1721 2592 1821 2661 YOU +1436 2678 1616 2749 please +1644 2682 1764 2732 come +1790 2660 1942 2729 again +1191 2862 1391 2931 Those +1426 2870 2018 2945 facebook.com +1565 2809 1690 2884 join +1709 2816 1777 2870 us +1799 2811 1868 2865 on +1838 2946 2024 3003 com .89 +1533 3006 2070 3088 ar.com/askbe +1300 3326 1659 3446 That's +1696 3308 1905 3424 not +1937 3289 2131 3408 all! +1450 3511 1633 3573 SCAN +1392 3589 1489 3645 QR +1509 3577 1698 3635 CODE +1321 3656 1370 3714 & +1517 3638 1768 3699 updates +1643 3882 1769 3932 Scan +1789 3868 1859 3926 Me \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/__init__.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/__init__.py new file mode 100755 index 0000000..066cfe3 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/__init__.py @@ -0,0 +1,127 @@ +import os +import torch +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.plugins import DDPPlugin +from utils.ema_callbacks import EMA + + +def _update_config(cfg): + cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints") + cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs") + + # set per-gpu batch size + num_devices = torch.cuda.device_count() + print('No. devices:', num_devices) + for mode in ["train", "val"]: + new_batch_size = cfg[mode].batch_size // num_devices + cfg[mode].batch_size = new_batch_size + +def _get_config_from_cli(): + cfg_cli = OmegaConf.from_cli() + cli_keys = list(cfg_cli.keys()) + for cli_key in cli_keys: + if "--" in cli_key: + cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key] + del cfg_cli[cli_key] + + return cfg_cli + +def get_callbacks(cfg): + callback_list = [] + checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir, + filename='best_model', + save_last=True, + save_top_k=1, + save_weights_only=True, + verbose=True, + monitor='val_f1', mode='max') + checkpoint_callback.FILE_EXTENSION = ".pth" + checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model" + callback_list.append(checkpoint_callback) + if cfg.callbacks.ema.decay != -1: + ema_callback = EMA(decay=0.9999) + callback_list.append(ema_callback) + return callback_list if len(callback_list) > 1 else checkpoint_callback + +def get_plugins(cfg): + plugins = [] + if cfg.train.strategy.type == "ddp": + plugins.append(DDPPlugin()) + + return plugins + +def get_loggers(cfg): + loggers = [] + + loggers.append( + TensorBoardLogger( + cfg.tensorboard_dir, name="", version="", default_hp_metric=False + ) + ) + + return loggers + +def cfg_to_hparams(cfg, hparam_dict, parent_str=""): + for key, val in cfg.items(): + if isinstance(val, DictConfig): + hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__") + else: + hparam_dict[parent_str + key] = str(val) + return hparam_dict + +def get_specific_pl_logger(pl_loggers, logger_type): + for pl_logger in pl_loggers: + if isinstance(pl_logger, logger_type): + return pl_logger + return None + +def get_class_names(dataset_root_path): + class_names_file = os.path.join(dataset_root_path[0], "class_names.txt") + class_names = ( + open(class_names_file, "r", encoding="utf-8").read().strip().split("\n") + ) + return class_names + +def create_exp_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Experiment dir : {}'.format(save_dir)) + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Save dir : {}'.format(save_dir)) + +def load_checkpoint(ckpt_path, model, key_include): + assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!" + state_dict = torch.load(ckpt_path, 'cpu')['state_dict'] + for key in list(state_dict.keys()): + if f'.{key_include}.' not in key: + del state_dict[key] + else: + state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something. + del state_dict[key] + model.load_state_dict(state_dict, strict=True) + print(f"Load checkpoint at {ckpt_path}") + return model + +def load_model_weight(net, pretrained_model_file): + pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[ + "state_dict" + ] + new_state_dict = {} + for k, v in pretrained_model_state_dict.items(): + new_k = k + if new_k.startswith("net."): + new_k = new_k[len("net.") :] + new_state_dict[new_k] = v + net.load_state_dict(new_state_dict) + + diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/ema_callbacks.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/ema_callbacks.py new file mode 100755 index 0000000..5f57cf8 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/ema_callbacks.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import copy +import os +import threading +from typing import Any, Dict, Iterable + +import pytorch_lightning as pl +import torch +from pytorch_lightning import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import rank_zero_info + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + def __init__( + self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + device = pl_module.device if not self.cpu_offload else torch.device('cpu') + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + # use the connector as NeMo calls the connector directly in the exp_manager when restoring. + connector = trainer._checkpoint_connector + ckpt_path = connector.resume_checkpoint_path + + if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__: + ext = checkpoint_callback.FILE_EXTENSION + if ckpt_path.endswith(f'-EMA{ext}'): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = ckpt_path.replace(ext, f'-EMA{ext}') + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) + + checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] + del ema_state_dict + rank_zero_info("EMA state has been restored.") + else: + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", + ) + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group['params']) + + def step(self, closure=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) for param in self.all_parameters() + ) + + if self.device.type == 'cuda': + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == 'cpu': + self.thread = threading.Thread( + target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters()) + state_dict = { + 'opt': self.optimizer.state_dict(), + 'ema': ema_params, + 'current_step': self.current_step, + 'decay': self.decay, + 'every_n_steps': self.every_n_steps, + } + return state_dict + + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict['opt']) + self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + self.current_step = state_dict['current_step'] + self.decay = state_dict['decay'] + self.every_n_steps = state_dict['every_n_steps'] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py new file mode 100755 index 0000000..f998f1a --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/functions.py @@ -0,0 +1,459 @@ +import os +import cv2 +import json +import torch +import glob +import re +import numpy as np +from tqdm import tqdm +from pdf2image import convert_from_path +from dicttoxml import dicttoxml +from word_preprocess import vat_standardizer, get_string, ap_standardizer +from kvu_dictionary import vat_dictionary, ap_dictionary + + + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Save dir : {}'.format(save_dir)) + +def pdf2image(pdf_dir, save_dir): + pdf_files = glob.glob(f'{pdf_dir}/*.pdf') + print('No. pdf files:', len(pdf_files)) + + for file in tqdm(pdf_files): + pages = convert_from_path(file, 500) + for i, page in enumerate(pages): + page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG') + print('Done!!!') + +def xyxy2xywh(bbox): + return [ + float(bbox[0]), + float(bbox[1]), + float(bbox[2]) - float(bbox[0]), + float(bbox[3]) - float(bbox[1]), + ] + +def write_to_json(file_path, content): + with open(file_path, mode='w', encoding='utf8') as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, 'r') as f: + return json.load(f) + +def read_xml(file_path): + with open(file_path, 'r') as xml_file: + return xml_file.read() + +def write_to_xml(file_path, content): + with open(file_path, mode="w", encoding='utf8') as f: + f.write(content) + +def write_to_xml_from_dict(file_path, content): + xml = dicttoxml(content) + xml = content + xml_decode = xml.decode() + + with open(file_path, mode="w") as f: + f.write(xml_decode) + + +def load_ocr_result(ocr_path): + with open(ocr_path, 'r') as f: + lines = f.read().splitlines() + + preds = [] + for line in lines: + preds.append(line.split('\t')) + return preds + +def post_process_basic_ocr(lwords: list) -> list: + pp_lwords = [] + for word in lwords: + pp_lwords.append(word.replace("✪", " ")) + return pp_lwords + +def read_ocr_result_from_txt(file_path: str): + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + + boxes, words = [], [] + for line in lines: + if line == "": + continue + word_info = line.split("\t") + if len(word_info) == 6: + x1, y1, x2, y2, text, _ = word_info + elif len(word_info) == 5: + x1, y1, x2, y2, text = word_info + + x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2)) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + +def get_colormap(): + return { + 'others': (0, 0, 255), # others: red + 'title': (0, 255, 255), # title: yellow + 'key': (255, 0, 0), # key: blue + 'value': (0, 255, 0), # value: green + 'header': (233, 197, 15), # header + 'group': (0, 128, 128), # group + 'relation': (0, 0, 255)# (128, 128, 128), # relation + } + +def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1): + exif = image._getexif() + orientation = None + if exif is not None: + orientation = exif.get(0x0112) + + # Convert the PIL image to OpenCV format + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + # Rotate the image in OpenCV if necessary + if orientation == 3: + image = cv2.rotate(image, cv2.ROTATE_180) + elif orientation == 6: + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 8: + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + else: + image = np.asarray(image) + + if len(image.shape) == 2: + image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + assert len(image.shape) == 3 + + if orientation is not None and orientation == 6: + width, height, _ = image.shape + else: + height, width, _ = image.shape + if len(pr_class_words) > 0: + id2label = {k: labels[k] for k in range(len(labels))} + for lb, groups in enumerate(pr_class_words): + if lb == 0: + continue + for group_id, group in enumerate(groups): + for i, word_id in enumerate(group): + x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000) + cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness) + + if i == 0: + x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2) + else: + x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2) + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness) + x_center0, y_center0 = x_center1, y_center1 + + if len(pr_relations) > 0: + for pair in pr_relations: + xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000) + xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000) + + x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2) + x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2) + + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness) + + return image + + +def get_pairs(json: list, rel_from: str, rel_to: str) -> dict: + outputs = {} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] in (rel_from, rel_to): + is_rel[element['class']]['status'] = 1 + is_rel[element['class']]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']] + return outputs + +def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict: + list_keys = list(header_key_pairs.keys()) + relations = {k: [] for k in list_keys} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] == rel_from and element['group_id'] in list_keys: + is_rel[rel_from]['status'] = 1 + is_rel[rel_from]['value'] = element + if element['class'] == rel_to: + is_rel[rel_to]['status'] = 1 + is_rel[rel_to]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id']) + return relations + +def get_key2values_relations(key_value_pairs: dict): + triple_linkings = {} + for value_group_id, key_value_pair in key_value_pairs.items(): + key_group_id = key_value_pair[0] + if key_group_id not in list(triple_linkings.keys()): + triple_linkings[key_group_id] = [] + triple_linkings[key_group_id].append(value_group_id) + return triple_linkings + + +def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict: + word_groups = {} + id2class = {i: labels[i] for i in range(len(labels))} + for class_id, lwgroups_in_class in enumerate(class_words): + for ltokens_in_wgroup in lwgroups_in_class: + group_id = ltokens_in_wgroup[0] + ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup] + text_string = get_string(ltokens_to_ltexts) + word_groups[group_id] = { + 'group_id': group_id, + 'text': text_string, + 'class': id2class[class_id], + 'tokens': ltokens_in_wgroup + } + return word_groups + +def verify_linking_id(word_groups: dict, linking_id: int) -> int: + if linking_id not in list(word_groups): + for wg_id, _word_group in word_groups.items(): + if linking_id in _word_group['tokens']: + return wg_id + return linking_id + +def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list: + outputs = [] + for pair in lrelations: + wg_from = verify_linking_id(word_groups, pair[0]) + wg_to = verify_linking_id(word_groups, pair[1]) + try: + outputs.append([word_groups[wg_from], word_groups[wg_to]]) + except: + print('Not valid pair:', wg_from, wg_to) + return outputs + + +def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + word_groups = merged_token_to_wordgroup(class_words, lwords, labels) + linking_pairs = matched_wordgroup_relations(word_groups, lrelations) + + header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]} + header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]} + key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]} + + # table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + + triplet_pairs = [] + single_pairs = [] + table = [] + # print('key2values_relations', key2values_relations) + for key_group_id, list_value_group_ids in key2values_relations.items(): + if len(list_value_group_ids) == 0: continue + elif len(list_value_group_ids) == 1: + value_group_id = list_value_group_ids[0] + single_pairs.append({word_groups[key_group_id]['text']: { + 'text': word_groups[value_group_id]['text'], + 'id': value_group_id, + 'class': "value" + }}) + else: + item = [] + for value_group_id in list_value_group_ids: + if value_group_id not in header_value.keys(): + header_name_for_value = "non-header" + else: + header_group_id = header_value[value_group_id][0] + header_name_for_value = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[value_group_id]['text'], + 'header': header_name_for_value, + 'id': value_group_id, + 'class': 'value' + }) + if key_group_id not in list(header_key.keys()): + triplet_pairs.append({ + word_groups[key_group_id]['text']: item + }) + else: + header_group_id = header_key[key_group_id][0] + header_name_for_key = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[key_group_id]['text'], + 'header': header_name_for_key, + 'id': key_group_id, + 'class': 'key' + }) + table.append({key_group_id: item}) + + if len(table) > 0: + table = sorted(table, key=lambda x: list(x.keys())[0]) + table = [v for item in table for k, v in item.items()] + + outputs = {} + outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id']))) + outputs['triplet'] = triplet_pairs + outputs['table'] = table + + file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path)) + write_to_json(file_path, outputs) + return outputs + + +def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + vat_outputs = {} + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + + # List of items in table + table = [] + for single_item in outputs['table']: + item = {k: [] for k in list(vat_dictionary(header=True).keys())} + for cell in single_item: + header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True) + if header_name in list(item.keys()): + # item[header_name] = value['text'] + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + table.append(item) + + + # VAT Information + single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())} + for pair in outputs['single']: + for key_name, value in pair.items(): + key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False) + # print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + # print('='*10, file_path) + # print(vat_info) + # Combine VAT information and table + vat_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if key_name in ("Ngày, tháng, năm lập hóa đơn"): + if len(list_potential_value) == 1: + vat_outputs[key_name] = list_potential_value[0]['content'] + else: + date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'} + for value in list_potential_value: + date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content']) + vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}" + else: + if len(list_potential_value) == 0: continue + if key_name in ("Mã số thuế người bán"): + selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code + vat_outputs[key_name] = selected_value['content'].replace(' ', '') + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + vat_outputs[key_name] = selected_value['content'] + + vat_outputs['table'] = table + + write_to_json(file_path, vat_outputs) + + +def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + # List of items in table + table = [] + for single_item in outputs['table']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + for cell in single_item: + header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True) + # print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}") + if header_name in list(item.keys()): + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + table.append(item) + + triplet_pairs = [] + for single_item in outputs['triplet']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + is_item_valid = 0 + for key_name, list_value in single_item.items(): + for value in list_value: + if value['header'] == "non-header": + continue + header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True) + if header_name in list(item.keys()): + is_item_valid = 1 + item[header_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + if is_item_valid == 1: + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item['productname'] = key_name + # triplet_pairs.append({key_name: new_item}) + triplet_pairs.append(item) + + single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())} + for pair in outputs['single']: + for key_name, value in pair.items(): + key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False) + # print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + ap_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if len(list_potential_value) == 0: continue + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + ap_outputs[key_name] = selected_value['content'] + + table = table + triplet_pairs + ap_outputs['table'] = table + # ap_outputs['triplet'] = triplet_pairs + + write_to_json(file_path, ap_outputs) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/kvu_dictionary.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/kvu_dictionary.py new file mode 100755 index 0000000..1248aa0 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/kvu_dictionary.py @@ -0,0 +1,138 @@ + +DKVU2XML = { + "Ký hiệu mẫu hóa đơn": "form_no", + "Ký hiệu hóa đơn": "serial_no", + "Số hóa đơn": "invoice_no", + "Ngày, tháng, năm lập hóa đơn": "issue_date", + "Tên người bán": "seller_name", + "Mã số thuế người bán": "seller_tax_code", + "Thuế suất": "tax_rate", + "Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount", + "Mặt hàng": "item", + "Đơn vị tính": "unit", + "Số lượng": "quantity", + "Đơn giá": "unit_price", + "Doanh số mua chưa có thuế": "amount" +} + + +def ap_dictionary(header: bool): + header_dictionary = { + 'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'], + 'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'], + 'qty': ['quantity', 'invoicequantity'] + } + + key_dictionary = { + 'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'], + 'retailername': ['retailer', 'retailername', 'ownedoperatedby'], + 'serial_number': ['serialnumber', 'serialno'], + 'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2'] + } + + return header_dictionary if header else key_dictionary + + +def vat_dictionary(header: bool): + header_dictionary = { + 'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'], + 'Đơn vị tính': ['dvt', 'donvitinh'], + 'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'], + 'Đơn giá': ['dongia'], + 'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'], + # 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber'] + } + + key_dictionary = { + 'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'], + 'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'], + 'Số hóa đơn': ['soinvoiceno', 'invoiceno'], + 'Ngày, tháng, năm lập hóa đơn': [], + 'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'], + 'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'], + 'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'], + 'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'], + # 'Ghi chú': [], + # 'Ngày': ['ngayday', 'ngay', 'day'], + # 'Tháng': ['thangmonth', 'thang', 'month'], + # 'Năm': ['namyear', 'nam', 'year'] + } + + # exact_dictionary = { + # 'Số hóa đơn': ['sono', 'so'], + # 'Mã số thuế người bán': ['mst'], + # 'Tên người bán': ['kyboi'], + # 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate'] + # } + + return header_dictionary if header else key_dictionary + +def manulife_dictionary(type: str): + key_dict = { + "Document type": ["documenttype", "loaichungtu"], + "Document name": ["documentname", "tenchungtu"], + "Patient Name": ["patientname", "tenbenhnhan"], + "Date of Birth/Year of birth": [ + "dateofbirth", + "yearofbirth", + "ngaythangnamsinh", + "namsinh", + ], + "Age": ["age", "tuoi"], + "Gender": ["gender", "gioitinh"], + "Social insurance card No.": ["socialinsurancecardno", "sothebhyt"], + "Medical service provider name": ["medicalserviceprovidername", "tencosoyte"], + "Department name": ["departmentname", "tenkhoadieutri"], + "Diagnosis description": ["diagnosisdescription", "motachandoan"], + "Diagnosis code": ["diagnosiscode", "machandoan"], + "Admission date": ["admissiondate", "ngaynhapvien"], + "Discharge date": ["dischargedate", "ngayxuatvien"], + "Treatment method": ["treatmentmethod", "phuongphapdieutri"], + "Treatment date": ["treatmentdate", "ngaydieutri", "ngaykham"], + "Follow up treatment date": ["followuptreatmentdate", "ngaytaikham"], + # "Name of precribed medicine": [], + # "Quantity of prescribed medicine": [], + # "Dosage for each medicine": [] + "Medical expense": ["Medical expense", "chiphiyte"], + "Invoice No.": ["invoiceno", "sohoadon"], + } + + title_dict = { + "Chứng từ y tế": [ + "giayravien", + "giaychungnhanphauthuat", + "cachthucphauthuat", + "phauthuat", + "tomtathosobenhan", + "donthuoc", + "toathuoc", + "donbosung", + "ketquaconghuongtu" + "ketqua", + "phieuchidinh", + "phieudangkykham", + "giayhenkhamlai", + "phieukhambenh", + "phieukhambenhvaovien", + "phieuxetnghiem", + "phieuketquaxetnghiem", + "phieuchidinhxetnghiem", + "ketquasieuam", + "phieuchidinhxetnghiem" + ], + "Chứng từ thanh toán": [ + "hoadon", + "hoadongiatrigiatang", + "hoadongiatrigiatangchuyendoituhoadondientu", + "bangkechiphibaohiem", + "bienlaithutien", + "bangkechiphidieutrinoitru" + ], + } + + if type == "key": + return key_dict + elif type == "title": + return title_dict + else: + raise ValueError(f"[ERROR] Dictionary type of {type} is not supported") \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/run_ocr.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/run_ocr.py new file mode 100755 index 0000000..6190b0a --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/run_ocr.py @@ -0,0 +1,30 @@ +import numpy as np +from pathlib import Path +from typing import Union, Tuple, List +import sys, os +cur_dir = os.path.dirname(__file__) +sys.path.append(os.path.join(os.path.dirname(cur_dir), "ocr-engine")) +from src.ocr import OcrEngine + + +def load_ocr_engine() -> OcrEngine: + print("[INFO] Loading engine...") + engine = OcrEngine() + print("[INFO] Engine loaded") + return engine + +def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None: + save_dir_or_path = Path(save_dir_or_path) + if isinstance(img, np.ndarray): + if save_dir_or_path.is_dir(): + raise ValueError("numpy array input require a save path, not a save dir") + page = engine(img) + save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt") + ) if save_dir_or_path.is_dir() else str(save_dir_or_path) + page.write_to_file('word', save_path) + if export_img: + page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, ) + +def read_img(img: Union[str, np.ndarray], engine: OcrEngine): + page = engine(img) + return ' '.join([f.text for f in page.llines]) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/utils.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/utils.py new file mode 100755 index 0000000..8bd4062 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/utils/utils.py @@ -0,0 +1,808 @@ +import os +import cv2 +import json +import random +import glob +import re +import numpy as np +from tqdm import tqdm +from pdf2image import convert_from_path +from dicttoxml import dicttoxml +from word_preprocess import ( + vat_standardizer, + ap_standardizer, + get_string_with_word2line, + split_key_value_by_colon, + normalize_kvu_output, + normalize_kvu_output_for_manulife, + manulife_standardizer +) +from utils.kvu_dictionary import ( + vat_dictionary, + ap_dictionary, + manulife_dictionary +) + + + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + # else: + # print("DIR already existed.") + # print('Save dir : {}'.format(save_dir)) + +def convert_pdf2img(pdf_dir, save_dir): + pdf_files = glob.glob(f'{pdf_dir}/*.pdf') + print('No. pdf files:', len(pdf_files)) + print(pdf_files) + + for file in tqdm(pdf_files): + pdf2img(file, save_dir, n_pages=-1, return_fname=False) + # pages = convert_from_path(file, 500) + # for i, page in enumerate(pages): + # page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG') + print('Done!!!') + +def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False): + file_names = [] + pages = convert_from_path(pdf_path) + if n_pages != -1: + pages = pages[:n_pages] + for i, page in enumerate(pages): + _save_path = os.path.join(save_dir, os.path.basename(pdf_path).replace('.pdf', f'_{i}.jpg')) + page.save(_save_path, 'JPEG') + file_names.append(_save_path) + if return_fname: + return file_names + +def xyxy2xywh(bbox): + return [ + float(bbox[0]), + float(bbox[1]), + float(bbox[2]) - float(bbox[0]), + float(bbox[3]) - float(bbox[1]), + ] + +def write_to_json(file_path, content): + with open(file_path, mode='w', encoding='utf8') as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, 'r') as f: + return json.load(f) + +def read_xml(file_path): + with open(file_path, 'r') as xml_file: + return xml_file.read() + +def write_to_xml(file_path, content): + with open(file_path, mode="w", encoding='utf8') as f: + f.write(content) + +def write_to_xml_from_dict(file_path, content): + xml = dicttoxml(content) + xml = content + xml_decode = xml.decode() + + with open(file_path, mode="w") as f: + f.write(xml_decode) + + +def load_ocr_result(ocr_path): + with open(ocr_path, 'r') as f: + lines = f.read().splitlines() + + preds = [] + for line in lines: + preds.append(line.split('\t')) + return preds + +def post_process_basic_ocr(lwords: list) -> list: + pp_lwords = [] + for word in lwords: + pp_lwords.append(word.replace("✪", " ")) + return pp_lwords + +def read_ocr_result_from_txt(file_path: str): + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + + boxes, words = [], [] + for line in lines: + if line == "": + continue + word_info = line.split("\t") + if len(word_info) == 6: + x1, y1, x2, y2, text, _ = word_info + elif len(word_info) == 5: + x1, y1, x2, y2, text = word_info + + x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2)) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + +def get_colormap(): + return { + 'others': (0, 0, 255), # others: red + 'title': (0, 255, 255), # title: yellow + 'key': (255, 0, 0), # key: blue + 'value': (0, 255, 0), # value: green + 'header': (233, 197, 15), # header + 'group': (0, 128, 128), # group + 'relation': (0, 0, 255)# (128, 128, 128), # relation + } + +def convert_image(image): + exif = image._getexif() + orientation = None + if exif is not None: + orientation = exif.get(0x0112) + # Convert the PIL image to OpenCV format + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + # Rotate the image in OpenCV if necessary + if orientation == 3: + image = cv2.rotate(image, cv2.ROTATE_180) + elif orientation == 6: + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 8: + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + else: + image = np.asarray(image) + + if len(image.shape) == 2: + image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + assert len(image.shape) == 3 + + return image, orientation + +def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1): + image, orientation = convert_image(image) + + # if orientation is not None and orientation == 6: + # width, height, _ = image.shape + # else: + # height, width, _ = image.shape + + if len(pr_class_words) > 0: + id2label = {k: labels[k] for k in range(len(labels))} + for lb, groups in enumerate(pr_class_words): + if lb == 0: + continue + for group_id, group in enumerate(groups): + for i, word_id in enumerate(group): + # x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000) + # x0, y0, x1, y1 = revert_box(bbox[word_id], width, height) + x0, y0, x1, y1 = bbox[word_id] + cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness) + + if i == 0: + x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2) + else: + x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2) + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness) + x_center0, y_center0 = x_center1, y_center1 + + if len(pr_relations) > 0: + for pair in pr_relations: + # xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000) + # xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000) + # xyxy0 = revert_box(bbox[pair[0]], width, height) + # xyxy1 = revert_box(bbox[pair[1]], width, height) + + xyxy0 = bbox[pair[0]] + xyxy1 = bbox[pair[1]] + + x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2) + x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2) + + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness) + + return image + +def revert_box(box, width, height): + return [ + int((box[0] / 1000) * width), + int((box[1] / 1000) * height), + int((box[2] / 1000) * width), + int((box[3] / 1000) * height) + ] + + +def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list: + points = [lbbox[i] for i in lword_ids] + x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1] + x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3] + return [x_min, y_min, x_max, y_max] + + +def get_pairs(json: list, rel_from: str, rel_to: str) -> dict: + outputs = {} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] in (rel_from, rel_to): + is_rel[element['class']]['status'] = 1 + is_rel[element['class']]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']] + return outputs + +def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict: + list_keys = list(header_key_pairs.keys()) + relations = {k: [] for k in list_keys} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] == rel_from and element['group_id'] in list_keys: + is_rel[rel_from]['status'] = 1 + is_rel[rel_from]['value'] = element + if element['class'] == rel_to: + is_rel[rel_to]['status'] = 1 + is_rel[rel_to]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id']) + return relations + +def get_key2values_relations(key_value_pairs: dict): + triple_linkings = {} + for value_group_id, key_value_pair in key_value_pairs.items(): + key_group_id = key_value_pair[0] + if key_group_id not in list(triple_linkings.keys()): + triple_linkings[key_group_id] = [] + triple_linkings[key_group_id].append(value_group_id) + return triple_linkings + + +def merged_token_to_wordgroup(class_words: list, lwords: list, lbboxes: list, labels: list) -> dict: + word_groups = {} + id2class = {i: labels[i] for i in range(len(labels))} + for class_id, lwgroups_in_class in enumerate(class_words): + for ltokens_in_wgroup in lwgroups_in_class: + group_id = ltokens_in_wgroup[0] + ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup] + ltokens_to_lbboxes = [lbboxes[token] for token in ltokens_in_wgroup] + # text_string = get_string(ltokens_to_ltexts) + # text_string= get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes) + text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes) + group_bbox = get_wordgroup_bbox(lbboxes, ltokens_in_wgroup) + word_groups[group_id] = { + 'group_id': group_id, + 'text': text_string, + 'class': id2class[class_id], + 'tokens': ltokens_in_wgroup, + 'bbox': group_bbox + } + return word_groups + +def verify_linking_id(word_groups: dict, linking_id: int) -> int: + if linking_id not in list(word_groups): + for wg_id, _word_group in word_groups.items(): + if linking_id in _word_group['tokens']: + return wg_id + return linking_id + +def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list: + outputs = [] + for pair in lrelations: + wg_from = verify_linking_id(word_groups, pair[0]) + wg_to = verify_linking_id(word_groups, pair[1]) + try: + outputs.append([word_groups[wg_from], word_groups[wg_to]]) + except: + print('Not valid pair:', wg_from, wg_to) + return outputs + +def get_single_entity(word_groups: dict, lrelations: list) -> list: + single_entity = {'title': [], 'key': [], 'value': [], 'header': []} + list_linked_ids = [] + for pair in lrelations: + list_linked_ids.extend(pair) + + for word_group_id, word_group in word_groups.items(): + if word_group_id not in list_linked_ids: + single_entity[word_group['class']].append(word_group) + return single_entity + + +def export_kvu_outputs(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels) + linking_pairs = matched_wordgroup_relations(word_groups, lrelations) + + header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]} + header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]} + key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]} + single_entity = get_single_entity(word_groups, lrelations) + # table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + triplet_pairs = [] + single_pairs = [] + table = [] + # print('key2values_relations', key2values_relations) + for key_group_id, list_value_group_ids in key2values_relations.items(): + if len(list_value_group_ids) == 0: continue + elif (len(list_value_group_ids) == 1) and (list_value_group_ids[0] not in list(header_value.keys())) and (key_group_id not in list(header_key.keys())): + value_group_id = list_value_group_ids[0] + + single_pairs.append({word_groups[key_group_id]['text']: { + 'text': word_groups[value_group_id]['text'], + 'id': value_group_id, + 'class': "value", + 'bbox': word_groups[value_group_id]['bbox'], + 'key_bbox': word_groups[key_group_id]['bbox'] + }}) + else: + item = [] + for value_group_id in list_value_group_ids: + if value_group_id not in header_value.keys(): + header_group_id = -1 # temp + header_name_for_value = "non-header" + else: + header_group_id = header_value[value_group_id][0] + header_name_for_value = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[value_group_id]['text'], + 'header': header_name_for_value, + 'id': value_group_id, + "key_id": key_group_id, + "header_id": header_group_id, + 'class': 'value', + 'bbox': word_groups[value_group_id]['bbox'], + 'key_bbox': word_groups[key_group_id]['bbox'], + 'header_bbox': word_groups[header_group_id]['bbox'] if header_group_id != -1 else [0, 0, 0, 0], + }) + if key_group_id not in list(header_key.keys()): + triplet_pairs.append({ + word_groups[key_group_id]['text']: item + }) + else: + header_group_id = header_key[key_group_id][0] + header_name_for_key = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[key_group_id]['text'], + 'header': header_name_for_key, + 'id': key_group_id, + "key_id": key_group_id, + "header_id": header_group_id, + 'class': 'key', + 'bbox': word_groups[value_group_id]['bbox'], + 'key_bbox': word_groups[key_group_id]['bbox'], + 'header_bbox': word_groups[header_group_id]['bbox'], + }) + table.append({key_group_id: item}) + + + single_entity_dict = {} + for class_name, single_items in single_entity.items(): + single_entity_dict[class_name] = [] + for single_item in single_items: + single_entity_dict[class_name].append({ + 'text': single_item['text'], + 'id': single_item['group_id'], + 'class': class_name, + 'bbox': single_item['bbox'] + }) + + + + if len(table) > 0: + table = sorted(table, key=lambda x: list(x.keys())[0]) + table = [v for item in table for k, v in item.items()] + + + outputs = {} + outputs['title'] = single_entity_dict['title'] + outputs['key'] = single_entity_dict['key'] + outputs['value'] = single_entity_dict['value'] + outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id']))) + outputs['triplet'] = triplet_pairs + outputs['table'] = table + + + create_dir(os.path.join(os.path.dirname(file_path), 'kvu_results')) + file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path)) + write_to_json(file_path, outputs) + return outputs + +def export_kvu_for_all(file_path, lwords, lbboxes, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']) -> dict: + raw_outputs = export_kvu_outputs( + file_path, lwords, lbboxes, class_words, lrelations, labels + ) + outputs = {} + # Title + outputs["title"] = ( + raw_outputs["title"][0]["text"] if len(raw_outputs["title"]) > 0 else None + ) + + # Pairs of key-value + for pair in raw_outputs["single"]: + for key, values in pair.items(): + # outputs[key] = values["text"] + elements = split_key_value_by_colon(key, values["text"]) + outputs[elements[0]] = elements[1] + + # Only key fields + for key in raw_outputs["key"]: + # outputs[key["text"]] = None + elements = split_key_value_by_colon(key["text"], None) + outputs[elements[0]] = elements[1] + + # Triplet data + for triplet in raw_outputs["triplet"]: + for key, list_value in triplet.items(): + outputs[key] = [value["text"] for value in list_value] + + # Table data + table = [] + header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row} + if header_list: + header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0]))) + print("Header_list:", header_list.keys()) + + for row in raw_outputs["table"]: + item = {header: None for header in list(header_list.keys())} + for cell in row: + item[cell["header"]] = cell["text"] + table.append(item) + outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}] + else: + outputs["tables"] = [] + outputs = normalize_kvu_output(outputs) + # write_to_json(file_path, outputs) + return outputs + +def export_kvu_for_manulife( + file_path, + lwords, + lbboxes, + class_words, + lrelations, + labels=["others", "title", "key", "value", "header"], +) -> dict: + raw_outputs = export_kvu_outputs( + file_path, lwords, lbboxes, class_words, lrelations, labels + ) + outputs = {} + # Title + title_list = [] + for title in raw_outputs["title"]: + is_match, title_name, score, proceessed_text = manulife_standardizer(title["text"], threshold=0.6, type_dict="title") + title_list.append({ + 'documment_type': title_name if is_match else None, + 'content': title['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': title['id'] + }) + + if len(title_list) > 0: + selected_element = max(title_list, key=lambda x: x['lcs_score']) + outputs["title"] = selected_element['content'].upper() + outputs["class_doc"] = selected_element['documment_type'] + + outputs["Loại chứng từ"] = selected_element['documment_type'] + outputs["Tên chứng từ"] = selected_element['content'] + else: + outputs["title"] = None + outputs["class_doc"] = None + outputs["Loại chứng từ"] = None + outputs["Tên chứng từ"] = None + + # Pairs of key-value + for pair in raw_outputs["single"]: + for key, values in pair.items(): + # outputs[key] = values["text"] + elements = split_key_value_by_colon(key, values["text"]) + outputs[elements[0]] = elements[1] + + # Only key fields + for key in raw_outputs["key"]: + # outputs[key["text"]] = None + elements = split_key_value_by_colon(key["text"], None) + outputs[elements[0]] = elements[1] + + # Triplet data + for triplet in raw_outputs["triplet"]: + for key, list_value in triplet.items(): + outputs[key] = [value["text"] for value in list_value] + + # Table data + table = [] + header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row} + if header_list: + header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0]))) + # print("Header_list:", header_list.keys()) + + for row in raw_outputs["table"]: + item = {header: None for header in list(header_list.keys())} + for cell in row: + item[cell["header"]] = cell["text"] + table.append(item) + outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}] + else: + outputs["tables"] = [] + outputs = normalize_kvu_output_for_manulife(outputs) + # write_to_json(file_path, outputs) + return outputs + + +# For FI-VAT project + +def get_vat_table_information(outputs): + table = [] + for single_item in outputs['table']: + headers = [item['header'] for sublist in outputs['table'] for item in sublist if 'header' in item] + item = {k: [] for k in headers} + print(item) + for cell in single_item: + # header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True) + # if header_name in list(item.keys()): + # item[header_name] = value['text'] + item[cell['header']].append({ + 'content': cell['text'], + 'processed_key_name': cell['header'], + 'lcs_score': random.uniform(0.75, 1.0), + 'token_id': cell['id'] + }) + + # for header_name, value in item.items(): + # if len(value) == 0: + # if header_name in ("Số lượng", "Doanh số mua chưa có thuế"): + # item[header_name] = '0' + # else: + # item[header_name] = None + # continue + # item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + # item = post_process_for_item(item) + + # if item["Mặt hàng"] == None: + # continue + table.append(item) + print(table) + return table + +def get_vat_information(outputs): + # VAT Information + single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())} + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'], + }) + + for triplet in outputs['triplet']: + for key, value_list in triplet.items(): + if len(value_list) == 1: + key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value_list[0]['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value_list[0]['id'] + }) + + for pair in value_list: + key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + for table_row in outputs['table']: + for pair in table_row: + key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + return single_pairs + + +def post_process_vat_information(single_pairs): + vat_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if key_name in ("Ngày, tháng, năm lập hóa đơn"): + if len(list_potential_value) == 1: + vat_outputs[key_name] = list_potential_value[0]['content'] + else: + date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'} + for value in list_potential_value: + date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content']) + vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}" + else: + if len(list_potential_value) == 0: continue + if key_name in ("Mã số thuế người bán"): + selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code + # tax_code_raw = selected_value['content'].replace(' ', '') + tax_code_raw = selected_value['content'] + if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated + tax_code_raw = tax_code_raw.split(' ') + tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0] + vat_outputs[key_name] = tax_code_raw.replace(' ', '') + + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + vat_outputs[key_name] = selected_value['content'] + return vat_outputs + + +def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + vat_outputs = {} + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + + # List of items in table + table = get_vat_table_information(outputs) + # table = outputs["table"] + + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + vat_outputs[raw_key_name] = value['text'] + + # VAT Information + # single_pairs = get_vat_information(outputs) + # vat_outputs = post_process_vat_information(single_pairs) + + # Combine VAT information and table + vat_outputs['table'] = table + + write_to_json(file_path, vat_outputs) + print(vat_outputs) + return vat_outputs + + +# For SBT project + +def get_ap_table_information(outputs): + table = [] + for single_item in outputs['table']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + for cell in single_item: + header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True) + # print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}") + if header_name in list(item.keys()): + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + table.append(item) + return table + +def get_ap_triplet_information(outputs): + triplet_pairs = [] + for single_item in outputs['triplet']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + is_item_valid = 0 + for key_name, list_value in single_item.items(): + for value in list_value: + if value['header'] == "non-header": + continue + header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True) + if header_name in list(item.keys()): + is_item_valid = 1 + item[header_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + if is_item_valid == 1: + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item['productname'] = key_name + # triplet_pairs.append({key_name: new_item}) + triplet_pairs.append(item) + return triplet_pairs + + +def get_ap_information(outputs): + single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())} + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + key_name, score, proceessed_text = ap_standardizer(raw_key_name, threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + ## Get single_pair if it in a table (Product Information) + is_product_info = False + for table_row in outputs['table']: + pair = {"key": None, 'value': None} + for cell in table_row: + _, _, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=False) + if any(txt in proceessed_text for txt in ['product', 'information', 'productinformation']): + is_product_info = True + if cell['class'] in pair: + pair[cell['class']] = cell + + if all(v is not None for k, v in pair.items()) and is_product_info == True: + key_name, score, proceessed_text = ap_standardizer(pair['key']['text'], threshold=0.8, header=False) + # print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': pair['value']['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['value']['id'] + }) + ## end_block + + ap_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if len(list_potential_value) == 0: continue + if key_name == "imei_number": + # print('list_potential_value', list_potential_value) + # ap_outputs[key_name] = [v['content'] for v in list_potential_value if v['content'].replace(' ', '').isdigit() and len(v['content'].replace(' ', '')) > 5] + ap_outputs[key_name] = [] + for v in list_potential_value: + imei = v['content'].replace(' ', '') + if imei.isdigit() and len(imei) > 5: # imei is number and have more 5 digits + ap_outputs[key_name].append(imei) + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + ap_outputs[key_name] = selected_value['content'] + + return ap_outputs + +def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + # List of items in table + table = get_ap_table_information(outputs) + triplet_pairs = get_ap_triplet_information(outputs) + table = table + triplet_pairs + + ap_outputs = get_ap_information(outputs) + + ap_outputs['table'] = table + # ap_outputs['triplet'] = triplet_pairs + + write_to_json(file_path, ap_outputs) + + return ap_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word2line.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word2line.py new file mode 100755 index 0000000..d8380ef --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word2line.py @@ -0,0 +1,226 @@ +class Word(): + def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""): + self.type = "word" + self.text =text + self.image = image + self.conf_detect = conf_detect + self.conf_cls = conf_cls + self.boundingbox = bndbox # [left, top,right,bot] coordinate of top-left and bottom-right point + self.word_id = 0 # id of word + self.word_group_id = 0 # id of word_group which instance belongs to + self.line_id = 0 #id of line which instance belongs to + self.paragraph_id = 0 #id of line which instance belongs to + self.kie_label = kie_label + def invalid_size(self): + return (self.boundingbox[2] - self.boundingbox[0]) * (self.boundingbox[3] - self.boundingbox[1]) > 0 + def is_special_word(self): + left, top, right, bottom = self.boundingbox + width, height = right - left, bottom - top + text = self.text + + if text is None: + return True + + # if len(text) > 7: + # return True + if len(text) >= 7: + no_digits = sum(c.isdigit() for c in text) + return no_digits / len(text) >= 0.3 + + return False + +class Word_group(): + def __init__(self): + self.type = "word_group" + self.list_words = [] # dict of word instances + self.word_group_id = 0 # word group id + self.line_id = 0 #id of line which instance belongs to + self.paragraph_id = 0# id of paragraph which instance belongs to + self.text ="" + self.boundingbox = [-1,-1,-1,-1] + self.kie_label ="" + def add_word(self, word:Word): #add a word instance to the word_group + if word.text != "✪": + for w in self.list_words: + if word.word_id == w.word_id: + print("Word id collision") + return False + word.word_group_id = self.word_group_id # + word.line_id = self.line_id + word.paragraph_id = self.paragraph_id + self.list_words.append(word) + self.text += ' '+ word.text + if self.boundingbox == [-1,-1,-1,-1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3])] + return True + else: + return False + + def update_word_group_id(self, new_word_group_id): + self.word_group_id = new_word_group_id + for i in range(len(self.list_words)): + self.list_words[i].word_group_id = new_word_group_id + + def update_kie_label(self): + list_kie_label = [word.kie_label for word in self.list_words] + dict_kie = dict() + for label in list_kie_label: + if label not in dict_kie: + dict_kie[label]=1 + else: + dict_kie[label]+=1 + total = len(list(dict_kie.values())) + max_value = max(list(dict_kie.values())) + list_keys = list(dict_kie.keys()) + list_values = list(dict_kie.values()) + self.kie_label = list_keys[list_values.index(max_value)] + +class Line(): + def __init__(self): + self.type = "line" + self.list_word_groups = [] # list of Word_group instances in the line + self.line_id = 0 #id of line in the paragraph + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = "" + self.boundingbox = [-1,-1,-1,-1] + def add_group(self, word_group:Word_group): # add a word_group instance + if word_group.list_words is not None: + for wg in self.list_word_groups: + if word_group.word_group_id == wg.word_group_id: + print("Word_group id collision") + return False + + self.list_word_groups.append(word_group) + self.text += word_group.text + word_group.paragraph_id = self.paragraph_id + word_group.line_id = self.line_id + + for i in range(len(word_group.list_words)): + word_group.list_words[i].paragraph_id = self.paragraph_id #set paragraph_id for word + word_group.list_words[i].line_id = self.line_id #set line_id for word + return True + return False + def update_line_id(self, new_line_id): + self.line_id = new_line_id + for i in range(len(self.list_word_groups)): + self.list_word_groups[i].line_id = new_line_id + for j in range(len(self.list_word_groups[i].list_words)): + self.list_word_groups[i].list_words[j].line_id = new_line_id + + + def merge_word(self, word): # word can be a Word instance or a Word_group instance + if word.text != "✪": + if self.boundingbox == [-1,-1,-1,-1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3])] + self.list_word_groups.append(word) + self.text += ' ' + word.text + return True + return False + + + def in_same_line(self, input_line, thresh=0.7): + # calculate iou in vertical direction + left1, top1, right1, bottom1 = self.boundingbox + left2, top2, right2, bottom2 = input_line.boundingbox + + sorted_vals = sorted([top1, bottom1, top2, bottom2]) + intersection = sorted_vals[2] - sorted_vals[1] + union = sorted_vals[3]-sorted_vals[0] + min_height = min(bottom1-top1, bottom2-top2) + if min_height==0: + return False + ratio = intersection / min_height + height1, height2 = top1-bottom1, top2-bottom2 + ratio_height = float(max(height1, height2))/float(min(height1, height2)) + # height_diff = (float(top1-bottom1))/(float(top2-bottom2)) + + + if (top1 in range(top2, bottom2) or top2 in range(top1, bottom1)) and ratio >= thresh and (ratio_height<2): + return True + return False + +def check_iomin(word:Word, word_group:Word_group): + min_height = min(word.boundingbox[3]-word.boundingbox[1],word_group.boundingbox[3]-word_group.boundingbox[1]) + intersect = min(word.boundingbox[3],word_group.boundingbox[3]) - max(word.boundingbox[1],word_group.boundingbox[1]) + if intersect/min_height > 0.7: + return True + return False + +def words_to_lines(words, check_special_lines=True): #words is list of Word instance + #sort word by top + words.sort(key = lambda x: (x.boundingbox[1], x.boundingbox[0])) + number_of_word = len(words) + #sort list words to list lines, which have not contained word_group yet + lines = [] + for i, word in enumerate(words): + if word.invalid_size()==0: + continue + new_line = True + for i in range(len(lines)): + if lines[i].in_same_line(word): #check if word is in the same line with lines[i] + lines[i].merge_word(word) + new_line = False + + if new_line ==True: + new_line = Line() + new_line.merge_word(word) + lines.append(new_line) + + # print(len(lines)) + #sort line from top to bottom according top coordinate + lines.sort(key = lambda x: x.boundingbox[1]) + + #construct word_groups in each line + line_id = 0 + word_group_id =0 + word_id = 0 + for i in range(len(lines)): + if len(lines[i].list_word_groups)==0: + continue + #left, top ,right, bottom + line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left + # print("line_width",line_width) + lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right + + #update text for lines after sorting + lines[i].text ="" + for word in lines[i].list_word_groups: + lines[i].text += " "+word.text + + list_word_groups=[] + inital_word_group = Word_group() + inital_word_group.word_group_id= word_group_id + word_group_id +=1 + lines[i].list_word_groups[0].word_id=word_id + inital_word_group.add_word(lines[i].list_word_groups[0]) + word_id+=1 + list_word_groups.append(inital_word_group) + for word in lines[i].list_word_groups[1:]: #iterate through each word object in list_word_groups (has not been construted to word_group yet) + check_word_group= True + #set id for each word in each line + word.word_id = word_id + word_id+=1 + if (not list_word_groups[-1].text.endswith(':')) and ((word.boundingbox[0]-list_word_groups[-1].boundingbox[2])/line_width <0.05) and check_iomin(word, list_word_groups[-1]): + list_word_groups[-1].add_word(word) + check_word_group=False + if check_word_group ==True: + new_word_group = Word_group() + new_word_group.word_group_id= word_group_id + word_group_id +=1 + new_word_group.add_word(word) + list_word_groups.append(new_word_group) + lines[i].list_word_groups = list_word_groups + # set id for lines from top to bottom + lines[i].update_line_id(line_id) + line_id +=1 + return lines, number_of_word \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word_preprocess.py b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word_preprocess.py new file mode 100755 index 0000000..20e6c4f --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/AnyKey_Value/word_preprocess.py @@ -0,0 +1,388 @@ +import nltk +import re +import string +import copy +from utils.kvu_dictionary import vat_dictionary, ap_dictionary, manulife_dictionary, DKVU2XML +from word2line import Word, words_to_lines +nltk.download('words') +words = set(nltk.corpus.words.words()) + +s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ' +s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy' + +# def clean_text(text): +# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text) + + +def get_string(lwords: list): + unique_list = [] + for item in lwords: + if item.isdigit() and len(item) == 1: + unique_list.append(item) + elif item not in unique_list: + unique_list.append(item) + return ' '.join(unique_list) + +def remove_english_words(text): + _word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words] + return ' '.join(_word) + +def remove_punctuation(text): + return text.translate(str.maketrans(" ", " ", string.punctuation)) + +def remove_accents(input_str, s0, s1): + s = '' + # print input_str.encode('utf-8') + for c in input_str: + if c in s1: + s += s0[s1.index(c)] + else: + s += c + return s + +def remove_spaces(text): + return text.replace(' ', '') + +def preprocessing(text: str): + # text = remove_english_words(text) if table else text + text = remove_punctuation(text) + text = remove_accents(text, s0, s1) + text = remove_spaces(text) + return text.lower() + + +def vat_standardize_outputs(vat_outputs: dict) -> dict: + outputs = {} + for key, value in vat_outputs.items(): + if key != "table": + outputs[DKVU2XML[key]] = value + else: + list_items = [] + for item in value: + list_items.append({ + DKVU2XML[item_key]: item_value for item_key, item_value in item.items() + }) + outputs['table'] = list_items + return outputs + + + +def vat_standardizer(text: str, threshold: float, header: bool): + dictionary = vat_dictionary(header) + processed_text = preprocessing(text) + + for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]: + if any([processed_text in txt for txt in candidates]): + processed_text = candidates[-1] + return "Ngày, tháng, năm lập hóa đơn", 5, processed_text + + _dictionary = copy.deepcopy(dictionary) + if not header: + exact_dictionary = { + 'Số hóa đơn': ['sono', 'so'], + 'Mã số thuế người bán': ['mst'], + 'Tên người bán': ['kyboi'], + 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate'] + } + for k, v in exact_dictionary.items(): + _dictionary[k] = dictionary[k] + exact_dictionary[k] + + for k, v in dictionary.items(): + # if k in ("Ngày, tháng, năm lập hóa đơn"): + # continue + # Prioritize match completely + if k in ('Tên người bán') and processed_text == "kyboi": + return k, 8, processed_text + + if any([processed_text == key for key in _dictionary[k]]): + return k, 10, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + if k in ("Ngày, tháng, năm lập hóa đơn"): + continue + + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score > threshold else text, score, processed_text + +def ap_standardizer(text: str, threshold: float, header: bool): + dictionary = ap_dictionary(header) + processed_text = preprocessing(text) + + # Prioritize match completely + _dictionary = copy.deepcopy(dictionary) + if not header: + _dictionary['serial_number'] = dictionary['serial_number'] + ['sn'] + _dictionary['imei_number'] = dictionary['imei_number'] + ['imel', 'imed', 'ime'] # text recog error + else: + _dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei'] + _dictionary['qty'] = dictionary['qty'] + ['qty'] + for k, v in dictionary.items(): + if any([processed_text == key for key in _dictionary[k]]): + return k, 10, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score >= threshold else text, score, processed_text + +def manulife_standardizer(text: str, threshold: float, type_dict: str): + dictionary = manulife_dictionary(type=type_dict) + processed_text = preprocessing(text) + + for key, candidates in dictionary.items(): + + if any([txt == processed_text for txt in candidates]): + return True, key, 5 * (1 + len(processed_text)), processed_text + + if any([txt in processed_text for txt in candidates]): + return True, key, 5, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + if len(v) == 0: + continue + scores[k] = max( + [ + longestCommonSubsequence(processed_text, key) / len(key) + for key in dictionary[k] + ] + ) + key, score = max(scores.items(), key=lambda x: x[1]) + return score > threshold, key if score > threshold else text, score, processed_text + + +def convert_format_number(s: str) -> float: + s = s.replace(' ', '').replace('O', '0').replace('o', '0') + if s.endswith(",00") or s.endswith(".00"): + s = s[:-3] + if all([delimiter in s for delimiter in [',', '.']]): + s = s.replace('.', '').split(',') + remain_value = s[1].split('0')[0] + return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value)) + else: + s = s.replace(',', '').replace('.', '') + return int(s) + + +def post_process_for_item(item: dict) -> dict: + check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế'] + mis_key = [] + for key in check_keys: + if item[key] in (None, '0'): + mis_key.append(key) + if len(mis_key) == 1: + try: + if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0: + item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__() + elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0: + item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__() + elif mis_key[0] == check_keys[2]: + item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__() + except Exception as e: + print("Cannot post process this item with error:", e) + return item + + +def longestCommonSubsequence(text1: str, text2: str) -> int: + # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis + dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)] + for i, c in enumerate(text1): + for j, d in enumerate(text2): + dp[i + 1][j + 1] = 1 + \ + dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j]) + return dp[-1][-1] + + +def longest_common_subsequence_with_idx(X, Y): + """ + This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS. + The longest_common_subsequence function takes two strings as input, and returns a tuple with three values: + the length of the LCS, + the index of the first character of the LCS in the first string, + and the index of the last character of the LCS in the first string. + """ + m, n = len(X), len(Y) + L = [[0 for i in range(n + 1)] for j in range(m + 1)] + + # Following steps build L[m+1][n+1] in bottom up fashion. Note + # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] + right_idx = 0 + max_lcs = 0 + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + if L[i][j] > max_lcs: + max_lcs = L[i][j] + right_idx = i + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + # Create a string variable to store the lcs string + lcs = L[i][j] + # Start from the right-most-bottom-most corner and + # one by one store characters in lcs[] + i = m + j = n + # right_idx = 0 + while i > 0 and j > 0: + # If current character in X[] and Y are same, then + # current character is part of LCS + if X[i - 1] == Y[j - 1]: + + i -= 1 + j -= 1 + # If not same, then find the larger of two and + # go in the direction of larger value + elif L[i - 1][j] > L[i][j - 1]: + # right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs + i -= 1 + else: + j -= 1 + return lcs, i, max(i + lcs, right_idx) + + +def get_string_by_deduplicate_bbox(lwords: list, lbboxes: list): + unique_list = [] + prev_bbox = [-1, -1, -1, -1] + for word, bbox in zip(lwords, lbboxes): + if bbox != prev_bbox: + unique_list.append(word) + prev_bbox = bbox + return ' '.join(unique_list) + +def get_string_with_word2line(lwords: list, lbboxes: list): + list_words = [] + unique_list = [] + list_sorted_words = [] + + prev_bbox = [-1, -1, -1, -1] + for word, bbox in zip(lwords, lbboxes): + if bbox != prev_bbox: + prev_bbox = bbox + list_words.append(Word(image=None, text=word, conf_cls=-1, bndbox=bbox, conf_detect=-1)) + unique_list.append(word) + llines = words_to_lines(list_words)[0] + + for line in llines: + for _word_group in line.list_word_groups: + for _word in _word_group.list_words: + list_sorted_words.append(_word.text) + + string_from_model = ' '.join(unique_list) + string_after_word2line = ' '.join(list_sorted_words) + + if string_from_model != string_after_word2line: + print("[Warning] Word group from model is different with word2line module") + print("Model: ", ' '.join(unique_list)) + print("Word2line: ", ' '.join(list_sorted_words)) + + return string_after_word2line + +def remove_bullet_points_and_punctuation(text): + # Remove bullet points (e.g., • or -) + text = re.sub(r'^\s*[\•\-\*]\s*', '', text, flags=re.MULTILINE) + text = re.sub("^\d+\s*", "", text) + text = text.strip() + # # Remove end-of-sentence punctuation (e.g., ., !, ?) + # text = re.sub(r'[.!?]', '', text) + if len(text) > 0 and text[0] in (',', '.', ':', ';', '?', '!'): + text = text[1:] + if len(text) > 0 and text[-1] in (',', '.', ':', ';', '?', '!'): + text = text[:-1] + return text.strip() + +def split_key_value_by_colon(key: str, value: str) -> list: + text_string = key + " " + value if value is not None else key + elements = text_string.split(':') + if len(elements) > 1: + return elements[0], text_string[len(elements[0]):] + return key, value + + +# def normalize_kvu_output(raw_outputs: dict) -> dict: +# outputs = {} +# for key, values in raw_outputs.items(): +# if key == "table": +# table = [] +# for row in values: +# item = {} +# for k, v in row.items(): +# k = remove_bullet_points_and_punctuation(k) +# if v is not None and len(v) > 0: +# v = remove_bullet_points_and_punctuation(v) +# item[k] = v +# table.append(item) +# outputs[key] = table +# else: +# key = remove_bullet_points_and_punctuation(key) +# if isinstance(values, list): +# values = [remove_bullet_points_and_punctuation(v) for v in values] +# elif values is not None and len(values) > 0: +# values = remove_bullet_points_and_punctuation(values) +# outputs[key] = values +# return outputs +def normalize_kvu_output(raw_outputs: dict) -> dict: + outputs = {} + for key, values in raw_outputs.items(): + if key == "tables" and len(values) > 0: + table_list = [] + for table in values: + headers, data = [], [] + headers = [remove_bullet_points_and_punctuation(header).upper() for header in table['headers']] + for row in table['data']: + item = [] + for k, v in row.items(): + if v is not None and len(v) > 0: + item.append(remove_bullet_points_and_punctuation(v)) + else: + item.append(v) + data.append(item) + table_list.append({"headers": headers, "data": data}) + outputs[key] = table_list + else: + key = remove_bullet_points_and_punctuation(key) + if isinstance(values, list): + values = [remove_bullet_points_and_punctuation(v) for v in values] + elif values is not None and len(values) > 0: + values = remove_bullet_points_and_punctuation(values) + outputs[key] = values + return outputs + +def normalize_kvu_output_for_manulife(raw_outputs: dict) -> dict: + outputs = {} + for key, values in raw_outputs.items(): + if key == "tables" and len(values) > 0: + table_list = [] + for table in values: + headers, data = [], [] + headers = [ + remove_bullet_points_and_punctuation(header).upper() + for header in table["headers"] + ] + for row in table["data"]: + item = [] + for k, v in row.items(): + if v is not None and len(v) > 0: + item.append(remove_bullet_points_and_punctuation(v)) + else: + item.append(v) + data.append(item) + table_list.append({"headers": headers, "data": data}) + outputs[key] = table_list + else: + if key not in ("title", "tables", "class_doc"): + key = remove_bullet_points_and_punctuation(key).capitalize() + if isinstance(values, list): + values = [remove_bullet_points_and_punctuation(v) for v in values] + elif values is not None and len(values) > 0: + values = remove_bullet_points_and_punctuation(values) + outputs[key] = values + return outputs \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/prediction.py b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction.py new file mode 100755 index 0000000..f6d4ad1 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction.py @@ -0,0 +1,58 @@ +from sdsvkie import Predictor +import cv2 +import numpy as np +import urllib + +model = Predictor( + cfg = "/ai-core/models/Kie_invoice_ap/config.yaml", + device = "cuda:0", + weights = "/ai-core/models/Kie_invoice_ap/ep21" + ) + +def predict(image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + out = model(img) + output = out["end2end_results"] + output_dict = { + "document_type": "invoice", + "fields": [] + } + for key in output.keys(): + field = { + "label": key if key != "id" else "Receipt Number", + "value": output[key]['value'] if output[key]['value'] else "", + "box": output[key]['box'], + "confidence": output[key]['conf'] + } + output_dict['fields'].append(field) + print(output_dict) + return output_dict + +if __name__ == "__main__": + image_url = "/mnt/ssd1T/hoanglv/Projects/KIE/sdsvkie/demos/2022_07_25 farewell lunch.jpg" + output = predict(image_url) + print(output) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_fi.py b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_fi.py new file mode 100755 index 0000000..57981f4 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_fi.py @@ -0,0 +1,77 @@ +import os, sys +cur_dir = os.path.dirname(__file__) +KIE_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvkie") +TD_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtd") +TR_PATH = os.path.join(os.path.dirname(cur_dir), "sdsvtr") +sys.path.append(KIE_PATH) +sys.path.append(TD_PATH) +sys.path.append(TR_PATH) + +from sdsvkie import Predictor +from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU +import cv2 +import numpy as np +import urllib + +model = Predictor( + cfg = "/models/Kie_invoice_ap/06062023/config.yaml", # TODO: Better be scalable + device = "cuda:0", + weights = "/models/Kie_invoice_ap/06062023/best" # TODO: Better be scalable + ) + +class_names = ['others', 'title', 'key', 'value', 'header'] +save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test") + +predictor, processor = load_engine(exp_dir="/models/Kie_invoice_fi/key_value_understanding-20230627-164536", + class_names=class_names, + mode=3) + +def predict_fi(page_numb, image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + # img = cv2.imread(image_url) + + # Phan cua LeHoang + out = model(img) + output = out["end2end_results"] + output_kie = { + field_name: field_item['value'] for field_name, field_item in output.items() + } + # print("Hoangggggggggggggggggggggggggggggggggggggggggggggg") + # print(output_kie) + + + #Phan cua Tuan + kvu_result, _ = Predictor_KVU(image_url, save_dir, predictor, processor) + # print("TuanNnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn") + # print(kvu_result) + # if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None: + return kvu_result, output_kie + +if __name__ == "__main__": + image_url = "/mnt/hdd2T/dxtan/TannedCung/OCR/workspace/Kie_Invoice_AP/tmp_image/{image_url}.jpg" + output = predict_fi(0, image_url) + print(output) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_sap.py b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_sap.py new file mode 100755 index 0000000..6fffaaa --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/prediction_sap.py @@ -0,0 +1,146 @@ +from sdsvkie import Predictor +import sys, os +cur_dir = os.path.dirname(__file__) +# sys.path.append("cope2n-ai-fi/Kie_Invoice_AP") # Better be relative +from .AnyKey_Value.anyKeyValue import load_engine, Predictor_KVU +import cv2 +import numpy as np +import urllib +import random + +# model = Predictor( +# cfg = "/cope2n-ai-fi/models/Kie_invoice_ap/config.yaml", +# device = "cuda:0", +# weights = "/cope2n-ai-fi/models/Kie_invoice_ap/best" +# ) + +class_names = ['others', 'title', 'key', 'value', 'header'] +save_dir = os.path.join(cur_dir, "AnyKey_Value/visualize/test") + +predictor, processor = load_engine(exp_dir=os.path.join(cur_dir, "AnyKey_Value/experiments/key_value_understanding-20231003-171748"), + class_names=class_names, + mode=3) + +def predict(page_numb, image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + # img = cv2.imread(image_url) + + # Phan cua LeHoang + # out = model(img) + # output = out["end2end_results"] + + + #Phan cua Tuan + kvu_result = Predictor_KVU(image_url, save_dir, predictor, processor) + output_dict = { + "document_type": "invoice", + "fields": [] + } + for key in kvu_result.keys(): + field = { + "label": key, + "value": kvu_result[key], + "box": [0, 0, 0, 0], + "confidence": random.uniform(0.9, 1.0), + "page": page_numb + } + output_dict['fields'].append(field) + print(output_dict) + return output_dict + + # if kvu_result['imei_number'] == None and kvu_result['serial_number'] == None: + # output_dict = { + # "document_type": "invoice", + # "fields": [] + # } + # for key in output.keys(): + # field = { + # "label": key if key != "id" else "Receipt Number", + # "value": output[key]['value'] if output[key]['value'] else "", + # "box": output[key]['box'], + # "confidence": output[key]['conf'], + # "page": page_numb + # } + # output_dict['fields'].append(field) + # table = kvu_result['table'] + # field_table = { + # "label": "table", + # "value": table, + # "box": [0, 0, 0, 0], + # "confidence": 0.98, + # "page": page_numb + # } + # output_dict['fields'].append(field_table) + # return output_dict + + # else: + # output_dict = { + # "document_type": "KSU", + # "fields": [] + # } + # # for key in output.keys(): + # # field = { + # # "label": key if key != "id" else "Receipt Number", + # # "value": output[key]['value'] if output[key]['value'] else "", + # # "box": output[key]['box'], + # # "confidence": output[key]['conf'], + # # "page": page_numb + # # } + # # output_dict['fields'].append(field) + + # # Serial Number + # serial_number = kvu_result['serial_number'] + # field_serial = { + # "label" : "serial_number", + # "value": serial_number, + # "box": [0, 0, 0, 0], + # "confidence": 0.98, + # "page": page_numb + # } + # output_dict['fields'].append(field_serial) + + # # IMEI Number + # imei_number = kvu_result['imei_number'] + # if imei_number == None: + # return output_dict + # if imei_number != None: + # for i in range(len(imei_number)): + # field_imei = { + # "label": "imei_number_{}".format(i+1), + # "value": imei_number[i], + # "box": [0, 0, 0, 0], + # "confidence": 0.98, + # "page": page_numb + # } + # output_dict['fields'].append(field_imei) + + # return output_dict + +if __name__ == "__main__": + image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg" + output = predict(0, image_url) + print(output) \ No newline at end of file diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/tmp.txt b/cope2n-ai-fi/api/Kie_Invoice_AP/tmp.txt new file mode 100755 index 0000000..4a32426 --- /dev/null +++ b/cope2n-ai-fi/api/Kie_Invoice_AP/tmp.txt @@ -0,0 +1,106 @@ +1113 773 1220 825 BEST +1243 759 1378 808 DENKI +1410 752 1487 799 (S) +1430 707 1515 748 TAX +1511 745 1598 790 PTE +1542 700 1725 740 TNVOICE +1618 742 1706 783 LTD +1783 725 1920 773 FUNAN +1943 723 2054 767 MALL +1434 797 1576 843 WORTH +1599 785 1760 831 BRIDGE +1784 778 1846 822 RD +1277 846 1632 897 #02-16/#03-1 +1655 832 1795 877 FUNAN +1817 822 1931 869 MALL +1272 897 1518 956 S(179105) +1548 890 1655 943 TEL: +1686 877 1911 928 69046183 +1247 1011 1334 1068 GST +1358 1006 1447 1059 REG +1360 1063 1449 1115 RCB +1473 1003 1575 1055 NO.: +1474 1059 1555 1110 NO. +1595 1042 1868 1096 198202199E +1607 985 1944 1040 M2-0053813-7 +1056 1134 1254 1194 Opening +1276 1127 1391 1181 Hrs: +1425 1112 1647 1170 10:00:00 +1672 1102 1735 1161 AN +1755 1101 1819 1157 to +1846 1090 2067 1147 10:00:00 +2090 1080 2156 1141 PH +1061 1308 1228 1366 Staff: +1258 1300 1378 1357 3296 +1710 1283 1880 1337 Trans: +1936 1266 2192 1322 262152554 +1060 1372 1201 1429 Date: +1260 1358 1494 1419 22-03-23 +1540 1344 1664 1409 9:05 +1712 1339 1856 1407 Slip: +1917 1328 2196 1387 2000130286 +1124 1487 1439 1545 SALESPERSON +1465 1477 1601 1537 CODE. +1633 1471 1752 1530 6043 +1777 1462 2004 1519 HUHAHHAD +2032 1451 2177 1509 RAZIH +1070 1558 1187 1617 Item +1211 1554 1276 1615 No +1439 1542 1585 1601 Price +1750 1530 1841 1597 Qty +1951 1517 2120 1579 Amount +1076 1683 1276 1741 ANDROID +1304 1673 1477 1733 TABLET +1080 1746 1280 1804 2105976 +1509 1729 1705 1784 SAMSUNG +1734 1719 1931 1776 SH-P613 +1964 1709 2101 1768 128GB +1082 1809 1285 1869 SM-P613 +1316 1802 1454 1860 12838 +1429 1859 1600 1919 518.00 +1481 1794 1596 1855 WIFI +1622 1790 1656 1850 G +1797 1845 1824 1904 1 +1993 1832 2165 1892 518.00 +1088 1935 1347 1995 PROMOTION +1091 2000 1294 2062 2105664 +1520 1983 1717 2039 SAMSUNG +1743 1963 2106 2030 F-Sam-Redeen +1439 2111 1557 2173 0.00 +1806 2095 1832 2156 1 +2053 2081 2174 2144 0.00 +1106 2248 1250 2312 Total +1974 2206 2146 2266 518.00 +1107 2312 1204 2377 UOB +1448 2291 1567 2355 CARD +1978 2268 2147 2327 518.00 +1253 2424 1375 2497 GST% +1456 2411 1655 2475 Net.Amt +1818 2393 1912 2460 GST +2023 2387 2192 2445 Amount +1106 2494 1231 2560 GST8 +1486 2472 1661 2537 479.63 +1770 2458 1916 2523 38.37 +2027 2448 2203 2511 518.00 +1553 2601 1699 2666 THANK +1721 2592 1821 2661 YOU +1436 2678 1616 2749 please +1644 2682 1764 2732 come +1790 2660 1942 2729 again +1191 2862 1391 2931 Those +1426 2870 2018 2945 facebook.com +1565 2809 1690 2884 join +1709 2816 1777 2870 us +1799 2811 1868 2865 on +1838 2946 2024 3003 com .89 +1533 3006 2070 3088 ar.com/askbe +1300 3326 1659 3446 That's +1696 3308 1905 3424 not +1937 3289 2131 3408 all! +1450 3511 1633 3573 SCAN +1392 3589 1489 3645 QR +1509 3577 1698 3635 CODE +1321 3656 1370 3714 & +1517 3638 1768 3699 updates +1643 3882 1769 3932 Scan +1789 3868 1859 3926 Me diff --git a/cope2n-ai-fi/api/Kie_Invoice_AP/tmp_image/{image_url}.jpg b/cope2n-ai-fi/api/Kie_Invoice_AP/tmp_image/{image_url}.jpg new file mode 100755 index 0000000..2fa1bfb Binary files /dev/null and b/cope2n-ai-fi/api/Kie_Invoice_AP/tmp_image/{image_url}.jpg differ diff --git a/cope2n-ai-fi/api/OCRBase/prediction.py b/cope2n-ai-fi/api/OCRBase/prediction.py new file mode 100755 index 0000000..f986686 --- /dev/null +++ b/cope2n-ai-fi/api/OCRBase/prediction.py @@ -0,0 +1,155 @@ +from OCRBase.text_recognition import ocr_predict +import cv2 +from shapely.geometry import Polygon +import urllib +import numpy as np + +def check_percent_overlap_bbox(boxA, boxB): + """check percent box A in boxB + + Args: + boxA (_type_): _description_ + boxB (_type_): _description_ + + Returns: + Float: percent overlap bbox + """ + # determine the (x, y)-coordinates of the intersection rectangle + box_shape_1 = [ + [boxA[0], boxA[1]], + [boxA[2], boxA[1]], + [boxA[2], boxA[3]], + [boxA[0], boxA[3]], + ] + + # Give dimensions of shape 2 + box_shape_2 = [ + [boxB[0], boxB[1]], + [boxB[2], boxB[1]], + [boxB[2], boxB[3]], + [boxB[0], boxB[3]], + ] + # Draw polygon 1 from shape 1 + # dimensions + polygon_1 = Polygon(box_shape_1) + + # Draw polygon 2 from shape 2 + # dimensions + polygon_2 = Polygon(box_shape_2) + + # Calculate the intersection of + # bounding boxes + intersect = polygon_1.intersection(polygon_2).area / polygon_1.area + + return intersect + + +def check_box_in_box(boxA, boxB): + """check boxA in boxB + + Args: + boxA (_type_): _description_ + boxB (_type_): _description_ + + Returns: + Boolean: True if boxA in boxB + """ + if ( + boxA[0] >= boxB[0] + and boxA[1] >= boxB[1] + and boxA[2] <= boxB[2] + and boxA[3] <= boxB[3] + ): + return True + else: + return False + + +def word_to_line_image_origin(list_words, bbox): + """use for predict image with bbox selected + + Args: + list_words (_type_): _description_ + bbox (_type_): _description_ + + Returns: + _type_: _description_ + """ + texts, boundingboxes = [], [] + for line in list_words: + if line.text == "": + continue + else: + # convert to bbox image original + boundingbox = line.boundingbox + boundingbox = list(boundingbox) + boundingbox[0] = boundingbox[0] + bbox[0] + boundingbox[1] = boundingbox[1] + bbox[1] + boundingbox[2] = boundingbox[2] + bbox[0] + boundingbox[3] = boundingbox[3] + bbox[1] + texts.append(line.text) + boundingboxes.append(boundingbox) + return texts, boundingboxes + + +def word_to_line(list_words): + """use for predict full image + + Args: + list_words (_type_): _description_ + """ + texts, boundingboxes = [], [] + for line in list_words: + print(line.text) + if line.text == "": + continue + else: + boundingbox = line.boundingbox + boundingbox = list(boundingbox) + texts.append(line.text) + boundingboxes.append(boundingbox) + return texts, boundingboxes + + +def predict(page_numb, image_url): + """predict text from image + + Args: + image_path (String): path image to predict + list_id (List): List id of bbox selected + list_bbox (List): List bbox selected + + Returns: + Dict: Dict result of prediction + """ + + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + image = cv2.imdecode(arr, -1) + list_lines = ocr_predict(image) + texts, boundingboxes = word_to_line(list_lines) + result = {} + texts_replace = [] + for text in texts: + if "✪" in text: + text = text.replace("✪", " ") + texts_replace.append(text) + else: + texts_replace.append(text) + result["texts"] = texts_replace + result["boundingboxes"] = boundingboxes + + output_dict = { + "document_type": "ocr-base", + "fields": [] + } + field = { + "label": "Text", + "value": result["texts"], + "box": result["boundingboxes"], + "confidence": 0.98, + "page": page_numb + } + output_dict['fields'].append(field) + + return output_dict diff --git a/cope2n-ai-fi/api/OCRBase/text_detection.py b/cope2n-ai-fi/api/OCRBase/text_detection.py new file mode 100755 index 0000000..010f7a8 --- /dev/null +++ b/cope2n-ai-fi/api/OCRBase/text_detection.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from distutils.command.config import config + +from mmdet.apis import inference_detector, init_detector +from PIL import Image +from io import BytesIO + +config = "model/yolox_s_8x8_300e_cocotext_1280.py" +checkpoint = "model/best_bbox_mAP_epoch_294.pth" +device = "cpu" + + +def read_imagefile(file) -> Image.Image: + image = Image.open(BytesIO(file)) + return image + + +def detection_predict(image): + model = init_detector(config, checkpoint, device=device) + # test a single image + result = inference_detector(model, image) + return result diff --git a/cope2n-ai-fi/api/OCRBase/text_recognition.py b/cope2n-ai-fi/api/OCRBase/text_recognition.py new file mode 100755 index 0000000..d431792 --- /dev/null +++ b/cope2n-ai-fi/api/OCRBase/text_recognition.py @@ -0,0 +1,39 @@ +from common.utils.ocr_yolox import OcrEngineForYoloX_Invoice +from common.utils.word_formation import Word, words_to_lines + + +det_ckpt = "/models/sdsvtd/hub/wild_receipt_finetune_weights_c_lite.pth" +cls_ckpt = "satrn-lite-general-pretrain-20230106" + +engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt) + + +def ocr_predict(img): + """Predict text from image + + Args: + image_path (str): _description_ + + Returns: + list: list of words + """ + try: + lbboxes, lwords = engine.run_image(img) + lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)] + list_lines, _ = words_to_lines(lWords) + return list_lines + # return lbboxes, lwords + except AssertionError as e: + print(e) + list_lines = [] + return list_lines + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, required=True) + args = parser.parse_args() + + list_lines = ocr_predict(args.image) diff --git a/cope2n-ai-fi/api/manulife/predict_manulife.py b/cope2n-ai-fi/api/manulife/predict_manulife.py new file mode 100644 index 0000000..6c6800e --- /dev/null +++ b/cope2n-ai-fi/api/manulife/predict_manulife.py @@ -0,0 +1,98 @@ +import cv2 +import urllib +import random +import numpy as np +from pathlib import Path +import sys, os +cur_dir = str(Path(__file__).parents[2]) +sys.path.append(cur_dir) +from modules.sdsvkvu import load_engine, process_img +from modules.ocr_engine import OcrEngine +from configs.manulife import device, ocr_cfg, kvu_cfg + +def load_ocr_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + engine = OcrEngine(**opt) + print("[INFO] Engine loaded") + return engine + + +print("OCR engine configfs: \n", ocr_cfg) +print("KVU configfs: \n", kvu_cfg) + +ocr_engine = load_ocr_engine(ocr_cfg) +kvu_cfg['ocr_engine'] = ocr_engine +option = kvu_cfg['option'] +kvu_cfg.pop("option") # pop option +manulife_engine = load_engine(kvu_cfg) + + +def manulife_predict(image_url, engine) -> None: + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + + save_dir = "./tmp_results" + # image_path = os.path.join(save_dir, f"{image_url}.jpg") + image_path = os.path.join(save_dir, "abc.jpg") + cv2.imwrite(image_path, img) + + outputs = process_img(img_path=image_path, + save_dir=save_dir, + engine=engine, + export_all=False, + option=option) + return outputs + + +def predict(page_numb, image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + kvu_result = manulife_predict(image_url, engine=manulife_engine) + output_dict = { + "document_type": kvu_result['title'] if kvu_result['title'] is not None else "unknown", + "document_class": kvu_result['class_doc'] if kvu_result['class_doc'] is not None else "unknown", + "page_number": page_numb, + "fields": [] + } + for key in kvu_result.keys(): + if key in ("title", "class_doc"): + continue + field = { + "label": key, + "value": kvu_result[key], + "box": [0, 0, 0, 0], + "confidence": random.uniform(0.9, 1.0), + "page": page_numb + } + output_dict['fields'].append(field) + print(output_dict) + return output_dict + + + + +if __name__ == "__main__": + image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg" + output = predict(0, image_url) + print(output) \ No newline at end of file diff --git a/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py b/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py new file mode 100755 index 0000000..13f2b85 --- /dev/null +++ b/cope2n-ai-fi/api/sdsap_sbt/prediction_sbt.py @@ -0,0 +1,94 @@ +import cv2 +import urllib +import random +import numpy as np +from pathlib import Path +import sys, os +cur_dir = str(Path(__file__).parents[2]) +sys.path.append(cur_dir) +from modules.sdsvkvu import load_engine, process_img +from modules.ocr_engine import OcrEngine +from configs.sdsap_sbt import device, ocr_cfg, kvu_cfg + + +def load_ocr_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + engine = OcrEngine(**opt) + print("[INFO] Engine loaded") + return engine + + +print("OCR engine configfs: \n", ocr_cfg) +print("KVU configfs: \n", kvu_cfg) + +ocr_engine = load_ocr_engine(ocr_cfg) +kvu_cfg['ocr_engine'] = ocr_engine +option = kvu_cfg['option'] +kvu_cfg.pop("option") # pop option +sbt_engine = load_engine(kvu_cfg) + + +def sbt_predict(image_url, engine) -> None: + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + + save_dir = "./tmp_results" + # image_path = os.path.join(save_dir, f"{image_url}.jpg") + image_path = os.path.join(save_dir, "abc.jpg") + cv2.imwrite(image_path, img) + + outputs = process_img(img_path=image_path, + save_dir=save_dir, + engine=engine, + export_all=False, + option=option) + return outputs + +def predict(page_numb, image_url): + """ + module predict function + + Args: + image_url (str): image url + + Returns: + example output: + "data": { + "document_type": "invoice", + "fields": [ + { + "label": "Invoice Number", + "value": "INV-12345", + "box": [0, 0, 0, 0], + "confidence": 0.98 + }, + ... + ] + } + dict: output of model + """ + + sbt_result = sbt_predict(image_url, engine=sbt_engine) + output_dict = { + "document_type": "invoice", + "document_class": " ", + "page_number": page_numb, + "fields": [] + } + for key in sbt_result.keys(): + field = { + "label": key, + "value": sbt_result[key], + "box": [0, 0, 0, 0], + "confidence": random.uniform(0.9, 1.0), + "page": page_numb + } + output_dict['fields'].append(field) + return output_dict + + +if __name__ == "__main__": + image_url = "/root/thucpd/20230322144639VUzu_16794962527791962785161104697882.jpg" + output = predict(0, image_url) + print(output) \ No newline at end of file diff --git a/cope2n-ai-fi/celery_worker/__init__.py b/cope2n-ai-fi/celery_worker/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cope2n-ai-fi/celery_worker/client_connector.py b/cope2n-ai-fi/celery_worker/client_connector.py new file mode 100755 index 0000000..49191e7 --- /dev/null +++ b/cope2n-ai-fi/celery_worker/client_connector.py @@ -0,0 +1,81 @@ +from celery import Celery +import base64 +import environ +env = environ.Env( + DEBUG=(bool, True) +) + +class CeleryConnector: + task_routes = { + "process_id_result": {"queue": "id_card_rs"}, + "process_driver_license_result": {"queue": "driver_license_rs"}, + "process_invoice_result": {"queue": "invoice_rs"}, + "process_ocr_with_box_result": {"queue": "ocr_with_box_rs"}, + "process_template_matching_result": {"queue": "template_matching_rs"}, + # mock task + "process_id": {"queue": "id_card"}, + "process_driver_license": {"queue": "driver_license"}, + "process_invoice": {"queue": "invoice"}, + "process_ocr_with_box": {"queue": "ocr_with_box"}, + "process_template_matching": {"queue": "template_matching"}, + } + app = Celery( + "postman", + broker=env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"), + # backend="rpc://", + ) + + def process_id_result(self, args): + return self.send_task("process_id_result", args) + + def process_driver_license_result(self, args): + return self.send_task("process_driver_license_result", args) + + def process_invoice_result(self, args): + return self.send_task("process_invoice_result", args) + + def process_ocr_with_box_result(self, args): + return self.send_task("process_ocr_with_box_result", args) + + def process_template_matching_result(self, args): + return self.send_task("process_template_matching_result", args) + + def process_id(self, args): + return self.send_task("process_id", args) + + def process_driver_license(self, args): + return self.send_task("process_driver_license", args) + + def process_invoice(self, args): + return self.send_task("process_invoice", args) + + def process_ocr_with_box(self, args): + return self.send_task("process_ocr_with_box", args) + + def process_template_matching(self, args): + return self.send_task("process_template_matching", args) + + def send_task(self, name=None, args=None): + if name not in self.task_routes or "queue" not in self.task_routes[name]: + return self.app.send_task(name, args) + + return self.app.send_task(name, args, queue=self.task_routes[name]["queue"]) + + +def main(): + rq_id = 345 + file_names = "abc.jpg" + list_data = [] + + with open("/home/sds/thucpd/aicr-2022/abc.jpg", "rb") as fs: + encoded_string = base64.b64encode(fs.read()).decode("utf-8") + list_data.append(encoded_string) + + c_connector = CeleryConnector() + a = c_connector.process_id(args=(rq_id, list_data, file_names)) + + print(a) + + +if __name__ == "__main__": + main() diff --git a/cope2n-ai-fi/celery_worker/client_connector_fi.py b/cope2n-ai-fi/celery_worker/client_connector_fi.py new file mode 100755 index 0000000..cd5f3ba --- /dev/null +++ b/cope2n-ai-fi/celery_worker/client_connector_fi.py @@ -0,0 +1,56 @@ +from celery import Celery +import environ +env = environ.Env( + DEBUG=(bool, True) +) + +class CeleryConnector: + task_routes = { + 'process_fi_invoice_result': {'queue': 'invoice_fi_rs'}, + 'process_sap_invoice_result': {'queue': 'invoice_sap_rs'}, + 'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'}, + 'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'}, + # mock task + 'process_fi_invoice': {'queue': "invoice_fi"}, + 'process_sap_invoice': {'queue': "invoice_sap"}, + 'process_manulife_invoice': {'queue': "invoice_manulife"}, + 'process_sbt_invoice': {'queue': "invoice_sbt"}, + } + app = Celery( + "postman", + broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"), + ) + + # mock task for FI + def process_fi_invoice_result(self, args): + return self.send_task("process_fi_invoice_result", args) + + def process_fi_invoice(self, args): + return self.send_task("process_fi_invoice", args) + + # mock task for SAP + def process_sap_invoice_result(self, args): + return self.send_task("process_sap_invoice_result", args) + + def process_sap_invoice(self, args): + return self.send_task("process_sap_invoice", args) + + # mock task for manulife + def process_manulife_invoice_result(self, args): + return self.send_task("process_manulife_invoice_result", args) + + def process_manulife_invoice(self, args): + return self.send_task("process_manulife_invoice", args) + + # mock task for manulife + def process_sbt_invoice_result(self, args): + return self.send_task("process_sbt_invoice_result", args) + + def process_sbt_invoice(self, args): + return self.send_task("process_sbt_invoice", args) + + def send_task(self, name=None, args=None): + if name not in self.task_routes or "queue" not in self.task_routes[name]: + return self.app.send_task(name, args) + + return self.app.send_task(name, args, queue=self.task_routes[name]["queue"]) \ No newline at end of file diff --git a/cope2n-ai-fi/celery_worker/mock_process_tasks.py b/cope2n-ai-fi/celery_worker/mock_process_tasks.py new file mode 100755 index 0000000..4aa97cf --- /dev/null +++ b/cope2n-ai-fi/celery_worker/mock_process_tasks.py @@ -0,0 +1,220 @@ +from celery_worker.worker import app +import numpy as np +import cv2 + +@app.task(name="process_id") +def process_id(rq_id, sub_id, folder_name, list_url, user_id): + from common.serve_model import predict + from celery_worker.client_connector import CeleryConnector + + c_connector = CeleryConnector() + try: + result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="id_card") + print(result) + result = { + "status": 200, + "content": result, + "message": "Success", + } + c_connector.process_id_result((rq_id, result)) + return {"rq_id": rq_id} + # if image_croped is not None: + # if result["data"] == []: + # result = { + # "status": 404, + # "content": {}, + # } + # c_connector.process_id_result((rq_id, result, None)) + # return {"rq_id": rq_id} + # else: + # result = { + # "status": 200, + # "content": result, + # "message": "Success", + # } + # c_connector.process_id_result((rq_id, result)) + # return {"rq_id": rq_id} + # elif image_croped is None: + # result = { + # "status": 404, + # "content": {}, + # } + # c_connector.process_id_result((rq_id, result, None)) + # return {"rq_id": rq_id} + + except Exception as e: + print(e) + result = { + "status": 404, + "content": {}, + } + c_connector.process_id_result((rq_id, result, None)) + return {"rq_id": rq_id} + + +@app.task(name="process_driver_license") +def process_driver_license(rq_id, sub_id, folder_name, list_url, user_id): + from common.serve_model import predict + from celery_worker.client_connector import CeleryConnector + + c_connector = CeleryConnector() + try: + result = predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name="driving_license") + result = { + "status": 200, + "content": result, + "message": "Success", + } + c_connector.process_driver_license_result((rq_id, result)) + return {"rq_id": rq_id} + # result, image_croped = predict(str(url), "driving_license") + # if image_croped is not None: + # if result["data"] == []: + # result = { + # "status": 404, + # "content": {}, + # } + # c_connector.process_driver_license_result((rq_id, result, None)) + # return {"rq_id": rq_id} + # else: + # result = { + # "status": 200, + # "content": result, + # "message": "Success", + # } + # path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id) + # cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_croped) + # c_connector.process_driver_license_result((rq_id, result, path_image_croped)) + # return {"rq_id": rq_id} + # elif image_croped is None: + # result = { + # "status": 404, + # "content": {}, + # } + # c_connector.process_driver_license_result((rq_id, result, None)) + # return {"rq_id": rq_id} + except Exception as e: + print(e) + result = { + "status": 404, + "content": {}, + } + c_connector.process_driver_license_result((rq_id, result, None)) + return {"rq_id": rq_id} + + +@app.task(name="process_template_matching") +def process_template_matching(rq_id, sub_id, folder_name, url, tmp_json, user_id): + from TemplateMatching.src.ocr_master import Extractor + from celery_worker.client_connector import CeleryConnector + import urllib + + c_connector = CeleryConnector() + extractor = Extractor() + try: + req = urllib.request.urlopen(url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + imgs = [img] + image_aliged = extractor.image_alige(imgs, tmp_json) + if image_aliged is None: + result = { + "status": 401, + "content": "Image is not match with template", + } + c_connector.process_template_matching_result( + (rq_id, result, None) + ) + return {"rq_id": rq_id} + else: + output = extractor.extract_information( + image_aliged, tmp_json + ) + path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id) + cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_aliged) + if output == {}: + result = {"status": 404, "content": {}} + c_connector.process_template_matching_result((rq_id, result, None)) + return {"rq_id": rq_id} + else: + result = { + "document_type": "template_matching", + "fields": [] + } + print(output) + for field in tmp_json["fields"]: + print(field["label"]) + field_value = { + "label": field["label"], + "value": output[field["label"]], + "box": [float(num) for num in field["box"]], + "confidence": 0.98 #TODO confidence + } + result["fields"].append(field_value) + + print(result) + result = {"status": 200, "content": result} + c_connector.process_template_matching_result( + (rq_id, result, path_image_croped) + ) + return {"rq_id": rq_id} + except Exception as e: + print(e) + result = {"status": 404, "content": {}} + c_connector.process_template_matching_result((rq_id, result, None)) + return {"rq_id": rq_id} + + +# @app.task(name="process_invoice") +# def process_invoice(rq_id, url): +# from celery_worker.client_connector import CeleryConnector +# from Kie_Hoanglv.prediction import predict + +# c_connector = CeleryConnector() +# try: +# print(url) +# result = predict(str(url)) +# hoadon = {"status": 200, "content": result, "message": "Success"} +# c_connector.process_invoice_result((rq_id, hoadon)) +# return {"rq_id": rq_id} + +# except Exception as e: +# print(e) +# hoadon = {"status": 404, "content": {}} +# c_connector.process_invoice_result((rq_id, hoadon)) +# return {"rq_id": rq_id} + +@app.task(name="process_invoice") +def process_invoice(rq_id, list_url): + from celery_worker.client_connector import CeleryConnector + from common.process_pdf import compile_output + + c_connector = CeleryConnector() + try: + result = compile_output(list_url) + hoadon = {"status": 200, "content": result, "message": "Success"} + c_connector.process_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + hoadon = {"status": 404, "content": {}} + c_connector.process_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + + +@app.task(name="process_ocr_with_box") +def process_ocr_with_box(rq_id, list_url): + from celery_worker.client_connector import CeleryConnector + from common.process_pdf import compile_output_ocr_base + + c_connector = CeleryConnector() + try: + result = compile_output_ocr_base(list_url) + result = {"status": 200, "content": result, "message": "Success"} + c_connector.process_ocr_with_box_result((rq_id, result)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + result = {"status": 404, "content": {}} + c_connector.process_ocr_with_box_result((rq_id, result)) + return {"rq_id": rq_id} \ No newline at end of file diff --git a/cope2n-ai-fi/celery_worker/mock_process_tasks_fi.py b/cope2n-ai-fi/celery_worker/mock_process_tasks_fi.py new file mode 100755 index 0000000..00bec4a --- /dev/null +++ b/cope2n-ai-fi/celery_worker/mock_process_tasks_fi.py @@ -0,0 +1,74 @@ +from celery_worker.worker_fi import app + +@app.task(name="process_fi_invoice") +def process_invoice(rq_id, list_url): + from celery_worker.client_connector_fi import CeleryConnector + from common.process_pdf import compile_output_fi + + c_connector = CeleryConnector() + try: + result = compile_output_fi(list_url) + hoadon = {"status": 200, "content": result, "message": "Success"} + print(hoadon) + c_connector.process_fi_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + hoadon = {"status": 404, "content": {}} + c_connector.process_fi_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + + +@app.task(name="process_sap_invoice") +def process_sap_invoice(rq_id, list_url): + from celery_worker.client_connector_fi import CeleryConnector + from common.process_pdf import compile_output + + print(list_url) + c_connector = CeleryConnector() + try: + result = compile_output(list_url) + hoadon = {"status": 200, "content": result, "message": "Success"} + c_connector.process_sap_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + hoadon = {"status": 404, "content": {}} + c_connector.process_sap_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + +@app.task(name="process_manulife_invoice") +def process_manulife_invoice(rq_id, list_url): + from celery_worker.client_connector_fi import CeleryConnector + from common.process_pdf import compile_output_manulife + # TODO: simply returning 200 and 404 doesn't make any sense + c_connector = CeleryConnector() + try: + result = compile_output_manulife(list_url) + hoadon = {"status": 200, "content": result, "message": "Success"} + print(hoadon) + c_connector.process_manulife_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + hoadon = {"status": 404, "content": {}} + c_connector.process_manulife_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + +@app.task(name="process_sbt_invoice") +def process_sbt_invoice(rq_id, list_url): + from celery_worker.client_connector_fi import CeleryConnector + from common.process_pdf import compile_output_sbt + # TODO: simply returning 200 and 404 doesn't make any sense + c_connector = CeleryConnector() + try: + result = compile_output_sbt(list_url) + hoadon = {"status": 200, "content": result, "message": "Success"} + print(hoadon) + c_connector.process_sbt_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} + except Exception as e: + print(e) + hoadon = {"status": 404, "content": {}} + c_connector.process_sbt_invoice_result((rq_id, hoadon)) + return {"rq_id": rq_id} \ No newline at end of file diff --git a/cope2n-ai-fi/celery_worker/worker.py b/cope2n-ai-fi/celery_worker/worker.py new file mode 100755 index 0000000..68b9d89 --- /dev/null +++ b/cope2n-ai-fi/celery_worker/worker.py @@ -0,0 +1,40 @@ +from celery import Celery +from kombu import Queue, Exchange +import environ +env = environ.Env( + DEBUG=(bool, True) +) + +app: Celery = Celery( + "postman", + broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"), + # backend="rpc://", + include=[ + "celery_worker.mock_process_tasks", + ], +) +task_exchange = Exchange("default", type="direct") +task_create_missing_queues = False +app.conf.update( + { + "result_expires": 3600, + "task_queues": [ + Queue("id_card"), + Queue("driver_license"), + Queue("invoice"), + Queue("ocr_with_box"), + Queue("template_matching"), + ], + "task_routes": { + "process_id": {"queue": "id_card"}, + "process_driver_license": {"queue": "driver_license"}, + "process_invoice": {"queue": "invoice"}, + "process_ocr_with_box": {"queue": "ocr_with_box"}, + "process_template_matching": {"queue": "template_matching"}, + }, + } +) + +if __name__ == "__main__": + argv = ["celery_worker.worker", "--loglevel=INFO", "--pool=solo"] # Window opts + app.worker_main(argv) \ No newline at end of file diff --git a/cope2n-ai-fi/celery_worker/worker_fi.py b/cope2n-ai-fi/celery_worker/worker_fi.py new file mode 100755 index 0000000..54c49ea --- /dev/null +++ b/cope2n-ai-fi/celery_worker/worker_fi.py @@ -0,0 +1,37 @@ +from celery import Celery +from kombu import Queue, Exchange +import environ +env = environ.Env( + DEBUG=(bool, True) +) + +app: Celery = Celery( + "postman", + broker= env.str("CELERY_BROKER", "amqp://test:test@rabbitmq:5672"), + include=[ + "celery_worker.mock_process_tasks_fi", + ], +) +task_exchange = Exchange("default", type="direct") +task_create_missing_queues = False +app.conf.update( + { + "result_expires": 3600, + "task_queues": [ + Queue("invoice_fi"), + Queue("invoice_sap"), + Queue("invoice_manulife"), + Queue("invoice_sbt"), + ], + "task_routes": { + 'process_fi_invoice': {'queue': "invoice_fi"}, + 'process_fi_invoice_result': {'queue': 'invoice_fi_rs'}, + 'process_sap_invoice': {'queue': "invoice_sap"}, + 'process_sap_invoice_result': {'queue': 'invoice_sap_rs'}, + 'process_manulife_invoice': {'queue': 'invoice_manulife'}, + 'process_manulife_invoice_result': {'queue': 'invoice_manulife_rs'}, + 'process_sbt_invoice': {'queue': 'invoice_sbt'}, + 'process_sbt_invoice_result': {'queue': 'invoice_sbt_rs'}, + }, + } +) \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/anyKeyValue.py b/cope2n-ai-fi/common/AnyKey_Value/anyKeyValue.py new file mode 100755 index 0000000..1157c7c --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/anyKeyValue.py @@ -0,0 +1,101 @@ +import os +import glob +import cv2 +import argparse +from tqdm import tqdm +from datetime import datetime +# from omegaconf import OmegaConf +import sys +sys.path.append('/home/thucpd/thucpd/git/PV2-2023/common/AnyKey_Value') # TODO: ???? +from predictor import KVUPredictor +from preprocess import KVUProcess, DocumentKVUProcess +from utils.utils import create_dir, visualize, get_colormap, export_kvu_for_VAT_invoice, export_kvu_outputs + + +def get_args(): + args = argparse.ArgumentParser(description='Main file') + args.add_argument('--img_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str, + help='Input image directory') + args.add_argument('--save_dir', default='/home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/', type=str, + help='Save directory') + args.add_argument('--exp_dir', default='/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/experiments/key_value_understanding-20230608-171900', type=str, + help='Checkpoint and config of model') + args.add_argument('--export_img', default=0, type=int, + help='Save visualize on image') + args.add_argument('--mode', default=3, type=int, + help="0:'normal' - 1:'full_tokens' - 2:'sliding' - 3: 'document'") + args.add_argument('--dir_level', default=0, type=int, + help='Number of subfolders contains image') + + return args.parse_args() + + +def load_engine(exp_dir: str, class_names: list, mode: int) -> KVUPredictor: + configs = { + 'cfg': glob.glob(f'{exp_dir}/*.yaml')[0], + 'ckpt': f'{exp_dir}/checkpoints/best_model.pth' + } + dummy_idx = 512 + predictor = KVUPredictor(configs, class_names, dummy_idx, mode) + + # processor = KVUProcess(predictor.net.tokenizer_layoutxlm, + # predictor.net.feature_extractor, predictor.backbone_type, class_names, + # predictor.slice_interval, predictor.window_size, run_ocr=1, mode=mode) + + processor = DocumentKVUProcess(predictor.net.tokenizer, predictor.net.feature_extractor, + predictor.backbone_type, class_names, + predictor.max_window_count, predictor.slice_interval, predictor.window_size, + run_ocr=1, mode=mode) + return predictor, processor + + +def predict_image(img_path: str, save_dir: str, predictor: KVUPredictor, processor) -> None: + fname = os.path.basename(img_path) + img_ext = os.path.splitext(img_path)[1] + output_ext = ".json" + inputs = processor(img_path, ocr_path='') + + bbox, lwords, pr_class_words, pr_relations = predictor.predict(inputs) + + slide_window = False if len(bbox) == 1 else True + + if len(bbox) == 0: + vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords, pr_class_words, pr_relations, predictor.class_names) + else: + for i in range(len(bbox)): + if not slide_window: + save_path = os.path.join(save_dir, 'kvu_results') + create_dir(save_path) + # export_kvu_for_SDSAP(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names) + vat_outputs = export_kvu_for_VAT_invoice(os.path.join(save_dir, fname.replace(img_ext, output_ext)), lwords[i], pr_class_words[i], pr_relations[i], predictor.class_names) + + return vat_outputs + + +def Predictor_KVU(img: str, save_dir: str, predictor: KVUPredictor, processor) -> None: + + # req = urllib.request.urlopen(image_url) + # arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + # img = cv2.imdecode(arr, -1) + curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S') + image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime) + cv2.imwrite(image_path, img) + vat_outputs = predict_image(image_path, save_dir, predictor, processor) + return vat_outputs + + +if __name__ == "__main__": + args = get_args() + class_names = ['others', 'title', 'key', 'value', 'header'] + predict_mode = { + 'normal': 0, + 'full_tokens': 1, + 'sliding': 2, + 'document': 3 + } + predictor, processor = load_engine(args.exp_dir, class_names, args.mode) + create_dir(args.save_dir) + image_path = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg" + save_dir = "/mnt/ssd1T/tuanlv/PV2-2023/common/AnyKey_Value/visualize/test1" + vat_outputs = predict_image(image_path, save_dir, predictor, processor) + print('[INFO] Done') diff --git a/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/__init__.py b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py new file mode 100755 index 0000000..83b4769 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py @@ -0,0 +1,133 @@ +import time + +import torch +import torch.utils.data +from overrides import overrides +from pytorch_lightning import LightningModule +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.utilities.distributed import rank_zero_only +from torch.optim import SGD, Adam, AdamW +from torch.optim.lr_scheduler import LambdaLR + +from lightning_modules.schedulers import ( + cosine_scheduler, + linear_scheduler, + multistep_scheduler, +) +from model import get_model +from utils import cfg_to_hparams, get_specific_pl_logger + + +class ClassifierModule(LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.net = get_model(self.cfg) + self.ignore_index = -100 + + self.time_tracker = None + + self.optimizer_types = { + "sgd": SGD, + "adam": Adam, + "adamw": AdamW, + } + + @overrides + def setup(self, stage): + self.time_tracker = time.time() + + @overrides + def configure_optimizers(self): + optimizer = self._get_optimizer() + scheduler = self._get_lr_scheduler(optimizer) + scheduler = { + "scheduler": scheduler, + "name": "learning_rate", + "interval": "step", + } + return [optimizer], [scheduler] + + def _get_lr_scheduler(self, optimizer): + cfg_train = self.cfg.train + lr_schedule_method = cfg_train.optimizer.lr_schedule.method + lr_schedule_params = cfg_train.optimizer.lr_schedule.params + + if lr_schedule_method is None: + scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1) + elif lr_schedule_method == "step": + scheduler = multistep_scheduler(optimizer, **lr_schedule_params) + elif lr_schedule_method == "cosine": + total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch + total_batch_size = cfg_train.batch_size * self.trainer.world_size + max_iter = total_samples / total_batch_size + scheduler = cosine_scheduler( + optimizer, training_steps=max_iter, **lr_schedule_params + ) + elif lr_schedule_method == "linear": + total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch + total_batch_size = cfg_train.batch_size * self.trainer.world_size + max_iter = total_samples / total_batch_size + scheduler = linear_scheduler( + optimizer, training_steps=max_iter, **lr_schedule_params + ) + else: + raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}") + + return scheduler + + def _get_optimizer(self): + opt_cfg = self.cfg.train.optimizer + method = opt_cfg.method.lower() + + if method not in self.optimizer_types: + raise ValueError(f"Unknown optimizer method={method}") + + kwargs = dict(opt_cfg.params) + kwargs["params"] = self.net.parameters() + optimizer = self.optimizer_types[method](**kwargs) + + return optimizer + + @rank_zero_only + @overrides + def on_fit_end(self): + hparam_dict = cfg_to_hparams(self.cfg, {}) + metric_dict = {"metric/dummy": 0} + + tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger) + + if tb_logger: + tb_logger.log_hyperparams(hparam_dict, metric_dict) + + @overrides + def training_epoch_end(self, training_step_outputs): + avg_loss = torch.tensor(0.0).to(self.device) + for step_out in training_step_outputs: + avg_loss += step_out["loss"] + + log_dict = {"train_loss": avg_loss} + self._log_shell(log_dict, prefix="train ") + + def _log_shell(self, log_info, prefix=""): + log_info_shell = {} + for k, v in log_info.items(): + new_v = v + if type(new_v) is torch.Tensor: + new_v = new_v.item() + log_info_shell[k] = new_v + + out_str = prefix.upper() + if prefix.upper().strip() in ["TRAIN", "VAL"]: + out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]" + + if self.training: + lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"] + log_info_shell["lr"] = lr + + for key, value in log_info_shell.items(): + out_str += f" || {key}: {round(value, 5)}" + out_str += f" || time: {round(time.time() - self.time_tracker, 1)}" + out_str += " secs." + # self.print(out_str) + self.time_tracker = time.time() diff --git a/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py new file mode 100755 index 0000000..5786cc5 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py @@ -0,0 +1,390 @@ +import numpy as np +import torch +import torch.utils.data +from overrides import overrides + +from lightning_modules.classifier import ClassifierModule +from utils import get_class_names + + +class KVUClassifierModule(ClassifierModule): + def __init__(self, cfg): + super().__init__(cfg) + + class_names = get_class_names(self.cfg.dataset_root_path) + + self.window_size = cfg.train.max_num_words + self.slice_interval = cfg.train.slice_interval + self.eval_kwargs = { + "class_names": class_names, + "dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step + } + self.stage = cfg.stage + + @overrides + def training_step(self, batch, batch_idx, *args): + if self.stage == 1: + _, loss = self.net(batch['windows']) + elif self.stage == 2: + _, loss = self.net(batch) + else: + raise ValueError( + f"Not supported stage: {self.stage}" + ) + + log_dict_input = {"train_loss": loss} + self.log_dict(log_dict_input, sync_dist=True) + return loss + + @torch.no_grad() + @overrides + def validation_step(self, batch, batch_idx, *args): + if self.stage == 1: + step_out_total = { + "loss": 0, + "ee":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }, + "el":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }, + "el_from_key":{ + "n_batch_gt": 0, + "n_batch_pr": 0, + "n_batch_correct": 0, + }} + for window in batch['windows']: + head_outputs, loss = self.net(window) + step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs) + for key in step_out_total: + if key == 'loss': + step_out_total[key] += step_out[key] + else: + for subkey in step_out_total[key]: + step_out_total[key][subkey] += step_out[key][subkey] + return step_out_total + + elif self.stage == 2: + head_outputs, loss = self.net(batch) + # self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1] + # step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs) + self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1] + step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs) + return step_out + + @torch.no_grad() + @overrides + def validation_epoch_end(self, validation_step_outputs): + scores = do_eval_epoch_end(validation_step_outputs) + self.print( + f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}" + ) + self.print( + f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}" + ) + self.print( + f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}" + ) + self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.) + tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'], + 'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'], + 'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \ + 'val_f1_el_from_key': scores['el_from_key']['f1'],} + return {'log': tensorboard_logs} + + +def do_eval_step(batch, head_outputs, loss, eval_kwargs): + class_names = eval_kwargs["class_names"] + dummy_idx = eval_kwargs["dummy_idx"] + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_labels = torch.argmax(itc_outputs, -1) + pr_stc_labels = torch.argmax(stc_outputs, -1) + pr_el_labels = torch.argmax(el_outputs, -1) + pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1) + + ( + n_batch_gt_classes, + n_batch_pr_classes, + n_batch_correct_classes, + ) = eval_ee_spade_batch( + pr_itc_labels, + batch["itc_labels"], + batch["are_box_first_tokens"], + pr_stc_labels, + batch["stc_labels"], + batch["attention_mask_layoutxlm"], + class_names, + dummy_idx, + ) + + n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch( + pr_el_labels, + batch["el_labels"], + batch["are_box_first_tokens"], + dummy_idx, + ) + + n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch( + pr_el_labels_from_key, + batch["el_labels_from_key"], + batch["are_box_first_tokens"], + dummy_idx, + ) + + step_out = { + "loss": loss, + "ee":{ + "n_batch_gt": n_batch_gt_classes, + "n_batch_pr": n_batch_pr_classes, + "n_batch_correct": n_batch_correct_classes, + }, + "el":{ + "n_batch_gt": n_batch_gt_rel, + "n_batch_pr": n_batch_pr_rel, + "n_batch_correct": n_batch_correct_rel, + }, + "el_from_key":{ + "n_batch_gt": n_batch_gt_rel_from_key, + "n_batch_pr": n_batch_pr_rel_from_key, + "n_batch_correct": n_batch_correct_rel_from_key, + } + + } + + return step_out + + +def eval_ee_spade_batch( + pr_itc_labels, + gt_itc_labels, + are_box_first_tokens, + pr_stc_labels, + gt_stc_labels, + attention_mask, + class_names, + dummy_idx, +): + n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0 + + bsz = pr_itc_labels.shape[0] + for example_idx in range(bsz): + n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example( + pr_itc_labels[example_idx], + gt_itc_labels[example_idx], + are_box_first_tokens[example_idx], + pr_stc_labels[example_idx], + gt_stc_labels[example_idx], + attention_mask[example_idx], + class_names, + dummy_idx, + ) + + n_batch_gt_classes += n_gt_classes + n_batch_pr_classes += n_pr_classes + n_batch_correct_classes += n_correct_classes + + return ( + n_batch_gt_classes, + n_batch_pr_classes, + n_batch_correct_classes, + ) + + +def eval_ee_spade_example( + pr_itc_label, + gt_itc_label, + box_first_token_mask, + pr_stc_label, + gt_stc_label, + attention_mask, + class_names, + dummy_idx, +): + gt_first_words = parse_initial_words( + gt_itc_label, box_first_token_mask, class_names + ) + gt_class_words = parse_subsequent_words( + gt_stc_label, attention_mask, gt_first_words, dummy_idx + ) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, dummy_idx + ) + + n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0 + for class_idx in range(len(class_names)): + # Evaluate by ID + gt_parse = set(gt_class_words[class_idx]) + pr_parse = set(pr_class_words[class_idx]) + + n_gt_classes += len(gt_parse) + n_pr_classes += len(pr_parse) + n_correct_classes += len(gt_parse & pr_parse) + + return n_gt_classes, n_pr_classes, n_correct_classes + + +def parse_initial_words(itc_label, box_first_token_mask, class_names): + itc_label_np = itc_label.cpu().numpy() + box_first_token_mask_np = box_first_token_mask.cpu().numpy() + + outputs = [[] for _ in range(len(class_names))] + + for token_idx, label in enumerate(itc_label_np): + if box_first_token_mask_np[token_idx] and label != 0: + outputs[label].append(token_idx) + + return outputs + + +def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx): + max_connections = 50 + + valid_stc_label = stc_label * attention_mask.bool() + valid_stc_label = valid_stc_label.cpu().numpy() + stc_label_np = stc_label.cpu().numpy() + + valid_token_indices = np.where( + (valid_stc_label != dummy_idx) * (valid_stc_label != 0) + ) + + next_token_idx_dict = {} + for token_idx in valid_token_indices[0]: + next_token_idx_dict[stc_label_np[token_idx]] = token_idx + + outputs = [] + for init_token_indices in init_words: + sub_outputs = [] + for init_token_idx in init_token_indices: + cur_token_indices = [init_token_idx] + for _ in range(max_connections): + if cur_token_indices[-1] in next_token_idx_dict: + if ( + next_token_idx_dict[cur_token_indices[-1]] + not in init_token_indices + ): + cur_token_indices.append( + next_token_idx_dict[cur_token_indices[-1]] + ) + else: + break + else: + break + sub_outputs.append(tuple(cur_token_indices)) + + outputs.append(sub_outputs) + + return outputs + + +def eval_el_spade_batch( + pr_el_labels, + gt_el_labels, + are_box_first_tokens, + dummy_idx, +): + n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0 + + bsz = pr_el_labels.shape[0] + for example_idx in range(bsz): + n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example( + pr_el_labels[example_idx], + gt_el_labels[example_idx], + are_box_first_tokens[example_idx], + dummy_idx, + ) + + n_batch_gt_rel += n_gt_rel + n_batch_pr_rel += n_pr_rel + n_batch_correct_rel += n_correct_rel + + return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel + + +def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx): + gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx) + pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) + + gt_relations = set(gt_relations) + pr_relations = set(pr_relations) + + n_gt_rel = len(gt_relations) + n_pr_rel = len(pr_relations) + n_correct_rel = len(gt_relations & pr_relations) + + return n_gt_rel, n_pr_rel, n_correct_rel + + +def parse_relations(el_label, box_first_token_mask, dummy_idx): + valid_el_labels = el_label * box_first_token_mask + valid_el_labels = valid_el_labels.cpu().numpy() + el_label_np = el_label.cpu().numpy() + + max_token = box_first_token_mask.shape[0] - 1 + + valid_token_indices = np.where( + ((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ### + ) + + link_map_tuples = [] + for token_idx in valid_token_indices[0]: + link_map_tuples.append((el_label_np[token_idx], token_idx)) + + return set(link_map_tuples) + +def do_eval_epoch_end(step_outputs): + scores = {} + for task in ['ee', 'el', 'el_from_key']: + n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0 + + for step_out in step_outputs: + n_total_gt_classes += step_out[task]["n_batch_gt"] + n_total_pr_classes += step_out[task]["n_batch_pr"] + n_total_correct_classes += step_out[task]["n_batch_correct"] + + precision = ( + 0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes + ) + recall = ( + 0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes + ) + f1 = ( + 0.0 + if recall * precision == 0 + else 2.0 * recall * precision / (recall + precision) + ) + + scores[task] = { + "precision": precision, + "recall": recall, + "f1": f1, + } + + return scores + + +def get_eval_kwargs_spade(dataset_root_path, max_seq_length): + class_names = get_class_names(dataset_root_path) + dummy_idx = max_seq_length + + eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx} + + return eval_kwargs + + +def get_eval_kwargs_spade_rel(max_seq_length): + dummy_idx = max_seq_length + + eval_kwargs = {"dummy_idx": dummy_idx} + + return eval_kwargs \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/schedulers.py b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/schedulers.py new file mode 100755 index 0000000..b49abc2 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/schedulers.py @@ -0,0 +1,53 @@ +""" +BROS +Copyright 2022-present NAVER Corp. +Apache License v2.0 +""" + +import math + +import numpy as np +from torch.optim.lr_scheduler import LambdaLR + + +def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1): + """linear_scheduler with warmup from huggingface""" + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, + float(training_steps - current_step) + / float(max(1, training_steps - warmup_steps)), + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def cosine_scheduler( + optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1 +): + """Cosine LR scheduler with warmup from huggingface""" + + def lr_lambda(current_step): + if current_step < warmup_steps: + return current_step / max(1, warmup_steps) + progress = current_step - warmup_steps + progress /= max(1, training_steps - warmup_steps) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1): + def lr_lambda(current_step): + if current_step < warmup_steps: + # calculate a warmup ratio + return current_step / max(1, warmup_steps) + else: + # calculate a multistep lr scaling ratio + idx = np.searchsorted(milestones, current_step) + return gamma ** idx + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py new file mode 100755 index 0000000..b9ca682 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py @@ -0,0 +1,161 @@ +import os +import copy +import numpy as np +import torch +import torch.nn as nn +import math + + + +def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list: + element_windows = [] + + if len(elements) > window_size: + max_step = math.ceil((len(elements) - window_size)/slice_interval) + + for i in range(0, max_step + 1): + # element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))])) + if (i*slice_interval+window_size) >= len(elements): + _window = copy.deepcopy(elements[i*slice_interval:]) + else: + _window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size]) + element_windows.append(_window) + return element_windows + else: + return [elements] + +def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list: + word_windows = [] + parse_class_windows = [] + parse_relation_windows = [] + + if len(lwords) > window_size: + max_step = math.ceil((len(lwords) - window_size)/slice_interval) + for i in range(0, max_step+1): + # _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))]) + if (i*slice_interval+window_size) >= len(lwords): + _word_window = copy.deepcopy(lwords[i*slice_interval:]) + else: + _word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size]) + + if len(_word_window) < 2: + continue + + first_word_id = _word_window[0]['word_id'] + last_word_id = _word_window[-1]['word_id'] + + # assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords)) + # word list + for _word in _word_window: + _word['word_id'] -= first_word_id + + + # Entity extraction + _class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id) + + # Entity Linking + _relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id) + + word_windows.append(_word_window) + parse_class_windows.append(_class_window) + parse_relation_windows.append(_relation_window) + + return word_windows, parse_class_windows, parse_relation_windows + else: + return [lwords], [parse_class], [parse_relation] + + +def entity_extraction_by_words(parse_class, first_word_id, last_word_id): + _class_window = {k: [] for k in list(parse_class.keys())} + for class_name, _parse_class in parse_class.items(): + for group in _parse_class: + tmp = [] + for idw in group: + idw -= first_word_id + if 0 <= idw <= (last_word_id - first_word_id): + tmp.append(idw) + _class_window[class_name].append(tmp) + return _class_window + +def entity_linking_by_words(parse_relation, first_word_id, last_word_id): + _relation_window = [] + for pair in parse_relation: + if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]): + _relation_window.append([idw - first_word_id for idw in pair]) + return _relation_window + + +def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: + start_pos = 1 + end_pos = start_pos + lvalids[0] + embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) + cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) + sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) + + for i in range(1, len(lpatches)): + start_pos = 1 + end_pos = start_pos + lvalids[i] + + overlap_gap = copy.deepcopy(loverlaps[i-1]) + window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...]) + + if overlap_gap != 0: + prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) + curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) + assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + + if average: + avg_overlap = ( + prev_overlap + curr_overlap + ) / 2. + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens, window], dim=1 + ) + return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) + + + +def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: + start_pos = 1 + end_pos = start_pos + lvalids[0] + embedding_tokens = lpatches[0][:, start_pos:end_pos, ...] + cls_token = lpatches[0][:, :1, ...] + sep_token = lpatches[0][:, -1:, ...] + + for i in range(1, len(lpatches)): + start_pos = 1 + end_pos = start_pos + lvalids[i] + + overlap_gap = loverlaps[i-1] + window = lpatches[i][:, start_pos:end_pos, ...] + + if overlap_gap != 0: + prev_overlap = embedding_tokens[:, -overlap_gap:, ...] + curr_overlap = window[:, :overlap_gap, ...] + assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + + if average: + avg_overlap = ( + prev_overlap + curr_overlap + ) / 2. + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens, window], dim=1 + ) + return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) + diff --git a/cope2n-ai-fi/common/AnyKey_Value/model/__init__.py b/cope2n-ai-fi/common/AnyKey_Value/model/__init__.py new file mode 100755 index 0000000..1ef6c85 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/model/__init__.py @@ -0,0 +1,15 @@ + +from model.combined_model import CombinedKVUModel +from model.kvu_model import KVUModel +from model.document_kvu_model import DocumentKVUModel + +def get_model(cfg): + if cfg.stage == 1: + model = CombinedKVUModel(cfg=cfg) + elif cfg.stage == 2: + model = KVUModel(cfg=cfg) + elif cfg.stage == 3: + model = DocumentKVUModel(cfg=cfg) + else: + raise Exception('[ERROR] Trainging stage is wrong') + return model diff --git a/cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py b/cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py new file mode 100755 index 0000000..db87e49 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/model/combined_model.py @@ -0,0 +1,76 @@ +import os +import torch +from torch import nn +from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor +from transformers import LayoutXLMTokenizer +from transformers import AutoTokenizer, XLMRobertaModel + + +from model.relation_extractor import RelationExtractor +from model.kvu_model import KVUModel +from utils import load_checkpoint + + +class CombinedKVUModel(KVUModel): + def __init__(self, cfg): + super().__init__(cfg) + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + if os.path.exists(self.model_cfg.ckpt_model_file): + self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer') + self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer') + self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer') + self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key') + + self.loss_func = nn.CrossEntropyLoss() + + if self.freeze: + for name, param in self.named_parameters(): + if 'backbone' in name: + param.requires_grad = False + if self.finetune_only == 'EE': + for name, param in self.named_parameters(): + if 'itc_layer' not in name and 'stc_layer' not in name: + param.requires_grad = False + if self.finetune_only == 'EL': + for name, param in self.named_parameters(): + if 'relation_layer' not in name or 'relation_layer_from_key' in name: + param.requires_grad = False + if self.finetune_only == 'ELK': + for name, param in self.named_parameters(): + if 'relation_layer_from_key' not in name: + param.requires_grad = False + + + def forward(self, batch): + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) + + last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} + + loss = 0.0 + if any(['labels' in key for key in batch.keys()]): + loss = self._get_loss(head_outputs, batch) + + return head_outputs, loss + diff --git a/cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py b/cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py new file mode 100755 index 0000000..0e04ed3 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/model/document_kvu_model.py @@ -0,0 +1,185 @@ +import torch +from torch import nn +from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor +from transformers import LayoutLMv2Config, LayoutLMv2Model +from transformers import LayoutXLMTokenizer +from transformers import XLMRobertaConfig, AutoTokenizer, XLMRobertaModel + + +from model.relation_extractor import RelationExtractor +from model.kvu_model import KVUModel +from utils import load_checkpoint + + +class DocumentKVUModel(KVUModel): + def __init__(self, cfg): + super().__init__(cfg) + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.train_cfg = cfg.train + + self._get_backbones(self.model_cfg.backbone) + # if 'pth' in self.model_cfg.ckpt_model_file: + # self.backbone = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone) + + self._create_head() + + self.loss_func = nn.CrossEntropyLoss() + + def _create_head(self): + self.backbone_hidden_size = self.backbone_config.hidden_size + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + self.n_classes = self.model_cfg.n_classes + 1 + self.repr_hiddent_size = self.backbone_hidden_size + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # Classfication Layer for whole document + # (1) Initial token classification + self.itc_layer_document = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (3) Linking token classification + self.relation_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (4) Linking token classification + self.relation_layer_from_key_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + self.relation_layer_from_key.apply(self._init_weight) + + self.itc_layer_document.apply(self._init_weight) + self.stc_layer_document.apply(self._init_weight) + self.relation_layer_document.apply(self._init_weight) + self.relation_layer_from_key_document.apply(self._init_weight) + + + def _get_backbones(self, config_type): + configs = { + 'layoutlm': {'config': LayoutLMConfig, 'tokenizer': LayoutLMTokenizer, 'backbone': LayoutLMModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, + 'layoutxlm': {'config': LayoutLMv2Config, 'tokenizer': LayoutXLMTokenizer, 'backbone': LayoutLMv2Model, 'feature_extrator': LayoutLMv2FeatureExtractor}, + 'xlm-roberta': {'config': XLMRobertaConfig, 'tokenizer': AutoTokenizer, 'backbone': XLMRobertaModel, 'feature_extrator': LayoutLMv2FeatureExtractor}, + } + + self.backbone_config = configs[config_type]['config'].from_pretrained(self.model_cfg.pretrained_model_path) + if config_type != 'xlm-roberta': + self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path) + else: + self.tokenizer = configs[config_type]['tokenizer'].from_pretrained(self.model_cfg.pretrained_model_path, use_fast=False) + self.feature_extractor = configs[config_type]['feature_extrator'](apply_ocr=False) + self.backbone = configs[config_type]['backbone'].from_pretrained(self.model_cfg.pretrained_model_path) + + + def forward(self, batches): + head_outputs_list = [] + loss = 0.0 + for batch in batches["windows"]: + image = batch["image"] + input_ids = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask = batch["attention_mask_layoutxlm"] + + if self.freeze: + for param in self.backbone.parameters(): + param.requires_grad = False + + if self.model_cfg.backbone == 'layoutxlm': + backbone_outputs = self.backbone( + image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask + ) + else: + backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) + + last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + + window_repr = last_hidden_states.transpose(0, 1).contiguous() + + head_outputs = {"window_repr": window_repr, + "itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + head_outputs_list.append(head_outputs) + + batch = batches["documents"] + + document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) + document_repr = document_repr.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) + el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0) + + head_outputs = {"itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + return head_outputs, loss + diff --git a/cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py b/cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py new file mode 100755 index 0000000..d500370 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/model/kvu_model.py @@ -0,0 +1,248 @@ +import os +import torch +from torch import nn +from transformers import LayoutLMv2Model, LayoutLMv2FeatureExtractor +from transformers import LayoutXLMTokenizer +from lightning_modules.utils import merged_token_embeddings, merged_token_embeddings2 + + +from model.relation_extractor import RelationExtractor +from utils import load_checkpoint + + +class KVUModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.device = 'cuda' + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + + # if cfg.stage == 2: + # self.freeze = True + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)): + self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + + self._create_head() + self.loss_func = nn.CrossEntropyLoss() + + if self.freeze: + for name, param in self.named_parameters(): + if 'backbone' in name: + param.requires_grad = False + + def _create_head(self): + self.backbone_hidden_size = 768 + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + self.n_classes = self.model_cfg.n_classes + 1 + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + + + def _get_backbones(self, config_type): + self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base') + self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) + self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base') + + @staticmethod + def _init_weight(module): + init_std = 0.02 + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.LayerNorm): + nn.init.normal_(module.weight, 1.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + + + # def forward(self, inputs): + # token_embeddings = inputs['embeddings'].transpose(0, 1).contiguous().cuda() + # itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() + # stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) + # el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) + # el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0) + # head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + # "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} + + # loss = self._get_loss(head_outputs, inputs) + # return head_outputs, loss + + + # def forward_single_doccument(self, lbatches): + def forward(self, lbatches): + windows = lbatches['windows'] + token_embeddings_windows = [] + lvalids = [] + loverlaps = [] + + for i, batch in enumerate(windows): + batch = {k: v.cuda() for k, v in batch.items() if k not in ('img_path', 'words')} + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) + + + last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] + + lvalids.append(batch['len_valid_tokens']) + loverlaps.append(batch['len_overlap_tokens']) + token_embeddings_windows.append(last_hidden_states_layoutxlm) + + + token_embeddings = merged_token_embeddings2(token_embeddings_windows, loverlaps, lvalids, average=False) + # token_embeddings = merged_token_embeddings(token_embeddings_windows, loverlaps, lvalids, average=True) + + + token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda() + itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(token_embeddings, token_embeddings).squeeze(0) + head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key, + 'embedding_tokens': token_embeddings.transpose(0, 1).contiguous().detach().cpu().numpy()} + + + + loss = 0.0 + if any(['labels' in key for key in lbatches.keys()]): + labels = {k: v.cuda() for k, v in lbatches["documents"].items() if k not in ('img_path')} + loss = self._get_loss(head_outputs, labels) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch) + el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) + + loss = itc_loss + stc_loss + el_loss + el_loss_from_key + + return loss + + def _get_itc_loss(self, itc_outputs, batch): + itc_mask = batch["are_box_first_tokens"].view(-1) + + itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) + itc_logits = itc_logits[itc_mask] + + itc_labels = batch["itc_labels"].view(-1) + itc_labels = itc_labels[itc_mask] + + itc_loss = self.loss_func(itc_logits, itc_labels) + + return itc_loss + + def _get_stc_loss(self, stc_outputs, batch): + inv_attention_mask = 1 - batch["attention_mask_layoutxlm"] + + bsz, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + + invalid_token_mask = torch.cat( + [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 + ).bool() + + stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool() + stc_logits = stc_outputs.view(-1, max_seq_length + 1) + stc_logits = stc_logits[stc_mask] + + stc_labels = batch["stc_labels"].view(-1) + stc_labels = stc_labels[stc_mask] + + stc_loss = self.loss_func(stc_logits, stc_labels) + + return stc_loss + + def _get_el_loss(self, el_outputs, batch, from_key=False): + bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape + + device = batch["attention_mask_layoutxlm"].device + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + + box_first_token_mask = torch.cat( + [ + (batch["are_box_first_tokens"] == False), + torch.zeros([bsz, 1], dtype=torch.bool).to(device), + ], + axis=1, + ) + el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) + el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + mask = batch["are_box_first_tokens"].view(-1) + + logits = el_outputs.view(-1, max_seq_length + 1) + logits = logits[mask] + + if from_key: + el_labels = batch["el_labels_from_key"] + else: + el_labels = batch["el_labels"] + labels = el_labels.view(-1) + labels = labels[mask] + + loss = self.loss_func(logits, labels) + return loss diff --git a/cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py b/cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py new file mode 100755 index 0000000..40a169e --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + + +class RelationExtractor(nn.Module): + def __init__( + self, + n_relations, + backbone_hidden_size, + head_hidden_size, + head_p_dropout=0.1, + ): + super().__init__() + + self.n_relations = n_relations + self.backbone_hidden_size = backbone_hidden_size + self.head_hidden_size = head_hidden_size + self.head_p_dropout = head_p_dropout + + self.drop = nn.Dropout(head_p_dropout) + self.q_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.k_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size)) + nn.init.normal_(self.dummy_node) + + def forward(self, h_q, h_k): + h_q = self.q_net(self.drop(h_q)) + + dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1) + h_k = torch.cat([h_k, dummy_vec], axis=0) + h_k = self.k_net(self.drop(h_k)) + + head_q = h_q.view( + h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size + ) + head_k = h_k.view( + h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size + ) + + relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k)) + + return relation_score diff --git a/cope2n-ai-fi/common/AnyKey_Value/predictor.py b/cope2n-ai-fi/common/AnyKey_Value/predictor.py new file mode 100755 index 0000000..3a68bf0 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/predictor.py @@ -0,0 +1,228 @@ +from omegaconf import OmegaConf +import torch +# from functions import get_colormap, visualize +import sys +sys.path.append('/mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/') #TODO: ?????? + +from lightning_modules.classifier_module import parse_initial_words, parse_subsequent_words, parse_relations +from model import get_model +from utils import load_model_weight + + +class KVUPredictor: + def __init__(self, configs, class_names, dummy_idx, mode=0): + cfg_path = configs['cfg'] + ckpt_path = configs['ckpt'] + + self.class_names = class_names + self.dummy_idx = dummy_idx + self.mode = mode + + print('[INFO] Loading Key-Value Understanding model ...') + self.net, cfg, self.backbone_type = self._load_model(cfg_path, ckpt_path) + print("[INFO] Loaded model") + + if mode == 3: + self.max_window_count = cfg.train.max_window_count + self.window_size = cfg.train.window_size + self.slice_interval = 0 + self.dummy_idx = dummy_idx * self.max_window_count + else: + self.slice_interval = cfg.train.slice_interval + self.window_size = cfg.train.max_num_words + + + self.device = 'cuda' + + def _load_model(self, cfg_path, ckpt_path): + cfg = OmegaConf.load(cfg_path) + cfg.stage = self.mode + backbone_type = cfg.model.backbone + + print('[INFO] Checkpoint:', ckpt_path) + net = get_model(cfg) + load_model_weight(net, ckpt_path) + net.to('cuda') + net.eval() + return net, cfg, backbone_type + + def predict(self, input_sample): + if self.mode == 0: + if len(input_sample['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.combined_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 1: + if len(input_sample['documents']['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 2: + if len(input_sample['windows'][0]['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = [], [], [], [] + for window in input_sample['windows']: + _bbox, _lwords, _pr_class_words, _pr_relations = self.combined_predict(window) + bbox.append(_bbox) + lwords.append(_lwords) + pr_class_words.append(_pr_class_words) + pr_relations.append(_pr_relations) + return bbox, lwords, pr_class_words, pr_relations + + elif self.mode == 3: + if len(input_sample["documents"]['words']) == 0: + return [], [], [], [] + bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + else: + raise ValueError( + f"Not supported mode: {self.mode}" + ) + + def doc_predict(self, input_sample): + lwords = input_sample['documents']['words'] + for idx, window in enumerate(input_sample['windows']): + input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')} + + # input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')} + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + input_sample = input_sample['documents'] + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) + bbox = input_sample['bbox'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + + def combined_predict(self, input_sample): + lwords = input_sample['words'] + input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')} + + input_sample = {k: v.to(self.device) for k, v in input_sample.items()} + + + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} + + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) + bbox = input_sample['bbox'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + def cat_predict(self, input_sample): + lwords = input_sample['documents']['words'] + + inputs = [] + for window in input_sample['windows']: + inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')}) + input_sample['windows'] = inputs + + with torch.no_grad(): + head_outputs, _ = self.net(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')} + + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['documents']['are_box_first_tokens'] + attention_mask = input_sample['documents']['attention_mask_layoutxlm'] + bbox = input_sample['documents']['bbox'] + + dummy_idx = input_sample['documents']['bbox'].shape[0] + + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return bbox, lwords, pr_class_words, pr_relations + + + def get_ground_truth_label(self, ground_truth): + # ground_truth = self.preprocessor.load_ground_truth(json_file) + gt_itc_label = ground_truth['itc_labels'].squeeze(0) # [1, 512] => [512] + gt_stc_label = ground_truth['stc_labels'].squeeze(0) # [1, 512] => [512] + gt_el_label = ground_truth['el_labels'].squeeze(0) + + gt_el_label_from_key = ground_truth['el_labels_from_key'].squeeze(0) + lwords = ground_truth["words"] + + box_first_token_mask = ground_truth['are_box_first_tokens'].squeeze(0) + attention_mask = ground_truth['attention_mask'].squeeze(0) + + bbox = ground_truth['bbox'].squeeze(0) + gt_first_words = parse_initial_words( + gt_itc_label, box_first_token_mask, self.class_names + ) + gt_class_words = parse_subsequent_words( + gt_stc_label, attention_mask, gt_first_words, self.dummy_idx + ) + + gt_relations_from_header = parse_relations(gt_el_label, box_first_token_mask, self.dummy_idx) + gt_relations_from_key = parse_relations(gt_el_label_from_key, box_first_token_mask, self.dummy_idx) + gt_relations = gt_relations_from_header | gt_relations_from_key + + return bbox, lwords, gt_class_words, gt_relations \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/preprocess.py b/cope2n-ai-fi/common/AnyKey_Value/preprocess.py new file mode 100755 index 0000000..365c745 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/preprocess.py @@ -0,0 +1,456 @@ +import os +from typing import Any +import numpy as np +import pandas as pd +import imagesize +import itertools +from PIL import Image +import argparse + +import torch + +from utils.utils import read_ocr_result_from_txt, read_json, post_process_basic_ocr +from utils.run_ocr import load_ocr_engine, process_img +from lightning_modules.utils import sliding_windows + + +class KVUProcess: + def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): + self.tokenizer_layoutxlm = tokenizer_layoutxlm + self.feature_extractor = feature_extractor + + self.max_seq_length = max_seq_length + self.backbone_type = backbone_type + self.class_names = class_names + + self.slice_interval = slice_interval + self.window_size = window_size + self.run_ocr = run_ocr + self.mode = mode + + self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._pad_token) + self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._cls_token) + self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(tokenizer_layoutxlm._sep_token) + self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm._unk_token) + + + self.class_idx_dic = dict( + [(class_name, idx) for idx, class_name in enumerate(self.class_names)] + ) + self.ocr_engine = None + if self.run_ocr == 1: + self.ocr_engine = load_ocr_engine() + + def __call__(self, img_path: str, ocr_path: str) -> list: + if (self.run_ocr == 1) or (not os.path.exists(ocr_path)): + ocr_path = "tmp.txt" + process_img(img_path, ocr_path, self.ocr_engine, export_img=False) + + lbboxes, lwords = read_ocr_result_from_txt(ocr_path) + lwords = post_process_basic_ocr(lwords) + bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval) + word_windows = sliding_windows(lwords, self.window_size, self.slice_interval) + assert len(bbox_windows) == len(word_windows), f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}" + + width, height = imagesize.get(img_path) + images = [Image.open(img_path).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + + + if self.mode == 0: + output = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + elif self.mode == 1: + output = {} + windows = [] + for i in range(len(bbox_windows)): + _words = word_windows[i] + _bboxes = bbox_windows[i] + windows.append( + self.preprocess( + _bboxes, _words, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + ) + + output['windows'] = windows + elif self.mode == 2: + output = {} + windows = [] + output['doduments'] = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=2048) + for i in range(len(bbox_windows)): + _words = word_windows[i] + _bboxes = bbox_windows[i] + windows.append( + self.preprocess( + _bboxes, _words, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + max_seq_length=self.max_seq_length) + ) + + output['windows'] = windows + else: + raise ValueError( + f"Not supported mode: {self.mode }" + ) + return output + + + def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): + list_word_objects = [] + for bb, text in zip(bounding_boxes, words): + boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] + tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text)) + list_word_objects.append({ + "layoutxlm_tokens": tokens, + "boundingBox": boundingBox, + "text": text + }) + + ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_tokens + ) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"]) + + assert len_list_tokens == len_valid_tokens + 2 + len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens + + ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 + + input_ids = input_ids[:ntokens] + attention_mask = attention_mask[:ntokens] + bbox = bbox[:ntokens] + are_box_first_tokens = are_box_first_tokens[:ntokens] + + + input_ids = torch.from_numpy(input_ids) + attention_mask = torch.from_numpy(attention_mask) + bbox = torch.from_numpy(bbox) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + len_valid_tokens = torch.tensor(len_valid_tokens) + len_overlap_tokens = torch.tensor(len_overlap_tokens) + return_dict = { + "img_path": feature_maps['img_path'], + "words": lwords, + "len_overlap_tokens": len_overlap_tokens, + 'len_valid_tokens': len_valid_tokens, + "image": feature_maps['image'], + "input_ids_layoutxlm": input_ids, + "attention_mask_layoutxlm": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + "bbox": bbox, + } + return return_dict + + + def parser_words(self, words, max_seq_length, width, height): + list_bbs = [] + list_words = [] + list_tokens = [] + cls_bbs = [0.0] * 8 + box2token_span_map = [] + box_to_token_indices = [] + lwords = [''] * max_seq_length + + cum_token_idx = 0 + len_valid_tokens = 0 + len_non_overlap_tokens = 0 + + + input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm + bbox = np.zeros((max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) + + for word_idx, word in enumerate(words): + this_box_token_indices = [] + + tokens = word["layoutxlm_tokens"] + bb = word["boundingBox"] + text = word["text"] + + len_valid_tokens += len(tokens) + if word_idx < self.slice_interval: + len_non_overlap_tokens += len(tokens) + + if len(tokens) == 0: + tokens.append(self.unk_token_id) + + if len(list_tokens) + len(tokens) > max_seq_length - 2: + break + + box2token_span_map.append( + [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] + ) # including st_idx + list_tokens += tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(tokens))] + texts = [text for _ in range(len(tokens))] + + for _ in tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + list_words.extend(texts) #### + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [width, height] * 4 + + # For [CLS] and [SEP] + list_tokens = ( + [self.cls_token_id_layoutxlm] + + list_tokens[: max_seq_length - 2] + + [self.sep_token_id_layoutxlm] + ) + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] + # list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ### + # if len(list_words) < 510: + # list_words.extend(['

' for _ in range(510 - len(list_words))]) + list_words = [self.tokenizer_layoutxlm._cls_token] + list_words[: max_seq_length - 2] + [self.tokenizer_layoutxlm._sep_token] + + + len_list_tokens = len(list_tokens) + input_ids[:len_list_tokens] = list_tokens + attention_mask[:len_list_tokens] = 1 + + bbox[:len_list_tokens, :] = list_bbs + lwords[:len_list_tokens] = list_words + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height + + if self.backbone_type in ("layoutlm", "layoutxlm"): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < max_seq_length + ] + are_box_first_tokens[st_indices] = True + + return ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_tokens + ) + + + def parser_entity_extraction(self, parse_class, box_to_token_indices, max_seq_length): + itc_labels = np.zeros(max_seq_length, dtype=int) + stc_labels = np.ones(max_seq_length, dtype=np.int64) * max_seq_length + + classes_dic = parse_class + for class_name in self.class_names: + if class_name == "others": + continue + if class_name not in classes_dic: + continue + + for word_list in classes_dic[class_name]: + is_first, last_word_idx = True, -1 + for word_idx in word_list: + if word_idx >= len(box_to_token_indices): + break + box2token_list = box_to_token_indices[word_idx] + for converted_word_idx in box2token_list: + if converted_word_idx >= max_seq_length: + break # out of idx + + if is_first: + itc_labels[converted_word_idx] = self.class_idx_dic[ + class_name + ] + is_first, last_word_idx = False, converted_word_idx + else: + stc_labels[converted_word_idx] = last_word_idx + last_word_idx = converted_word_idx + + return itc_labels, stc_labels + + + def parser_entity_linking(self, parse_relation, itc_labels, box2token_span_map, max_seq_length): + el_labels = np.ones(max_seq_length, dtype=int) * max_seq_length + el_labels_from_key = np.ones(max_seq_length, dtype=int) * max_seq_length + + + relations = parse_relation + for relation in relations: + if relation[0] >= len(box2token_span_map) or relation[1] >= len( + box2token_span_map + ): + continue + if ( + box2token_span_map[relation[0]][0] >= max_seq_length + or box2token_span_map[relation[1]][0] >= max_seq_length + ): + continue + + word_from = box2token_span_map[relation[0]][0] + word_to = box2token_span_map[relation[1]][0] + # el_labels[word_to] = word_from + + if el_labels[word_to] != 512 and el_labels_from_key[word_to] != 512: + continue + + if itc_labels[word_from] == 2 and itc_labels[word_to] == 3: + el_labels_from_key[word_to] = word_from # pair of (key-value) + if itc_labels[word_from] == 4 and (itc_labels[word_to] in (2, 3)): + el_labels[word_to] = word_from # pair of (header, key) or (header-value) + return el_labels, el_labels_from_key + + +class DocumentKVUProcess(KVUProcess): + def __init__(self, tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, max_window_count, slice_interval, window_size, run_ocr, max_seq_length=512, mode=0): + super().__init__(tokenizer_layoutxlm, feature_extractor, backbone_type, class_names, slice_interval, window_size, run_ocr, max_seq_length, mode) + self.max_window_count = max_window_count + self.pad_token_id = self.pad_token_id_layoutxlm + self.cls_token_id = self.cls_token_id_layoutxlm + self.sep_token_id = self.sep_token_id_layoutxlm + self.unk_token_id = self.unk_token_id_layoutxlm + self.tokenizer = self.tokenizer_layoutxlm + + def __call__(self, img_path: str, ocr_path: str) -> list: + if (self.run_ocr == 1) and (not os.path.exists(ocr_path)): + ocr_path = "tmp.txt" + process_img(img_path, ocr_path, self.ocr_engine, export_img=False) + + lbboxes, lwords = read_ocr_result_from_txt(ocr_path) + lwords = post_process_basic_ocr(lwords) + + width, height = imagesize.get(img_path) + images = [Image.open(img_path).convert("RGB")] + image_features = torch.from_numpy(self.feature_extractor(images)['pixel_values'][0].copy()) + output = self.preprocess(lbboxes, lwords, + {'image': image_features, 'width': width, 'height': height, 'img_path': img_path}, + self.max_seq_length) + return output + + def preprocess(self, bounding_boxes, words, feature_maps, max_seq_length): + n_words = len(words) + output_dicts = {'windows': [], 'documents': []} + n_empty_windows = 0 + + for i in range(self.max_window_count): + input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id + bbox = np.zeros((self.max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(self.max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_) + + if n_words == 0: + n_empty_windows += 1 + output_dicts['windows'].append({ + "image": feature_maps['image'], + "input_ids_layoutxlm": torch.from_numpy(input_ids), + "bbox": torch.from_numpy(bbox), + "words": [], + "attention_mask_layoutxlm": torch.from_numpy(attention_mask), + "are_box_first_tokens": torch.from_numpy(are_box_first_tokens), + }) + continue + + start_word_idx = i * self.window_size + stop_word_idx = min(n_words, (i+1)*self.window_size) + + if start_word_idx >= stop_word_idx: + n_empty_windows += 1 + output_dicts['windows'].append(output_dicts['windows'][-1]) + continue + + list_word_objects = [] + for bb, text in zip(bounding_boxes[start_word_idx:stop_word_idx], words[start_word_idx:stop_word_idx]): + boundingBox = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] + tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids(self.tokenizer_layoutxlm.tokenize(text)) + list_word_objects.append({ + "layoutxlm_tokens": tokens, + "boundingBox": boundingBox, + "text": text + }) + + ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_layoutxlm_tokens + ) = self.parser_words(list_word_objects, self.max_seq_length, feature_maps["width"], feature_maps["height"]) + + + input_ids = torch.from_numpy(input_ids) + bbox = torch.from_numpy(bbox) + attention_mask = torch.from_numpy(attention_mask) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + return_dict = { + "image": feature_maps['image'], + "input_ids_layoutxlm": input_ids, + "bbox": bbox, + "words": lwords, + "attention_mask_layoutxlm": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + } + output_dicts["windows"].append(return_dict) + + attention_mask = torch.cat([o['attention_mask_layoutxlm'] for o in output_dicts["windows"]]) + are_box_first_tokens = torch.cat([o['are_box_first_tokens'] for o in output_dicts["windows"]]) + if n_empty_windows > 0: + attention_mask[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=int)) + are_box_first_tokens[self.max_seq_length * (self.max_window_count - n_empty_windows):] = torch.from_numpy(np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_)) + bbox = torch.cat([o['bbox'] for o in output_dicts["windows"]]) + words = [] + for o in output_dicts['windows']: + words.extend(o['words']) + + return_dict = { + "attention_mask_layoutxlm": attention_mask, + "bbox": bbox, + "are_box_first_tokens": are_box_first_tokens, + "n_empty_windows": n_empty_windows, + "words": words + } + output_dicts['documents'] = return_dict + + return output_dicts + + + \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/requirements.txt b/cope2n-ai-fi/common/AnyKey_Value/requirements.txt new file mode 100755 index 0000000..82c2ce3 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/requirements.txt @@ -0,0 +1,30 @@ +nptyping==1.4.2 +numpy==1.20.3 +opencv-python-headless==4.5.4.60 +pytorch-lightning==1.5.6 +omegaconf +# pillow +six +overrides==4.1.2 +# transformers==4.11.3 +seqeval==0.0.12 +imagesize +pandas==2.0.1 +xmltodict +dicttoxml + +tensorboard>=2.2.0 + +# code-style +isort==5.9.3 +black==21.9b0 + +# # pytorch +# --find-links https://download.pytorch.org/whl/torch_stable.html +# torch==1.9.1+cu102 +# torchvision==0.10.1+cu102 + +# pytorch +# --find-links https://download.pytorch.org/whl/torch_stable.html +# torch==1.10.0+cu113 +# torchvision==0.11.1+cu113 diff --git a/cope2n-ai-fi/common/AnyKey_Value/run.sh b/cope2n-ai-fi/common/AnyKey_Value/run.sh new file mode 100755 index 0000000..1b0442f --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/run.sh @@ -0,0 +1 @@ +python anyKeyValue.py --img_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --save_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/visualize/test/ --exp_dir /home/ai-core/Kie_Invoice_AP/AnyKey_Value/experiments/key_value_understanding-20230608-171900 --export_img 1 --mode 3 --dir_level 0 \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/tmp.txt b/cope2n-ai-fi/common/AnyKey_Value/tmp.txt new file mode 100755 index 0000000..776ddc3 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/tmp.txt @@ -0,0 +1,393 @@ +550 14 644 53 HÓA +654 20 745 53 ĐƠN +755 13 833 53 GIÁ +842 20 916 60 TRỊ +925 22 1001 53 GIA +1012 14 1130 54 TĂNG +208 63 363 87 CHUNGHO +752 76 815 101 (VAT +821 76 940 101 INVOICE) +1193 80 1230 109 Ký +1235 80 1286 109 hiệu +1293 83 1390 108 (Serial): +1398 81 1519 105 1C23TCV +135 130 330 171 ChungHo +339 131 436 165 Vina +600 150 662 178 (BẢN +667 152 719 177 THỂ +678 119 743 147 Ngày +724 152 788 177 HIÊN +749 119 779 142 01 +786 119 854 147 tháng +794 152 847 176 CỦA +853 151 908 176 HÓA +860 119 891 142 03 +897 121 949 142 năm +915 156 973 176 ĐƠN +956 120 1014 142 2023 +979 154 1043 179 ĐIỆN +1049 151 1093 181 TỬ) +1193 118 1228 148 Số +1234 125 1303 151 (No.): +1309 116 1372 148 829 +187 179 270 197 HEALTH +276 179 383 197 SOLUTION +664 187 795 210 (EINVOICE +802 188 908 209 DISPLAY +915 188 1032 213 VERSION) +465 240 507 264 Mã +512 241 554 264 của +560 241 626 266 CQT: +655 239 1186 266 PURIBUZANAMOIC99C9CTROINS +90 286 141 308 Đơn +147 285 172 312 vị +177 285 221 308 bán +226 285 285 313 hàng +293 286 387 312 (Seller): +395 279 495 307 CÔNG +504 280 549 307 TY +558 280 657 307 TNHH +664 279 844 307 CHUNGHO +852 279 937 307 VINA +945 279 1091 309 HEALTH +1100 280 1274 309 SOLUTION +90 326 130 348 Mã +135 322 165 349 số +170 322 223 349 thuế +229 327 281 351 (Tax +288 327 358 351 code): +363 324 383 347 0 +390 324 406 348 1 +411 321 605 350 08118215 +89 366 132 392 Địa +138 366 176 389 chỉ +182 368 298 393 (Address): +303 362 334 389 Số +340 367 373 392 11, +378 368 401 389 lô +407 368 477 392 N04A, +484 367 526 389 khu +530 367 560 389 đô +564 367 595 392 thị +599 367 643 390 mới +648 367 702 392 Dịch +707 368 773 393 Vọng, +779 367 863 394 phường +867 367 922 392 Dịch +927 368 993 394 Vọng, +999 368 1053 394 quận +1057 363 1103 390 Cầu +1108 364 1167 394 Giấy, +1173 368 1233 390 thành +1238 364 1281 394 phố +1285 367 1318 390 Hà +1323 368 1371 393 Nội, +1377 367 1424 393 Việt +1430 368 1483 390 Nam +89 406 146 431 Điện +152 407 211 432 thoại +218 407 279 431 (Tel): +286 407 331 428 024 +339 406 397 428 7300 +404 406 460 428 0891 +826 406 878 430 Fax: +89 440 121 467 Số +127 444 158 467 tài +165 445 236 466 khoản +245 447 348 470 (Account +355 446 413 472 No.): +421 445 616 469 700-010-446490 +622 445 653 471 tại +659 445 722 471 Ngân +728 445 792 472 Hàng +799 444 893 468 Shinhan +898 449 911 467 - +916 444 960 468 Chi +965 445 1036 468 nhánh +1043 445 1078 469 Hà +1084 444 1129 472 Nội +89 486 126 513 Họ +132 486 169 510 tên +174 486 244 513 người +251 493 302 509 mua +308 486 367 514 hàng +374 488 469 514 (Buyer): +90 529 136 550 Tên +142 529 188 551 đơn +193 529 219 555 vị +226 530 342 556 (Company +349 531 428 555 name): +435 524 518 552 CÔNG +525 529 562 552 TY +570 530 650 551 TNHH +657 529 796 552 SAMSUNG +803 529 856 552 SDS +862 527 930 556 VIỆT +937 529 1004 552 NAM +89 570 129 592 Mã +136 564 165 593 số +170 565 223 592 thuế +229 571 281 595 (Tax +287 571 358 596 code): +365 569 510 592 2300680991 +88 611 133 638 Địa +138 612 175 634 chỉ +182 613 297 638 (Address): +303 611 339 633 Lô +346 612 424 636 CN05, +432 611 514 639 đường +521 612 583 638 YP6, +589 612 645 634 Khu +652 611 713 640 công +719 612 803 641 nghiệp +810 611 862 635 Yên +869 612 956 640 Phong, +962 612 1000 635 Xã +1006 612 1057 635 Yên +1064 612 1152 640 Trung, +1159 611 1242 640 Huyện +1249 611 1300 635 Yên +1307 612 1394 640 Phong, +1402 612 1463 635 Tỉnh +1470 605 1518 636 Bắc +89 654 158 678 Ninh, +165 654 219 679 Việt +225 654 286 676 Nam +89 681 122 705 Số +127 684 157 705 tài +164 682 236 705 khoản +244 685 348 709 (Account +354 684 413 709 No.): +90 724 149 747 Hình +157 724 208 747 thức +216 724 282 746 thanh +289 724 341 747 toán +349 726 457 749 (Payment +464 726 563 748 method): +572 725 663 747 TM/CK +98 789 148 812 STT +163 770 212 793 Tên +218 770 281 797 hàng +287 770 341 796 hóa, +348 769 405 796 dịch +436 770 491 792 Đơn +498 770 525 798 vị +569 783 603 811 Số +610 789 684 817 lượng +747 789 802 811 Đơn +808 788 850 818 giá +979 788 1063 812 Thành +1070 782 1119 812 tiền +1214 765 1282 793 Thuế +1287 764 1344 793 suất +1393 764 1452 792 Tiền +1459 765 1518 792 thuế +94 825 153 850 (No.) +206 844 356 870 (Description) +266 813 299 836 vụ +448 845 515 868 (Unit) +454 808 507 830 tính +568 825 685 852 (Quantity) +733 826 793 848 (Unit +798 826 865 851 price) +993 824 1103 851 (Amount) +1223 807 1287 830 (VAT +1262 843 1295 869 %) +1290 809 1335 829 rate +1378 843 1542 869 (VAT Amount) +1417 807 1491 830 GTGT +116 891 128 913 1 +273 890 290 913 2 +472 890 488 912 3 +617 889 635 912 4 +790 889 807 914 5 +992 889 1103 916 6=4x5 +1270 889 1287 913 7 +1399 889 1510 914 8=6X7 +158 939 200 961 Phí +207 939 259 961 thuê +266 939 316 966 máy +323 938 361 965 lọc +159 977 219 998 nước +225 977 285 1002 nóng +292 976 345 1001 lạnh +161 1014 236 1040 Digital +244 1014 307 1035 CHP- +114 1032 127 1055 1 +159 1052 267 1073 3800ST1 +276 1052 312 1076 (từ +318 1051 377 1078 ngày +453 1032 507 1060 Máy +678 1035 697 1059 4 +795 1031 893 1057 800.000 +1074 1031 1195 1057 3,200,000 +1297 1031 1353 1057 10% +1452 1031 1551 1058 320,000 +158 1089 292 1111 01/02/2023 +301 1083 343 1110 đến +351 1084 388 1110 hết +159 1125 304 1151 28/02/2023) +159 1173 200 1195 Phí +207 1173 258 1195 thuê +265 1173 316 1200 máy +323 1173 361 1199 lọc +158 1213 218 1234 nước +225 1212 285 1238 nóng +292 1211 344 1237 lạnh +161 1249 236 1274 Digital +243 1248 307 1270 CHP- +112 1267 129 1291 2 +160 1287 267 1309 3800ST1 +276 1287 312 1312 (từ +318 1287 377 1313 ngày +454 1267 506 1293 Máy +664 1266 697 1292 20 +793 1265 896 1294 876,800 +1062 1265 1198 1297 17,536,000 +1430 1266 1552 1295 1,753,600 +159 1323 292 1346 29/01/2023 +301 1319 343 1345 đến +351 1319 388 1345 hết +160 1360 304 1387 28/02/2023) +160 1409 200 1431 Phí +207 1409 258 1431 thuê +265 1409 316 1436 máy +323 1408 361 1435 lọc +158 1447 218 1468 nước +225 1446 285 1473 nóng +292 1446 344 1472 lạnh +113 1503 128 1527 3 +160 1484 236 1511 Digital +243 1484 307 1505 CHP- +452 1503 506 1531 Máy +795 1502 897 1529 544,000 +1074 1502 1197 1529 2,176,000 +1450 1502 1550 1528 217,600 +160 1522 267 1544 3800ST1 +276 1522 312 1546 (từ +318 1522 377 1549 ngày +162 1559 292 1581 10/02/2023 +301 1555 343 1581 đến +351 1555 388 1581 hết +159 1596 304 1623 28/02/2023) +160 1645 200 1667 Phí +207 1644 259 1667 thuê +265 1645 316 1672 máy +324 1645 362 1671 lọc +159 1683 219 1706 nước +226 1683 286 1709 nóng +293 1683 345 1708 lạnh +160 1720 237 1746 Digital +245 1720 307 1742 CHP- +112 1738 129 1760 4 +159 1758 268 1780 3800ST1 +276 1758 313 1783 (từ +319 1758 377 1785 ngày +453 1737 508 1767 Máy +677 1737 699 1764 4 +795 1737 895 1764 256,000 +1077 1737 1197 1764 1,024,000 +1297 1737 1354 1762 10% +1453 1737 1552 1764 102,400 +158 1795 293 1817 20/02/2023 +301 1791 344 1817 đến +350 1791 388 1817 hết +158 1831 304 1859 28/02/2023) +93 2004 134 2027 Giá +139 2006 169 2031 trị +173 2007 221 2026 theo +226 2006 276 2026 mức +281 2002 331 2026 thuế +337 2005 412 2026 GTGT +676 1986 747 2008 Thành +752 1982 795 2008 tiền +800 1986 859 2008 trước +865 1983 915 2008 thuế +920 1986 992 2007 GTGT +1025 1983 1074 2008 Tiền +1079 1983 1127 2008 thuế +1132 1986 1201 2008 GTGT +1234 1985 1303 2008 Thành +1308 1982 1352 2008 tiền +1357 1985 1406 2012 gồm +1411 1982 1460 2008 thuế +1466 1985 1538 2008 GTGT +718 2023 813 2047 (Amount +819 2023 889 2048 before +895 2022 952 2048 VAT) +1035 2022 1096 2046 (VAT +1100 2023 1195 2048 Amount) +1248 2023 1345 2047 (Amount +1351 2023 1459 2049 including +1465 2022 1524 2048 VAT) +92 2071 132 2093 Giá +137 2072 162 2096 trị +166 2072 202 2092 với +207 2069 253 2093 thuế +259 2072 328 2092 GTGT +334 2071 371 2092 0% +378 2072 472 2094 (Amount +478 2074 502 2093 at +507 2071 545 2093 0% +552 2070 610 2097 VAT) +93 2119 132 2141 Giá +137 2120 162 2145 trị +167 2120 202 2141 với +207 2117 253 2141 thuế +259 2121 328 2140 GTGT +335 2120 371 2140 5% +378 2121 471 2141 (Amount +477 2122 502 2141 at +507 2119 546 2141 5% +552 2119 609 2145 VAT) +94 2167 132 2189 Giá +137 2168 162 2192 trị +166 2167 202 2189 với +207 2164 253 2189 thuế +259 2169 328 2188 GTGT +335 2168 372 2188 8% +379 2168 472 2189 (Amount +478 2171 502 2189 at +507 2168 545 2189 8% +552 2167 609 2193 VAT) +92 2216 132 2237 Giá +137 2217 162 2241 trị +167 2215 202 2236 với +207 2212 253 2237 thuế +259 2216 329 2236 GTGT +337 2216 384 2236 10% +391 2217 486 2237 (Amount +491 2219 516 2236 at +523 2216 572 2237 10% +579 2214 636 2240 VAT) +891 2213 1005 2239 23,936,000 +1111 2213 1215 2239 2,393,600 +1435 2213 1552 2240 26,329,600 +94 2263 170 2285 TỔNG +176 2263 252 2288 CỘNG +260 2264 361 2287 (GRAND +368 2263 461 2288 TOTAL) +874 2262 1007 2288 23,936,000 +1097 2262 1215 2288 2,393,600 +1416 2261 1549 2288 26,329,600 +87 2307 119 2333 Số +125 2307 171 2332 tiền +178 2304 223 2332 viết +230 2308 289 2337 bằng +295 2308 340 2332 chữ +347 2311 443 2332 (Amount +449 2310 473 2332 in +479 2309 564 2336 words): +571 2308 616 2331 Hai +621 2310 684 2331 mươi +690 2308 731 2331 sáu +736 2308 793 2336 triệu +797 2308 828 2331 ba +833 2309 888 2331 trăm +894 2308 932 2331 hai +938 2309 1001 2331 mươi +1007 2308 1059 2331 chín +1065 2308 1133 2336 nghìn +1139 2307 1180 2331 sáu +1186 2309 1241 2331 trăm +1247 2305 1308 2336 đồng diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py b/cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py new file mode 100755 index 0000000..12f320a --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py @@ -0,0 +1,127 @@ +import os +import torch +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.plugins import DDPPlugin +from utils.ema_callbacks import EMA + + +def _update_config(cfg): + cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints") + cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs") + + # set per-gpu batch size + num_devices = torch.cuda.device_count() + print('No. devices:', num_devices) + for mode in ["train", "val"]: + new_batch_size = cfg[mode].batch_size // num_devices + cfg[mode].batch_size = new_batch_size + +def _get_config_from_cli(): + cfg_cli = OmegaConf.from_cli() + cli_keys = list(cfg_cli.keys()) + for cli_key in cli_keys: + if "--" in cli_key: + cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key] + del cfg_cli[cli_key] + + return cfg_cli + +def get_callbacks(cfg): + callback_list = [] + checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir, + filename='best_model', + save_last=True, + save_top_k=1, + save_weights_only=True, + verbose=True, + monitor='val_f1', mode='max') + checkpoint_callback.FILE_EXTENSION = ".pth" + checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model" + callback_list.append(checkpoint_callback) + if cfg.callbacks.ema.decay != -1: + ema_callback = EMA(decay=0.9999) + callback_list.append(ema_callback) + return callback_list if len(callback_list) > 1 else checkpoint_callback + +def get_plugins(cfg): + plugins = [] + if cfg.train.strategy.type == "ddp": + plugins.append(DDPPlugin()) + + return plugins + +def get_loggers(cfg): + loggers = [] + + loggers.append( + TensorBoardLogger( + cfg.tensorboard_dir, name="", version="", default_hp_metric=False + ) + ) + + return loggers + +def cfg_to_hparams(cfg, hparam_dict, parent_str=""): + for key, val in cfg.items(): + if isinstance(val, DictConfig): + hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__") + else: + hparam_dict[parent_str + key] = str(val) + return hparam_dict + +def get_specific_pl_logger(pl_loggers, logger_type): + for pl_logger in pl_loggers: + if isinstance(pl_logger, logger_type): + return pl_logger + return None + +def get_class_names(dataset_root_path): + class_names_file = os.path.join(dataset_root_path[0], "class_names.txt") + class_names = ( + open(class_names_file, "r", encoding="utf-8").read().strip().split("\n") + ) + return class_names + +def create_exp_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Experiment dir : {}'.format(save_dir)) + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Save dir : {}'.format(save_dir)) + +def load_checkpoint(ckpt_path, model, key_include): + assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!" + state_dict = torch.load(ckpt_path, 'cpu')['state_dict'] + for key in list(state_dict.keys()): + if f'.{key_include}.' not in key: + del state_dict[key] + else: + state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something. + del state_dict[key] + model.load_state_dict(state_dict, strict=True) + print(f"Load checkpoint at {ckpt_path}") + return model + +def load_model_weight(net, pretrained_model_file): + pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[ + "state_dict" + ] + new_state_dict = {} + for k, v in pretrained_model_state_dict.items(): + new_k = k + if new_k.startswith("net."): + new_k = new_k[len("net.") :] + new_state_dict[new_k] = v + net.load_state_dict(new_state_dict) + + diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/ema_callbacks.py b/cope2n-ai-fi/common/AnyKey_Value/utils/ema_callbacks.py new file mode 100755 index 0000000..956e0bf --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/ema_callbacks.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import copy +import os +import threading +from typing import Any, Dict, Iterable + +import pytorch_lightning as pl +import torch +from pytorch_lightning import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import rank_zero_info + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + def __init__( + self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + device = pl_module.device if not self.cpu_offload else torch.device('cpu') + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + # use the connector as NeMo calls the connector directly in the exp_manager when restoring. + connector = trainer._checkpoint_connector + ckpt_path = connector.resume_checkpoint_path + + if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__: + ext = checkpoint_callback.FILE_EXTENSION + if ckpt_path.endswith(f'-EMA{ext}'): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = ckpt_path.replace(ext, f'-EMA{ext}') + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) + + checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] + del ema_state_dict + rank_zero_info("EMA state has been restored.") + else: + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", + ) + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group['params']) + + def step(self, closure=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) for param in self.all_parameters() + ) + + if self.device.type == 'cuda': + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == 'cpu': + self.thread = threading.Thread( + target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters()) + state_dict = { + 'opt': self.optimizer.state_dict(), + 'ema': ema_params, + 'current_step': self.current_step, + 'decay': self.decay, + 'every_n_steps': self.every_n_steps, + } + return state_dict + + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict['opt']) + self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + self.current_step = state_dict['current_step'] + self.decay = state_dict['decay'] + self.every_n_steps = state_dict['every_n_steps'] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/kvu_dictionary.py b/cope2n-ai-fi/common/AnyKey_Value/utils/kvu_dictionary.py new file mode 100755 index 0000000..51e33b0 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/kvu_dictionary.py @@ -0,0 +1,68 @@ + +DKVU2XML = { + "Ký hiệu mẫu hóa đơn": "form_no", + "Ký hiệu hóa đơn": "serial_no", + "Số hóa đơn": "invoice_no", + "Ngày, tháng, năm lập hóa đơn": "issue_date", + "Tên người bán": "seller_name", + "Mã số thuế người bán": "seller_tax_code", + "Thuế suất": "tax_rate", + "Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount", + "Mặt hàng": "item", + "Đơn vị tính": "unit", + "Số lượng": "quantity", + "Đơn giá": "unit_price", + "Doanh số mua chưa có thuế": "amount" +} + + +def ap_dictionary(header: bool): + header_dictionary = { + 'productname': ['description', 'paticulars', 'articledescription', 'descriptionofgood', 'itemdescription', 'product', 'productdescription', 'modelname', 'device', 'items', 'itemno'], + 'modelnumber': ['serialno', 'model', 'code', 'mcode', 'simimeiserial', 'serial', 'productcode', 'product', 'imeiccid', 'articles', 'article', 'articlenumber', 'articleidmaterialcode', 'transaction', 'itemcode'], + 'qty': ['quantity', 'invoicequantity'] + } + + key_dictionary = { + 'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'], + 'retailername': ['retailer', 'retailername', 'ownedoperatedby'], + 'serial_number': ['serialnumber', 'serialno'], + 'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2'] + } + + return header_dictionary if header else key_dictionary + + +def vat_dictionary(header: bool): + header_dictionary = { + 'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'sanpham', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'], + 'Đơn vị tính': ['dvt', 'donvitinh'], + 'Số lượng': ['soluong', 'sl','qty', 'quantity', 'invoicequantity'], + 'Đơn giá': ['dongia'], + 'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'], + # 'Số sản phẩm': ['serialno', 'model', 'mcode', 'simimeiserial', 'serial', 'sku', 'sn', 'productcode', 'product', 'particulars', 'imeiccid', 'articles', 'article', 'articleidmaterialcode', 'transaction', 'imei', 'articlenumber'] + } + + key_dictionary = { + 'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'], + 'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'], + 'Số hóa đơn': ['soinvoiceno', 'invoiceno'], + 'Ngày, tháng, năm lập hóa đơn': [], + 'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'], + 'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'], + 'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'], + 'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'], + # 'Ghi chú': [], + # 'Ngày': ['ngayday', 'ngay', 'day'], + # 'Tháng': ['thangmonth', 'thang', 'month'], + # 'Năm': ['namyear', 'nam', 'year'] + } + + # exact_dictionary = { + # 'Số hóa đơn': ['sono', 'so'], + # 'Mã số thuế người bán': ['mst'], + # 'Tên người bán': ['kyboi'], + # 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate'] + # } + + return header_dictionary if header else key_dictionary \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/run_ocr.py b/cope2n-ai-fi/common/AnyKey_Value/utils/run_ocr.py new file mode 100755 index 0000000..9976914 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/run_ocr.py @@ -0,0 +1,33 @@ +import numpy as np +from pathlib import Path +from typing import Union, Tuple, List +import sys +# sys.path.append('/home/thucpd/thucpd/PV2-2023/common/AnyKey_Value/ocr-engine') +# from src.ocr import OcrEngine +sys.path.append('/home/thucpd/thucpd/git/PV2-2023/kie-invoice/components/prediction') # TODO: ?????? +import serve_model + + +# def load_ocr_engine() -> OcrEngine: +def load_ocr_engine() -> OcrEngine: + print("[INFO] Loading engine...") + # engine = OcrEngine() + engine = serve_model.engine + print("[INFO] Engine loaded") + return engine + +def process_img(img: Union[str, np.ndarray], save_dir_or_path: str, engine: OcrEngine, export_img: bool) -> None: + save_dir_or_path = Path(save_dir_or_path) + if isinstance(img, np.ndarray): + if save_dir_or_path.is_dir(): + raise ValueError("numpy array input require a save path, not a save dir") + page = engine(img) + save_path = str(save_dir_or_path.joinpath(Path(img).stem + ".txt") + ) if save_dir_or_path.is_dir() else str(save_dir_or_path) + page.write_to_file('word', save_path) + if export_img: + page.save_img(save_path.replace(".txt", ".jpg"), is_vnese=True, ) + +def read_img(img: Union[str, np.ndarray], engine: OcrEngine): + page = engine(img) + return ' '.join([f.text for f in page.llines]) \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/split_docs.py b/cope2n-ai-fi/common/AnyKey_Value/utils/split_docs.py new file mode 100644 index 0000000..1b517d0 --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/split_docs.py @@ -0,0 +1,101 @@ +import os +import glob +import json +from tqdm import tqdm + +def longestCommonSubsequence(text1: str, text2: str) -> int: + # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis + dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)] + for i, c in enumerate(text1): + for j, d in enumerate(text2): + dp[i + 1][j + 1] = 1 + \ + dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j]) + return dp[-1][-1] + +def write_to_json(file_path, content): + with open(file_path, mode="w", encoding="utf8") as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, "r") as f: + return json.load(f) + +def check_label_exists(array, target_label): + for obj in array: + if obj["label"] == target_label: + return True # Label exists in the array + return False # Label does not exist in the array + +def merged_kvu_outputs(loutputs: list) -> dict: + compiled = [] + for output_model in loutputs: + for field in output_model: + if field['value'] != "" and not check_label_exists(compiled, field['label']): + element = { + 'label': field['label'], + 'value': field['value'], + } + compiled.append(element) + elif field['label'] == 'table' and check_label_exists(compiled, "table"): + for index, obj in enumerate(compiled): + if obj['label'] == 'table' and len(field['value']) > 0: + compiled[index]['value'].append(field['value']) + return compiled + + +def split_docs(doc_data: list, threshold: float=0.6) -> list: + num_pages = len(doc_data) + outputs = [] + kvu_content = [] + doc_data = sorted(doc_data, key=lambda x: int(x['page_number'])) + for data in doc_data: + page_id = int(data['page_number']) + doc_type = data['document_type'] + doc_class = data['document_class'] + fields = data['fields'] + if page_id == 0: + prev_title = doc_type + start_page_id = page_id + prev_class = doc_class + curr_title = doc_type if doc_type != "unknown" else prev_title + curr_class = doc_class if doc_class != "unknown" else "other" + kvu_content.append(fields) + similarity_score = longestCommonSubsequence(curr_title, prev_title) / len(prev_title) + if similarity_score < threshold: + end_page_id = page_id - 1 + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": end_page_id, + "content": merged_kvu_outputs(kvu_content[:-1]) + }) + prev_title = curr_title + prev_class = curr_class + start_page_id = page_id + kvu_content = kvu_content[-1:] + if page_id == num_pages - 1: # end_page + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": page_id, + "content": merged_kvu_outputs(kvu_content) + }) + elif page_id == num_pages - 1: # end_page + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": page_id, + "content": merged_kvu_outputs(kvu_content) + }) + return outputs + + +if __name__ == "__main__": + threshold = 0.9 + json_path = "/home/sds/tuanlv/02-KVU/02-KVU_test/visualize/manulife_v2/json_outputs/HS_YCBT_No_IP_HMTD.json" + doc_data = read_json(json_path) + + outputs = split_docs(doc_data, threshold) + + write_to_json(os.path.join(os.path.dirname(json_path), "splited_doc.json"), outputs) \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/utils/utils.py b/cope2n-ai-fi/common/AnyKey_Value/utils/utils.py new file mode 100755 index 0000000..2e857ac --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/utils/utils.py @@ -0,0 +1,548 @@ +import os +import cv2 +import json +import torch +import glob +import re +import numpy as np +from tqdm import tqdm +from pdf2image import convert_from_path +from dicttoxml import dicttoxml +from word_preprocess import vat_standardizer, get_string, ap_standardizer, post_process_for_item +from utils.kvu_dictionary import vat_dictionary, ap_dictionary + + + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + else: + print("DIR already existed.") + print('Save dir : {}'.format(save_dir)) + +def pdf2image(pdf_dir, save_dir): + pdf_files = glob.glob(f'{pdf_dir}/*.pdf') + print('No. pdf files:', len(pdf_files)) + + for file in tqdm(pdf_files): + pages = convert_from_path(file, 500) + for i, page in enumerate(pages): + page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG') + print('Done!!!') + +def xyxy2xywh(bbox): + return [ + float(bbox[0]), + float(bbox[1]), + float(bbox[2]) - float(bbox[0]), + float(bbox[3]) - float(bbox[1]), + ] + +def write_to_json(file_path, content): + with open(file_path, mode='w', encoding='utf8') as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, 'r') as f: + return json.load(f) + +def read_xml(file_path): + with open(file_path, 'r') as xml_file: + return xml_file.read() + +def write_to_xml(file_path, content): + with open(file_path, mode="w", encoding='utf8') as f: + f.write(content) + +def write_to_xml_from_dict(file_path, content): + xml = dicttoxml(content) + xml = content + xml_decode = xml.decode() + + with open(file_path, mode="w") as f: + f.write(xml_decode) + + +def load_ocr_result(ocr_path): + with open(ocr_path, 'r') as f: + lines = f.read().splitlines() + + preds = [] + for line in lines: + preds.append(line.split('\t')) + return preds + +def post_process_basic_ocr(lwords: list) -> list: + pp_lwords = [] + for word in lwords: + pp_lwords.append(word.replace("✪", " ")) + return pp_lwords + +def read_ocr_result_from_txt(file_path: str): + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + + boxes, words = [], [] + for line in lines: + if line == "": + continue + word_info = line.split("\t") + if len(word_info) == 6: + x1, y1, x2, y2, text, _ = word_info + elif len(word_info) == 5: + x1, y1, x2, y2, text = word_info + + x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2)) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + +def get_colormap(): + return { + 'others': (0, 0, 255), # others: red + 'title': (0, 255, 255), # title: yellow + 'key': (255, 0, 0), # key: blue + 'value': (0, 255, 0), # value: green + 'header': (233, 197, 15), # header + 'group': (0, 128, 128), # group + 'relation': (0, 0, 255)# (128, 128, 128), # relation + } + + +def convert_image(image): + exif = image._getexif() + orientation = None + if exif is not None: + orientation = exif.get(0x0112) + + # Convert the PIL image to OpenCV format + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + # Rotate the image in OpenCV if necessary + if orientation == 3: + image = cv2.rotate(image, cv2.ROTATE_180) + elif orientation == 6: + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 8: + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + else: + image = np.asarray(image) + + if len(image.shape) == 2: + image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + assert len(image.shape) == 3 + + return image, orientation + +def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1): + image, orientation = convert_image(image) + + if orientation is not None and orientation == 6: + width, height, _ = image.shape + else: + height, width, _ = image.shape + + if len(pr_class_words) > 0: + id2label = {k: labels[k] for k in range(len(labels))} + for lb, groups in enumerate(pr_class_words): + if lb == 0: + continue + for group_id, group in enumerate(groups): + for i, word_id in enumerate(group): + x0, y0, x1, y1 = int(bbox[word_id][0]*width/1000), int(bbox[word_id][1]*height/1000), int(bbox[word_id][2]*width/1000), int(bbox[word_id][3]*height/1000) + cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness) + + if i == 0: + x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2) + else: + x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2) + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness) + x_center0, y_center0 = x_center1, y_center1 + + if len(pr_relations) > 0: + for pair in pr_relations: + xyxy0 = int(bbox[pair[0]][0]*width/1000), int(bbox[pair[0]][1]*height/1000), int(bbox[pair[0]][2]*width/1000), int(bbox[pair[0]][3]*height/1000) + xyxy1 = int(bbox[pair[1]][0]*width/1000), int(bbox[pair[1]][1]*height/1000), int(bbox[pair[1]][2]*width/1000), int(bbox[pair[1]][3]*height/1000) + + x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2) + x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2) + + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness) + + return image + + +def get_pairs(json: list, rel_from: str, rel_to: str) -> dict: + outputs = {} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] in (rel_from, rel_to): + is_rel[element['class']]['status'] = 1 + is_rel[element['class']]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']] + return outputs + +def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict: + list_keys = list(header_key_pairs.keys()) + relations = {k: [] for k in list_keys} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] == rel_from and element['group_id'] in list_keys: + is_rel[rel_from]['status'] = 1 + is_rel[rel_from]['value'] = element + if element['class'] == rel_to: + is_rel[rel_to]['status'] = 1 + is_rel[rel_to]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id']) + return relations + +def get_key2values_relations(key_value_pairs: dict): + triple_linkings = {} + for value_group_id, key_value_pair in key_value_pairs.items(): + key_group_id = key_value_pair[0] + if key_group_id not in list(triple_linkings.keys()): + triple_linkings[key_group_id] = [] + triple_linkings[key_group_id].append(value_group_id) + return triple_linkings + + +def merged_token_to_wordgroup(class_words: list, lwords, labels) -> dict: + word_groups = {} + id2class = {i: labels[i] for i in range(len(labels))} + for class_id, lwgroups_in_class in enumerate(class_words): + for ltokens_in_wgroup in lwgroups_in_class: + group_id = ltokens_in_wgroup[0] + ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup] + text_string = get_string(ltokens_to_ltexts) + word_groups[group_id] = { + 'group_id': group_id, + 'text': text_string, + 'class': id2class[class_id], + 'tokens': ltokens_in_wgroup + } + return word_groups + +def verify_linking_id(word_groups: dict, linking_id: int) -> int: + if linking_id not in list(word_groups): + for wg_id, _word_group in word_groups.items(): + if linking_id in _word_group['tokens']: + return wg_id + return linking_id + +def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list: + outputs = [] + for pair in lrelations: + wg_from = verify_linking_id(word_groups, pair[0]) + wg_to = verify_linking_id(word_groups, pair[1]) + try: + outputs.append([word_groups[wg_from], word_groups[wg_to]]) + except Exception as e: + print('Not valid pair:', wg_from, wg_to) + return outputs + + +def export_kvu_outputs(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + word_groups = merged_token_to_wordgroup(class_words, lwords, labels) + linking_pairs = matched_wordgroup_relations(word_groups, lrelations) + + header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]} + header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]} + key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]} + + # table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + + triplet_pairs = [] + single_pairs = [] + table = [] + # print('key2values_relations', key2values_relations) + for key_group_id, list_value_group_ids in key2values_relations.items(): + if len(list_value_group_ids) == 0: continue + elif len(list_value_group_ids) == 1: + value_group_id = list_value_group_ids[0] + single_pairs.append({word_groups[key_group_id]['text']: { + 'text': word_groups[value_group_id]['text'], + 'id': value_group_id, + 'class': "value" + }}) + else: + item = [] + for value_group_id in list_value_group_ids: + if value_group_id not in header_value.keys(): + header_name_for_value = "non-header" + else: + header_group_id = header_value[value_group_id][0] + header_name_for_value = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[value_group_id]['text'], + 'header': header_name_for_value, + 'id': value_group_id, + 'class': 'value' + }) + if key_group_id not in list(header_key.keys()): + triplet_pairs.append({ + word_groups[key_group_id]['text']: item + }) + else: + header_group_id = header_key[key_group_id][0] + header_name_for_key = word_groups[header_group_id]['text'] + item.append({ + 'text': word_groups[key_group_id]['text'], + 'header': header_name_for_key, + 'id': key_group_id, + 'class': 'key' + }) + table.append({key_group_id: item}) + + if len(table) > 0: + table = sorted(table, key=lambda x: list(x.keys())[0]) + table = [v for item in table for k, v in item.items()] + + outputs = {} + outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id']))) + outputs['triplet'] = triplet_pairs + outputs['table'] = table + + file_path = os.path.join(os.path.dirname(file_path), 'kvu_results', os.path.basename(file_path)) + write_to_json(file_path, outputs) + return outputs + +# For FI-VAT project + +def get_vat_table_information(outputs): + table = [] + for single_item in outputs['table']: + item = {k: [] for k in list(vat_dictionary(header=True).keys())} + for cell in single_item: + header_name, score, proceessed_text = vat_standardizer(cell['header'], threshold=0.75, header=True) + if header_name in list(item.keys()): + # item[header_name] = value['text'] + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + + for header_name, value in item.items(): + if len(value) == 0: + if header_name in ("Số lượng", "Doanh số mua chưa có thuế"): + item[header_name] = '0' + else: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item = post_process_for_item(item) + + if item["Mặt hàng"] == None: + continue + table.append(item) + return table + +def get_vat_information(outputs): + # VAT Information + single_pairs = {k: [] for k in list(vat_dictionary(header=False).keys())} + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + key_name, score, proceessed_text = vat_standardizer(raw_key_name, threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'], + }) + + for triplet in outputs['triplet']: + for key, value_list in triplet.items(): + if len(value_list) == 1: + key_name, score, proceessed_text = vat_standardizer(key, threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value_list[0]['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value_list[0]['id'] + }) + + for pair in value_list: + key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + for table_row in outputs['table']: + for pair in table_row: + key_name, score, proceessed_text = vat_standardizer(pair['header'], threshold=0.8, header=False) + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + return single_pairs + + +def post_process_vat_information(single_pairs): + vat_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if key_name in ("Ngày, tháng, năm lập hóa đơn"): + if len(list_potential_value) == 1: + vat_outputs[key_name] = list_potential_value[0]['content'] + else: + date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'} + for value in list_potential_value: + date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content']) + vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}" + else: + if len(list_potential_value) == 0: continue + if key_name in ("Mã số thuế người bán"): + selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code + # tax_code_raw = selected_value['content'].replace(' ', '') + tax_code_raw = selected_value['content'] + if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first number dupicated + tax_code_raw = tax_code_raw.split(' ') + tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0] + vat_outputs[key_name] = tax_code_raw.replace(' ', '') + + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + vat_outputs[key_name] = selected_value['content'] + return vat_outputs + + +def export_kvu_for_VAT_invoice(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + vat_outputs = {} + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + + # List of items in table + table = get_vat_table_information(outputs) + + # VAT Information + single_pairs = get_vat_information(outputs) + vat_outputs = post_process_vat_information(single_pairs) + + # Combine VAT information and table + vat_outputs['table'] = table + + write_to_json(file_path, vat_outputs) + return vat_outputs + + +# For SBT project + +def get_ap_table_information(outputs): + table = [] + for single_item in outputs['table']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + for cell in single_item: + header_name, score, proceessed_text = ap_standardizer(cell['header'], threshold=0.8, header=True) + # print(f"{key} ==> {proceessed_text} ==> {header_name} : {score} - {value['text']}") + if header_name in list(item.keys()): + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + table.append(item) + return table + +def get_ap_triplet_information(outputs): + triplet_pairs = [] + for single_item in outputs['triplet']: + item = {k: [] for k in list(ap_dictionary(header=True).keys())} + is_item_valid = 0 + for key_name, list_value in single_item.items(): + for value in list_value: + if value['header'] == "non-header": + continue + header_name, score, proceessed_text = ap_standardizer(value['header'], threshold=0.8, header=True) + if header_name in list(item.keys()): + is_item_valid = 1 + item[header_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + if is_item_valid == 1: + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item['productname'] = key_name + # triplet_pairs.append({key_name: new_item}) + triplet_pairs.append(item) + return triplet_pairs + + +def get_ap_information(outputs): + single_pairs = {k: [] for k in list(ap_dictionary(header=False).keys())} + for pair in outputs['single']: + for key_name, value in pair.items(): + key_name, score, proceessed_text = ap_standardizer(key_name, threshold=0.8, header=False) + # print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + ap_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if len(list_potential_value) == 0: continue + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + ap_outputs[key_name] = selected_value['content'] + + return ap_outputs + +def export_kvu_for_SDSAP(file_path, lwords, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + outputs = export_kvu_outputs(file_path, lwords, class_words, lrelations, labels) + # List of items in table + table = get_ap_table_information(outputs) + triplet_pairs = get_ap_triplet_information(outputs) + table = table + triplet_pairs + + ap_outputs = get_ap_information(outputs) + + ap_outputs['table'] = table + # ap_outputs['triplet'] = triplet_pairs + + write_to_json(file_path, ap_outputs) \ No newline at end of file diff --git a/cope2n-ai-fi/common/AnyKey_Value/word_preprocess.py b/cope2n-ai-fi/common/AnyKey_Value/word_preprocess.py new file mode 100755 index 0000000..19273cf --- /dev/null +++ b/cope2n-ai-fi/common/AnyKey_Value/word_preprocess.py @@ -0,0 +1,224 @@ +import nltk +import re +import string +import copy +from utils.kvu_dictionary import vat_dictionary, ap_dictionary, DKVU2XML +nltk.download('words') +words = set(nltk.corpus.words.words()) + +s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ' +s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy' + +# def clean_text(text): +# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text) + + +def get_string(lwords: list): + unique_list = [] + for item in lwords: + if item.isdigit() and len(item) == 1: + unique_list.append(item) + elif item not in unique_list: + unique_list.append(item) + return ' '.join(unique_list) + +def remove_english_words(text): + _word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words] + return ' '.join(_word) + +def remove_punctuation(text): + return text.translate(str.maketrans(" ", " ", string.punctuation)) + +def remove_accents(input_str, s0, s1): + s = '' + # print input_str.encode('utf-8') + for c in input_str: + if c in s1: + s += s0[s1.index(c)] + else: + s += c + return s + +def remove_spaces(text): + return text.replace(' ', '') + +def preprocessing(text: str): + # text = remove_english_words(text) if table else text + text = remove_punctuation(text) + text = remove_accents(text, s0, s1) + text = remove_spaces(text) + return text.lower() + + +def vat_standardize_outputs(vat_outputs: dict) -> dict: + outputs = {} + for key, value in vat_outputs.items(): + if key != "table": + outputs[DKVU2XML[key]] = value + else: + list_items = [] + for item in value: + list_items.append({ + DKVU2XML[item_key]: item_value for item_key, item_value in item.items() + }) + outputs['table'] = list_items + return outputs + + + +def vat_standardizer(text: str, threshold: float, header: bool): + dictionary = vat_dictionary(header) + processed_text = preprocessing(text) + + for candidates in [('ngayday', 'ngaydate', 'ngay', 'day'), ('thangmonth', 'thang', 'month'), ('namyear', 'nam', 'year')]: + if any([processed_text in txt for txt in candidates]): + processed_text = candidates[-1] + return "Ngày, tháng, năm lập hóa đơn", 5, processed_text + + _dictionary = copy.deepcopy(dictionary) + if not header: + exact_dictionary = { + 'Số hóa đơn': ['sono', 'so'], + 'Mã số thuế người bán': ['mst'], + 'Tên người bán': ['kyboi'], + 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate'] + } + for k, v in exact_dictionary.items(): + _dictionary[k] = dictionary[k] + exact_dictionary[k] + + for k, v in dictionary.items(): + # if k in ("Ngày, tháng, năm lập hóa đơn"): + # continue + # Prioritize match completely + if k in ('Tên người bán') and processed_text == "kyboi": + return k, 8, processed_text + + if any([processed_text == key for key in _dictionary[k]]): + return k, 10, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + if k in ("Ngày, tháng, năm lập hóa đơn"): + continue + + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score > threshold else text, score, processed_text + +def ap_standardizer(text: str, threshold: float, header: bool): + dictionary = ap_dictionary(header) + processed_text = preprocessing(text) + + # Prioritize match completely + _dictionary = copy.deepcopy(dictionary) + if not header: + _dictionary['serial_number'] = dictionary['serial_number'] + ['sn'] + _dictionary['imei_number'] = dictionary['imei_number'] + ['imel'] + else: + _dictionary['modelnumber'] = dictionary['modelnumber'] + ['sku', 'sn', 'imei'] + _dictionary['qty'] = dictionary['qty'] + ['qty'] + + for k, v in dictionary.items(): + if any([processed_text == key for key in _dictionary[k]]): + return k, 10, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score >= threshold else text, score, processed_text + + +def convert_format_number(s: str) -> float: + s = s.replace(' ', '').replace('O', '0').replace('o', '0') + if s.endswith(",00") or s.endswith(".00"): + s = s[:-3] + if all([delimiter in s for delimiter in [',', '.']]): + s = s.replace('.', '').split(',') + remain_value = s[1].split('0')[0] + return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value)) + else: + s = s.replace(',', '').replace('.', '') + return int(s) + + +def post_process_for_item(item: dict) -> dict: + check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế'] + mis_key = [] + for key in check_keys: + if item[key] in (None, '0'): + mis_key.append(key) + if len(mis_key) == 1: + try: + if mis_key[0] == check_keys[0] and convert_format_number(item[check_keys[1]]) != 0: + item[mis_key[0]] = round(convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[1]])).__str__() + elif mis_key[0] == check_keys[1] and convert_format_number(item[check_keys[0]]) != 0: + item[mis_key[0]] = (convert_format_number(item[check_keys[2]]) / convert_format_number(item[check_keys[0]])).__str__() + elif mis_key[0] == check_keys[2]: + item[mis_key[0]] = (convert_format_number(item[check_keys[0]]) * convert_format_number(item[check_keys[1]])).__str__() + except Exception as e: + print("Cannot post process this item with error:", e) + return item + + +def longestCommonSubsequence(text1: str, text2: str) -> int: + # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis + dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)] + for i, c in enumerate(text1): + for j, d in enumerate(text2): + dp[i + 1][j + 1] = 1 + \ + dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j]) + return dp[-1][-1] + + +def longest_common_subsequence_with_idx(X, Y): + """ + This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS. + The longest_common_subsequence function takes two strings as input, and returns a tuple with three values: + the length of the LCS, + the index of the first character of the LCS in the first string, + and the index of the last character of the LCS in the first string. + """ + m, n = len(X), len(Y) + L = [[0 for i in range(n + 1)] for j in range(m + 1)] + + # Following steps build L[m+1][n+1] in bottom up fashion. Note + # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] + right_idx = 0 + max_lcs = 0 + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + if L[i][j] > max_lcs: + max_lcs = L[i][j] + right_idx = i + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + # Create a string variable to store the lcs string + lcs = L[i][j] + # Start from the right-most-bottom-most corner and + # one by one store characters in lcs[] + i = m + j = n + # right_idx = 0 + while i > 0 and j > 0: + # If current character in X[] and Y are same, then + # current character is part of LCS + if X[i - 1] == Y[j - 1]: + + i -= 1 + j -= 1 + # If not same, then find the larger of two and + # go in the direction of larger value + elif L[i - 1][j] > L[i][j - 1]: + # right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs + i -= 1 + else: + j -= 1 + return lcs, i, max(i + lcs, right_idx) diff --git a/cope2n-ai-fi/common/crop_location.py b/cope2n-ai-fi/common/crop_location.py new file mode 100755 index 0000000..818803c --- /dev/null +++ b/cope2n-ai-fi/common/crop_location.py @@ -0,0 +1,172 @@ +from mmdet.apis import inference_detector, init_detector +import cv2 +import numpy as np +import urllib + +def get_center(box): + xmin, ymin, xmax, ymax = box + x_center = int((xmin + xmax) / 2) + y_center = int((ymin + ymax) / 2) + return [x_center, y_center] + + +def cal_euclidean_dist(p1, p2): + return np.linalg.norm(p1 - p2) + + +def bbox_to_four_poinst(bbox): + """convert one bouding box to 4 corner poinst + + Args: + bbox (_type_): _description_ + """ + xmin, ymin, xmax, ymax = bbox + + poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + return poinst + + +def find_closest_point(src_point, point_list): + """ + + Args: + point (list): point format xy + point_list (list[list]): list of point xy + """ + + point_list = np.array(point_list) + dist_list = np.array( + cal_euclidean_dist(src_point, target_point) for target_point in point_list + ) + + index_closest_point = np.argmin(dist_list) + return index_closest_point + + +def crop_align_card(img_src, corner_box_list): + """Dewarp image based on four courners + + Args: + corner_list (list): four points of corners + """ + img = img_src.copy() + if isinstance(corner_box_list[0], list): + poinst = [get_center(box) for box in corner_box_list] + else: + # print(corner_box_list) + xmin, ymin, xmax, ymax = corner_box_list + poinst = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + + return dewarp(img, poinst) + + +def dewarp(image, poinst): + if isinstance(poinst, list): + poinst = np.array(poinst, dtype="float32") + (tl, tr, br, bl) = poinst + widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) + widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) + maxWidth = max(int(widthA), int(widthB)) + heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) + heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) + maxHeight = max(int(heightA), int(heightB)) + + dst = np.array( + [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], + dtype="float32", + ) + + M = cv2.getPerspectiveTransform(poinst, dst) + warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) + return warped + + +class MdetPredictor: + def __init__(self, config: str, checkpoint: str, device: str = "cpu"): + self.model = init_detector(config, checkpoint, device=device) + self.class_names = self.model.CLASSES + + def infer(self, image, threshold=0.2): + bbox_result = inference_detector(self.model, image) + + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + res_bboxes = [] + res_labels = [] + for idx, box in enumerate(bboxes): + score = box[-1] + if score >= threshold: + label = labels[idx] + res_bboxes.append(box.tolist()[:4]) + res_labels.append(self.class_names[label]) + + return res_bboxes, res_labels + + +class ImageTransformer: + def __init__(self, config: str, checkpoint: str, device: str = "cpu"): + self.corner_detect_model = MdetPredictor(config, checkpoint, device) + + def __call__(self, image, threshold=0.2): + """ + + Args: + image (np.ndarray): BGR image + """ + corner_result = self.corner_detect_model.infer(image) + corners_dict = self.__extract_corners(corner_result) + card_image = self.__crop_image_based_on_corners(image, corners_dict) + + return card_image + + def __extract_corners(self, corner_result): + + bboxes, labels = corner_result + # convert bbox to int + bboxes = [[int(x) for x in box] for box in bboxes] + output = {k: bboxes[labels.index(k)] for k in labels} + # print(output) + return output + + def __crop_image_based_on_corners(self, image, corners_dict): + """ + + Args: + corners_dict (_type_): _description_ + """ + if "card" in corners_dict.keys(): + if len(corners_dict.keys()) == 5: + points = [ + corners_dict["top_left"], + corners_dict["top_right"], + corners_dict["bottom_right"], + corners_dict["bottom_left"], + ] + else: + points = corners_dict["card"] + card_image = crop_align_card(image, points) + else: + card_image = None + + return card_image + + +def crop_location(image_url): + transform_module = ImageTransformer( + config="./models/Kie_AHung/yolox_s_8x8_300e_idcard5_coco.py", + checkpoint="./models/Kie_AHung/best_bbox_mAP_epoch_100.pth", + device="cuda:0", + ) + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + card_image = transform_module(img) + if card_image is not None: + return card_image + else: + return img diff --git a/cope2n-ai-fi/common/dates_gplx.json b/cope2n-ai-fi/common/dates_gplx.json new file mode 100755 index 0000000..b39df86 --- /dev/null +++ b/cope2n-ai-fi/common/dates_gplx.json @@ -0,0 +1,622 @@ +{ + "20221027_154840.json": { + "label": "ngày /date 05 tháng /month 04 năm/year 2016", + "pred": "ngày date 05 tháng month 04 năm/year 2016" + }, + "7ba0f6b2f2ff34a16dee18.json": { + "label": "ngày /date 01 tháng /month 04 năm /year 2022", + "pred": "ngày date 01 tháng Amount 04 năm year 2022" + }, + "20221027_155646.json": { + "label": "ngày /date 01 tháng /month 04 năm/year✪2022", + "pred": "ngày date or tháng month 04 năm/year2022" + }, + "799c679b63d6a588fcc717.json": { + "label": "ngày /date 30 tháng /month 07 năm/year 2015", + "pred": "ngày lone 30 thể 2 work 07 ndmyvar 2015" + }, + "150094748_1728953773944481_6269983404281027305_n.json": { + "label": "ngày/da e 15✪tháng # #th 04 năm /year 2020", + "pred": "ngày/da - 15tháng % with 04 năm (year 2020" + }, + "20221027_155711.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "ngày date 16 tháng month 06 năm 'year 2020" + }, + "20221027_155638.json": { + "label": "ngày /date 05 tháng /month 04 năm/year✪2016", + "pred": "ngày /date 05 tháng month 04 năm/year2016" + }, + "20221027_155754.json": { + "label": "ngày/date 19 tháng /month 03 năm/year 2018", + "pred": "ngày/date 19 tháng month 03 năm/year 2018" + }, + "201417393_4052045311588809_501501345369021923_n.json": { + "label": "ngày /date 27 tháng /month 07 năm/year 2020", + "pred": "ngày date 27 tháng month 07 năm/year 2020" + }, + "c50de8e0e2ad24f37dbc28.json": { + "label": "ngày /date 24 tháng /month 05 năm/year 2016", + "pred": "ngày date 24 tháng month 05 năm/year 2016" + }, + "178033599_360318769019204_7688552975060615249_n.json": { + "label": "ngày /date 13 tháng month 08 năm /year 2020", + "pred": "ngày date 73 tháng month 08 năm year 2020" + }, + "20221027_154755.json": { + "label": "ngày /date 22 tháng /month 04 năm/year 2022", + "pred": "ngày date 22 tháng month 04 năm/year 2022" + }, + "20221027_154701.json": { + "label": "ngày/date 27 tháng /month 10 năm/year 2014", + "pred": "ngày/date 27 tháng month 10 năm/year 2014" + }, + "20221027_154617.json": { + "label": "ngày/date 10 tháng /month 03 năm/year 2022", + "pred": "ngày/date 10 than Wmonth 03 năm/year 2022" + }, + "20221027_155429.json": { + "label": "ngày /date 29 tháng /month 10 năm/year 2020", + "pred": "ngày date 29 tháng month 10 năm/year 2020" + }, + "38949066_1826046370807755_4672633932229902336_n.json": { + "label": "ngày /date 03 tháng /month 07 năm /year 2017", + "pred": "ngày date 03 tháng month 0Z năm year 2017" + }, + "174353602_3093780914182491_6316781286680210887_n.json": { + "label": "ngày /date 09 tháng /month 09 năm/year 2019", + "pred": "ngày date 09 tháng month 09 năm/voce 2019" + }, + "135575662_779633739578177_65454671165731184_n.json": { + "label": "ngày /date 02 tháng /month 10 năm /year 2016", + "pred": "ngày 'date 07 tháng month 10 năm Sear 2016" + }, + "198291731_4210067705720978_7154894655460708366_n.json": { + "label": "ngày /date 05 tháng month 05 năm/year 2014", + "pred": "ngày date 05 tháng month 05 ndmivar 2014" + }, + "20221027_155325.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 tháng month 07 năm 'year 2019" + }, + "20221027_155526.json": { + "label": "ngày/date 14 tháng /month 01 năm/year✪2019", + "pred": "ngày/date 14 tháng month 01 năm/year2019" + }, + "20221027_155759.json": { + "label": "ngày/date 24 tháng /month 05 năm/year 2016", + "pred": "ngày/date 24 tháng month 05 năm/year 2016" + }, + "f40789388d754b2b126416.json": { + "label": "ngày /date 10 tháng/month 09 năm /year 2020", + "pred": "ngày /date 10thing month 09 năm year 2020" + }, + "20221027_155453.json": { + "label": "ngày /date 27 tháng /month 10 năm/year 2014", + "pred": "ngày date 27 tháng month 10 năm/year 2014" + }, + "88bb0caa03e7c5b99cf64.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "ngày dg h 6AR ag M ath 06 năm hear 2020" + }, + "20221027_154408.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019", + "pred": "ngày date 28 tháng month 05 năm year 2019" + }, + "200064855_1976213642526776_4676665588314498194_n.json": { + "label": "ngày /date 04 tháng /month 01 năm /year 2018", + "pred": "ngày date 04 tháng /month 01 năm year 2018" + }, + "61c259385775912bc86410.json": { + "label": "ngày /date 20 tháng/month 09 năm/year 2017", + "pred": "ngày /de l 20 th and 4 onth 09 năm/year 2017" + }, + "20221027_155630.json": { + "label": "ngày /date 10 tháng /month 09 năm /year 2020", + "pred": "ngày date 10 tháng month 09 năm year 2020" + }, + "20221027_155342.json": { + "label": "ngày/date 19 tháng /month 03 năm/year 2018", + "pred": "ngày/date 19 tháng month 03 năm/year 2018" + }, + "165824171_1804443136392377_8891768953420682785_n.json": { + "label": "ngày /date 03 tháng month 04 năm /year 2018", + "pred": "ngày date 03 tháng month 04 năm year 2018" + }, + "107005164_1003706713393058_3039921363490738151_n.json": { + "label": "ngày /date 1 8 tháng month 11 năm /year 2019", + "pred": "ngày 'date I 8 tháng month 11 năm vear 2019" + }, + "20221027_155658.json": { + "label": "ngày/date 14 tháng /month 01 năm/year 2019", + "pred": "ngày/date 14 tháng month 01 năm/year 2019" + }, + "20221027_154654.json": { + "label": "ngày /date 08 tháng /month 12 năm/year 2015", + "pred": "ngày date 08 tháng month 12 năm/year 2015" + }, + "f2badabfd1f217ac4ee327.json": { + "label": "ngày /date 19 tháng /month 03 năm/year 2018", + "pred": "ngày /date 19 their almonth 03 năm/year 2018" + }, + "20221027_155501.json": { + "label": "ngày/date 11 tháng /month 06 năm/year 2014", + "pred": "ngày/date 1 l tháng month 06 năm/year. 2014" + }, + "74179950_806804509749947_7322741604127604736_n.json": { + "label": "ngày /date 11 tháng /month 06 năm /year 2019", + "pred": "ngày date 11 tháng month 06 năm \\/gar 2019" + }, + "197892118_1197184050744529_3186157591590303981_n.json": { + "label": "ngày /date 15 tháng /month 05 năm year 2015", + "pred": "ngày date IS tháng month 05 năm year 2015" + }, + "8265ab8ba0c666983fd722.json": { + "label": "ngày /date 08 tháng /month 12 năm /year 2015", + "pred": "ngày /date 08 tháng month 12 năm year 2015" + }, + "185421881_360318465685901_6968676669094190049_n.json": { + "label": "ngày /date 13 tháng month 08 năm /year 2020", + "pred": "ngày date 73 tháng month 08 năm year 2020" + }, + "20221027_155142.json": { + "label": "ngày/date 30 tháng/month 07 năm/year 2015", + "pred": "ngày/date 30 tháng/month 07 năm/year 2015" + }, + "39441751_326131871285503_8401816317220356096_n.json": { + "label": "ngày /date 03 tháng /month 08 năm/year✪2017", + "pred": "ngày 'date 03 tháng month 08 năm✪year✪2017" + }, + "20221027_154912.json": { + "label": "ngày/date 30 tháng /month 07 năm/year 2015", + "pred": "ngày/date 30 tháng month 07 năm/year 2015" + }, + "199235658_1994072707411372_221206969024509405_n.json": { + "label": "ngày /date 05 tháng month 06 năm year 2020", + "pred": "ngày date 05 tháng month 06 năm year 2020" + }, + "168434217_1698404580364474_5439600436729489777_n.json": { + "label": "ngày /date 05 tháng /month 11 năm/year 2020", + "pred": "ngày date 05 tháng month 11 ndmhear 2020" + }, + "20221027_154805.json": { + "label": "ngày /date 14 tháng /month 01 năm/year 2019", + "pred": "ngày date 14 tháng month 01 năm/year 2019" + }, + "e043761b7256b408ed4712.json": { + "label": "ngày /date 27 tháng /month 10 năm/year 2014", + "pred": "ngày date 27 tháng month 10 năm/year 2014" + }, + "20221027_155401.json": { + "label": "ngày /date 29 tháng /month 10 năm/year✪2013", + "pred": "ngày date 29 tháng /uponth 10 năm/year2013" + }, + "20221027_155316.json": { + "label": "ngày/date 20 tháng /month 09 năm/year 2017", + "pred": "ngày/date 20 tháng month 09 năm/year 2017" + }, + "193926583_1626674417674644_309549447428666454_n.json": { + "label": "ngày /date 29 tháng /month 07 năm/year 2020", + "pred": "ngày (date 29 tháng /month 07 năm✪year 2020" + }, + "20221027_154823.json": { + "label": "ngày /date 01 tháng /month 04 năm /year 2022", + "pred": "ngày date or tháng month 04 năm year 2022" + }, + "138942425_242694630705291_5683978028617807264_n.json": { + "label": "ngày /date 30 tháng/month 01 năm/year✪2013", + "pred": "ngày date 30 tháng/month 01 năm/year2013" + }, + "20221027_155334.json": { + "label": "ngày /date 24 tháng /month 05 năm/year 2016", + "pred": "ngày date 24 tháng month 05 năm/year 2016" + }, + "187421917_1668044206721318_779901369147309116_n.json": { + "label": "ngày date 04 # # 05 # 2022", + "pred": "ngày date 4 them Please 05 advise 2021" + }, + "20221027_154716.json": { + "label": "ngày /date 11 tháng /month 06 năm/year 2014", + "pred": "ngày date 11 tháng month 06 năm/year 2014" + }, + "5b1116c61d8bdbd5829a23.json": { + "label": "ngày /date 29 tháng /month 10 năm/year 2020", + "pred": "ngày date 29 tháng /month 10 năm/year 2020" + }, + "20221027_154850.json": { + "label": "ngày /date 10 tháng /month 09 năm /year 2020", + "pred": "ngày date 10 tháng month 09 năm year 2020" + }, + "1489864e88034e5d17122.json": { + "label": "ngày /date 08 tháng /month 12 năm/year 2015", + "pred": "ngày date 08 hàng month 12 năm/year 2015" + }, + "199990418_1443812262638998_8173300652488821384_n.json": { + "label": "ngày/date 10✪tháng /month 03 năm/year 2021", + "pred": "ngày/date 10.50m Noun th 03 ndm/year 2021" + }, + "139073668_833869180522062_7998364448555134241_n.json": { + "label": "ngày/date 26 tháng /month 07 năm/year 2017", + "pred": "ngày/date 26 tháng /month 07 năm/year2 2017" + }, + "20221027_154528.json": { + "label": "ngày date 19 tháng /month 03 năm/year 2018", + "pred": "ngày date 19 tháng month 03 năm/year 2018" + }, + "20221027_154423.json": { + "label": "ngày/date 20 tháng /month 09 năm/year 2017", + "pred": "ngày/date 20 tháng month 09 năm/yew 2017" + }, + "20221027_154722.json": { + "label": "ngày /date 11 tháng /month 06 năm/year 2014", + "pred": "ngày date 1 I tháng month 06 năm/year 2014" + }, + "144003628_4199201026774245_6202264670231940239_n.json": { + "label": "ngày date 07 tháng month 03 ######## ### #", + "pred": "ngày date0 thôn Normmm 05 adortear 7" + }, + "20221027_154434.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 tháng month 07 năm 'year 2019" + }, + "20221027_154630.json": { + "label": "ngày/date 29 tháng /month 10 năm/year 2020", + "pred": "ngày/date 29 tháng month 10 năm/year 2020" + }, + "20221027_155552.json": { + "label": "ngày /date 10 tháng /month 09 năm /year 2020", + "pred": "ngày date 10 tháng month 09 năm 'year 2020" + }, + "131114177_1027132027767399_411142190418396877_n.json": { + "label": "ngày /date 17 tháng /month 01 năm/year✪2018", + "pred": "ngày date 17 tháng /month 01 năm/year2018" + }, + "164703297_455738728964651_5260332814645460915_n.json": { + "label": "ngày date 03 tháng month 11 năm/year 2020", + "pred": "ngày date 03 tháng month I I năm/year 2020" + }, + "20221027_155926.json": { + "label": "ngày/date 20 tháng /month 09 năm /year 2017", + "pred": "ngày/date 20 tháng month 09 năm year 2017" + }, + "20221027_154832.json": { + "label": "ngày /date 05 tháng /month 04 năm/year 2016", + "pred": "ngày date 05 tháng month 04 năm/year 2016 6" + }, + "dd09b9b6b2fb74a52dea24.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "123) d 2010 Way Yount 06 năm year 2020" + }, + "20221027_154646.json": { + "label": "ngày /date 08 tháng /month 12 năm/year 2015", + "pred": "ngày date 08 2 tháng month 12 năm/year2 2015" + }, + "180534342_1213803569050037_4381710158357942629_n.json": { + "label": "ngày /date 20 tháng /month 10 năm/year 21", + "pred": "ngày date 20 tháng month 10 năm/year 2" + }, + "20221027_155443.json": { + "label": "ngày /date 08 tháng /month 12 năm/year 2015", + "pred": "ngày date 08 tháng month 12 năm/year 2015" + }, + "25a9717c7b31bd6fe42029.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 tháng month 07 năm hear 2019" + }, + "a0a8281f2652e00cb9435.json": { + "label": "ngày/date 10 tháng /month 03 năm/year 2022", + "pred": "ngày/date 10 that abroath 03 năm/year 2022" + }, + "48793dfd37b0f1eea8a130.json": { + "label": "tháng/month 09 năm/year\n ngày/date 20 tháng/month\n 163", + "pred": "ngày/date 20 chăn knowin năm/ya\n 163" + }, + "20221027_154730.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "ngày date to tháng month 06 năm year 2020" + }, + "174102242_893537741194123_1381062036549019974_n.json": { + "label": "ngày /date 11 tháng /month 11 năm/year 2019", + "pred": "ngày date 17 tháng Thuonth 11 năm/year 2019" + }, + "20221027_154541.json": { + "label": "ngày /date 29 tháng /month 10 năm/year✪2013", + "pred": "ngày date 29 tháng month 10 năm/year2013" + }, + "20221027_155939.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019", + "pred": "ngày date 28 tháng month 05 năm 'year 2019" + }, + "20221027_154452.json": { + "label": "ngày /date 24 tháng /month 05 năm/year 2016", + "pred": "ngày date 24 tháng month 05 năm/year 2016" + }, + "104353445_990772771353119_6131582365614146594_n.json": { + "label": "ngày date 18 tháng /month 03 năm /year 2019", + "pred": "ngày date If tháng month 03 năm year 2019" + }, + "20221027_155418.json": { + "label": "ngày/date 10 tháng/month 03 năm/year 2022", + "pred": "ngày/date 10than dmonth 03 năm/year 2022" + }, + "20221027_155916.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 tháng month 07 năm 'year 2019" + }, + "195887607_545276056640128_7265052621888807786_n.json": { + "label": "ngày /date 12 tháng /month 01 năm /year 2017", + "pred": "ngày /dote 12 tháng month 01 năm hear 201 7" + }, + "168303942_358282189193092_4968412916165104911_n.json": { + "label": "ngày/date 19 tháng month 03 năm year 2015", + "pred": "ngày/date 19 tháng month 03 năm your 1019" + }, + "6a70d15bd51613484a0713.json": { + "label": "ngày /date 22 tháng /month 04 năm /year 2022", + "pred": "ngày date n háu Z with 04 năm year 2022" + }, + "20221027_155302.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019", + "pred": "ngày date 28 tháng month 05 năm 'year 2019" + }, + "20221027_154815.json": { + "label": "ngày /date 01 tháng /month 04 năm/year 2022", + "pred": "ngày date or tháng month 04 năm/year 2022" + }, + "c87b81298e64483a11757.json": { + "label": "ngày /date 19 tháng /month 03 năm/year 201", + "pred": "ngày date 19 thớ almonth 03 năm/year 201" + }, + "20221027_155620.json": { + "label": "ngày/date 30 tháng /month 07 năm/year 2015", + "pred": "ngày/date 30 tháng month 07 năm/year 2015" + }, + "20221027_155511.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "ngày date to tháng month 06 năm 'year 2020" + }, + "745c5d6b52269478cd379.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 thân g worth 07 năm hear 2019" + }, + "9329511f5552930cca4315.json": { + "label": "ngày/date 11 tháng /month 06 năm/year 2014", + "pred": "ngày/date 11 tháng month 06 năm/year 2014" + }, + "158882925_262850065433814_5526034984745996835_n.json": { + "label": "ngày /date 16 tháng month 12 năm/year 2015", + "pred": "ngày date 16 tháng month 12 năm/year 2015" + }, + "140687263_3755683421155059_7637736837539526203_n.json": { + "label": "năm/year 2017\n ngày /date # # # # 08 năm/year", + "pred": "năm 2017\n ngày 'date 422 ins 1.1 ma 08 năm" + }, + "20221027_155717.json": { + "label": "ngày/date 11 tháng /month 06 năm/year 2014", + "pred": "ngày/date 11 tháng month 06 năm/year 2014" + }, + "148919455_2877898732481724_2579276238538203411_n.json": { + "label": "ngày /date 05 tháng /month 09 năm/year✪2018", + "pred": "ngày date 05 thán g /month 09 năm/year2018" + }, + "c8b0dc9cd8d11e8f47c014.json": { + "label": "ngày /date 05 tháng /month 04 năm/year 2016", + "pred": "ngày /date 0.5 tháng g/month 04 năm/year 2016" + }, + "175913976_2827333254262221_2873818403698028020_n.json": { + "label": "ngày date 05 tháng month 01 năm year 2018", + "pred": "ngày dan us thông month 01 năm year 2018" + }, + "20221027_155029.json": { + "label": "ngày/date 30 tháng/month\n 07 năm/year 2015", + "pred": "ngày/date\n năm/year 2015" + }, + "196165776_1160925321042008_58817602967276351_n.json": { + "label": "ngày date 01 tháng month 09 năm year 2020", + "pred": "ngày date or tháng month 09 năm year 2020" + }, + "162820484_451115505943597_8326385834717580925_n.json": { + "label": "ngày /date 27 tháng /month 08 thường 2014", + "pred": "ngày/ date 27 tháng furonth 08 năm/year 2014" + }, + "41594446_1316256461838413_661251515624718336_n.json": { + "label": "ngày /date 21 tháng /month 06 năm/year✪2017", + "pred": "ngày /date 27 tháng month 06 năm/year201 7" + }, + "20221027_155728.json": { + "label": "ngày /date 08 tháng /month 12 năm/year 2015", + "pred": "ngày date 08 tháng month 12 năm/year 2015" + }, + "142090201_2826170774286852_1233962294093312865_n.json": { + "label": "ngày /date 29 tháng month 07 năm /year 2019", + "pred": "ngày date 29 tháng month 07 năm year 2019" + }, + "0dd2fe8bf5c633986ad726.json": { + "label": "ngày /date 29 tháng /month 10 năm/year✪2013", + "pred": "march date 29 tháng month 10✪năm✪volur✪2013" + }, + "190919701_1913643422140446_6855763478065892825_n.json": { + "label": "ngày date 24 tháng /month 07 năm/year 2017", + "pred": "ngày /date 24 tháng month 07 năm/year 2017" + }, + "147585615_2757487791230682_5515346433540820516_n.json": { + "label": "ngày /date # 9✪tháng /month 05 năm /year 2019", + "pred": "ngày date 2 Danate month 05 năm year 2019" + }, + "5fcae09ee4d3228d7bc211.json": { + "label": "ngày /date 14 tháng /month 01 năm/year 2019", + "pred": "ngày 'date 14 tháng month 0.1 năm/year 2019" + }, + "20221027_154417.json": { + "label": "ngày/date 20 tháng /month 09 năm/year 2017", + "pred": "ngày/date 20 thân ghi onth 09 năm/year 2017" + }, + "20221027_154802.json": { + "label": "ngày /date 14 tháng /month 01 năm/year 2019", + "pred": "ngày date 14 tháng month 01 năm/year 2019" + }, + "20221027_154749.json": { + "label": "ngày /date 22 tháng /month 04 năm /year 2022", + "pred": "ngày date 22 tháng month 04 năm year 2022" + }, + "20221027_154900.json": { + "label": "ngày /date 10 tháng /month 09 năm /year 2020", + "pred": "ngày date 10 tháng month 09 năm Year 2020" + }, + "4b2b35453e08f856a11925.json": { + "label": "ngày/date 10 tháng/month 03 năm/year 2022", + "pred": "ngày/date 1 in 03 năm/year 2022" + }, + "20221027_155734.json": { + "label": "ngày /date 29 tháng /month 10 năm/year 2020", + "pred": "ngày date 29 tháng month 10 năm/year 2020" + }, + "198876248_2931797967062357_4287721016641237281_n.json": { + "label": "ngày /date 29 áng /month 10 năm /year 2019", + "pred": "ngày Warm 204 đón Quanth 10 năm vear 2019" + }, + "20221027_154613.json": { + "label": "ngày/date 10 tháng/month 03 năm/year 2022", + "pred": "ngày/date 10than Imonth 03 năm/year 2022" + }, + "20221027_154600.json": { + "label": "ngày/date 29 tháng /month 10 năm/year✪2013", + "pred": "ngày/date 29 tháng month 10 năm/year2013" + }, + "191389634_910736173104716_4923402486196996972_n.json": { + "label": "ngày /date 24 tháng /month 02 năm/year 2021", + "pred": "ngày date 24 tháng month 02 năm/year 2021" + }, + "20221027_155723.json": { + "label": "ngày /date 27 tháng /month 10 năm/year 2014", + "pred": "ngày /date 27 tháng month 10 năm/year 2014" + }, + "184606042_1586323798376373_2179113485447088932_n.json": { + "label": "ngày /date 29 tháng /month 07 năm/year 2020", + "pred": "ngày (date 29 tháng /month 07 năm✪year 2020" + }, + "6fb460d06a9dacc3f58c31.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019", + "pred": "ngày date 28 tháng month 05 năm year 2019" + }, + "7e810ce502a8c4f69db93.json": { + "label": "ngày /date 29 tháng /month 10 năm /year 2020", + "pred": "ngày de it 29 thing month 10 năm Year 2020" + }, + "20221027_154400.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019", + "pred": "ngày date 28 tháng month 05 năm year 2019" + }, + "139579296_107198731349013_7325819456715999953_n.json": { + "label": "ngày /date 10tháng month 05 năm /year 2018", + "pred": "ngày date 1 month 05 năm hear 2018" + }, + "20221027_155748.json": { + "label": "ngày/date 29 tháng/month 10 năm/year✪2013", + "pred": "ngày/date 29 tháng/month 10 năm/year2013" + }, + "164359233_2788848161366629_6843431895499380423_n.json": { + "label": "ngày /date 25 tháng /month 12 năm/year 2015", + "pred": "ngày dute 25 tháng month 12 năm/year 2015" + }, + "962650ff5eb298ecc1a36.json": { + "label": "ngày /date 29 tháng/month 10 năm/year 2013", + "pred": "10 năm/year 2013\n ngày the 29thanginonth\n TL" + }, + "20221027_155600.json": { + "label": "ngày/date 30 tháng /month 07 năm/year 2015", + "pred": "ngày/date 30 tháng month 07 năm/year 2015" + }, + "951970367f7bb925e06a1.json": { + "label": "ngày /date 28 tháng /month 05 năm /year 2019 -", + "pred": "ngày late 28 hàng worth 05 năm hear 2019" + }, + "20221027_154627.json": { + "label": "ngày /date 29 tháng /month 10 năm/year 2020", + "pred": "ngày date 29 tháng month 10 năm/year. 2020" + }, + "20221027_155535.json": { + "label": "ngày /date 01 tháng /month 04 năm /year 2022", + "pred": "ngày date 01 tháng month 04 năm 'year 2022" + }, + "106402928_1000018507095212_5438034148254460378_n.json": { + "label": "ngày date 04 tháng month 09 năm /year 2019", + "pred": "ngày dan 04 tháng month 09 năm Year 2019" + }, + "20221027_154458.json": { + "label": "ngày /date 24 tháng /month 05 năm/year 2016", + "pred": "ngày date 24 tháng month 03 năm/year 2016" + }, + "190841312_3057702594458479_8551202571498845435_n.json": { + "label": "ngày /date 14 tháng month 12 năm /year 2018", + "pred": "ngày date 14 tháng month 12 năm year 2018" + }, + "20221027_154907.json": { + "label": "ngày/date 30 tháng /month 07 năm/year✪2015", + "pred": "ngày/date 30 tháng month 07 năm/year2015" + }, + "136098726_3413968628702123_4090292519699106839_n.json": { + "label": "ngày /date 20tháng /month 11 năm /year 2020", + "pred": "ngày date 70tháng month 11 năm year 2020" + }, + "20221027_154440.json": { + "label": "ngày /date 09 tháng /month 07 năm /year 2019", + "pred": "ngày date 09 tháng mc each 07 năm year 2019" + }, + "07966c1b6256a408fd478.json": { + "label": "ngày/date 24 tháng /month 05 năm/year 2016", + "pred": "ngày/d te 24 thán day anth 05 năm/year 2016" + }, + "20221027_155740.json": { + "label": "ngày/date 10 tháng/month 03 năm/year 2022", + "pred": "ngày/date 10than almonth 03 năm/year 2022" + }, + "130727988_1377512982599930_6481917606912865462_n.json": { + "label": "ngày/date 30 tháng /month 05 năm/year 2016", + "pred": "ngày/date 30 tháng month 05 năm/year 2016" + }, + "194993073_539716147195165_8525378287933192246_n.json": { + "label": "ngày /date 12 tháng /month 06 năm /year 2020", + "pred": "ngày date 12 tháng /month 06 năm vear 2020" + }, + "20221027_154521.json": { + "label": "ngày/date 19 tháng /month 03 năm/year 2018", + "pred": "ngày/dec 19 tháng month 03 năm/year 2018" + }, + "174247511_900878714088455_7516565117828455890_n.json": { + "label": "ngày/date 1 0✪tháng /month 05 năm/year 2018", + "pred": "ngày/date 4 Other month 05 năm/year 2018" + }, + "20221027_155519.json": { + "label": "ngày /date 22 tháng /month 04 năm /year 2022", + "pred": "ngày date 22 tháng month 04 năm year 2022" + }, + "20221027_154706.json": { + "label": "ngày /date 27 tháng /month 10 năm/year 2014", + "pred": "ngày date 27 tháng month 10 năm/year. 2014" + }, + "a88ae66aed272b79723621.json": { + "label": "ngày /date 27 tháng /month 10 năm/year 2014", + "pred": "" + }, + "4378298-95583fbb1edb703a6c5bbc1744246058-1-1.json": { + "label": "ngày/d 24 tháng /month 1 năm/year 2015", + "pred": "ngày/d / 24 than chuonth 1 1 năm/year 2015" + }, + "20221027_155704.json": { + "label": "ngày /date 22 tháng /month 04 năm/year 2022", + "pred": "ngày date 22 tháng month 04 năm hear 2022" + }, + "179642128_1945301335636182_5557211235870766646_n.json": { + "label": "ngày/date #✪thá# # onth 10 năm/year 201", + "pred": "ngowdan (Ký - touth 10 năm/year 201" + }, + "20221027_154734.json": { + "label": "ngày /date 16 tháng /month 06 năm /year 2020", + "pred": "ngày date to tháng month 06 năm year 2020" + }, + "20221027_155542.json": { + "label": "ngày /date 05 tháng /month 04 năm/year 2016", + "pred": "ngày /date 05 tháng month 04 năm/year 2016" + } +} \ No newline at end of file diff --git a/cope2n-ai-fi/common/json2xml.py b/cope2n-ai-fi/common/json2xml.py new file mode 100755 index 0000000..8b643fd --- /dev/null +++ b/cope2n-ai-fi/common/json2xml.py @@ -0,0 +1,196 @@ +import xml.etree.ElementTree as ET +from datetime import datetime +ET.register_namespace('', "http://www.w3.org/2000/09/xmldsig#") + + +xml_template3 = """ + + + + None + None + None + None + None + None + None + None + None + None + + + + None + None + None + None + + + None + None + None + None + None + + + + None + None + None + None + None + None + None + None + None + None + + + + + + None + None + None + + + None + None + None + None + None + + + + None + None + +""" + +def replace_xml_values(xml_str, replacement_dict): + """ replace xml values + + Args: + xml_str (_type_): _description_ + replacement_dict (_type_): _description_ + + Returns: + _type_: _description_ + """ + try: + root = ET.fromstring(xml_str) + for key, value in replacement_dict.items(): + if not value: + continue + if key == "TToan": + ttoan_element = root.find(".//TToan") + tsuat_element = ttoan_element.find(".//TgTThue") + tthue_element = ttoan_element.find(".//TgTTTBSo") + tthuebchu_element = ttoan_element.find(".//TgTTTBChu") + if value["TgTThue"]: + tsuat_element.text = value["TgTThue"] + if value["TgTTTBSo"]: + tthue_element.text = value["TgTTTBSo"] + if value["TgTTTBChu"]: + tthuebchu_element.text = value["TgTTTBChu"] + elif key == "NMua": + nmua_element = root.find(".//NMua") + for key_ in ["DChi", "SDThoai", "MST", "Ten", "HVTNMHang"]: + if value.get(key_, None): + nmua_element_key = nmua_element.find(f".//{key_}") + nmua_element_key.text = value[key_] + elif key == "NBan": + nban_element = root.find(".//NBan") + for key_ in ["DChi", "SDThoai", "MST", "Ten"]: + if value.get(key_, None): + nban_element_key = nban_element.find(f".//{key_}") + nban_element_key.text = value[key_] + elif key == "HHDVu": + dshhdvu_element = root.find(".//DSHHDVu") + hhdvu_template = root.find(".//HHDVu") + if hhdvu_template is not None and dshhdvu_element is not None: + dshhdvu_element.remove(hhdvu_template) # Remove the template + for hhdvu_data in value: + hhdvu_element = ET.SubElement(dshhdvu_element, "HHDVu") + for h_key, h_value in hhdvu_data.items(): + h_element = ET.SubElement(hhdvu_element, h_key) + h_element.text = h_value if h_value is not None else "None" + elif key == "NLap": + nlap_element = root.find(".//NLap") + if nlap_element is not None: + # Convert the date to yyyy-mm-dd format + try: + date_obj = datetime.strptime(value, "%d/%m/%Y") + formatted_date = date_obj.strftime("%Y-%m-%d") + nlap_element.text = formatted_date + except ValueError: + print(f"Invalid date format for {key}: {value}") + nlap_element.text = value + else: + element = root.find(f".//{key}") + if element is not None: + element.text = value + ET.register_namespace("", "http://www.w3.org/2000/09/xmldsig#") + return ET.tostring(root, encoding="unicode") + except ET.ParseError as e: + print(f"Error parsing XML: {e}") + return None + + +def convert_key_names(original_dict): + """_summary_ + + Args: + original_dict (_type_): _description_ + + Returns: + _type_: _description_ + """ + key_mapping = { + "table": "HHDVu", + "Mặt hàng": "THHDVu", + "Đơn vị tính": "DVTinh", + "Số lượng": "SLuong", + "Đơn giá": "DGia", + "Doanh số mua chưa có thuế": "ThTien", + "buyer_address_value": "NMua.DChi", + 'buyer_company_name_value': 'NMua.Ten', + 'buyer_personal_name_value': 'NMua.HVTNMHang', + 'buyer_tax_code_value': 'NMua.MST', + 'buyer_tel_value': 'NMua.SDThoai', + 'seller_address_value': 'NBan.DChi', + 'seller_company_name_value': 'NBan.Ten', + 'seller_tax_code_value': 'NBan.MST', + 'seller_tel_value': 'NBan.SDThoai', + 'date_value': 'NLap', + 'form_value': 'KHMSHDon', + 'no_value': 'SHDon', + 'serial_value': 'KHHDon', + 'tax_amount_value': 'TToan.TgTThue', + 'total_in_words_value': 'TToan.TgTTTBChu', + 'total_value': 'TToan.TgTTTBSo' + } + + converted_dict = {} + for key, value in original_dict.items(): + new_key = key_mapping.get(key, key) + if "." in new_key: + parts = new_key.split(".") + current_dict = converted_dict + for i, part in enumerate(parts): + if i == len(parts) - 1: + current_dict[part] = value + else: + current_dict.setdefault(part, {}) + current_dict = current_dict[part] + else: + if key == "table": + lconverted_table_values = [] + for table_value in value: + converted_table_value = convert_key_names(table_value) + lconverted_table_values.append(converted_table_value) + converted_dict[new_key] = lconverted_table_values + else: + converted_dict[new_key] = value + + return converted_dict \ No newline at end of file diff --git a/cope2n-ai-fi/common/ocr.py b/cope2n-ai-fi/common/ocr.py new file mode 100755 index 0000000..5399a26 --- /dev/null +++ b/cope2n-ai-fi/common/ocr.py @@ -0,0 +1,37 @@ +from common.utils.ocr_yolox import OcrEngineForYoloX_ID_Driving +from common.utils.word_formation import Word, words_to_lines + +det_ckpt = "yolox-s-general-text-pretrain-20221226" +cls_ckpt = "satrn-lite-general-pretrain-20230106" + +engine = OcrEngineForYoloX_ID_Driving(det_ckpt, cls_ckpt) + + +def ocr_predict(image): + """Predict text from image + + Args: + image_path (str): _description_ + + Returns: + list: list of words + """ + try: + lbboxes, lwords = engine.run_image(image) + lWords = [Word(text=word, bndbox=bbox) for word, bbox in zip(lwords, lbboxes)] + list_lines, _ = words_to_lines(lWords) + return list_lines + except AssertionError as e: + print(e) + list_lines = [] + return list_lines + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, required=True) + args = parser.parse_args() + + list_lines = ocr_predict(args.image) diff --git a/cope2n-ai-fi/common/post_processing_datetime.py b/cope2n-ai-fi/common/post_processing_datetime.py new file mode 100755 index 0000000..ad5cd27 --- /dev/null +++ b/cope2n-ai-fi/common/post_processing_datetime.py @@ -0,0 +1,113 @@ +import re +from datetime import datetime +from sklearn.metrics import classification_report +from common.utils.utils import read_json +from underthesea import word_tokenize + + +class DatetimeCorrector: + @staticmethod + def verify_and_convert_date(date_str): + # Try to parse the date string using the datetime module + try: + date = datetime.strptime(date_str, "%d/%m/%Y") # TODO: fix this + except ValueError: + # If the date string is not in a valTid format, return False + return "" + + # If the date string is in the correct format, check if it is already in the "dd/mm/yyyy" format + if date_str[:2] == "dd" and date_str[3:5] == "mm" and date_str[6:] == "yyyy": + # If the date string is already in the correct format, return it as is + return date_str + else: + # If the date string is not in the correct format, use the strftime method to convert it + return date.strftime("%d/%m/%Y") + + @staticmethod + def get_date_from_date_string_by_prefix(date_string_, prefix_): + prefix = prefix_.lower() + date_string = date_string_.lower() + if prefix in date_string: + try: + if prefix == "năm": + match = re.split( + r"năm[^\d]*(\d{4}|\d{1}[\s.]*\d{3}|\d{3}[\s.]*\d{1}|\d{2}[\s.]*\d{2}|\d{2}[\s.]*\d{1}[\s.]*\d{1}|\d{1}[\s.]*\d{2}[\s.]*\d{1}|\d{1}[\s.]*\d{1}[\s.]*\d{2}|\d{1}[\s.]*\d{1}[\s.]*\d{1}[\s.]*\d{1})[\s.]*\b", + date_string) # match "năm" following with all combination of 4 numbers and whitespace/dot such as 1111; 111.1; 111 1; 11 2 1, 2 2 2.2; ... + elif prefix == "ngày": + match = re.split(r"ngày[^\d]*(\d{2}|\d{1}[\s.]*\d{1}|\d{1})[\s.]*\b", date_string) + else: + match = re.split(r"tháng[^\d]*(\d{2}|\d{1}[\s.]*\d{1}|\d{1})[\s.]*\b", date_string) + num = match[1] + remain_string = match[2] if prefix != "năm" else match[0] + return num, remain_string + except: + return "", date_string_ + else: + return '', date_string_ + + @staticmethod + def get_date_by_pattern(date_string): + match = re.findall(r"([^\d\s]+)?\s*(\d{1}\s*\d?\s+|\d{2}\s+|\d+\s*\b)", date_string) + if not match: + return "" + if len(match) > 3: + day = match[0][-1].replace(" ", "") + year = match[-1][-1].replace(" ", "") + # since in the VIETNAMESE DRIVER LICENSE, the tháng/month is behind the stamp and can be recognized as any thing => mistạken number may be in range (1->-3) => choose month to be -2 + month = match[-2][-1].replace(" ", "") + return "/".join([day, month, year]) + else: + return "/".join([m[-1].replace(" ", "") for m in match]) + + @staticmethod + def extract_date_from_string(date_string): + remain_str = date_string + ldate = [] + for d in ["năm", "ngày", "tháng"]: + date, remain_str = DatetimeCorrector.get_date_from_date_string_by_prefix(date_string, d) + if not date: + return DatetimeCorrector.get_date_by_pattern(date_string) + ldate.append(date.strip().replace(" ", "").replace(".", "")) + return "/".join([ldate[1], ldate[2], ldate[0]]) + + @staticmethod + def correct(date_string): + # Extract the day, month, and year from the string using regular expressions + date_string = date_string.lower().replace("✪", " ") + date_string = " ".join(word_tokenize(date_string)) + parsed_date_string_ = DatetimeCorrector.verify_and_convert_date(date_string) # if already in datetime format + if parsed_date_string_: + return parsed_date_string_ + extracted_date = DatetimeCorrector.extract_date_from_string(date_string) + parsed_date_string_ = DatetimeCorrector.verify_and_convert_date(extracted_date) + return parsed_date_string_ if parsed_date_string_ else date_string + + @staticmethod + def eval(): + data = read_json("common/dates_gplx.json") + type_column = "GPLX" # Invoice/GPLX + y_true, y_pred = [], [] + lexcludes = {} + ddata = {} + for k, d in data.items(): + if k in lexcludes: + continue + if k == "inv_SDV_215": + print("debugging") + pred = DatetimeCorrector.correct(d["pred"]) + label = DatetimeCorrector.correct(d["label"]) + ddata[k] = {} + data[k]["Type"] = type_column + ddata[k]["Predict"] = d["pred"] + ddata[k]["Label"] = d["label"] + ddata[k]["Post-processed"] = pred + y_pred.append(pred == label) + y_true.append(1) + if k == "invoice_1219_000": + print("\n", k, '-' * 50) + print(pred, "------", d["pred"]) + print(label, "------", d["label"]) + print(classification_report(y_true, y_pred)) + import pandas as pd + df = pd.DataFrame.from_dict(ddata, orient="index") + df.to_csv(f"result/datetime_post_processed_{type_column}.csv") \ No newline at end of file diff --git a/cope2n-ai-fi/common/post_processing_driver.py b/cope2n-ai-fi/common/post_processing_driver.py new file mode 100755 index 0000000..49b7423 --- /dev/null +++ b/cope2n-ai-fi/common/post_processing_driver.py @@ -0,0 +1,51 @@ +from common.utils.word_formation import words_to_lines +from Kie_AHung.prediction import KIE_LABELS, IGNORE_KIE_LABEL +from common.post_processing_datetime import DatetimeCorrector + + +def merge_bbox(list_bbox): + if not list_bbox: + return list_bbox + left = min(list_bbox, key=lambda x: x[0])[0] + top = min(list_bbox, key=lambda x: x[1])[1] + right = max(list_bbox, key=lambda x: x[2])[2] + bot = max(list_bbox, key=lambda x: x[3])[3] + return [left, top, right, bot] + + +def create_result_kie_dict(): + return { + KIE_LABELS[i]: {} + for i in range(len(KIE_LABELS)) + if KIE_LABELS[i] != IGNORE_KIE_LABEL + } + + +def create_empty_kie_dict(): + return { + KIE_LABELS[i]: [] + for i in range(len(KIE_LABELS)) + if KIE_LABELS[i] != IGNORE_KIE_LABEL + } + + +def create_kie_dict(list_words): + kie_dict = create_empty_kie_dict() + # append each word to respected dict + for word in list_words: + if word.kie_label in kie_dict: + kie_dict[word.kie_label].append(word) + word.text = word.text.strip() + # construct line from words for each kie_label + result_dict = create_result_kie_dict() + for kie_label in result_dict: + list_lines, _ = words_to_lines(kie_dict[kie_label]) + text = "\n ".join([line.text.strip() for line in list_lines]) + if kie_label == "date": + # text = post_processing_datetime(text) + text = DatetimeCorrector.correct(text) + result_dict[kie_label]["text"] = text + result_dict[kie_label]["bbox"] = merge_bbox( + [line.boundingbox for line in list_lines] + ) + return result_dict diff --git a/cope2n-ai-fi/common/post_processing_id.py b/cope2n-ai-fi/common/post_processing_id.py new file mode 100755 index 0000000..82c0597 --- /dev/null +++ b/cope2n-ai-fi/common/post_processing_id.py @@ -0,0 +1,51 @@ +from common.utils.word_formation import words_to_lines +from Kie_AHung_ID.prediction import KIE_LABELS, IGNORE_KIE_LABEL +from common.post_processing_datetime import DatetimeCorrector + + +def merge_bbox(list_bbox): + if not list_bbox: + return list_bbox + left = min(list_bbox, key=lambda x: x[0])[0] + top = min(list_bbox, key=lambda x: x[1])[1] + right = max(list_bbox, key=lambda x: x[2])[2] + bot = max(list_bbox, key=lambda x: x[3])[3] + return [left, top, right, bot] + + +def create_result_kie_dict(): + return { + KIE_LABELS[i]: {} + for i in range(len(KIE_LABELS)) + if KIE_LABELS[i] != IGNORE_KIE_LABEL + } + + +def create_empty_kie_dict(): + return { + KIE_LABELS[i]: [] + for i in range(len(KIE_LABELS)) + if KIE_LABELS[i] != IGNORE_KIE_LABEL + } + + +def create_kie_dict(list_words): + kie_dict = create_empty_kie_dict() + # append each word to respected dict + for word in list_words: + if word.kie_label in kie_dict: + kie_dict[word.kie_label].append(word) + word.text = word.text.strip() + # construct line from words for each kie_label + result_dict = create_result_kie_dict() + for kie_label in result_dict: + list_lines, _ = words_to_lines(kie_dict[kie_label]) + text = "\n ".join([line.text.strip() for line in list_lines]) + if kie_label == "date": + # text = post_processing_datetime(text) + text = DatetimeCorrector.correct(text) + result_dict[kie_label]["text"] = text + result_dict[kie_label]["bbox"] = merge_bbox( + [line.boundingbox for line in list_lines] + ) + return result_dict diff --git a/cope2n-ai-fi/common/process_pdf.py b/cope2n-ai-fi/common/process_pdf.py new file mode 100755 index 0000000..4ad48b1 --- /dev/null +++ b/cope2n-ai-fi/common/process_pdf.py @@ -0,0 +1,252 @@ +import os +import json + +from common import json2xml +from common.json2xml import convert_key_names, replace_xml_values +from common.utils_kvu.split_docs import split_docs, merge_sbt_output + +# from api.OCRBase.prediction import predict as ocr_predict +# from api.Kie_Invoice_AP.prediction_sap import predict +# from api.Kie_Invoice_AP.prediction_fi import predict_fi +# from api.manulife.predict_manulife import predict as predict_manulife +from api.sdsap_sbt.prediction_sbt import predict as predict_sbt + +os.environ['PYTHONPATH'] = '/home/thucpd/thucpd/cope2n-ai/cope2n-ai/' + +def check_label_exists(array, target_label): + for obj in array: + if obj["label"] == target_label: + return True # Label exists in the array + return False # Label does not exist in the array + +def compile_output(list_url): + """_summary_ + + Args: + pdf_extracted (list): list: [{ + "1": url},{"2": url}, + ...] + Raises: + NotImplementedError: _description_ + + Returns: + dict: output compiled + """ + + results = { + "model":{ + "name":"Invoice", + "confidence": 1.0, + "type": "finance/invoice", + "isValid": True, + "shape": "letter", + } + } + compile_outputs = [] + compiled = [] + for page in list_url: + output_model = predict(page['page_number'], page['file_url']) + for field in output_model['fields']: + if field['value'] != "" and not check_label_exists(compiled, field['label']): + element = { + 'label': field['label'], + 'value': field['value'], + } + compiled.append(element) + elif field['label'] == 'table' and check_label_exists(compiled, "table"): + for index, obj in enumerate(compiled): + if obj['label'] == 'table': + compiled[index]['value'].append(field['value']) + compile_output = { + 'page_index': page['page_number'], + 'request_file_id': page['request_file_id'], + 'fields': output_model['fields'] + } + compile_outputs.append(compile_output) + results['combine_results'] = compiled + results['pages'] = compile_outputs + return results + +def update_null_values(kvu_result, next_kvu_result): + for key, value in kvu_result.items(): + if value is None and next_kvu_result.get(key) is not None: + kvu_result[key] = next_kvu_result[key] + +def replace_empty_null_values(my_dict): + for key, value in my_dict.items(): + if value == '': + my_dict[key] = None + return my_dict + +def compile_output_fi(list_url): + """_summary_ + + Args: + pdf_extracted (list): list: [{ + "1": url},{"2": url}, + ...] + Raises: + NotImplementedError: _description_ + + Returns: + dict: output compiled + """ + + results = { + "model":{ + "name":"Invoice", + "confidence": 1.0, + "type": "finance/invoice", + "isValid": True, + "shape": "letter", + } + } + # Loop through the list_url to update kvu_result + for i in range(len(list_url) - 1): + page = list_url[i] + next_page = list_url[i + 1] + kvu_result, output_kie = predict_fi(page['page_number'], page['file_url']) + next_kvu_result, next_output_kie = predict_fi(next_page['page_number'], next_page['file_url']) + + update_null_values(kvu_result, next_kvu_result) + output_kie = replace_empty_null_values(output_kie) + next_output_kie = replace_empty_null_values(next_output_kie) + update_null_values(output_kie, next_output_kie) + + # Handle the last item in the list_url + if list_url: + page = list_url[-1] + kvu_result, output_kie = predict_fi(page['page_number'], page['file_url']) + + converted_dict = convert_key_names(kvu_result) + converted_dict.update(convert_key_names(output_kie)) + output_fi = replace_xml_values(json2xml.xml_template3, converted_dict) + field_fi = { + "xml": output_fi, + } + results['combine_results'] = field_fi + # results['combine_results'] = converted_dict + # results['combine_results_kie'] = output_kie + return results + +def compile_output_ocr_base(list_url): + """Compile output of OCRBase + + Args: + list_url (list): List string url of image + + Returns: + dict: dict of output + """ + + results = { + "model":{ + "name":"OCRBase", + "confidence": 1.0, + "type": "ocrbase", + "isValid": True, + "shape": "letter", + } + } + compile_outputs = [] + for page in list_url: + output_model = ocr_predict(page['page_number'], page['file_url']) + compile_output = { + 'page_index': page['page_number'], + 'request_file_id': page['request_file_id'], + 'fields': output_model['fields'] + } + compile_outputs.append(compile_output) + results['pages'] = compile_outputs + return results + +def compile_output_manulife(list_url): + """_summary_ + + Args: + pdf_extracted (list): list: [{ + "1": url},{"2": url}, + ...] + Raises: + NotImplementedError: _description_ + + Returns: + dict: output compiled + """ + + results = { + "model":{ + "name":"Invoice", + "confidence": 1.0, + "type": "finance/invoice", + "isValid": True, + "shape": "letter", + } + } + + outputs = [] + for page in list_url: + output_model = predict_manulife(page['page_number'], page['file_url']) # gotta be predict_manulife(), for the time being, this function is not avaible, we just leave a dummy function here instead + print("output_model", output_model) + outputs.append(output_model) + print("outputs", outputs) + documents = split_docs(outputs) + print("documents", documents) + results = { + "total_pages": len(list_url), + "ocr_num_pages": len(list_url), + "document": documents + } + return results + +def compile_output_sbt(list_url): + """_summary_ + + Args: + pdf_extracted (list): list: [{ + "1": url},{"2": url}, + ...] + Raises: + NotImplementedError: _description_ + + Returns: + dict: output compiled + """ + + results = { + "model":{ + "name":"Invoice", + "confidence": 1.0, + "type": "finance/invoice", + "isValid": True, + "shape": "letter", + } + } + + + outputs = [] + for page in list_url: + output_model = predict_sbt(page['page_number'], page['file_url']) + if "doc_type" in page: + output_model['doc_type'] = page['doc_type'] + outputs.append(output_model) + documents = merge_sbt_output(outputs) + results = { + "total_pages": len(list_url), + "ocr_num_pages": len(list_url), + "document": documents + } + return results + + +def main(): + """ + main function + """ + list_url = [{"file_url": "https://www.irs.gov/pub/irs-pdf/fw9.pdf", "page_number": 1, "request_file_id": 1}, ...] + results = compile_output(list_url) + with open('output.json', 'w', encoding='utf-8') as file: + json.dump(results, file, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cope2n-ai-fi/common/serve_model.py b/cope2n-ai-fi/common/serve_model.py new file mode 100755 index 0000000..8d1cc3c --- /dev/null +++ b/cope2n-ai-fi/common/serve_model.py @@ -0,0 +1,93 @@ +import cv2 +from common.ocr import ocr_predict +from common.crop_location import crop_location +from Kie_AHung.prediction import infer_driving_license +from Kie_AHung_ID.prediction import infer_id_card +from common.post_processing_datetime import DatetimeCorrector +from transformers import ( + LayoutXLMTokenizer, + LayoutLMv2FeatureExtractor, + LayoutXLMProcessor + ) + +MAX_SEQ_LENGTH = 512 # TODO Fix this hard code +tokenizer = LayoutXLMTokenizer.from_pretrained( + "./Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH +) + +feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) +processor = LayoutXLMProcessor(feature_extractor, tokenizer) + +max_n_words = 100 + +def predict(rq_id, sub_id, folder_name, list_url, user_id, infer_name): + """Predict text from image + + Args: + image_path (str): path to image + + Returns: + dict: dict result of prediction + """ + + results = { + "model":{ + "name":infer_name, + "confidence": 1.0, + "type": "finance/invoice", + "isValid": True, + "shape": "letter", + } + } + compile_outputs = [] + for page in list_url: + image_location = crop_location(page['file_url']) + if image_location is None: + compile_output = { + 'page_index': page['page_number'], + 'path_image_croped': None, + 'request_file_id': page['request_file_id'], + 'fields': None + } + compile_outputs.append(compile_output) + + elif image_location is not None: + path_image_croped = "/app/media/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id) + cv2.imwrite("/users/{}/subscriptions/{}/requests/{}/{}/image_croped.jpg".format(user_id,sub_id,folder_name,rq_id), image_location) + list_line = ocr_predict(image_location) + + if infer_name == "driving_license": + from common.post_processing_driver import create_kie_dict + _, _, _, list_words = infer_driving_license(image_location, list_line, max_n_words, processor) + result_dict = create_kie_dict(list_words) + elif infer_name == "id_card": + from common.post_processing_id import create_kie_dict + _, _, _, list_words = infer_id_card(image_location, list_line, max_n_words, processor) + result_dict = create_kie_dict(list_words) + + fields = [] + for kie_label in result_dict: + if result_dict[kie_label]["text"] != "": + if kie_label == "Date Range": + text = DatetimeCorrector.correct(result_dict[kie_label]["text"]) + else: + text = result_dict[kie_label]["text"] + + field = { + "label": kie_label, + "value": text.replace("✪", " ") if "✪" in text else text, + "box": result_dict[kie_label]["bbox"], + "confidence": 0.99 #TODO: add confidence + } + fields.append(field) + + compile_output = { + 'page_index': page['page_number'], + 'path_image_croped': str(path_image_croped), + 'request_file_id': page['request_file_id'], + 'fields': fields + } + + compile_outputs.append(compile_output) + results['pages'] = compile_outputs + return results diff --git a/cope2n-ai-fi/common/utils/blurry_detection.py b/cope2n-ai-fi/common/utils/blurry_detection.py new file mode 100755 index 0000000..7cbf1af --- /dev/null +++ b/cope2n-ai-fi/common/utils/blurry_detection.py @@ -0,0 +1,35 @@ +import cv2 +import urllib +import numpy as np + +class BlurryDetection: + def __init__(self): + # initialize the detector + pass + + def variance_of_laplacian(self, image): + # compute the Laplacian of the image and then return the focus + # measure, which is simply the variance of the Laplacian + return cv2.Laplacian(image, cv2.CV_64F).var() + + def __call__(self, img, thr=100): + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + fm = self.variance_of_laplacian(gray) + + if fm >= thr: + return "non_blurry", fm + else: + return "blurry", fm + + +detector = BlurryDetection() + + +def check_blur(image_url): + req = urllib.request.urlopen(image_url) + arr = np.asarray(bytearray(req.read()), dtype=np.uint8) + img = cv2.imdecode(arr, -1) + pred = detector(img, thr=10) + score = pred[0] + return score diff --git a/cope2n-ai-fi/common/utils/global_variables.py b/cope2n-ai-fi/common/utils/global_variables.py new file mode 100755 index 0000000..c59f7df --- /dev/null +++ b/cope2n-ai-fi/common/utils/global_variables.py @@ -0,0 +1,71 @@ +MAX_SEQ_LENGTH = 512 +DEVICE = "cuda:0" +KIE_LABELS = [ + "other", + "form_key", + "form_value", + "serial_key", + "serial_value", + "no_key", + "no_value", + "date", + "seller_name_value", + "seller_name_key", + "seller_tax_code_key", + "seller_tax_code_value", + "seller_address_value", + "seller_address_key", + "seller_mobile_key", + "buyer_name_key", + "buyer_company_name_key", + "buyer_company_name_value", + "buyer_tax_code_key", + "buyer_tax_code_value", + "buyer_address_value", + "buyer_address_key", + "VAT_amount_key", + "VAT_amount_value", + "total_key", + "total_value", + "total_in_words_key", + "total_in_words_value", + "seller_mobile_value", + "buyer_name_value", + "buyer_mobile_key", + "buyer_mobile_value", +] + +BRIEF_LABELS = [ + "o", + "fk", + "fv", + "sk", + "sv", + "nk", + "nv", + "d", + "snv", + "snk", + "stck", + "stcv", + "sav", + "sak", + "smk", + "bnk", + "bcnk", + "bcnv", + "btck", + "btcv", + "bav", + "bak", + "VATk", + "VATv", + "tk", + "tv", + "tiwk", + "tiwv", + "smv", + "bnv", + "bmk", + "bmv", +] diff --git a/cope2n-ai-fi/common/utils/layoutLM_utils.py b/cope2n-ai-fi/common/utils/layoutLM_utils.py new file mode 100755 index 0000000..854c33f --- /dev/null +++ b/cope2n-ai-fi/common/utils/layoutLM_utils.py @@ -0,0 +1,78 @@ +from config import config as cfg +import json +import glob +from sklearn.model_selection import train_test_split +import os +import pandas as pd + + +def load_kie_labels_yolo(label_path): + with open(label_path, "r") as f: + lines = f.read().splitlines() + words, boxes, labels = [], [], [] + for line in lines: + x1, y1, x2, y2, text, kie = line.split("\t") + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + if text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + labels.append(kie) + return words, boxes, labels + + +def create_empty_kie_dict(): + return { + cfg.KIE_LABELS[i]: [] + for i in range(len(cfg.KIE_LABELS)) + if cfg.KIE_LABELS[i] != cfg.IGNORE_KIE_LABEL + } + + +def write_to_json_(file_path, content): + with open(file_path, mode="w", encoding="utf8") as f: + json.dump(content, f, ensure_ascii=False) + + +def load_train_val_id_cards(train_root, label_path): + train_labels = glob.glob(os.path.join(label_path, "*.txt")) + img_names = [ + os.path.basename(train_label).replace(".txt", ".jpg") + for train_label in train_labels + ] + train_paths = [os.path.join(train_root, img_name) for img_name in img_names] + train_df = pd.DataFrame.from_dict( + {"image_path": train_paths, "label": train_labels} + ) + train, test = train_test_split(train_df, test_size=0.2, random_state=cfg.SEED) + return train, test + + +def read_json(file_path): + with open(file_path, "r") as f: + return json.load(f) + + +def get_name(file_path, ext: bool = True): + file_path_ = os.path.basename(file_path) + return file_path_ if ext else os.path.splitext(file_path_)[0] + + +def construct_file_path(dir, file_path, ext=""): + """ + args: + dir: /path/to/dir + file_path /example_path/to/file.txt + ext = '.json' + return + /path/to/dir/file.json + """ + return ( + os.path.join(dir, get_name(file_path, True)) + if ext == "" + else os.path.join(dir, get_name(file_path, False)) + ext + ) + + +def write_to_txt_(file_path, content): + with open(file_path, "w") as f: + f.write(content) diff --git a/cope2n-ai-fi/common/utils/merge_box.py b/cope2n-ai-fi/common/utils/merge_box.py new file mode 100755 index 0000000..c184232 --- /dev/null +++ b/cope2n-ai-fi/common/utils/merge_box.py @@ -0,0 +1,163 @@ +import cv2 +import numpy as np + +# tuplify +def tup(point): + return (point[0], point[1]) + + +# returns true if the two boxes overlap +def overlap(source, target): + # unpack points + tl1, br1 = source + tl2, br2 = target + + # checks + if tl1[0] >= br2[0] or tl2[0] >= br1[0]: + return False + if tl1[1] >= br2[1] or tl2[1] >= br1[1]: + return False + return True + + +# returns all overlapping boxes +def getAllOverlaps(boxes, bounds, index): + overlaps = [] + for a in range(len(boxes)): + if a != index and overlap(bounds, boxes[a]): + overlaps.append(a) + return overlaps + + +img = cv2.imread("test.png") +orig = np.copy(img) +blue, green, red = cv2.split(img) + + +def medianCanny(img, thresh1, thresh2): + median = np.median(img) + img = cv2.Canny(img, int(thresh1 * median), int(thresh2 * median)) + return img + + +blue_edges = medianCanny(blue, 0, 1) +green_edges = medianCanny(green, 0, 1) +red_edges = medianCanny(red, 0, 1) + +edges = blue_edges | green_edges | red_edges + +# I'm using OpenCV 3.4. This returns (contours, hierarchy) in OpenCV 2 and 4 +_, contours, hierarchy = cv2.findContours( + edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE +) + +# go through the contours and save the box edges +boxes = [] +# each element is [[top-left], [bottom-right]]; +hierarchy = hierarchy[0] +for component in zip(contours, hierarchy): + currentContour = component[0] + currentHierarchy = component[1] + x, y, w, h = cv2.boundingRect(currentContour) + if currentHierarchy[3] < 0: + cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 1) + boxes.append([[x, y], [x + w, y + h]]) + +# filter out excessively large boxes +filtered = [] +max_area = 30000 +for box in boxes: + w = box[1][0] - box[0][0] + h = box[1][1] - box[0][1] + if w * h < max_area: + filtered.append(box) +boxes = filtered + +# go through the boxes and start merging +merge_margin = 15 + +# this is gonna take a long time +finished = False +highlight = [[0, 0], [1, 1]] +points = [[[0, 0]]] +while not finished: + # set end con + finished = True + + # check progress + print("Len Boxes: " + str(len(boxes))) + + # draw boxes # comment this section out to run faster + copy = np.copy(orig) + for box in boxes: + cv2.rectangle(copy, tup(box[0]), tup(box[1]), (0, 200, 0), 1) + cv2.rectangle(copy, tup(highlight[0]), tup(highlight[1]), (0, 0, 255), 2) + for point in points: + point = point[0] + cv2.circle(copy, tup(point), 4, (255, 0, 0), -1) + cv2.imshow("Copy", copy) + key = cv2.waitKey(1) + if key == ord("q"): + break + + # loop through boxes + index = len(boxes) - 1 + while index >= 0: + # grab current box + curr = boxes[index] + + # add margin + tl = curr[0][:] + br = curr[1][:] + tl[0] -= merge_margin + tl[1] -= merge_margin + br[0] += merge_margin + br[1] += merge_margin + + # get matching boxes + overlaps = getAllOverlaps(boxes, [tl, br], index) + + # check if empty + if len(overlaps) > 0: + # combine boxes + # convert to a contour + con = [] + overlaps.append(index) + for ind in overlaps: + tl, br = boxes[ind] + con.append([tl]) + con.append([br]) + con = np.array(con) + + # get bounding rect + x, y, w, h = cv2.boundingRect(con) + + # stop growing + w -= 1 + h -= 1 + merged = [[x, y], [x + w, y + h]] + + # highlights + highlight = merged[:] + points = con + + # remove boxes from list + overlaps.sort(reverse=True) + for ind in overlaps: + del boxes[ind] + boxes.append(merged) + + # set flag + finished = False + break + + # increment + index -= 1 +cv2.destroyAllWindows() + +# show final +copy = np.copy(orig) +for box in boxes: + cv2.rectangle(copy, tup(box[0]), tup(box[1]), (0, 200, 0), 1) +cv2.imshow("Final", copy) +cv2.waitKey(0) diff --git a/cope2n-ai-fi/common/utils/ocr_yolox.py b/cope2n-ai-fi/common/utils/ocr_yolox.py new file mode 100755 index 0000000..f41cbba --- /dev/null +++ b/cope2n-ai-fi/common/utils/ocr_yolox.py @@ -0,0 +1,77 @@ +import numpy as np +from .utils import get_crop_img_and_bbox +from sdsvtr import StandaloneSATRNRunner +from sdsvtd import StandaloneYOLOXRunner +import urllib +import cv2 + + +class YoloX: + def __init__(self, checkpoint): + self.model = StandaloneYOLOXRunner(checkpoint, device = "cuda:0") + + def inference(self, img=None): + runner = self.model + return runner(img) + + +class Classifier_SATRN: + def __init__(self, checkpoint): + self.model = StandaloneSATRNRunner(checkpoint, return_confident=True, device = "cuda:0") + + def inference(self, numpy_image): + model_inference = self.model + result = model_inference(numpy_image) + preds_str = result[0] + confidence = result[1] + return preds_str, confidence + +class OcrEngineForYoloX_Invoice: + def __init__(self, det_ckpt, cls_ckpt): + self.det = YoloX(det_ckpt) + self.cls = Classifier_SATRN(cls_ckpt) + + def run_image(self, img): + + pred_det = self.det.inference(img) + pred_det = pred_det[0] + + pred_det = sorted(pred_det, key=lambda box: [box[1], box[0]]) + if len(pred_det) == 0: + return [], [] + else: + bboxes = np.vstack(pred_det) + lbboxes = [] + lcropped_img = [] + assert len(bboxes) != 0, f"No bbox found in image, skipped" + for bbox in bboxes: + try: + crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True) + lbboxes.append(bbox_) + lcropped_img.append(crop_img) + except AssertionError as e: + print(e) + print(f"[ERROR]: Skipping invalid bbox in image") + lwords, _ = self.cls.inference(lcropped_img) + return lbboxes, lwords + +class OcrEngineForYoloX_ID_Driving: + def __init__(self, det_ckpt, cls_ckpt): + self.det = YoloX(det_ckpt) + self.cls = Classifier_SATRN(cls_ckpt) + + def run_image(self, img): + pred_det = self.det.inference(img) + bboxes = np.vstack(pred_det) + lbboxes = [] + lcropped_img = [] + assert len(bboxes) != 0, f"No bbox found in image, skipped" + for bbox in bboxes: + try: + crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True) + lbboxes.append(bbox_) + lcropped_img.append(crop_img) + except AssertionError: + print(f"[ERROR]: Skipping invalid bbox image in ") + lwords, _ = self.cls.inference(lcropped_img) + return lbboxes, lwords diff --git a/cope2n-ai-fi/common/utils/process_label.py b/cope2n-ai-fi/common/utils/process_label.py new file mode 100755 index 0000000..6f76e33 --- /dev/null +++ b/cope2n-ai-fi/common/utils/process_label.py @@ -0,0 +1,139 @@ +import os +import cv2 as cv +import glob +from xml.dom.expatbuilder import parseString +from lxml.etree import Element, tostring, SubElement +import tqdm +from common.utils.global_variables import * + + +def boxes_to_xml(boxes_lst, xml_pth, img_pth=""): + """_summary_ + + Args: + boxes_lst (_type_): _description_ + xml_pth (_type_): _description_ + img_pth (str, optional): _description_. Defaults to ''. + """ + node_root = Element("annotation") + + node_folder = SubElement(node_root, "folder") + node_folder.text = "images" + + node_filename = SubElement(node_root, "filename") + node_filename.text = os.path.basename(img_pth) + + # insert size of image + if img_pth == "": + width, height = 0, 0 + else: + img = cv.imread(img_pth) + new_path = xml_pth[:-3] + "jpg" + cv.imwrite(new_path, img) + width, height = img.shape[:2] + + node_size = SubElement(node_root, "size") + + node_width = SubElement(node_size, "width") + node_width.text = str(width) + + node_height = SubElement(node_size, "height") + node_height.text = str(height) + + node_depth = SubElement(node_size, "depth") + node_depth.text = "3" + + node_segmented = SubElement(node_root, "segmented") + node_segmented.text = "0" + + for box in boxes_lst: + left, top, right, bottom = box.xmin, box.ymin, box.xmax, box.ymax + left, top, right, bottom = str(left), str(top), str(right), str(bottom) + label = box.label + if label == None: + label = "" + + node_object = SubElement(node_root, "object") + node_name = SubElement(node_object, "name") + node_name.text = label + + node_pose = SubElement(node_object, "pose") + node_pose.text = "Unspecified" + node_truncated = SubElement(node_object, "truncated") + node_truncated.text = "0" + node_difficult = SubElement(node_object, "difficult") + node_difficult.text = "0" + + # insert bounding box + node_bndbox = SubElement(node_object, "bndbox") + node_xmin = SubElement(node_bndbox, "xmin") + node_xmin.text = left + node_ymin = SubElement(node_bndbox, "ymin") + node_ymin.text = top + node_xmax = SubElement(node_bndbox, "xmax") + node_xmax.text = right + node_ymax = SubElement(node_bndbox, "ymax") + node_ymax.text = bottom + + xml = tostring(node_root, pretty_print=True) + dom = parseString(xml) + with open(xml_pth, "w+", encoding="utf-8") as f: + dom.writexml(f, indent="\t", addindent="\t", encoding="utf-8") + + +class Box: + def __init__(self): + self.xmax = 0 + self.ymax = 0 + self.xmin = 0 + self.ymin = 0 + self.label = "" + self.kie_label = "" + + +def check_iou(box1: Box, box2: Box, threshold=0.9): + area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin) + area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) + xmin_intersect = max(box1.xmin, box2.xmin) + ymin_intersect = max(box1.ymin, box2.ymin) + xmax_intersect = min(box1.xmax, box2.xmax) + ymax_intersect = min(box1.ymax, box2.ymax) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + area_intersect = 0 + else: + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect * ymin_intersect + ) + union = area1 + area2 - area_intersect + print(union) + iou = area_intersect / area1 + if iou > threshold: + return True + return False + + +DATA_ROOT = "/home/sds/hoangmd/TokenClassification/images/infer" +PSEUDO_LABEL = "/home/sds/hoangmd/TokenClassification/infer/" +list_files = glob.glob(PSEUDO_LABEL + "*.txt") + +for file in tqdm.tqdm(list_files): + xml_path = os.path.join("generated_label/", os.path.basename(file)[:-3] + "xml") + img_path = os.path.join(DATA_ROOT, os.path.basename(file)[:-3] + "jpg") + if not os.path.exists(img_path): + continue + f = open(file, "r", encoding="utf-8") + boxes = [] + for line in f.readlines(): + xmin, ymin, xmax, ymax, label = line.split("\t") + label = label[:-1] + box = Box() + box.xmin = int(float(xmin)) # left , top , right, bottom + box.ymin = int(float(ymin)) + box.xmax = int(float(xmax)) + box.ymax = int(float(ymax)) + box.label = label + boxes.append(box) + f.close() + boxes.sort(key=lambda x: [x.ymin, x.xmin]) + + boxes_to_xml(boxes, xml_path, img_path) diff --git a/cope2n-ai-fi/common/utils/utils.py b/cope2n-ai-fi/common/utils/utils.py new file mode 100755 index 0000000..f812c1a --- /dev/null +++ b/cope2n-ai-fi/common/utils/utils.py @@ -0,0 +1,180 @@ +import os +import json +import glob +import random +import cv2 + + +def read_txt(file): + with open(file, "r", encoding="utf8") as f: + data = [line.strip() for line in f] + return data + + +def write_txt(file, data): + with open(file, "w", encoding="utf8") as f: + for item in data: + f.write(item + "\n") + + +def write_json(file, data): + with open(file, "w", encoding="utf8") as f: + json.dump(data, f, ensure_ascii=False, sort_keys=True) + + +def read_json(file): + with open(file, "r", encoding="utf8") as f: + data = json.load(f) + return data + + +def get_colors(kie_labels): + + random.seed(1997) + colors = [] + for _ in range(len(kie_labels)): + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + colors.append(color) + + return colors + + +def normalize_box(box, width, height): + assert ( + max(box) <= width or max(box) <= height + ), "box must smaller than width, height; max box = {}, width = {}, height = {}".format( + max(box), width, height + ) + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def unnormalize_box(bbox, width, height): + return [ + width * (bbox[0] / 1000), + height * (bbox[1] / 1000), + width * (bbox[2] / 1000), + height * (bbox[3] / 1000), + ] + + +def load_image_paths_and_labels(data_dir): + r"""Load (image path, label) pairs into a DataFrame with keys ``image_path`` and ``label`` + + @todo Add OCR paths here + """ + + img_paths = [path for path in glob.glob(data_dir + "/*") if ".txt" not in path] + label_paths = [os.path.splitext(path)[0] + ".txt" for path in img_paths] + + return img_paths, label_paths + + +import cv2 + + +def read_image_file(img_path): + image = cv2.imread(img_path) + return image + + +def normalize_bbox(x1, y1, x2, y2, w, h): + x1 = int(float(min(max(0, x1), w))) + x2 = int(float(min(max(0, x2), w))) + y1 = int(float(min(max(0, y1), h))) + y2 = int(float(min(max(0, y2), h))) + return (x1, y1, x2, y2) + + +def extend_crop_img( + left, top, right, bottom, margin_l=0, margin_t=0.03, margin_r=0.02, margin_b=0.05 +): + top = top - (bottom - top) * margin_t + bottom = bottom + (bottom - top) * margin_b + left = left - (right - left) * margin_l + right = right + (right - left) * margin_r + return left, top, right, bottom + + +def get_crop_img_and_bbox(img, bbox, extend: bool): + """ + img : numpy array img + bbox : should be xyxy format + """ + if len(bbox) == 5: + left, top, right, bottom, _conf = bbox + elif len(bbox) == 4: + left, top, right, bottom = bbox + if extend: + left, top, right, bottom = extend_crop_img(left, top, right, bottom) + left, top, right, bottom = normalize_bbox( + left, top, right, bottom, img.shape[1], img.shape[0] + ) + assert (bottom - top) * (right - left) > 0, "bbox is invalid" + crop_img = img[top:bottom, left:right] + return crop_img, (left, top, right, bottom) + + +import json +import os + + + +def load_kie_labels_yolo(label_path): + with open(label_path, 'r') as f: + lines = f.read().splitlines() + words, boxes, labels = [], [], [] + for line in lines: + x1, y1, x2, y2, text, kie = line.split("\t") + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + if text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + labels.append(kie) + return words, boxes, labels + + +def create_empty_kie_dict(): + return {cfg.KIE_LABELS[i]: [] for i in range(len(cfg.KIE_LABELS)) if cfg.KIE_LABELS[i] != cfg.IGNORE_KIE_LABEL} + + +def write_to_json_(file_path, content): + with open(file_path, mode='w', encoding='utf8') as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, 'r') as f: + return json.load(f) + + +def get_name(file_path, ext: bool = True): + file_path_ = os.path.basename(file_path) + return file_path_ if ext else os.path.splitext(file_path_)[0] + + +def construct_file_path(dir, file_path, ext=''): + ''' + args: + dir: /path/to/dir + file_path /example_path/to/file.txt + ext = '.json' + return + /path/to/dir/file.json + ''' + return os.path.join( + dir, get_name(file_path, + True)) if ext == '' else os.path.join( + dir, get_name(file_path, + False)) + ext + + +def write_to_txt_(file_path, content): + with open(file_path, 'w') as f: + f.write(content) + + diff --git a/cope2n-ai-fi/common/utils/word_formation.py b/cope2n-ai-fi/common/utils/word_formation.py new file mode 100755 index 0000000..ed5b071 --- /dev/null +++ b/cope2n-ai-fi/common/utils/word_formation.py @@ -0,0 +1,599 @@ +from builtins import dict +from common.utils.global_variables import * + +MIN_IOU_HEIGHT = 0.7 +MIN_WIDTH_LINE_RATIO = 0.05 + + +class Word: + def __init__( + self, + text="", + image=None, + conf_detect=0.0, + conf_cls=0.0, + bndbox=None, + kie_label="", + ): + self.type = "word" + self.text = text + self.image = image + self.conf_detect = conf_detect + self.conf_cls = conf_cls + self.boundingbox = bndbox if bndbox is not None else [-1, -1, -1, -1]# [left, top,right,bot] coordinate of top-left and bottom-right point + self.word_id = 0 # id of word + self.word_group_id = 0 # id of word_group which instance belongs to + self.line_id = 0 # id of line which instance belongs to + self.paragraph_id = 0 # id of line which instance belongs to + self.kie_label = kie_label + + def invalid_size(self): + return (self.boundingbox[2] - self.boundingbox[0]) * ( + self.boundingbox[3] - self.boundingbox[1] + ) > 0 + + def is_special_word(self): + left, top, right, bottom = self.boundingbox + width, height = right - left, bottom - top + text = self.text + + if text is None: + return True + + if len(text) >= 7: + no_digits = sum(c.isdigit() for c in text) + return no_digits / len(text) >= 0.3 + + return False + + +class Word_group: + def __init__(self): + self.type = "word_group" + self.list_words = [] # dict of word instances + self.word_group_id = 0 # word group id + self.line_id = 0 # id of line which instance belongs to + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = "" + self.boundingbox = [-1, -1, -1, -1] + self.kie_label = "" + + def add_word(self, word: Word): # add a word instance to the word_group + if word.text != "✪": + for w in self.list_words: + if word.word_id == w.word_id: + print("Word id collision") + return False + word.word_group_id = self.word_group_id # + word.line_id = self.line_id + word.paragraph_id = self.paragraph_id + self.list_words.append(word) + self.text += " " + word.text + if self.boundingbox == [-1, -1, -1, -1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [ + min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3]), + ] + return True + else: + return False + + def update_word_group_id(self, new_word_group_id): + self.word_group_id = new_word_group_id + for i in range(len(self.list_words)): + self.list_words[i].word_group_id = new_word_group_id + + def update_kie_label(self): + list_kie_label = [word.kie_label for word in self.list_words] + dict_kie = dict() + for label in list_kie_label: + if label not in dict_kie: + dict_kie[label] = 1 + else: + dict_kie[label] += 1 + max_value = max(list(dict_kie.values())) + list_keys = list(dict_kie.keys()) + list_values = list(dict_kie.values()) + self.kie_label = list_keys[list_values.index(max_value)] + + def update_text(self): # update text after changing positions of words in list word + text = "" + for word in self.list_words: + text += " " + word.text + self.text = text + + +class Line: + def __init__(self): + self.type = "line" + self.list_word_groups = [] # list of Word_group instances in the line + self.line_id = 0 # id of line in the paragraph + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = "" + self.boundingbox = [-1, -1, -1, -1] + + def add_group(self, word_group: Word_group): # add a word_group instance + if word_group.list_words is not None: + for wg in self.list_word_groups: + if word_group.word_group_id == wg.word_group_id: + print("Word_group id collision") + return False + + self.list_word_groups.append(word_group) + self.text += word_group.text + word_group.paragraph_id = self.paragraph_id + word_group.line_id = self.line_id + + for i in range(len(word_group.list_words)): + word_group.list_words[ + i + ].paragraph_id = self.paragraph_id # set paragraph_id for word + word_group.list_words[i].line_id = self.line_id # set line_id for word + return True + return False + + def update_line_id(self, new_line_id): + self.line_id = new_line_id + for i in range(len(self.list_word_groups)): + self.list_word_groups[i].line_id = new_line_id + for j in range(len(self.list_word_groups[i].list_words)): + self.list_word_groups[i].list_words[j].line_id = new_line_id + + def merge_word(self, word): # word can be a Word instance or a Word_group instance + if word.text != "✪": + if self.boundingbox == [-1, -1, -1, -1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [ + min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3]), + ] + self.list_word_groups.append(word) + self.text += " " + word.text + return True + return False + + def __cal_ratio(self, top1, bottom1, top2, bottom2): + sorted_vals = sorted([top1, bottom1, top2, bottom2]) + intersection = sorted_vals[2] - sorted_vals[1] + min_height = min(bottom1 - top1, bottom2 - top2) + if min_height == 0: + return -1 + ratio = intersection / min_height + return ratio + + def __cal_ratio_height(self, top1, bottom1, top2, bottom2): + + height1, height2 = top1 - bottom1, top2 - bottom2 + ratio_height = float(max(height1, height2)) / float(min(height1, height2)) + return ratio_height + + def in_same_line(self, input_line, thresh=0.7): + # calculate iou in vertical direction + _, top1, _, bottom1 = self.boundingbox + _, top2, _, bottom2 = input_line.boundingbox + + ratio = self.__cal_ratio(top1, bottom1, top2, bottom2) + ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2) + + if ( + (top1 in range(top2, bottom2) or top2 in range(top1, bottom1)) + and ratio >= thresh + and (ratio_height < 2) + ): + return True + return False + + +class Paragraph: + def __init__(self, id=0, lines=None): + self.list_lines = ( + lines if lines is not None else [] + ) # list of all lines in the paragraph + self.paragraph_id = id # index of paragraph in the ist of paragraph + self.text = "" + self.boundingbox = [-1, -1, -1, -1] + + def add_line(self, line: Line): # add a line instance + if line.list_word_groups is not None: + for l in self.list_lines: + if line.line_id == l.line_id: + print("Line id collision") + return False + for i in range(len(line.list_word_groups)): + line.list_word_groups[ + i + ].paragraph_id = ( + self.paragraph_id + ) # set paragraph id for every word group in line + for j in range(len(line.list_word_groups[i].list_words)): + line.list_word_groups[i].list_words[ + j + ].paragraph_id = ( + self.paragraph_id + ) # set paragraph id for every word in word groups + line.paragraph_id = self.paragraph_id # set paragraph id for line + self.list_lines.append(line) # add line to paragraph + self.text += " " + line.text + return True + else: + return False + + def update_paragraph_id( + self, new_paragraph_id + ): # update new paragraph_id for all lines, word_groups, words inside paragraph + self.paragraph_id = new_paragraph_id + for i in range(len(self.list_lines)): + self.list_lines[ + i + ].paragraph_id = new_paragraph_id # set new paragraph_id for line + for j in range(len(self.list_lines[i].list_word_groups)): + self.list_lines[i].list_word_groups[ + j + ].paragraph_id = new_paragraph_id # set new paragraph_id for word_group + for k in range(len(self.list_lines[i].list_word_groups[j].list_words)): + self.list_lines[i].list_word_groups[j].list_words[ + k + ].paragraph_id = new_paragraph_id # set new paragraph id for word + return True + + +def resize_to_original( + boundingbox, scale +): # resize coordinates to match size of original image + left, top, right, bottom = boundingbox + left *= scale[1] + right *= scale[1] + top *= scale[0] + bottom *= scale[0] + return [left, top, right, bottom] + + +def check_iomin(word: Word, word_group: Word_group): + min_height = min( + word.boundingbox[3] - word.boundingbox[1], + word_group.boundingbox[3] - word_group.boundingbox[1], + ) + intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max( + word.boundingbox[1], word_group.boundingbox[1] + ) + if intersect / min_height > 0.7: + return True + return False + + +def prepare_line(words): + lines = [] + visited = [False] * len(words) + for id_word, word in enumerate(words): + if word.invalid_size() == 0: + continue + new_line = True + for i in range(len(lines)): + if ( + lines[i].in_same_line(word) and not visited[id_word] + ): # check if word is in the same line with lines[i] + lines[i].merge_word(word) + new_line = False + visited[id_word] = True + + if new_line == True: + new_line = Line() + new_line.merge_word(word) + lines.append(new_line) + + # print(len(lines)) + # sort line from top to bottom according top coordinate + lines.sort(key=lambda x: x.boundingbox[1]) + return lines + + +def __create_word_group(word, word_group_id): + new_word_group = Word_group() + new_word_group.word_group_id = word_group_id + new_word_group.add_word(word) + + return new_word_group + + +def __sort_line(line): + line.list_word_groups.sort( + key=lambda x: x.boundingbox[0] + ) # sort word in lines from left to right + + return line + + +def __merge_text_for_line(line): + line.text = "" + for word in line.list_word_groups: + line.text += " " + word.text + + return line + + +def __update_list_word_groups(line, word_group_id, word_id, line_width): + + old_list_word_group = line.list_word_groups + list_word_groups = [] + + inital_word_group = __create_word_group(old_list_word_group[0], word_group_id) + old_list_word_group[0].word_id = word_id + list_word_groups.append(inital_word_group) + word_group_id += 1 + word_id += 1 + + for word in old_list_word_group[1:]: + check_word_group = True + word.word_id = word_id + word_id += 1 + + if ( + (not list_word_groups[-1].text.endswith(":")) + and ( + (word.boundingbox[0] - list_word_groups[-1].boundingbox[2]) / line_width + < MIN_WIDTH_LINE_RATIO + ) + and check_iomin(word, list_word_groups[-1]) + ): + list_word_groups[-1].add_word(word) + check_word_group = False + + if check_word_group: + new_word_group = __create_word_group(word, word_group_id) + list_word_groups.append(new_word_group) + word_group_id += 1 + line.list_word_groups = list_word_groups + return line, word_group_id, word_id + + +def construct_word_groups_in_each_line(lines): + line_id = 0 + word_group_id = 0 + word_id = 0 + for i in range(len(lines)): + if len(lines[i].list_word_groups) == 0: + continue + + # left, top ,right, bottom + line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left + + lines[i] = __sort_line(lines[i]) + + # update text for lines after sorting + lines[i] = __merge_text_for_line(lines[i]) + + lines[i], word_group_id, word_id = __update_list_word_groups( + lines[i], word_group_id, word_id, line_width + ) + lines[i].update_line_id(line_id) + line_id += 1 + return lines + + +def words_to_lines(words, check_special_lines=True): # words is list of Word instance + # sort word by top + words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0])) + number_of_word = len(words) + # print(number_of_word) + # sort list words to list lines, which have not contained word_group yet + lines = prepare_line(words) + + # construct word_groups in each line + lines = construct_word_groups_in_each_line(lines) + return lines, number_of_word + + +def near(word_group1: Word_group, word_group2: Word_group): + min_height = min( + word_group1.boundingbox[3] - word_group1.boundingbox[1], + word_group2.boundingbox[3] - word_group2.boundingbox[1], + ) + overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max( + word_group1.boundingbox[1], word_group2.boundingbox[1] + ) + + if overlap > 0: + return True + if abs(overlap / min_height) < 1.5: + print("near enough", abs(overlap / min_height), overlap, min_height) + return True + return False + + +def calculate_iou_and_near(wg1: Word_group, wg2: Word_group): + min_height = min( + wg1.boundingbox[3] - wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1] + ) + overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max( + wg1.boundingbox[1], wg2.boundingbox[1] + ) + iou = overlap / min_height + distance = min( + abs(wg1.boundingbox[0] - wg2.boundingbox[2]), + abs(wg1.boundingbox[2] - wg2.boundingbox[0]), + ) + if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]): + return True + return False + + +def construct_word_groups_to_kie_label(list_word_groups: list): + kie_dict = dict() + for wg in list_word_groups: + if wg.kie_label == "other": + continue + if wg.kie_label not in kie_dict: + kie_dict[wg.kie_label] = [wg] + else: + kie_dict[wg.kie_label].append(wg) + + new_dict = dict() + for key, value in kie_dict.items(): + if len(value) == 1: + new_dict[key] = value + continue + + value.sort(key=lambda x: x.boundingbox[1]) + new_dict[key] = value + return new_dict + + +def invoice_construct_word_groups_to_kie_label(list_word_groups: list): + kie_dict = dict() + + for wg in list_word_groups: + if wg.kie_label == "other": + continue + if wg.kie_label not in kie_dict: + kie_dict[wg.kie_label] = [wg] + else: + kie_dict[wg.kie_label].append(wg) + + return kie_dict + + +def postprocess_total_value(kie_dict): + if "total_in_words_value" not in kie_dict: + return kie_dict + + for k, value in kie_dict.items(): + if k == "total_in_words_value": + continue + l = [] + for v in value: + if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]: + l.append(v) + + if len(l) != 0: + kie_dict[k] = l + + return kie_dict + + +def postprocess_tax_code_value(kie_dict): + if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict: + return kie_dict + + kie_dict["buyer_tax_code_value"] = [] + for v in kie_dict["seller_tax_code_value"]: + if "buyer_name_key" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_key"][0]) + ): + kie_dict["buyer_tax_code_value"].append(v) + continue + + if "buyer_name_value" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_value"][0]) + ): + kie_dict["buyer_tax_code_value"].append(v) + continue + + if "buyer_address_value" in kie_dict and near( + kie_dict["buyer_address_value"][0], v + ): + kie_dict["buyer_tax_code_value"].append(v) + return kie_dict + + +def postprocess_tax_code_key(kie_dict): + if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict: + return kie_dict + kie_dict["buyer_tax_code_key"] = [] + for v in kie_dict["seller_tax_code_key"]: + if "buyer_name_key" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_key"][0]) + ): + kie_dict["buyer_tax_code_key"].append(v) + continue + + if "buyer_name_value" in kie_dict and ( + v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] + or near(v, kie_dict["buyer_name_value"][0]) + ): + kie_dict["buyer_tax_code_key"].append(v) + continue + + if "buyer_address_value" in kie_dict and near( + kie_dict["buyer_address_value"][0], v + ): + kie_dict["buyer_tax_code_key"].append(v) + + return kie_dict + + +def invoice_postprocess(kie_dict: dict): + # all keys or values which are below total_in_words_value will be thrown away + kie_dict = postprocess_total_value(kie_dict) + kie_dict = postprocess_tax_code_value(kie_dict) + kie_dict = postprocess_tax_code_key(kie_dict) + return kie_dict + + +def throw_overlapping_words(list_words): + new_list = [list_words[0]] + for word in list_words: + overlap = False + area = (word.boundingbox[2] - word.boundingbox[0]) * ( + word.boundingbox[3] - word.boundingbox[1] + ) + for word2 in new_list: + area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * ( + word2.boundingbox[3] - word2.boundingbox[1] + ) + xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0]) + xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2]) + ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1]) + ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3]) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + continue + + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect - ymin_intersect + ) + if area_intersect / area > 0.7 or area_intersect / area2 > 0.7: + overlap = True + if overlap == False: + new_list.append(word) + return new_list + + +class Box: + def __init__(self, xmin=0, ymin=0, xmax=0, ymax=0, label="", kie_label=""): + self.xmax = xmax + self.ymax = ymax + self.xmin = xmin + self.ymin = ymin + self.label = label + self.kie_label = kie_label + + +def check_iou(box1: Word, box2: Box, threshold=0.9): + area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * ( + box1.boundingbox[3] - box1.boundingbox[1] + ) + area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) + xmin_intersect = max(box1.boundingbox[0], box2.xmin) + ymin_intersect = max(box1.boundingbox[1], box2.ymin) + xmax_intersect = min(box1.boundingbox[2], box2.xmax) + ymax_intersect = min(box1.boundingbox[3], box2.ymax) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + area_intersect = 0 + else: + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect - ymin_intersect + ) + union = area1 + area2 - area_intersect + iou = area_intersect / union + if iou > threshold: + return True + return False diff --git a/cope2n-ai-fi/common/utils_invoice/load_model.py b/cope2n-ai-fi/common/utils_invoice/load_model.py new file mode 100755 index 0000000..915a2f3 --- /dev/null +++ b/cope2n-ai-fi/common/utils_invoice/load_model.py @@ -0,0 +1,61 @@ +from torch import nn +import torch +from transformers import ( + LayoutXLMTokenizer, + LayoutLMv2FeatureExtractor, + LayoutXLMProcessor, + LayoutLMv2ForTokenClassification, +) + + +class PositionalEncoding(nn.Module): + """Positional encoding.""" + + def __init__(self, num_hiddens, max_len=10000): + super(PositionalEncoding, self).__init__() + # Create a long enough `P` + self.num_hiddens = num_hiddens + + def forward(self, inputs): + max_len = inputs.shape[1] + P = torch.zeros((1, max_len, self.num_hiddens)) + X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow( + 10000, + torch.arange(0, self.num_hiddens, 2, dtype=torch.float32) + / self.num_hiddens, + ) + P[:, :, 0::2] = torch.sin(X) + P[:, :, 1::2] = torch.cos(X) + return P.to(inputs.device) + + +def load_layoutlmv2_custom_model( + weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list +): + + model, processor = load_layoutlmv2(tokenizer_dir, weight_dir, max_seq_len, classes) + # fix for longer lenght + model.layoutlmv2.embeddings.position_embeddings = PositionalEncoding( + num_hiddens=768, max_len=max_seq_len + ) + model.layoutlmv2.embeddings.max_position_embeddings = max_seq_len + model.config.max_position_embeddings = max_seq_len + model.layoutlmv2.embeddings.register_buffer( + "position_ids", torch.arange(max_seq_len).expand((1, -1)) + ) + + return model, processor + + +def load_layoutlmv2( + weight_dir: str, tokenizer_dir: str, max_seq_len: int, classes: list +): + tokenizer = LayoutXLMTokenizer.from_pretrained( + pretrained_model_name_or_path=tokenizer_dir, model_max_length=max_seq_len + ) + feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) + processor = LayoutXLMProcessor(feature_extractor, tokenizer) + model = LayoutLMv2ForTokenClassification.from_pretrained( + weight_dir, num_labels=len(classes) + ) + return model, processor diff --git a/cope2n-ai-fi/common/utils_invoice/run_ocr.py b/cope2n-ai-fi/common/utils_invoice/run_ocr.py new file mode 100755 index 0000000..01b3df9 --- /dev/null +++ b/cope2n-ai-fi/common/utils_invoice/run_ocr.py @@ -0,0 +1,13 @@ +from ..utils.ocr_yolox import OcrEngineForYoloX_Invoice + + +det_ckpt = "yolox-s-general-text-pretrain-20221226" +cls_ckpt = "satrn-lite-general-pretrain-20230106" + +ocr_engine = OcrEngineForYoloX_Invoice(det_ckpt, cls_ckpt) + + +def ocr_predict(image_url): + + bboxes, texts = ocr_engine.run_image(image_url) + return bboxes, texts \ No newline at end of file diff --git a/cope2n-ai-fi/common/utils_kvu/split_docs.py b/cope2n-ai-fi/common/utils_kvu/split_docs.py new file mode 100755 index 0000000..cf0d16f --- /dev/null +++ b/cope2n-ai-fi/common/utils_kvu/split_docs.py @@ -0,0 +1,149 @@ +import os +import glob +import json +from tqdm import tqdm + +def longestCommonSubsequence(text1: str, text2: str) -> int: + # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis + dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)] + for i, c in enumerate(text1): + for j, d in enumerate(text2): + dp[i + 1][j + 1] = 1 + \ + dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j]) + return dp[-1][-1] + +def write_to_json(file_path, content): + with open(file_path, mode="w", encoding="utf8") as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, "r") as f: + return json.load(f) + +def check_label_exists(array, target_label): + for obj in array: + if obj["label"] == target_label: + return True # Label exists in the array + return False # Label does not exist in the array + +def merged_kvu_outputs(loutputs: list) -> dict: + compiled = [] + for output_model in loutputs: + for field in output_model: + if field['value'] != "" and not check_label_exists(compiled, field['label']): + element = { + 'label': field['label'], + 'value': field['value'], + } + compiled.append(element) + elif field['label'] == 'table' and check_label_exists(compiled, "table"): + for index, obj in enumerate(compiled): + if obj['label'] == 'table' and len(field['value']) > 0: + compiled[index]['value'].append(field['value']) + return compiled + + +def split_docs(doc_data: list, threshold: float=0.6) -> list: + num_pages = len(doc_data) + outputs = [] + kvu_content = [] + doc_data = sorted(doc_data, key=lambda x: int(x['page_number'])) + for data in doc_data: + page_id = int(data['page_number']) + doc_type = data['document_type'] + doc_class = data['document_class'] + fields = data['fields'] + if page_id == 0: + prev_title = doc_type + start_page_id = page_id + prev_class = doc_class + curr_title = doc_type if doc_type != "unknown" else prev_title + curr_class = doc_class if doc_class != "unknown" else "other" + kvu_content.append(fields) + similarity_score = longestCommonSubsequence(curr_title, prev_title) / len(prev_title) + if similarity_score < threshold: + end_page_id = page_id - 1 + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": end_page_id, + "content": merged_kvu_outputs(kvu_content[:-1]) + }) + prev_title = curr_title + prev_class = curr_class + start_page_id = page_id + kvu_content = kvu_content[-1:] + if page_id == num_pages - 1: # end_page + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": page_id, + "content": merged_kvu_outputs(kvu_content) + }) + elif page_id == num_pages - 1: # end_page + outputs.append({ + "doc_type": f"({prev_class}) {prev_title}" if prev_class != "other" else prev_title, + "start_page": start_page_id, + "end_page": page_id, + "content": merged_kvu_outputs(kvu_content) + }) + return outputs + + +def merge_sbt_output(loutputs): + # TODO: This function is too circumlocutory, need to refactor the whole flow + def dict_to_list_of_dict(the_dict): + output = [] + for k,v in the_dict.items(): + output.append({ + 'label': k, + 'value': v, + }) + return output + + merged_output = [] + combined_output = {"retailername": None, + "sold_to_party": None, + "purchase_date": [], + "imei_number": []} # place holder for the output + for output in loutputs: + fields = output['fields'] + if "doc_type" not in output: # Should not contain more than 1 page + for field in fields: + combined_output[field["label"]] = field["value"] + combined_output["imei_number"] = [combined_output["imei_number"]] + break + else: + if output['doc_type'] == "imei": + for field in fields: + if field["label"] == "imei_number": + combined_output[field["label"]].append(field["value"]) + if output['doc_type'] == "invoice": + for field in fields: + if field["label"] in ["retailername", "sold_to_party", "purchase_date"] : + if isinstance(combined_output[field["label"]], list): + if field["value"] is not None: + if isinstance(field["value"], list): + combined_output[field["label"]] += field["value"] + else: + combined_output[field["label"]].append(field["value"]) + else: + combined_output[field["label"]] = field["value"] + + merged_output.append({ + "doc_type": "sbt_document", + "start_page": 1, + "end_page": len(loutputs), + "content": dict_to_list_of_dict(combined_output) + }) + return merged_output + +if __name__ == "__main__": + threshold = 0.9 + json_path = "/home/sds/tuanlv/02-KVU/02-KVU_test/visualize/manulife_v2/json_outputs/HS_YCBT_No_IP_HMTD.json" + doc_data = read_json(json_path) + + outputs = split_docs(doc_data, threshold) + + write_to_json(os.path.join(os.path.dirname(json_path), "splited_doc.json"), outputs) \ No newline at end of file diff --git a/cope2n-ai-fi/common/utils_ocr/create_kie_labels.py b/cope2n-ai-fi/common/utils_ocr/create_kie_labels.py new file mode 100755 index 0000000..2199ff1 --- /dev/null +++ b/cope2n-ai-fi/common/utils_ocr/create_kie_labels.py @@ -0,0 +1,67 @@ +# %% +# from pathlib import Path # add Fiintrade path to import config, required to run main() +import sys + +# TODO: Why??? for what reason ??????????????? +sys.path.append(".") # add Fiintrade/ to path + + +from srcc.tools.utils import ( + load_kie_labels_yolo, + create_empty_kie_dict, + write_to_json_, + load_train_val_id_cards, +) +import glob +from OCRBase.config import config as cfg +import os +import pandas as pd + +sys.path.append("/home/sds/hoangmd/TokenClassification") # TODO: Why there are bunch of absolute path here +from src.experiments.word_formation import * +from process_label import * + +KIE_LABEL_DIR = "data/label/207/kie" +KIE_LABEL_LINE_PATH = "/home/sds/hungbnt/KIE_pretrained/data/label/207/json" # TODO: Absolute path ????? + +# %% + + +def create_kie_dict(list_words): + kie_dict = create_empty_kie_dict() + list_words = throw_overlapping_words(list_words) + for word in list_words: + if word.kie_label in kie_dict: + kie_dict[word.kie_label].append(word) + word.text = word.text.strip() + for kie_label in kie_dict: + list_lines, _ = words_to_lines(kie_dict[kie_label]) + kie_dict[kie_label] = "\n ".join([line.text.strip() for line in list_lines]) + return kie_dict + + +# %% + + +def main(): + label_paths = glob.glob(f"{KIE_LABEL_DIR}/*.txt") + for label_path in label_paths: + words, bboxes, kie_labels = load_kie_labels_yolo(label_path) + list_words = [] + for i, kie_label in enumerate(kie_labels): + list_words.append( + Word(text=words[i], bndbox=bboxes[i], kie_label=kie_label) + ) + + kie_dict = create_kie_dict(list_words) + kie_path = os.path.join( + KIE_LABEL_LINE_PATH, os.path.basename(label_path).replace(".txt", ".json") + ) + write_to_json_(kie_path, kie_dict) + + +# %% + + +if __name__ == "__main__": + main() diff --git a/cope2n-ai-fi/configs/config_id_dr/__init__.py b/cope2n-ai-fi/configs/config_id_dr/__init__.py new file mode 100755 index 0000000..dc8da6b --- /dev/null +++ b/cope2n-ai-fi/configs/config_id_dr/__init__.py @@ -0,0 +1,8 @@ +from ...common.configs.config import BASE_CONFIG, V2, V3, ID_CARD + +__mapping__ = { + "base": BASE_CONFIG, + "v2": V2, + "v3": V3, + "id_card": ID_CARD, +} diff --git a/cope2n-ai-fi/configs/config_id_dr/config.py b/cope2n-ai-fi/configs/config_id_dr/config.py new file mode 100755 index 0000000..d6853a9 --- /dev/null +++ b/cope2n-ai-fi/configs/config_id_dr/config.py @@ -0,0 +1,212 @@ +# GLOBAL VARS +DEVICE = "cuda:0" +IGNORE_KIE_LABEL = "others" +KIE_LABELS = [ + "id", + "name", + "dob", + "home", + "add", + "sex", + "nat", + "exp", + "eth", + "rel", + "date", + "org", + IGNORE_KIE_LABEL, + "rank", +] +SEED = 42 +NAME_LABEL = "microsoft/layoutxlm-base" +########################################## +BASE_CONFIG = { + "global": { + "device": DEVICE, + "kie_labels": KIE_LABELS, + }, + "data": { + "custom": True, + "path": "src/custom/load_data.py", + "method": "load_data", + "train_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/synthesis_for_train/", + "val_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/SDV_Meddoc_BirthCert/", + "max_seq_len": 512, + "batch_size": 8, + "pretrained_processor": NAME_LABEL, + "kie_labels": KIE_LABELS, + "device": DEVICE, + }, + "model": { + "custom": True, + "path": "src/custom/load_model.py", + "method": "load_model", + "pretrained_model": NAME_LABEL, + "kie_labels": KIE_LABELS, + "device": DEVICE, + }, + "optimizer": { + "custom": True, + "path": "src/custom/load_optimizer.py", + "method": "load_optimizer", + "lr": 5e-6, + "weight_decay": 0, + "betas": (0.9, 0.999), + }, + "trainer": { + "custom": True, + "path": "src/custom/load_trainer.py", + "method": "load_trainer", + "kie_labels": KIE_LABELS, + "save_dir": "weights", + "n_epoches": 100, + }, +} + +ID_CARD = BASE_CONFIG +ID_CARD["data"] = { + "custom": True, + "path": "src/custom/load_data_id_card.py", + "method": "load_data", + "train_path": "/home/sds/hungbnt/KIE_pretrained/data/207/idcard_cmnd_8-9-2022", + "label_path": "/home/sds/hungbnt/KIE_pretrained/data/207/label/", + "max_seq_len": 512, + "batch_size": 8, + "pretrained_processor": NAME_LABEL, + "kie_labels": KIE_LABELS, + "device": DEVICE, +} + + +# GLOBAL VARS +DEVICE = "cuda:1" +# DEVICE = "cpu" +# DEVICE = "cpu" # for debugging https://stackoverflow.com/questions/51691563/cuda-runtime-error-59-device-side-assert-triggered +# DEVICE = "cpu" +# KIE_LABELS = ['gen', 'nk', 'nv', 'dobk', 'dobv', 'other'] +IGNORE_KIE_LABEL = 'others' +# KIE_LABELS = ['id', 'name', 'dob', 'home', 'add', 'sex', 'nat', 'exp', 'eth', 'rel', 'date', 'org', IGNORE_KIE_LABEL] +# KIE_WEIGHTS = "/home/sds/hungbnt/KIE_pretrained/weights/ID_CARD_145_train_300_val_0.02_char_0.06_word" +# TODO: current yield index error if pass to gplx['data]['kie_label] (maybe mismatch with somewhere else) => fix this so that kie_label in gplx can be made global +KIE_LABELS = ['id', 'name', 'dob', 'home', 'add', 'sex', 'nat', + 'exp', 'eth', 'rel', 'date', 'org', IGNORE_KIE_LABEL, 'rank'] +KIE_WEIGHTS = 'weights/driver_license' +SEED = 42 + +########################################## +BASE_CONFIG = { + 'global': { + 'device': DEVICE, + 'kie_labels': KIE_LABELS, + }, + "data": { + "custom": True, + "path": "src/custom/load_data.py", + "method": "load_data", + "train_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/synthesis_for_train/", + "val_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/SDV_Meddoc_BirthCert/", + # "size": 320, + "max_seq_len": 512, + "batch_size": 8, + # "workers": 10, + 'pretrained_processor': 'microsoft/layoutxlm-base', + 'kie_labels': KIE_LABELS, + 'device': DEVICE, + }, + + "model": { + "custom": True, + "path": "src/custom/load_model.py", + "method": "load_model", + "pretrained_model": 'microsoft/layoutxlm-base', + 'kie_labels': KIE_LABELS, + 'device': DEVICE, + }, + + "optimizer": { + "custom": True, + "path": "src/custom/load_optimizer.py", + "method": "load_optimizer", + "lr": 5e-6, + "weight_decay": 0, # default = 0 + "betas": (0.9, 0.999), # beta1 in transformer, default = 0.9 + }, + + "trainer": { + "custom": True, + "path": "src/custom/load_trainer.py", + "method": "load_trainer", + "kie_labels": KIE_LABELS, + "save_dir": 'weights', + "n_epoches": 100, + }, +} + +V2 = BASE_CONFIG +# V2['data'] = { +# "custom": True, +# "pretrained_model": 'microsoft/layoutxlm-base', +# 'kie_labels': KIE_LABELS, +# 'device': DEVICE, +# } + +V3 = BASE_CONFIG +# V3["data"] = { +# "custom": True, +# "path": "src/custom/load_data_v3.py", +# "method": "load_data", +# "train_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/synthesis_for_train/", +# "val_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/SDV_Meddoc_BirthCert/", +# # "size": 320, +# "max_seq_len": 512, +# "batch_size": 8, +# # "workers": 10, +# 'pretrained_processor': "microsoft/layoutlmv3-base", +# 'kie_labels': KIE_LABELS, +# 'device': DEVICE, +# } +# V3['model'] = {; +# "custom": False, +# 'name': 'layoutlm_v3', +# "pretrained_model": 'microsoft/layoutlmv3-base', +# 'kie_labels': KIE_LABELS, +# 'device': DEVICE, +# } + +ID_CARD = BASE_CONFIG +ID_CARD['data'] = { + "custom": True, + "path": "src/custom/load_data_id_card.py", + "method": "load_data", + "train_path": "/home/sds/hungbnt/KIE_pretrained/data/207/idcard_cmnd_8-9-2022", + "label_path": "/home/sds/hungbnt/KIE_pretrained/data/207/label/", + # "size": 320, + "max_seq_len": 512, + "batch_size": 8, + # "workers": 10, + 'pretrained_processor': 'microsoft/layoutxlm-base', + 'kie_labels': KIE_LABELS, + 'device': DEVICE, +} + + +GPLX = BASE_CONFIG +GPLX['data'] = { + "custom": True, + "path": "srcc/custom/load_data_gplx.py", + "method": "load_data", + "train_path": "/home/sds/hungbnt/KIE_pretrained/data/GPLX/train/crop_blx_10_10_2022", + "val_path": "/home/sds/hungbnt/KIE_pretrained/data/GPLX/val/crop_blx_5_10_2022", + "train_label_path": "/home/sds/hungbnt/KIE_pretrained/data/label/GPLX/kie/train", + "val_label_path": "/home/sds/hungbnt/KIE_pretrained/data/label/GPLX/kie/val", + # "size": 320, + "max_seq_len": 512, + "batch_size": 8, + # "workers": 10, + 'pretrained_processor': 'microsoft/layoutxlm-base', + 'kie_labels': KIE_LABELS, + 'device': DEVICE, +} + + + diff --git a/cope2n-ai-fi/configs/config_invoice/layoutxlm_base_invoice.py b/cope2n-ai-fi/configs/config_invoice/layoutxlm_base_invoice.py new file mode 100755 index 0000000..292d85c --- /dev/null +++ b/cope2n-ai-fi/configs/config_invoice/layoutxlm_base_invoice.py @@ -0,0 +1,67 @@ +CONFIF_PATH = __file__ +TRAIN_DIR = "/home/sds/hoanglv/Projects/TokenClassification_invoice/DATA/train" +TEST_DIR = "/home/sds/hoanglv/Projects/TokenClassification_invoice/DATA/test" +TOKENIZER_DIR = "Kie_Hoanglv/model/layoutxlm-base-tokenizer" +TOKENIZER_NAME = "microsoft/layoutxlm-base" +MODEL_WEIGHT = "microsoft/layoutxlm-base" +# pretrained model hyperparameter +MAX_SEQ_LENGTH = 512 +IMG_SIZE = 224 # default + +VN_list_char = "aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!#$%&()*+,-./:;<=>?@[\]^_`{|}~" + +DEVICE = "cuda:0" +SAVE_DIR = "runs/layoutxlm-base-17-10-2022-maxwords150_samplingv2" +BATCH_SIZE = 8 +NUM_WORKER = 0 +EPOCHS = 100 +SAVE_INTERVAL = 1000 +LR_RATE = 5e-6 # ori: 5e-5 + +# infer +MAX_N_WORDS = 150 +TRAINED_DIR = "Kie_Hoanglv/model/layoutxlm-base-17-10-2022-maxwords150_samplingv2/last" +PRED_DIR = "/home/sds/hoanglv/Projects/TokenClassification_invoice/runs/infer/kie_e2e_pred_17-10-2022-maxwords150_samplingv2_rm_dup_boxes_test" +VISUALIZE_DIR = PRED_DIR + "/visualize" + +KIE_LABELS = [ + # id invoice + "no_key", + "no_value", + "form_key", + "form_value", + "serial_key", + "serial_value", + "date", + # seller info + "seller_company_name_key", + "seller_company_name_value", + "seller_tax_code_key", + "seller_tax_code_value", + "seller_address_value", + "seller_address_key", + "seller_mobile_key", + "seller_mobile_value", + # buyer info + "buyer_name_key", + "buyer_name_value", + "buyer_company_name_value", + "buyer_company_name_key", + "buyer_tax_code_key", + "buyer_tax_code_value", + "buyer_address_key", + "buyer_address_value", + "buyer_mobile_key", + "buyer_mobile_value", + # money info + "VAT_amount_key", + "VAT_amount_value", + "total_key", + "total_value", + "total_in_words_key", + "total_in_words_value", + "other", +] + + +SKIP_LABEL_EVAL = ["buyer_mobile_value"] diff --git a/cope2n-ai-fi/configs/config_ocr/__init__.py b/cope2n-ai-fi/configs/config_ocr/__init__.py new file mode 100755 index 0000000..9f5a9de --- /dev/null +++ b/cope2n-ai-fi/configs/config_ocr/__init__.py @@ -0,0 +1,8 @@ +from .config import BASE_CONFIG, V2, V3, ID_CARD + +__mapping__ = { + "base": BASE_CONFIG, + "v2": V2, + "v3": V3, + "id_card": ID_CARD, +} diff --git a/cope2n-ai-fi/configs/config_ocr/config.py b/cope2n-ai-fi/configs/config_ocr/config.py new file mode 100755 index 0000000..a29dfc8 --- /dev/null +++ b/cope2n-ai-fi/configs/config_ocr/config.py @@ -0,0 +1,79 @@ +# GLOBAL VARS +DEVICE = "cuda:0" +IGNORE_KIE_LABEL = "others" +KIE_LABELS = [ + "id", + "name", + "dob", + "home", + "add", + "sex", + "nat", + "exp", + "eth", + "rel", + "date", + "org", + IGNORE_KIE_LABEL, +] + +SEED = 42 +NAME_LABEL = "microsoft/layoutxlm-base" + +########################################## +BASE_CONFIG = { + "global": { + "device": DEVICE, + "kie_labels": KIE_LABELS, + }, + "data": { + "custom": True, + "path": "src/custom/load_data.py", + "method": "load_data", + "train_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/synthesis_for_train/", + "val_path": "/home/sds/hoangmd/TokenClassification_copy/giaykhaisinh/SDV_Meddoc_BirthCert/", + "max_seq_len": 512, + "batch_size": 8, + "pretrained_processor": NAME_LABEL , + "kie_labels": KIE_LABELS, + "device": DEVICE, + }, + "model": { + "custom": True, + "path": "src/custom/load_model.py", + "method": "load_model", + "pretrained_model": NAME_LABEL, + "kie_labels": KIE_LABELS, + "device": DEVICE, + }, + "optimizer": { + "custom": True, + "path": "src/custom/load_optimizer.py", + "method": "load_optimizer", + "lr": 5e-6, + "weight_decay": 0, + "betas": (0.9, 0.999), + }, + "trainer": { + "custom": True, + "path": "src/custom/load_trainer.py", + "method": "load_trainer", + "kie_labels": KIE_LABELS, + "save_dir": "weights", + "n_epoches": 100, + }, +} + +ID_CARD = BASE_CONFIG +ID_CARD["data"] = { + "custom": True, + "path": "src/custom/load_data_id_card.py", + "method": "load_data", + "train_path": "/home/sds/hungbnt/KIE_pretrained/data/207/idcard_cmnd_8-9-2022", + "label_path": "/home/sds/hungbnt/KIE_pretrained/data/207/label/", + "max_seq_len": 512, + "batch_size": 8, + "pretrained_processor": NAME_LABEL, + "kie_labels": KIE_LABELS, + "device": DEVICE, +} diff --git a/cope2n-ai-fi/configs/default_env.py b/cope2n-ai-fi/configs/default_env.py new file mode 100644 index 0000000..f3fe0b6 --- /dev/null +++ b/cope2n-ai-fi/configs/default_env.py @@ -0,0 +1,3 @@ +CELERY_BROKER = "" +SAP_KIE_MODEL = "" +FI_KIE_MODEL = "" diff --git a/cope2n-ai-fi/configs/manulife/__init__.py b/cope2n-ai-fi/configs/manulife/__init__.py new file mode 100644 index 0000000..624fa5c --- /dev/null +++ b/cope2n-ai-fi/configs/manulife/__init__.py @@ -0,0 +1,3 @@ +from .configs import device +from .configs import ocr_engine as ocr_cfg +from .configs import kvu_model as kvu_cfg \ No newline at end of file diff --git a/cope2n-ai-fi/configs/manulife/configs.py b/cope2n-ai-fi/configs/manulife/configs.py new file mode 100644 index 0000000..4e73b92 --- /dev/null +++ b/cope2n-ai-fi/configs/manulife/configs.py @@ -0,0 +1,35 @@ +device = "cuda:0" +ocr_engine = { + "detector": { + "version": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsvtd/epoch_100_params.pth", + "rotator_version": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsvtd/best_bbox_mAP_epoch_30_lite.pth", + "device": device + }, + "recognizer": { + "version": "/workspace/cope2n-ai-fi/weights/models/sdsvtr/hub/jxqhbem4to.pth", + "device": device + }, + "deskew": { + "enable": True, + "text_detector": { + "config": "/workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml", + "weight": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsv_dewarp/ch_PP-OCRv3_det_infer" + }, + "text_cls": { + "config": "/workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml", + "weight": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsv_dewarp/ch_ppocr_mobile_v2.0_cls_infer" + }, + "device": device + } +} + +kvu_model = { + "device": device, + "mode": 3, + "option": "manulife", + "model": { + "pretrained_model_path": "/workspace/cope2n-ai-fi/weights/layoutxlm-base", + "config": "/workspace/cope2n-ai-fi/weights/models/sdsvkvu/key_value_understanding-20231024-125646_manulife2/base.yaml", + "checkpoint": "/workspace/cope2n-ai-fi/weights/models/sdsvkvu/key_value_understanding-20231024-125646_manulife2/checkpoints/best_model.pth" + } +} \ No newline at end of file diff --git a/cope2n-ai-fi/configs/sdsap_sbt/__init__.py b/cope2n-ai-fi/configs/sdsap_sbt/__init__.py new file mode 100644 index 0000000..624fa5c --- /dev/null +++ b/cope2n-ai-fi/configs/sdsap_sbt/__init__.py @@ -0,0 +1,3 @@ +from .configs import device +from .configs import ocr_engine as ocr_cfg +from .configs import kvu_model as kvu_cfg \ No newline at end of file diff --git a/cope2n-ai-fi/configs/sdsap_sbt/configs.py b/cope2n-ai-fi/configs/sdsap_sbt/configs.py new file mode 100644 index 0000000..e23e9c1 --- /dev/null +++ b/cope2n-ai-fi/configs/sdsap_sbt/configs.py @@ -0,0 +1,35 @@ +device = "cuda:0" +ocr_engine = { + "detector": { + "version": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsvtd/epoch_100_params.pth", + "rotator_version": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsvtd/best_bbox_mAP_epoch_30_lite.pth", + "device": device + }, + "recognizer": { + "version": "/workspace/cope2n-ai-fi/weights/models/sdsvtr/hub/jxqhbem4to.pth", + "device": device + }, + "deskew": { + "enable": True, + "text_detector": { + "config": "/workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml", + "weight": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsv_dewarp/ch_PP-OCRv3_det_infer" + }, + "text_cls": { + "config": "/workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml", + "weight": "/workspace/cope2n-ai-fi/weights/models/ocr_engine/sdsv_dewarp/ch_ppocr_mobile_v2.0_cls_infer" + }, + "device": device + } +} + +kvu_model = { + "device": device, + "mode": 4, + "option": "sbt_v2", + "model": { + "pretrained_model_path": "/workspace/cope2n-ai-fi/weights/layoutxlm-base", + "config": "/workspace/cope2n-ai-fi/weights/models/sdsvkvu/key_value_understanding_for_sbt-20231118-175013/base.yaml", + "checkpoint": "/workspace/cope2n-ai-fi/weights/models/sdsvkvu/key_value_understanding_for_sbt-20231118-175013/checkpoints/best_model.pth" + } +} \ No newline at end of file diff --git a/cope2n-ai-fi/docker-compose.yaml b/cope2n-ai-fi/docker-compose.yaml new file mode 100755 index 0000000..7a2590e --- /dev/null +++ b/cope2n-ai-fi/docker-compose.yaml @@ -0,0 +1,48 @@ +services: + cope2n-fi: + build: + context: . + shm_size: 10gb + dockerfile: Dockerfile + shm_size: 10gb + image: tuanlv/cope2n-ai-fi + container_name: "tuanlv-cope2n-ai-fi-dev" + network_mode: "host" + privileged: true + volumes: + - /mnt/hdd4T/OCR/tuanlv/05-copen-ai/cope2n-ai-fi:/workspace/cope2n-ai-fi # for dev container only + - /mnt/hdd2T/dxtan/TannedCung/OCR/cope2n-api:/workspace/cope2n-api + - /mnt/hdd2T/dxtan/TannedCung/OCR/cope2n-fe:/workspace/cope2n-fe + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + command: bash -c "tail -f > /dev/null" + + # train_component: + # build: + # context: . + # shm_size: 10gb + # args: + # - NODE_ENV=local + # dockerfile: Dockerfile + # shm_size: 10gb + # image: tannedcung/kubeflow-text-recognition + # container_name: "TannedCung-kubeflow-TextRecognition-Train" + # network_mode: "host" + # privileged: true + # depends_on: + # data_preparation_component: + # condition: service_completed_successfully + # volumes: + # # - /mnt/hdd2T/dxtan/TannedCung/VI/vi-vision-inspection-kubeflow/components/text_recognition:/workspace + # - /mnt/ssd500/datnt/mmocr/logs/satrn_lite_2023-04-13_fwd_finetuned:/weights/ + # - /mnt/hdd2T/dxtan/TannedCung/OCR/TextRecognition/test_input/:/test_input/ + # - /mnt/hdd2T/dxtan/TannedCung/OCR/TextRecognition/train_output/:/train_output/ + # - /mnt/hdd2T/dxtan/TannedCung/Data/:/Data + # - /mnt/hdd2T/dxtan/TannedCung/VI/vi-vision-inspection-kubeflow/components/text_recognition/configs:/configs + # command: bash -c "python /workspace/tools/train.py --config=/workspace/configs/satrn_lite.py --load_from=/weights/textrecog_fwd_tuned_20230413_params.pth --gpu_id=1 --img_path_prefix=/Data --vimlops_token=123 --total_epochs=5 --batch_size=32 --work_dir=/train_output" + # command: bash -c "tail -f > /dev/null" \ No newline at end of file diff --git a/cope2n-ai-fi/dockerfile_old b/cope2n-ai-fi/dockerfile_old new file mode 100755 index 0000000..926fc35 --- /dev/null +++ b/cope2n-ai-fi/dockerfile_old @@ -0,0 +1,27 @@ +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 + +ARG UID=1000 +ARG GID=1000 +ARG USERNAME=container-user +RUN groupadd --gid ${GID} ${USERNAME} \ + && useradd --uid ${UID} --gid ${GID} -m ${USERNAME} \ + && apt-get update \ + && apt-get install -y sudo \ + && apt install -y python3-pip ffmpeg libsm6 libxext6 \ + && apt install git -y \ + && echo ${USERNAME} ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/${USERNAME} \ + && chmod 0440 /etc/sudoers.d/${USERNAME} + +USER ${UID} +ADD --chown=${UID}:${GID} . /cope2n-ai + +WORKDIR /cope2n-ai +RUN pip3 install -r requirements.txt --no-cache-dir +RUN python3 -m pip install -e detectron2 +RUN cd /cope2n-ai/sdsvtd && pip install -v -e . +RUN cd /cope2n-ai/sdsvtr && pip install -v -e . +RUN pip install -U openmim && mim install mmcv-full==1.7.0 +RUN cd /cope2n-ai +RUN export PYTHONPATH="." + +CMD ["sh", "run.sh"] \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/setting.yml b/cope2n-ai-fi/modules/TemplateMatching/setting.yml new file mode 100755 index 0000000..d8c0933 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/setting.yml @@ -0,0 +1,21 @@ +text_detection: + setting: TemplateMatching/textdetection/setting.yml + +text_recognition: + setting: TemplateMatching/textrecognition/setting.yml + +document_classification: + setting: TemplateMatching/documentclassification/setting.yml + +template_based_extraction: + setting: TemplateMatching/templatebasedextraction/setting.yml + +id_card_detection: + setting: TemplateMatching/idcarddetection/setting.yml + +checkbox_detection: + setting: TemplateMatching/checkboxdetection/setting.yml + + +deploy: + port: 7979 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/src/ocr_master.py b/cope2n-ai-fi/modules/TemplateMatching/src/ocr_master.py new file mode 100755 index 0000000..54aa2ae --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/src/ocr_master.py @@ -0,0 +1,87 @@ +import requests +import yaml +from time import time +import numpy as np +import re + +from TemplateMatching.templatebasedextraction.src.serve_model import Predictor +from TemplateMatching.textdetection.serve_model import Predictor as TextDetector +from TemplateMatching.textrecognition.src.serve_model import Predictor as TextRecognizer + +class Extractor: + def __init__(self): + with open("./TemplateMatching/setting.yml") as f: + self.setting = yaml.safe_load(f) + self.predictor = Predictor(self.setting["template_based_extraction"]["setting"]) + self.text_detector = TextDetector(self.setting["text_detection"]["setting"]) + self.text_recognizer = TextRecognizer( + self.setting["text_recognition"]["setting"] + ) + + def _format_output(self, document): + result = dict() + for field, values in document.items(): + print(values["value"]) + if "✪" in values["value"]: + values = values["value"].replace("✪", " ") + result[field] = values + else: + values = values["value"] + result[field] = values + return result + + def _extract_idcard_info(self, images): + id_card_crops = self.idcard_detector(np.array(images)) + processed_images = [] + for i in range(len(id_card_crops)): + aligned_img = id_card_crops[i] + if aligned_img is not None: + processed_images.append(aligned_img) + else: + processed_images.append(images[i]) + return processed_images + + def _extract_id_no(self, doc): + page = doc["page_data"][0] + content = " ".join(page["contents"]) + result1 = re.findall("[0-9]{12}", content) + if len(result1) == 0: + result2 = re.findall("[0-9]{9}", content) + if len(result2) == 0: + return None + return result2 + return result1[0] + + def image_alige(self, images, tmp_json): + template_image_dir = "/" + template_name = tmp_json["template_name"] + + image_aliged = self.predictor.align_image( + images[0], tmp_json, template_image_dir, template_name + ) + + return image_aliged + + def extract_information(self, image_aliged, tmp_json): + image_aligeds = [image_aliged] + batch_boxes = self.text_detector(image_aligeds) + cropped_images = [ + image_aliged[int(y1) : int(y2), int(x1) : int(x2)] + for x1, y1, x2, y2 in batch_boxes[0] + ] + texts = self.text_recognizer(cropped_images) + texts = [res for res in texts] + + doc_page = dict() + doc_page["boxes"] = batch_boxes + doc_page["contents"] = texts + doc_page["types"] = ["word"] * len(batch_boxes) + doc_page["image"] = image_aliged + + documents_with_info = self.predictor.template_based_extractor( + batch_boxes, texts, doc_page, tmp_json + ) + + result = self._format_output(documents_with_info) + + return result \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/.gitignore b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/.gitignore new file mode 100755 index 0000000..e6cb784 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/.gitignore @@ -0,0 +1,206 @@ +# Created by https://www.toptal.com/developers/gitignore/api/jupyternotebooks,visualstudiocode,ssh,python +# Edit at https://www.toptal.com/developers/gitignore?templates=jupyternotebooks,visualstudiocode,ssh,python + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + + +### SSH ### +**/.ssh/id_* +**/.ssh/*_id_* +**/.ssh/known_hosts + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/jupyternotebooks,visualstudiocode,ssh,python \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/setting.yml b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/setting.yml new file mode 100755 index 0000000..667c844 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/setting.yml @@ -0,0 +1,17 @@ +templates: + template_im_dir: /mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/templates + + config: + { + "min_match_count": 4, + "flann_based_matcher_config": {"algorithm": 0, "trees": 5}, + "matching_topk": 2, + "distance_threshold": 0.6, + "ransac_threshold": 5.0, + "valid_size_ratio_margin": 0.15, + "valid_area_threshold": 0.75, + "image_max_size": 1024, + "similar_triangle_threshold": 4, + "roi_to_template_box_ratio": 3.0, + "default_image_size": [1654, 2368] + } \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/line_parser.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/line_parser.py new file mode 100755 index 0000000..4782c5a --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/line_parser.py @@ -0,0 +1,236 @@ +TEMPLATE_BOXES = { + "POS01": { + "page_1": [ + { + "name": "field", + "type": "text", + "position": {"top": 1951, "left": 1173}, + "size": {"width": 1224, "height": 110}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1607, "left": 457}, + "size": {"width": 787, "height": 119}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1092, "left": 1621}, + "size": {"width": 748, "height": 110}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1620, "left": 1506}, + "size": {"width": 358, "height": 79}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1875, "left": 1062}, + "size": {"width": 727, "height": 84}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1872, "left": 487}, + "size": {"width": 387, "height": 85}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1781, "left": 665}, + "size": {"width": 886, "height": 97}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1625, "left": 2085}, + "size": {"width": 301, "height": 86}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1192, "left": 608}, + "size": {"width": 752, "height": 72}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1345, "left": 415}, + "size": {"width": 1922, "height": 120}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1712, "left": 501}, + "size": {"width": 749, "height": 79}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1546, "left": 1725}, + "size": {"width": 703, "height": 87}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1261, "left": 1599}, + "size": {"width": 731, "height": 88}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1263, "left": 667}, + "size": {"width": 735, "height": 94}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1189, "left": 1549}, + "size": {"width": 785, "height": 79}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1103, "left": 524}, + "size": {"width": 835, "height": 101}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1006, "left": 657}, + "size": {"width": 1820, "height": 111}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 603, "left": 876}, + "size": {"width": 1456, "height": 114}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 691, "left": 1041}, + "size": {"width": 1299, "height": 110}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 512, "left": 1567}, + "size": {"width": 729, "height": 90}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 504, "left": 673}, + "size": {"width": 598, "height": 105}, + }, + ], + "page_2": [ + { + "name": "field", + "type": "text", + "position": {"top": 3055, "left": 1193}, + "size": {"width": 649, "height": 106}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 3055, "left": 526}, + "size": {"width": 535, "height": 95}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 2815, "left": 360}, + "size": {"width": 371, "height": 79}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 2707, "left": 433}, + "size": {"width": 805, "height": 125}, + }, + ], + }, + "POS04": { + "page_1": [ + { + "name": "field", + "type": "text", + "position": {"top": 430, "left": 583}, + "size": {"width": 958, "height": 66}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 714, "left": 844}, + "size": {"width": 348, "height": 65}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 492, "left": 689}, + "size": {"width": 858, "height": 67}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 370, "left": 1037}, + "size": {"width": 488, "height": 65}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 368, "left": 447}, + "size": {"width": 399, "height": 63}, + }, + ], + "page_2": [ + { + "name": "field", + "type": "text", + "position": {"top": 1639, "left": 287}, + "size": {"width": 263, "height": 62}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1643, "left": 1368}, + "size": {"width": 203, "height": 52}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1556, "left": 330}, + "size": {"width": 554, "height": 95}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1639, "left": 982}, + "size": {"width": 251, "height": 57}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1575, "left": 1024}, + "size": {"width": 550, "height": 69}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1800, "left": 843}, + "size": {"width": 493, "height": 80}, + }, + { + "name": "field", + "type": "text", + "position": {"top": 1798, "left": 391}, + "size": {"width": 363, "height": 81}, + }, + ], + }, +} diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/sift_based_aligner.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/sift_based_aligner.py new file mode 100755 index 0000000..dc3f6c5 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/config/sift_based_aligner.py @@ -0,0 +1,178 @@ +config = { + "template_info": { + "edit_info_1": { + "image_path": "./assest/form_1_edit_personal_info/Scan47_0.jpg", + "anchors": [ + { + "id": "4a1d57e6-2403-4884-b6dd-39c844e32efc", + "position": {"top": 77, "left": 207}, + "size": {"width": 226, "height": 198}, + }, + { + "id": "eca7c11a-f35e-4924-9488-53fbeaeee867", + "position": {"top": 61, "left": 1948}, + "size": {"width": 406, "height": 257}, + }, + { + "id": "befc0ae6-cdb7-42df-b474-22cfb2795a05", + "position": {"top": 3017, "left": 210}, + "size": {"width": 2161, "height": 285}, + }, + ], + "fields": [ + { + "id": "c1be03a0-6556-4f98-a9e1-8687c57ad4de", + "position": {"top": 512, "left": 1567}, + "size": {"width": 729, "height": 90}, + }, + { + "id": "5bcc329b-b367-484a-9952-e20e859404b1", + "position": {"top": 1622, "left": 447}, + "size": {"width": 797, "height": 106}, + }, + { + "id": "8979b1ea-9c6a-4327-b355-22f3e6ce3a7a", + "position": {"top": 1951, "left": 1173}, + "size": {"width": 1152, "height": 122}, + }, + { + "id": "a72ca9cb-d4f0-4554-954b-e8e984a8ad6a", + "position": {"top": 700, "left": 1041}, + "size": {"width": 1287, "height": 91}, + }, + { + "id": "043da3cb-81ce-4cf8-966f-da91d72d3358", + "position": {"top": 608, "left": 876}, + "size": {"width": 1450, "height": 94}, + }, + { + "id": "6e55e678-58fd-4e75-976f-fb7852bdd6bc", + "position": {"top": 504, "left": 673}, + "size": {"width": 598, "height": 105}, + }, + ], + }, + "edit_info_2": { + "image_path": "./assest/form_1_edit_personal_info/Scan47_1.jpg", + "anchors": [ + { + "id": "1cf5b737-06ca-492f-b90b-87bda100b045", + "position": {"top": 3274, "left": 2078}, + "size": {"width": 234, "height": 162}, + }, + { + "id": "4825e878-1331-48ac-8fad-90cdf6cf412d", + "position": {"top": 3247, "left": 203}, + "size": {"width": 800, "height": 176}, + }, + { + "id": "8fb8d488-8280-4a56-93e0-67f8659a738c", + "position": {"top": 52, "left": 208}, + "size": {"width": 1063, "height": 183}, + }, + ], + "fields": [ + { + "id": "2db4507e-820f-48b9-8ec9-424a0207d5ca", + "position": {"top": 2815, "left": 360}, + "size": {"width": 371, "height": 79}, + }, + { + "id": "02cb11ed-ed6e-4125-88ec-cb8043d122a5", + "position": {"top": 2707, "left": 433}, + "size": {"width": 805, "height": 125}, + }, + ], + }, + "restore_contract_1": { + "image_path": "./assest/form_4/8_0.jpg", + "anchors": [ + { + "id": "7b3bcdaa-d6ab-40ed-b70b-96513887cef4", + "position": {"top": 1443, "left": 145}, + "size": {"width": 1427, "height": 162}, + }, + { + "id": "3457d3d4-3b4c-4464-b2f4-f94a4088f7f8", + "position": {"top": 53, "left": 133}, + "size": {"width": 152, "height": 129}, + }, + { + "id": "17c3b519-e4a6-423e-8149-e21b502c255f", + "position": {"top": 51, "left": 1294}, + "size": {"width": 267, "height": 154}, + }, + ], + "fields": [ + { + "id": "deb466f7-09ee-4f0d-9d95-6cb0d620dbe6", + "position": {"top": 493, "left": 689}, + "size": {"width": 864, "height": 56}, + }, + { + "id": "ba00580f-bd9a-40df-bfef-b88a77bf96c2", + "position": {"top": 370, "left": 1037}, + "size": {"width": 488, "height": 65}, + }, + { + "id": "339b766c-1079-46d1-bc4d-d4dc5cb4ca01", + "position": {"top": 368, "left": 447}, + "size": {"width": 399, "height": 63}, + }, + { + "id": "b6f87b41-e151-4fb3-be7f-75a80fcefde0", + "position": {"top": 714, "left": 844}, + "size": {"width": 343, "height": 65}, + }, + { + "id": "cf787f24-47f5-40d6-ac68-6e9a77b61731", + "position": {"top": 430, "left": 583}, + "size": {"width": 969, "height": 66}, + }, + ], + }, + "restore_contract_2": { + "image_path": "./assest/form_4/8_1.jpg", + "anchors": [ + { + "id": "8077ead5-32bc-4be7-bcd5-c0bf31ffb56f", + "position": {"top": 1373, "left": 952}, + "size": {"width": 551, "height": 82}, + }, + { + "id": "a87a5bfc-2fcc-479c-b0d0-f6a5cbc1f841", + "position": {"top": 1369, "left": 384}, + "size": {"width": 306, "height": 89}, + }, + { + "id": "4fc086b2-42e6-4132-9af4-338558501cd6", + "position": {"top": 34, "left": 180}, + "size": {"width": 674, "height": 120}, + }, + ], + "fields": [ + { + "id": "fe118028-8c0e-4b23-9254-eb7605c41d52", + "position": {"top": 1556, "left": 330}, + "size": {"width": 542, "height": 95}, + }, + { + "id": "6109edfc-963a-419f-9cd2-456e0660c28b", + "position": {"top": 1639, "left": 287}, + "size": {"width": 263, "height": 54}, + }, + ], + }, + }, + "min_match_count": 4, + "flann_based_matcher_config": {"algorithm": 0, "trees": 5}, + "matching_topk": 2, + "distance_threshold": 0.6, + "ransac_threshold": 5.0, + "valid_size_ratio_margin": 0.15, + "valid_area_threshold": 0.75, + "image_max_size": 1024, + "similar_triangle_threshold": 4, + "roi_to_template_box_ratio": 3.0, + "default_image_size": (1654, 2368), +} diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/field_module.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/field_module.py new file mode 100755 index 0000000..8ab1271 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/field_module.py @@ -0,0 +1,1194 @@ +import cv2 + + +class FieldParser: + def __init__(self): + pass + + def parse(self, ocr_output, field_infos, iou_threshold=0.7): + """parse field infor from template + + Args: + ocr_output (list[dict]): + [ + { + 'box': [xmin, ymin, xmax, ymax] + 'text': str + } + ] + + + field_infos (list[dict]): _description_ + - example: + [ + { + 'id' : 'field_1' + 'box': [xmin, ymin, xmax, ymax], + } + [ + + Returns: + field text: + [ + { + 'id' : 'field_1' + 'box': [xmin, ymin, xmax, ymax], + 'text': 'abc' + } + [ + """ + for field_item in field_infos: + if "list_words" not in field_item: + field_item["list_words"] = [] + + for ocr_item in ocr_output: + box = ocr_item["box"] + for field_item in field_infos: + field_name = field_item["id"] + field_box = field_item["box"] + iou = self.cal_iou_custom(box, field_box) + # if iou > 0: + # print(iou, ocr_item) + if iou > iou_threshold: + field_item["list_words"].append(ocr_item) + break # break if find field box + + for field_item in field_infos: + list_words = field_item["list_words"] + list_words = sorted(list_words, key=lambda item: item["box"][0]) + field_text = " ".join([item["text"] for item in list_words]) + field_item["text"] = field_text + + return field_infos + + def cal_iou_custom(self, box_A, box_B): + """calculate iou between two boxes + union = smaller box between two boxes + + Args: + box_A (list): _description_ + box_B (list): _description_ + + Returns: + (float): iou value + """ + + area1 = (box_A[2] - box_A[0]) * (box_A[3] - box_A[1]) + area2 = (box_B[2] - box_B[0]) * (box_B[3] - box_B[1]) + + xmin_intersect = max(box_A[0], box_B[0]) + ymin_intersect = max(box_A[1], box_B[1]) + xmax_intersect = min(box_A[2], box_B[2]) + ymax_intersect = min(box_A[3], box_B[3]) + if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: + area_intersect = 0 + else: + area_intersect = (xmax_intersect - xmin_intersect) * ( + ymax_intersect - ymin_intersect + ) + + # union = area1 + area2 - area_intersect + union = min(area1, area2) + if union == 0: + return 0 + + iou = area_intersect / union + + return iou + + +def format_field_info(data, id): + """{'name': 'field', + 'type': 'text', + 'position': {'top': 1951, 'left': 1173}, + 'size': {'width': 1224, 'height': 110}} + + Args: + data (_type_): _description_ + """ + + output = {} + output["id"] = data["name"] + "_" + str(id) + xmin, ymin, w, h = ( + data["position"]["left"], + data["position"]["top"], + data["size"]["width"], + data["size"]["height"], + ) + output["box"] = [xmin, ymin, xmin + w, ymin + h] + return output + + +def vis_field(img, field_infos): + for field_item in field_infos: + box = field_item["box"] + cv2.rectangle( + img, (box[0], box[1]), (box[2], box[3]), color=(0, 255, 0), thickness=1 + ) + + return img + + +def merge_field(image, field_infos): + pass + + +if __name__ == "__main__": + ocr_output = ( + [ + (1972, 101, 2324, 231), + (1980, 238, 2213, 283), + (607, 277, 702, 317), + (525, 279, 600, 317), + (242, 280, 345, 323), + (400, 280, 518, 317), + (819, 281, 887, 324), + (348, 282, 395, 324), + (895, 282, 988, 318), + (997, 282, 1076, 325), + (712, 282, 812, 317), + (1083, 283, 1174, 318), + (241, 325, 327, 366), + (331, 328, 390, 367), + (397, 330, 684, 370), + (245, 388, 441, 457), + (778, 390, 945, 458), + (616, 390, 762, 457), + (965, 395, 1176, 459), + (458, 395, 599, 457), + (1198, 396, 1422, 472), + (1574, 397, 1677, 459), + (1697, 398, 1883, 460), + (1440, 398, 1557, 458), + (1584, 518, 2283, 608), + (976, 527, 1056, 593), + (1046, 530, 1093, 591), + (687, 531, 742, 590), + (826, 532, 881, 596), + (904, 533, 956, 594), + (751, 534, 805, 597), + (1638, 537, 1687, 597), + (1563, 537, 1621, 597), + (508, 550, 600, 589), + (606, 551, 665, 592), + (1295, 551, 1352, 594), + (327, 551, 423, 597), + (429, 552, 502, 590), + (240, 552, 322, 597), + (1354, 554, 1436, 597), + (1442, 555, 1549, 597), + (1040, 619, 1228, 709), + (1504, 626, 1631, 707), + (617, 635, 710, 675), + (719, 636, 873, 681), + (239, 636, 299, 682), + (537, 636, 611, 675), + (1276, 637, 1440, 690), + (371, 637, 444, 674), + (302, 638, 365, 675), + (450, 644, 532, 675), + (1188, 703, 1374, 778), + (1631, 704, 1743, 781), + (1423, 707, 1566, 777), + (675, 718, 768, 759), + (777, 720, 917, 764), + (238, 720, 299, 766), + (595, 720, 669, 759), + (492, 721, 589, 764), + (371, 721, 486, 764), + (301, 721, 365, 759), + (925, 722, 1038, 760), + (1663, 789, 1718, 841), + (1488, 791, 1583, 835), + (1256, 792, 1347, 840), + (562, 792, 656, 835), + (1587, 793, 1661, 835), + (483, 794, 558, 835), + (737, 794, 809, 834), + (238, 795, 313, 840), + (814, 795, 897, 834), + (1721, 795, 1806, 835), + (659, 795, 734, 842), + (1181, 796, 1254, 836), + (317, 796, 391, 835), + (1350, 796, 1383, 837), + (1812, 796, 1895, 841), + (1012, 797, 1074, 841), + (905, 797, 1007, 834), + (1079, 797, 1175, 841), + (1386, 797, 1482, 839), + (396, 802, 477, 833), + (239, 868, 300, 926), + (336, 876, 418, 922), + (302, 876, 332, 917), + (424, 877, 521, 922), + (822, 877, 897, 922), + (657, 877, 726, 918), + (730, 877, 817, 919), + (527, 878, 650, 922), + (681, 957, 741, 1007), + (1806, 959, 1869, 1007), + (1355, 960, 1425, 1006), + (1695, 960, 1732, 1002), + (1427, 960, 1488, 1001), + (506, 960, 568, 1007), + (752, 960, 821, 1007), + (824, 960, 887, 1002), + (300, 961, 367, 1007), + (370, 961, 430, 1001), + (433, 961, 503, 1001), + (1734, 961, 1803, 1001), + (1631, 962, 1692, 1002), + (1029, 963, 1090, 1002), + (1493, 963, 1626, 1006), + (892, 964, 1023, 1006), + (298, 1042, 356, 1086), + (359, 1045, 439, 1089), + (443, 1046, 510, 1085), + (515, 1046, 647, 1090), + (303, 1128, 519, 1176), + (1360, 1129, 1610, 1176), + (519, 1210, 607, 1260), + (1358, 1211, 1463, 1253), + (300, 1212, 393, 1255), + (393, 1213, 515, 1254), + (1467, 1214, 1540, 1259), + (1992, 1260, 2326, 1347), + (934, 1269, 1415, 1353), + (990, 1280, 1037, 1345), + (912, 1284, 961, 1350), + (243, 1285, 303, 1341), + (1415, 1291, 1481, 1341), + (828, 1291, 873, 1347), + (694, 1292, 738, 1343), + (1484, 1293, 1594, 1341), + (759, 1294, 799, 1345), + (540, 1294, 662, 1344), + (394, 1294, 535, 1341), + (303, 1295, 389, 1342), + (242, 1371, 301, 1423), + (303, 1379, 416, 1422), + (239, 1474, 302, 1533), + (302, 1478, 348, 1527), + (670, 1480, 742, 1524), + (439, 1480, 538, 1529), + (745, 1480, 851, 1523), + (349, 1481, 434, 1530), + (857, 1481, 959, 1524), + (541, 1483, 666, 1528), + (1561, 1558, 1618, 1606), + (1248, 1558, 1305, 1609), + (946, 1559, 1005, 1608), + (527, 1559, 586, 1608), + (835, 1561, 933, 1607), + (1305, 1562, 1367, 1612), + (1006, 1563, 1124, 1607), + (589, 1563, 664, 1607), + (757, 1563, 832, 1607), + (240, 1565, 330, 1608), + (1368, 1565, 1434, 1606), + (442, 1566, 517, 1608), + (1129, 1566, 1234, 1606), + (1440, 1566, 1556, 1604), + (335, 1567, 438, 1607), + (1621, 1570, 1723, 1611), + (669, 1573, 752, 1606), + (1855, 1642, 1909, 1693), + (1246, 1642, 1309, 1694), + (1913, 1645, 1992, 1688), + (302, 1647, 362, 1698), + (1996, 1648, 2084, 1688), + (1416, 1648, 1506, 1689), + (1311, 1648, 1412, 1695), + (1581, 1651, 1608, 1690), + (364, 1652, 441, 1692), + (1667, 1652, 1700, 1689), + (304, 1733, 407, 1775), + (410, 1733, 499, 1781), + (244, 1811, 300, 1863), + (302, 1813, 359, 1861), + (359, 1815, 441, 1866), + (442, 1817, 490, 1859), + (559, 1818, 658, 1859), + (491, 1818, 557, 1864), + (979, 1896, 1065, 1948), + (904, 1897, 977, 1942), + (408, 1900, 493, 1948), + (300, 1901, 405, 1951), + (647, 1905, 676, 1942), + (556, 1905, 584, 1942), + (518, 1909, 542, 1939), + (598, 1909, 627, 1940), + (1786, 1941, 1889, 2034), + (1922, 1949, 2024, 2054), + (2043, 1951, 2191, 2041), + (1450, 1952, 1599, 2038), + (1298, 1957, 1420, 2046), + (1641, 1968, 1765, 2036), + (243, 1973, 305, 2033), + (1078, 1978, 1169, 2029), + (794, 1980, 840, 2025), + (715, 1981, 793, 2029), + (941, 1982, 1023, 2030), + (305, 1982, 409, 2033), + (842, 1983, 938, 2031), + (413, 1984, 659, 2032), + (1024, 1986, 1077, 2030), + (660, 1988, 713, 2031), + (306, 2055, 344, 2097), + (239, 2057, 309, 2097), + (2274, 2083, 2328, 2118), + (2208, 2083, 2272, 2119), + (1980, 2083, 2053, 2123), + (2125, 2084, 2203, 2118), + (1326, 2085, 1376, 2126), + (2055, 2085, 2123, 2118), + (1911, 2085, 1979, 2124), + (1249, 2085, 1325, 2121), + (1639, 2086, 1739, 2124), + (1479, 2086, 1551, 2120), + (1170, 2086, 1247, 2126), + (1824, 2086, 1908, 2119), + (1025, 2086, 1091, 2122), + (354, 2087, 428, 2127), + (1379, 2087, 1478, 2120), + (783, 2087, 861, 2122), + (1743, 2087, 1821, 2119), + (1555, 2087, 1633, 2119), + (1093, 2087, 1168, 2121), + (945, 2087, 1022, 2122), + (863, 2087, 942, 2122), + (593, 2087, 674, 2126), + (677, 2088, 780, 2125), + (520, 2089, 591, 2124), + (430, 2089, 466, 2124), + (467, 2089, 518, 2126), + (816, 2127, 853, 2166), + (639, 2128, 694, 2167), + (945, 2128, 1050, 2165), + (695, 2129, 734, 2166), + (351, 2129, 416, 2165), + (735, 2129, 815, 2165), + (597, 2130, 639, 2164), + (518, 2130, 596, 2163), + (418, 2130, 516, 2166), + (853, 2131, 942, 2165), + (335, 2142, 357, 2170), + (1768, 2163, 1820, 2198), + (2230, 2163, 2268, 2200), + (2270, 2165, 2329, 2199), + (1126, 2165, 1194, 2205), + (1659, 2165, 1728, 2203), + (2133, 2165, 2227, 2202), + (1823, 2165, 1890, 2202), + (2092, 2165, 2130, 2198), + (1984, 2165, 2027, 2199), + (1536, 2166, 1598, 2199), + (1411, 2166, 1465, 2205), + (1893, 2166, 1981, 2202), + (1730, 2166, 1766, 2199), + (2030, 2167, 2089, 2197), + (1196, 2167, 1288, 2201), + (1342, 2167, 1410, 2205), + (1045, 2167, 1123, 2204), + (991, 2167, 1044, 2206), + (882, 2168, 953, 2206), + (1467, 2168, 1533, 2199), + (355, 2168, 410, 2204), + (679, 2169, 768, 2206), + (531, 2169, 616, 2203), + (1291, 2169, 1340, 2201), + (770, 2169, 816, 2203), + (816, 2169, 881, 2203), + (953, 2169, 990, 2204), + (617, 2169, 678, 2204), + (461, 2169, 529, 2204), + (1601, 2170, 1657, 2199), + (411, 2170, 459, 2204), + (333, 2172, 353, 2202), + (1657, 2206, 1705, 2243), + (2000, 2206, 2071, 2239), + (1772, 2206, 1837, 2239), + (1913, 2206, 1997, 2243), + (1252, 2206, 1305, 2244), + (1841, 2207, 1908, 2241), + (1560, 2207, 1654, 2239), + (1427, 2207, 1486, 2240), + (1489, 2207, 1557, 2240), + (807, 2207, 854, 2246), + (1068, 2207, 1153, 2241), + (1156, 2207, 1250, 2243), + (987, 2208, 1066, 2245), + (1707, 2208, 1768, 2242), + (596, 2208, 668, 2247), + (668, 2208, 720, 2243), + (722, 2208, 807, 2242), + (926, 2208, 986, 2242), + (855, 2208, 925, 2246), + (1309, 2208, 1424, 2243), + (501, 2209, 594, 2247), + (436, 2210, 499, 2244), + (353, 2216, 434, 2248), + (1462, 2245, 1521, 2285), + (1734, 2245, 1812, 2284), + (1213, 2246, 1288, 2281), + (1092, 2246, 1148, 2287), + (1335, 2246, 1388, 2280), + (1816, 2247, 1919, 2283), + (897, 2247, 977, 2287), + (1148, 2247, 1210, 2282), + (1618, 2247, 1680, 2285), + (737, 2247, 806, 2283), + (1391, 2247, 1459, 2285), + (1683, 2247, 1731, 2280), + (979, 2248, 1090, 2287), + (675, 2248, 735, 2282), + (809, 2248, 894, 2282), + (1923, 2248, 1991, 2279), + (1525, 2248, 1614, 2284), + (1291, 2249, 1332, 2281), + (613, 2249, 674, 2288), + (545, 2250, 611, 2288), + (472, 2250, 542, 2288), + (417, 2250, 470, 2284), + (356, 2250, 415, 2284), + (335, 2252, 352, 2284), + (697, 2326, 793, 2369), + (259, 2326, 356, 2378), + (635, 2326, 694, 2374), + (467, 2327, 538, 2368), + (543, 2329, 631, 2368), + (365, 2332, 462, 2375), + (947, 2401, 1008, 2456), + (1247, 2404, 1306, 2455), + (1564, 2406, 1619, 2449), + (839, 2407, 934, 2451), + (1307, 2408, 1368, 2457), + (1010, 2409, 1125, 2452), + (535, 2409, 587, 2463), + (1131, 2410, 1236, 2451), + (761, 2411, 834, 2451), + (1443, 2411, 1558, 2448), + (1372, 2412, 1437, 2449), + (594, 2413, 666, 2452), + (244, 2413, 331, 2453), + (1623, 2415, 1725, 2455), + (339, 2416, 439, 2452), + (447, 2416, 518, 2453), + (673, 2420, 754, 2451), + (1930, 2518, 1980, 2565), + (1879, 2519, 1926, 2565), + (1639, 2521, 1720, 2558), + (1726, 2521, 1775, 2565), + (1780, 2522, 1873, 2563), + (759, 2524, 808, 2569), + (812, 2525, 862, 2563), + (672, 2525, 754, 2562), + (690, 2613, 798, 2723), + (1626, 2632, 1837, 2756), + (465, 2727, 650, 2873), + (878, 2734, 997, 2826), + (695, 2756, 839, 2841), + (1868, 2760, 1997, 2853), + (1686, 2785, 1833, 2862), + (1488, 2787, 1667, 2942), + (2031, 3036, 2080, 3080), + (1819, 3037, 1870, 3082), + (2212, 3040, 2305, 3081), + (1683, 3040, 1797, 3084), + (1879, 3040, 1979, 3084), + (1275, 3042, 1337, 3082), + (2310, 3042, 2336, 3082), + (2089, 3042, 2204, 3081), + (1056, 3042, 1137, 3082), + (1982, 3042, 2011, 3084), + (1394, 3043, 1481, 3080), + (1608, 3043, 1675, 3080), + (1489, 3044, 1601, 3080), + (1341, 3045, 1389, 3082), + (1000, 3045, 1050, 3083), + (1191, 3046, 1270, 3087), + (1142, 3046, 1186, 3083), + (903, 3048, 994, 3088), + (739, 3048, 822, 3089), + (246, 3048, 322, 3084), + (534, 3049, 592, 3083), + (829, 3049, 896, 3084), + (331, 3049, 408, 3089), + (417, 3050, 526, 3083), + (675, 3050, 733, 3083), + (598, 3051, 668, 3089), + (346, 3097, 412, 3131), + (247, 3100, 335, 3131), + (1685, 3136, 1776, 3175), + (1276, 3138, 1494, 3182), + (1782, 3139, 1841, 3180), + (1606, 3140, 1678, 3176), + (1502, 3141, 1599, 3181), + (1197, 3141, 1268, 3178), + (544, 3142, 625, 3178), + (1030, 3143, 1102, 3178), + (793, 3144, 845, 3179), + (959, 3144, 1022, 3183), + (852, 3144, 953, 3178), + (716, 3144, 788, 3184), + (491, 3146, 538, 3185), + (632, 3146, 709, 3178), + (414, 3147, 484, 3179), + (283, 3148, 404, 3186), + (1109, 3149, 1189, 3178), + (245, 3150, 273, 3180), + (2117, 3183, 2206, 3220), + (1666, 3185, 1827, 3224), + (1961, 3187, 2028, 3221), + (2038, 3187, 2106, 3220), + (1482, 3188, 1578, 3230), + (2216, 3188, 2283, 3220), + (1586, 3189, 1658, 3225), + (1885, 3189, 1953, 3228), + (1222, 3189, 1283, 3233), + (1396, 3190, 1476, 3231), + (1097, 3190, 1163, 3227), + (1168, 3190, 1218, 3228), + (1034, 3192, 1093, 3228), + (1288, 3192, 1389, 3231), + (484, 3192, 547, 3227), + (909, 3193, 1027, 3233), + (829, 3193, 903, 3228), + (638, 3194, 718, 3233), + (554, 3195, 630, 3227), + (1832, 3195, 1880, 3224), + (726, 3195, 820, 3233), + (420, 3195, 477, 3228), + (290, 3196, 411, 3234), + (243, 3197, 281, 3230), + (1376, 3236, 1426, 3274), + (1299, 3238, 1371, 3279), + (1730, 3239, 1808, 3278), + (1671, 3239, 1724, 3273), + (1555, 3239, 1664, 3278), + (1432, 3239, 1547, 3278), + (1200, 3240, 1292, 3280), + (971, 3241, 1083, 3280), + (1091, 3241, 1192, 3275), + (865, 3241, 962, 3275), + (779, 3242, 858, 3280), + (656, 3243, 707, 3275), + (604, 3243, 650, 3275), + (522, 3243, 598, 3281), + (714, 3243, 771, 3275), + (329, 3244, 362, 3276), + (370, 3245, 462, 3283), + (246, 3246, 320, 3276), + (470, 3250, 517, 3281), + (2185, 3354, 2332, 3379), + (876, 3357, 948, 3387), + (734, 3358, 807, 3387), + (811, 3358, 871, 3391), + (522, 3358, 592, 3388), + (679, 3359, 729, 3391), + (599, 3359, 674, 3387), + (461, 3361, 517, 3388), + (250, 3362, 326, 3393), + (367, 3362, 455, 3388), + (331, 3363, 362, 3393), + (2290, 3391, 2332, 3416), + (380, 3398, 626, 3427), + (249, 3400, 374, 3427), + ], + [ + "FWD", + "insurance", + "hiểm", + "Bảo", + "Công", + "TNHH", + "thọ", + "ty", + "FWD", + "Việt", + "Nhân", + "Nam", + "Mẫu", + "số:", + "POS01_2022.09", + "Phiếu", + "Điều", + "Cầu", + "Chỉnh", + "Yêu", + "Thông", + "Cá", + "Nhân", + "Tin", + "0357788028", + "3", + "0", + "13", + "2", + "3", + "4", + "13", + "10", + "hiểm", + "số:", + "Số", + "đồng", + "bảo", + "Hợp", + "điện", + "thoại:", + "Nguyễn", + "Hiệp", + "hiểm", + "(BMBH):", + "Họ", + "bảo", + "hoan", + "Bên", + "tên", + "mua", + "Nguyễn", + "Hiệp", + "Hoàng", + "hiểm", + "(NĐBH)", + "Họ", + "bảo", + "được", + "Người", + "tên", + "chính:", + "(x)", + "đánh", + "(cắc)", + "hiểm", + "dấu", + "bảo", + "cầu", + "Tôi,", + "điều", + "dưới", + "yêu", + "của", + "Bên", + "ô", + "đây:", + "nội", + "chỉnh", + "dung", + "được", + "mua", + "X", + "Cập", + "L", + "Nhật", + "Lạc", + "Tin", + "Liên", + "Thông", + "0", + "lạc", + "Địa", + "&", + "chỉ", + "lạc", + "Địa", + "chỉ", + "Địa", + "chỉ", + "liên", + "liên", + "trú", + "trú", + "thường", + "thường", + "Số", + "nhà,", + "tên", + "đường:", + "Phường/Xã:", + "Quận/Huyện:", + "phố:", + "Quốc", + "Tỉnh/", + "Thành", + "gia:", + "T", + "345968", + "3", + "7", + "8", + "(cố", + "2", + "o", + "định):", + "3", + "động):", + "thoại(di", + "Điện", + "0", + "Email:", + "X", + "II.", + "Tin", + "Nhật", + "Nhân", + "Cập", + "Thân", + "Thông", + "bổ", + "0", + "0", + "0", + "hiểm", + "Họ", + "NĐBH", + "Bên", + "bảo", + "Điều", + "tên", + "cho", + "chính", + "NĐBH", + "chỉnh", + "sung:", + "mua", + "0", + "0", + "Giới", + "Họ", + "tính:", + "sinh:", + "Ngày", + "/", + "tên:", + "J", + "Quốc", + "tịch:", + "0", + "Số", + "giấy", + "tờ", + "thân:", + "tùy", + "cấp:", + "Nơi", + "cấp:", + "Ngày", + "/", + "/", + "✪", + "_.", + "thể", + "điện", + "thoại", + "doan", + "Kinh", + "Sum", + "_", + "thể):", + "tả", + "(mô", + "Việc", + "Nghề", + "công", + "nghiệp/Chức", + "cụ", + "vụ", + "ý:", + "Lưu", + "thẻ", + "Các", + "Giấy", + "sinh/", + "Hộ", + "khai", + "đội/", + "dân/", + "Chứng", + "Khai", + "công", + "Quân", + "Căn", + "Giấy", + "chiếu/", + "minh", + "minh", + "sinh/", + "cước", + "dân/", + "nhân", + "gồm:", + "Chứng", + "thân", + "tờ", + "tùy", + "lý", + "giá", + "đương", + "trị", + "ban", + "pháp", + "CÓ", + "khác", + "ngành", + "tương", + "✪", + "thể", + "từ", + "liên", + "Quý", + "giấy", + "chứng", + "hiện", + "và", + "tin", + "bản", + "gửi", + "thông", + "tờ", + "mới", + "khách", + "lòng", + "thân,", + "tùy", + "giấy", + "kèm", + "Đối", + "thông", + "chỉnh", + "vui", + "tin", + "trên", + "tờ", + "các", + "điều", + "sao", + "với", + "-", + "Họ", + "sinh.", + "Giới", + "Ngày", + "địa", + "tính,", + "chỉnh:", + "nếu", + "điều", + "hộ", + "chính", + "quyền", + "nhận", + "tên,", + "định", + "cải", + "chính", + "xác", + "tịch;", + "phương", + "Quyết", + "như", + "quan", + "đổi,", + "nghề", + "hiểm", + "phí", + "thể", + "nghiệp", + "nghề", + "bảo", + "ứng", + "điểu", + "thay", + "với", + "nghiệp,", + "cầu", + "chỉnh", + "mới.", + "tương", + "CÓ", + "yêu", + "hiện", + "thực", + "khi", + "Sau", + ".", + "Mẫu", + "XII.", + "Ký", + "Đổi", + "Chữ", + "Thay", + "X", + "0", + "bổ", + "hiểm", + "Họ", + "NĐBH", + "X", + "chính", + "bảo", + "NĐBH", + "tên", + "Bên", + "Điều", + "sung:", + "chỉnh", + "cho", + "mua", + "lại", + "ký", + "Chữ", + "ký", + "đăng", + "ký", + "cũ", + "Chữ", + "Hiệp", + "Huệp", + "Nguyễn", + "Hiệp", + "Hoàng", + "Hiệp", + "Hoàng", + "Nguyễn", + "0", + "0", + "đồng", + "phẩm:", + "Đồng", + "tắc", + "Ý", + "Không", + "hiểu", + "Ý", + "Điều", + "sản", + "khoản", + "và", + "đã", + "Quy", + "rõ", + "nhận", + "lòng", + "Nếu", + "lăn", + "xác", + "Quý", + "khách", + "vui", + "tay,", + "Kết", + "Cam", + "hiểm", + "hiểm/Người", + "ký.", + "bảo", + "được", + "bảo", + "mẫu", + "Bên", + "do", + "tôi,", + "chính", + "đây", + "ký", + "trên", + "chữ", + "Những", + "mua", + "1.", + "biểm", + "hiểm/Hồ", + "cầu", + "bảo", + "đồng", + "nêu", + "bảo", + "yêu", + "ghi", + "Hợp", + "tiết", + "đã", + "chi", + "trong", + "tiết", + "những", + "như", + "đây,", + "trên", + "SƠ", + "cũng", + "chi", + "Những", + "2.", + "về", + "luật", + "này.", + "tin", + "thông", + "những", + "pháp", + "nhiệm", + "trước", + "trách", + "chịu", + "tôi", + "và", + "thật", + "xin", + "là", + "đúng", + "trên", + "sự", + "V1.092022", + "Nam", + "FWD", + "Việt", + "hiểm", + "thọ", + "Nhân", + "Bảo", + "Công", + "TNHH", + "ty", + "1/2", + "www.fwd.com.vn", + "Website:", + ], + ) + + ocr_output = [ + {"box": box, "text": text} for (box, text) in zip(ocr_output[0], ocr_output[1]) + ] + + field_infos = [ + {"id": "field_0", "box": [1173, 1951, 2397, 2061]}, + {"id": "field_1", "box": [457, 1607, 1244, 1726]}, + {"id": "field_2", "box": [1621, 1092, 2369, 1202]}, + {"id": "field_3", "box": [1506, 1620, 1864, 1699]}, + {"id": "field_4", "box": [1062, 1875, 1789, 1959]}, + {"id": "field_5", "box": [487, 1872, 874, 1957]}, + {"id": "field_6", "box": [665, 1781, 1551, 1878]}, + {"id": "field_7", "box": [2085, 1625, 2386, 1711]}, + {"id": "field_8", "box": [608, 1192, 1360, 1264]}, + {"id": "field_9", "box": [415, 1345, 2337, 1465]}, + {"id": "field_10", "box": [501, 1712, 1250, 1791]}, + {"id": "field_11", "box": [1725, 1546, 2428, 1633]}, + {"id": "field_12", "box": [1599, 1261, 2330, 1349]}, + {"id": "field_13", "box": [667, 1263, 1402, 1357]}, + {"id": "field_14", "box": [1549, 1189, 2334, 1268]}, + {"id": "field_15", "box": [524, 1103, 1359, 1204]}, + {"id": "field_16", "box": [657, 1006, 2477, 1117]}, + {"id": "field_17", "box": [876, 603, 2332, 717]}, + {"id": "field_18", "box": [1041, 691, 2340, 801]}, + {"id": "field_19", "box": [1567, 512, 2296, 602]}, + {"id": "field_20", "box": [673, 504, 1271, 609]}, + ] + + # field_infos = [{'name': 'field', + # 'type': 'text', + # 'position': {'top': 1951, 'left': 1173}, + # 'size': {'width': 1224, 'height': 110}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1607, 'left': 457}, + # 'size': {'width': 787, 'height': 119}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1092, 'left': 1621}, + # 'size': {'width': 748, 'height': 110}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1620, 'left': 1506}, + # 'size': {'width': 358, 'height': 79}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1875, 'left': 1062}, + # 'size': {'width': 727, 'height': 84}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1872, 'left': 487}, + # 'size': {'width': 387, 'height': 85}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1781, 'left': 665}, + # 'size': {'width': 886, 'height': 97}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1625, 'left': 2085}, + # 'size': {'width': 301, 'height': 86}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1192, 'left': 608}, + # 'size': {'width': 752, 'height': 72}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1345, 'left': 415}, + # 'size': {'width': 1922, 'height': 120}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1712, 'left': 501}, + # 'size': {'width': 749, 'height': 79}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1546, 'left': 1725}, + # 'size': {'width': 703, 'height': 87}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1261, 'left': 1599}, + # 'size': {'width': 731, 'height': 88}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1263, 'left': 667}, + # 'size': {'width': 735, 'height': 94}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1189, 'left': 1549}, + # 'size': {'width': 785, 'height': 79}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1103, 'left': 524}, + # 'size': {'width': 835, 'height': 101}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 1006, 'left': 657}, + # 'size': {'width': 1820, 'height': 111}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 603, 'left': 876}, + # 'size': {'width': 1456, 'height': 114}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 691, 'left': 1041}, + # 'size': {'width': 1299, 'height': 110}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 512, 'left': 1567}, + # 'size': {'width': 729, 'height': 90}}, + # {'name': 'field', + # 'type': 'text', + # 'position': {'top': 504, 'left': 673}, + # 'size': {'width': 598, 'height': 105}}] + + # field_infos = [ + # format_field_info(field_item, idx) + # for idx, field_item in enumerate(field_infos) + # ] + + print(field_infos) + + img = cv2.imread( + "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/assest/form_1_edit_personal_info/Scan47_0.jpg" + ) + img = vis_field(img, field_infos) + cv2.imwrite("vis_field.jpg", img) + print(ocr_output[0]) + print(field_infos[0]) + parser = FieldParser() + field_outputs = parser.parse(ocr_output, field_infos) + + for field_item in field_outputs: + if len(field_item["list_words"]) > 0: + print(field_item["id"], field_item["text"]) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/line_parser.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/line_parser.py new file mode 100755 index 0000000..32cf756 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/line_parser.py @@ -0,0 +1,34 @@ +# @brief Line parser based on template boxes, i.e. simply crop text boxes given the original image and text boxes' coordinates +class TemplateBoxParser: + def __init__(self): + pass + + ## + # @brief Run line parser + # + # @param images: Refer to interface + # @param metadata: Refer to interface. Each metadata dict, i.e. correspond to an image, has format + # + # { + # "boxes": [ + # (top, left, w, h), + # (top, left, w, h), + # ... + # ] + # } + # + # where coordinates are absolute coordinates + # + # @return cropped_images: Refer to interface + def run(self, images, metadata): + cropped_images = [] + for image, _metadata in zip(images, metadata): + _cropped_images = [] + for box in _metadata["boxes"]: + # y, x, w, h = box + # x1, y1, x2, y2 = x, y, x + w, y + h + x1, y1, x2, y2 = box + # print(bo) + _cropped_images.append(image[y1:y2, x1:x2].copy()) + cropped_images.append(_cropped_images) + return cropped_images diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/ocr_module.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/ocr_module.py new file mode 100755 index 0000000..e1d108c --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/ocr_module.py @@ -0,0 +1,126 @@ +# temp #for debug +import glob +import os +from mmdet.apis import inference_detector, init_detector +from mmocr.apis import init_detector as init_classifier +from mmocr.apis.inference import model_inference +import cv2 +import numpy as np + +# from src.tools.utils import * + +import time +from src.utils.visualize import visualize_ocr_output + + +def clip_box(x1, y1, x2, y2, w, h): + x1 = int(float(min(max(0, x1), w))) + x2 = int(float(min(max(0, x2), w))) + y1 = int(float(min(max(0, y1), h))) + y2 = int(float(min(max(0, y2), h))) + return (x1, y1, x2, y2) + + +def get_crop_img_and_bbox(img, bbox, extend: bool = False): + """ + img : numpy array img + bbox : should be xyxy format + """ + if len(bbox) == 5: + left, top, right, bottom, _conf = bbox + elif len(bbox) == 4: + left, top, right, bottom = bbox + left, top, right, bottom = clip_box( + left, top, right, bottom, img.shape[1], img.shape[0] + ) + # assert (bottom - top) * (right - left) > 0, "bbox is invalid" + crop_img = img[top:bottom, left:right] + return crop_img, (left, top, right, bottom) + + +class YoloX: + def __init__(self, config, checkpoint, device="cuda:0"): + self.model = init_detector(config, checkpoint, device=device) + + def inference(self, img=None): + t1 = time.time() + output = inference_detector(self.model, img) + print("Time det: ", time.time() - t1) + return output + + +class Classifier_SATRN: + def __init__(self, config, checkpoint, device="cuda:0"): + self.model = init_classifier(config, checkpoint, device) + + def inference(self, numpy_image): + t1 = time.time() + result = model_inference(self.model, numpy_image, batch_mode=True) + preds_str = [r["text"] for r in result] + confidence = [r["score"] for r in result] + + print("Time reg: ", time.time() - t1) + return preds_str, confidence + + +class OcrEngine: + def __init__(self, det_cfg, det_ckpt, cls_cfg, cls_ckpt, device="cuda:0"): + self.det = YoloX(det_cfg, det_ckpt, device) + self.cls = Classifier_SATRN(cls_cfg, cls_ckpt, device) + + def run_image(self, img): + pred_det = self.det.inference(img) + + pred_det = pred_det[0] # batch_size=1 + + pred_det = sorted(pred_det, key=lambda box: [box[1], box[0]]) + bboxes = np.vstack(pred_det) + + lbboxes = [] + lcropped_img = [] + assert len(bboxes) != 0, f"No bbox found in {img_path}, skipped" + for bbox in bboxes: + try: + crop_img, bbox_ = get_crop_img_and_bbox(img, bbox, extend=True) + lbboxes.append(bbox_) + lcropped_img.append(crop_img) + except AssertionError: + print(f"[ERROR]: Skipping invalid bbox {bbox} in ", img_path) + lwords, _ = self.cls.inference(lcropped_img) + return lbboxes, lwords + + +def visualize(image, boxes, color=(0, 255, 0)): + for box in boxes: + cv2.rectangle(image, (box[0], box[1]), (box[2], box[3])) + + +if __name__ == "__main__": + det_cfg = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/weights/yolox_s_8x8_300e_cocotext_1280.py" + det_ckpt = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/weights/best_bbox_mAP_epoch_100.pth" + cls_cfg = "/home/sds/datnt/mmocr/logs/satrn_big_2022-04-25/satrn_big.py" + cls_ckpt = "/home/sds/datnt/mmocr/logs/satrn_big_2022-04-25/best.pth" + + engine = OcrEngine(det_cfg, det_ckpt, cls_cfg, cls_ckpt, device="cuda:0") + + # img_path = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/assest/form_1_edit_personal_info/Scan47_0.jpg" + + img_dir = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/raw_images/POS01" + out_dir = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/outputs/visualize_ocr/POS01" + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + img_paths = glob.glob(img_dir + "/*") + for img_path in img_paths: + img = cv2.imread(img_path) + t1 = time.time() + res = engine.run_image(img) + + visualize_ocr_output( + res, + img, + vis_dir=out_dir, + prefix_name=os.path.splitext(os.path.basename(img_path))[0], + font_path="./assest/visualize/times.ttf", + is_vis_kie=False, + ) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/satrn_classifier.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/satrn_classifier.py new file mode 100755 index 0000000..be33149 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/satrn_classifier.py @@ -0,0 +1,15 @@ +from mmocr.apis import init_detector as init_classifier +from mmocr.apis.inference import model_inference +import numpy as np +from .utils import * + + +class Classifier_SATRN: + def __init__(self, config, checkpoint, device="cuda:0"): + self.model = init_classifier(config, checkpoint, device) + + def inference(self, numpy_image): + result = model_inference(self.model, numpy_image, batch_mode=True) + preds_str = [r["text"] for r in result] + confidence = [r["score"] for r in result] + return preds_str, confidence diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/sift_based_aligner.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/sift_based_aligner.py new file mode 100755 index 0000000..67fd059 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/modules/sift_based_aligner.py @@ -0,0 +1,385 @@ +import os +import time + +import numpy as np +import cv2 + +from scipy.spatial import ConvexHull, convex_hull_plot_2d +from shapely.geometry import Polygon + +from ..utils.image_calib import check_similar_triangle, check_angle_between_2_lines +import tqdm + +## +# @brief Document classifier based on SIFT-based template matching +# @note This classifier can support varying illumination, varying out-of-plane rotation angles (up to roughly 60 degrees compared with the template image's orientation), and invariant to scale. These constraints are given in [this slides](http://vision.stanford.edu/teaching/cs231a_autumn1112/lecture/lecture12_SIFT_single_obj_recog_cs231a_marked.pdf) + + +class SIFTBasedAligner: + ## + # @brief Initializer + # + # @param template_info (dict): Mapping from class names to template info lists. The format of this dict is given as + # + # ``` + # { + # "doc_type": { + # "image_path": str, # path to template image + # "anchors": [ + # { + # "id": str, # anchor ID in template image + # "position": {"top": int, "left": int}, + # "size": {"width": , "height": float} + # }, + # { + # "id": str, # anchor ID in template image + # "position": {"top": int, "left": int}, + # "size": {"width": , "height": float} + # }, + # ... + # ] + # }, + # "doc_type": { + # "image_path": str, # path to template image + # "anchors": [ + # { + # "id": str, # anchor ID in template image + # "position": {"top": int, "left": int}, + # "size": {"width": , "height": float} + # }, + # { + # "id": str, # anchor ID in template image + # "position": {"top": int, "left": int}, + # "size": {"width": , "height": float} + # }, + # ... + # ] + # }, + # ... + # } + # + # ``` + # + # @param min_match_count (int): Minimum number of matched points for a template to be considered "FOUND" in the input image + # @param flann_based_matcher_config (dict): Configurations for cv2.FlannBasedMatcher, i.e. refer to the this class for more details + # @param matching_topk (int): This must be 2 + # @param distance_threshold (float): Upper threshold for top-1-distance-over-top-2-distance ratio after kNN matching + # @param ransac_threshold (float): RANSAC threshold for locating the template within the input image (if found) + # @param valid_size_ratio_margin (float): Valid max-size-to-min-size ratio margin of the min-area surrounding box of the found template in the image, within which the found template is considered valid + # @param valid_area_threshold (float): Valid area threhsold, above which the found template is considered valid + # @param image_max_size (int): Maximum size for the input image. None if no limit + # @param similar_triangle_threshold (float): Threshold, above which two triangles are considered non-similar + # @param roi_to_template_box_ratio (float): Ratio to scale the template anchor to estimate the ROI, within which the template anchor is likely to exist + def __init__( + self, + template_info, + min_match_count, + flann_based_matcher_config, + matching_topk, + distance_threshold, + ransac_threshold, + valid_size_ratio_margin, + valid_area_threshold, + image_max_size, + similar_triangle_threshold, + roi_to_template_box_ratio, + default_image_size, + template_im_dir, + ): + assert matching_topk == 2, "Invalid matching_topk" + + # SIFT feature extractor + # self.sift = cv2.xfeatures2d.SIFT_create() + self.sift = cv2.SIFT_create() + + self.template_im_dir = template_im_dir + + # load templates + ( + self.template_images, + self.template_anchors, + self.template_features, + self.template_metadata, + ) = self._load_template(template_info) + # kNN feature matcher + self.matcher = cv2.FlannBasedMatcher(flann_based_matcher_config, {}) + + # other arguments + self.image_max_size = image_max_size + self.min_match_count = min_match_count + self.matching_topk = matching_topk + self.distance_threshold = distance_threshold + self.ransac_threshold = ransac_threshold + self.roi_to_template_box_ratio = roi_to_template_box_ratio + self.default_image_size = default_image_size + + # validity thresholds + self.valid_size_ratio_interval = [ + 1 - valid_size_ratio_margin, + 1 + valid_size_ratio_margin, + ] + self.valid_area_threshold = valid_area_threshold + self.similar_triangle_threshold = similar_triangle_threshold + + def _load_template(self, template_info): + r"""Load template images from paths and extract features""" + template_images, template_anchors = {}, {} + template_features, template_metadata = {}, {} + # print(template_info) + # for doc_type in template_info: + template_anchors = [] + template_features = [] + template_metadata = [] + template_im_path = os.path.join( + self.template_im_dir, template_info["image_path"] + ) + print(template_im_path) + assert os.path.exists(template_im_path), print(template_im_path) + template_image = cv2.imread(template_im_path) + for anchor in template_info["anchors"]: + # extract anchor + anchor = [int(float(item)) for item in anchor] + x1, y1, x2, y2 = anchor + template_anchor = template_image[y1:y2, x1:x2] + template_anchor = cv2.cvtColor(template_anchor, cv2.COLOR_BGR2GRAY) + template_kpts, template_desc = self.sift.detectAndCompute( + template_anchor, None + ) + + # append to dict + max_size = np.max(template_anchor.shape[:2]) + min_size = np.min(template_anchor.shape[:2]) + + template_images = template_image + template_anchors.append(template_anchor) + template_features.append( + { + "kpts": template_kpts, + "desc": template_desc, + "ratio": max_size / min_size, + } + ) + template_metadata.append({"box": [x1, y1, x2, y2]}) + return (template_images, template_anchors, template_features, template_metadata) + + + def run_alige(self, images, metadata=None): + transformed_images = [] + for image, _metadata in tqdm.tqdm(zip(images, metadata)): + # find templates + doc_type = _metadata["doc_type"] + template_image = self.template_images + gray_image = self._preprocess(image) + + # match against all templates of the given doc type + anchor_centers, anchor_locations = [], [] + found_centers, found_locations = [], [] + for template_anchor, template_feature, template_metadata in zip( + self.template_anchors, + self.template_features, + self.template_metadata, + ): + ( + anchor_center, + anchor_location, + found_center, + found_location, + ) = self._find_template( + gray_image, + template_image, + template_anchor, + template_feature, + template_metadata, + ) + if found_location is not None: + anchor_centers.append(anchor_center) + anchor_locations.append(anchor_location) + found_centers.append(found_center) + found_locations.append(found_location) + + if len(found_locations) == 0: + print("len(found_locations) == 0") + transformed_images.append(None) + continue + else: + found_locations = np.concatenate(found_locations, axis=0) + anchor_locations = np.concatenate(anchor_locations, axis=0) + + # check calibration ability + calib_success = False + if len(found_centers) < 3: + print(found_centers) + print("len(found_centers) < 3") + transformed_images.append(None) + continue + else: + calib_success = check_similar_triangle( + anchor_centers, + found_centers, + diff_thres=self.similar_triangle_threshold, + ) + + # align image + # TODO: calib_success = False even when calib is successful + calib_success = True + if calib_success: + perspective_trans, _ = cv2.findHomography( + found_locations, anchor_locations + ) + + transformed_image = cv2.warpPerspective( + image, + perspective_trans, + (template_image.shape[1], template_image.shape[0]), + ) + transformed_images.append(transformed_image) + else: + print("calib_success is False") + transformed_images.append(None) + + return transformed_images + + def _preprocess(self, image): + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + return gray_image + + def _find_template( + self, + gray_image, + gray_template, + gray_template_anchor, + template_feature, + template_metadata, + ): + r"""Find templates and return template center""" + # parser inputs + template_kpts = template_feature["kpts"] + template_desc = template_feature["desc"] + template_box = template_metadata["box"] + + # resize ROI + gray_image, shift, scale = self._crop_roi_and_resize( + gray_image, gray_template, template_box + ) + + # extract features + image_kpts, image_desc = self.sift.detectAndCompute(gray_image, None) + # if image_desc is None: + # print("Error matching") + # return None, None, None, None + # print(image_desc) + try: + # knnMatch to get top-K then sort by their distance + matches = self.matcher.knnMatch( + template_desc, image_desc, self.matching_topk + ) + except Exception as err: + print(err) + return None, None, None, None + matches = sorted(matches, key=lambda x: x[0].distance) + + # ratio test, to get good matches. + # idea: good matches should uniquely match each other, i.e. top-1 and top-2 distances are much difference + good = [ + m1 + for (m1, m2) in matches + if m1.distance < self.distance_threshold * m2.distance + ] + + # find homography matrix + if len(good) > self.min_match_count: + # (queryIndex for the small object, trainIndex for the scene ) + src_pts = np.float32([template_kpts[m.queryIdx].pt for m in good]).reshape( + -1, 1, 2 + ) + dst_pts = np.float32([image_kpts[m.trainIdx].pt for m in good]).reshape( + -1, 1, 2 + ) + + # find homography matrix in cv2.RANSAC using good match points + M, mask = cv2.findHomography( + src_pts, dst_pts, cv2.RANSAC, self.ransac_threshold + ) + if M is not None: + # get template center in original image + h, w = gray_template_anchor.shape[:2] + pts = np.float32( + [[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]] + ).reshape(-1, 1, 2) + dst = cv2.perspectiveTransform(pts, M) + + # get convex hull of the match and its min-area surrounding box + hull = ConvexHull(dst[:, 0, :]).vertices + hull = dst[hull][:, 0, :] + hull_rect = cv2.minAreaRect(hull[:, None, :]) + hull_box = cv2.boxPoints(hull_rect) + + # compute sizes of the hull box + hull_box_size = ( + np.sqrt(np.sum((hull_box[0] - hull_box[1]) ** 2, axis=-1)), + np.sqrt(np.sum((hull_box[1] - hull_box[2]) ** 2, axis=-1)), + ) + + # verify max-size-over-min-size ratio + hull_box_ratio = np.max(hull_box_size) / np.min(hull_box_size) + template_ratio = template_feature["ratio"] + is_valid_ratio = ( + hull_box_ratio > self.valid_size_ratio_interval[0] * template_ratio + ) and ( + hull_box_ratio < self.valid_size_ratio_interval[1] * template_ratio + ) + + # verify hull-area-to-hull-box-area ratio + hull_area = Polygon(hull).area + hull_box_area = Polygon(hull_box).area + is_valid_hull_area = ( + hull_area >= self.valid_area_threshold * hull_box_area + ) + + # return score as average of inverse distance to closest match + if is_valid_hull_area and is_valid_ratio: + pts[..., 0] += template_box[0] + pts[..., 1] += template_box[1] + anchor_center = np.mean(pts[:, 0, :], axis=0).tolist() + anchor_location = pts[:, 0, :] + + dst[..., 0] = dst[..., 0] / scale + shift[0] + dst[..., 1] = dst[..., 1] / scale + shift[1] + found_center = np.mean(dst[:, 0, :], axis=0).tolist() + found_location = dst[:, 0, :] + + return ( + anchor_center, + anchor_location, + found_center, + found_location, + ) + return None, None, None, None + + def _crop_roi_and_resize(self, query_image, template_image, box): + r"""Crop ROI which possibly containing template anchor and resize it""" + # get template anchor box coordinates relative to template image size + x1, y1, x2, y2 = box + x, y = x1 / template_image.shape[1], y1 / template_image.shape[0] + w = (x2 - x1) / template_image.shape[1] + h = (y2 - y1) / template_image.shape[0] + + # crop ROI + pad_ratio = (self.roi_to_template_box_ratio - 1.0) / 2 + x1 = max(min(x - w * pad_ratio, 1.0), 0.0) + y1 = max(min(y - h * pad_ratio, 1.0), 0.0) + x2 = max(min(x + w * self.roi_to_template_box_ratio, 1.0), 0.0) + y2 = max(min(y + h * self.roi_to_template_box_ratio, 1.0), 0.0) + x1, y1 = int(x1 * query_image.shape[1]), int(y1 * query_image.shape[0]) + x2, y2 = int(x2 * query_image.shape[1]), int(y2 * query_image.shape[0]) + query_image = query_image[y1:y2, x1:x2] + + # resize ROI + query_image_max_size = max(query_image.shape[:2]) + if self.image_max_size and query_image_max_size > self.image_max_size: + ratio = self.image_max_size / query_image_max_size + query_image = cv2.resize(query_image, (0, 0), fx=ratio, fy=ratio) + else: + ratio = 1.0 + + return query_image, (x1, y1), ratio diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/infer_img_template_aligner_delete.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/infer_img_template_aligner_delete.py new file mode 100755 index 0000000..b9ccbf8 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/infer_img_template_aligner_delete.py @@ -0,0 +1,86 @@ +from src.config.sift_based_aligner import config +from src.modules.sift_based_aligner import SIFTBasedAligner +from src.utils.common import read_json +from argparse import ArgumentParser +import os +import cv2 + +num_pages_dict = {"pos01": 2, "pos04": 2} + +exception_files = {"pos01": ["SKM_458e Ag22101217490_0.png"]} # only page_1 + +template_path_dict = { + "pos01": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos01.json", + "pos04": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos04.json", +} + + +def reformat(config, doc_id): + # template_infos = config['template_info'] + + template_path = template_path_dict[doc_id] + template_info = read_json(template_path) + + # num_page = num_pages_dict[doc_id] + + config["template_info"] = template_info + return config + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--img_dir") + parser.add_argument("--output") + parser.add_argument("--doc_id", help="pos01/pos04", default="pos01") + args = parser.parse_args() + + # make dir + if not os.path.exists(args.output): + os.makedirs(args.output, exist_ok=True) + error_dir = args.output + "_error" + if not os.path.exists(error_dir): + os.makedirs(error_dir, exist_ok=True) + + doc_id = args.doc_id + print("DOCID: ", doc_id) + + # load img paths + img_paths = [args.img_dir] + images = [cv2.imread(img_path) for img_path in img_paths] + print("total samples: ", len(img_paths)) + + # reformat config + config = reformat(config, doc_id=args.doc_id) + + # aligner init + aligner = SIFTBasedAligner(**config) + + metadata = [{"doc_type": "{}_1".format(doc_id), "img_path": img_paths[0]}] + # metadata = [{'doc_type': 'edit_form_1_1' if "_0.jpg" in img_path else 'edit_form_1_2', 'img_path': img_path} for img_path in img_paths] + transformed_images = aligner.run(images, metadata) + + print(len(img_paths), len(transformed_images)) + + error_count = 0 + for idx in range(len(transformed_images)): + img_name = os.path.basename(img_paths[idx]) + img_outpath = os.path.join(args.output, img_name) + img_out = transformed_images[idx] + doc_type = metadata[idx]["doc_type"] + field_boxes = config["template_info"][doc_type]["fields"] + + for bbox in field_boxes: + x, y = bbox["position"]["left"], bbox["position"]["top"] + w, h = bbox["size"]["width"], bbox["size"]["height"] + x1, y1, x2, y2 = int(x), int(y), int(x + w), int(y + h) + cv2.rectangle(img_out, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) + + if img_out is not None: + print("Write to: ", img_outpath) + cv2.imwrite(img_outpath, img_out) + else: + error_count += 1 + print("Image None: ", img_paths[idx]) + cv2.imwrite(os.path.join(error_dir, img_name), images[idx]) + + print("Num error cases: ", error_count) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_crop_lines.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_crop_lines.py new file mode 100755 index 0000000..70b5d85 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_crop_lines.py @@ -0,0 +1,93 @@ +from argparse import ArgumentParser +import cv2 +import glob +import os +import time +import tqdm + +from src.modules.line_parser import TemplateBoxParser +from src.config.line_parser import TEMPLATE_BOXES +from src.utils.common import read_json, get_doc_id_with_page + + +""" +Crop per form (2 page) +""" + +template_path_dict = { + "pos01": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos01.json", + "pos04": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos04.json", + "pos02": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos02.json", + "pos03": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos03.json", + "pos08": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos08.json", + "pos05": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos05.json", + "pos06": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos06.json", +} + + +def identity_page(img_path, doc_id): + page_type = "" + if "_0.jpg" in img_path: + page_number = 1 + elif "_1.jpg" in img_path: + page_number = 2 + else: + idx = int(float(img_path.split(".jpg")[0].split("_")[-1])) + if idx % 2 == 0: + page_number = 1 + else: + page_number = 2 + + doc_template_id = "{}_page_{}".format(doc_id, page_number) + + return doc_template_id # page_1 / page_2 + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--img_dir") + parser.add_argument("--out_dir") + parser.add_argument("--doc_id", help="pos01/pos02", default="pos01") + args = parser.parse_args() + + line_parser = TemplateBoxParser() + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + img_paths = glob.glob(args.img_dir + "/*") + print("Len imgs: ", len(img_paths)) + + template_info = read_json(template_path_dict[args.doc_id]) + + crop_metadata = {} + pages = ["page_"] + for page in pages: + metadata = {"boxes": [], "box_types": []} + + crop_metadata[page] = metadata + + count = 0 + for idx, img_path in tqdm.tqdm(enumerate(img_paths)): + aligned_images = cv2.imread(img_path) + + doc_template_id = get_doc_id_with_page(img_path, args.doc_id) + # print(img_path, doc_template_id, aligned_images) + + cropped_images = line_parser.run( + [aligned_images], + metadata=[{"boxes": template_info[doc_template_id]["fields"]}], + ) + + count += len(cropped_images[0]) + for id_img, crop_img in enumerate(cropped_images[0]): + out_path = os.path.join( + args.out_dir, + os.path.splitext(os.path.basename(img_path))[0] + + "_" + + str(id_img) + + ".jpg", + ) + cv2.imwrite(out_path, crop_img) + + print("Total: ", count) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_ocr.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_ocr.py new file mode 100755 index 0000000..2a2106c --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/samples/run_ocr.py @@ -0,0 +1,46 @@ +from src.modules.ocr_module import OcrEngine +from argparse import ArgumentParser +import os +import glob +import cv2 +import time +from src.utils.visualize import visualize_ocr_output + + +def main(img_dir, out_dir, device="cuda:0"): + det_cfg = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/weights/yolox_s_8x8_300e_cocotext_1280.py" + det_ckpt = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/weights/best_bbox_mAP_epoch_100.pth" + cls_cfg = "/home/sds/datnt/mmocr/logs/satrn_big_2022-04-25/satrn_big.py" + cls_ckpt = "/home/sds/datnt/mmocr/logs/satrn_big_2022-04-25/best.pth" + + engine = OcrEngine(det_cfg, det_ckpt, cls_cfg, cls_ckpt, device=device) + + # img_dir = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/raw_images/POS01" + # out_dir = "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/outputs/visualize_ocr/POS01" + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + img_paths = glob.glob(img_dir + "/*") + for img_path in img_paths: + img = cv2.imread(img_path) + t1 = time.time() + res = engine.run_image(img) + + visualize_ocr_output( + res, + img, + vis_dir=out_dir, + prefix_name=os.path.splitext(os.path.basename(img_path))[0], + font_path="./assest/visualize/times.ttf", + is_vis_kie=False, + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--img_dir") + parser.add_argument("--out_dir") + parser.add_argument("--device", help="cuda:0 / cuda:0") + args = parser.parse_args() + + main(args.img_dir, args.out_dir, args.device) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/serve_model.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/serve_model.py new file mode 100755 index 0000000..731c213 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/serve_model.py @@ -0,0 +1,139 @@ +import yaml +import numpy as np + +from .config.sift_based_aligner import config +from .modules.sift_based_aligner import SIFTBasedAligner +from .utils.common import read_json +from common.utils.word_formation import Word, words_to_lines + + +def calc_pct_overlapped_area(bboxes1, bboxes2): + # assert True + assert len(bboxes1.shape) == 2 and bboxes1.shape[1] == 4 + assert len(bboxes2.shape) == 2 and bboxes2.shape[1] == 4 + + bboxes1 = bboxes1.copy() + bboxes2 = bboxes2.copy() + + x11, y11, x12, y12 = np.split(bboxes1, 4, axis=1) + x21, y21, x22, y22 = np.split(bboxes2, 4, axis=1) + xA = np.maximum(x11, np.transpose(x21)) + yA = np.maximum(y11, np.transpose(y21)) + xB = np.minimum(x12, np.transpose(x22)) + yB = np.minimum(y12, np.transpose(y22)) + interArea = np.maximum((xB - xA + 1), 0) * np.maximum((yB - yA + 1), 0) + boxBArea = (x22 - x21 + 1) * (y22 - y21 + 1) + boxBArea = np.tile(boxBArea, (1, len(bboxes1))) + iou = interArea / boxBArea.T + return iou + + +class Predictor: + def __init__(self, setting_file="setting.yml"): + with open(setting_file) as f: + # use safe_load instead load + self.setting = yaml.safe_load(f) + self.config = self.setting["templates"]["config"] + + def _align(self, config, temp_name, image): + # init aligner + aligner = SIFTBasedAligner(**config) + metadata = [{"doc_type": temp_name}] + aligned_images = aligner.run_alige([image], metadata) + aligned_image = aligned_images[0] + return aligned_image + + def _reorder_words(self, boxes): + arr_x1 = boxes[:, 0] + return np.argsort(arr_x1) + + def _asign_words_to_field( + self, boxes, contents, types, page_template_info, threshold=0.8 + ): + field_coords = [element["box"] for element in page_template_info["fields"]] + field_coords = np.array(field_coords) + field_coords = field_coords.astype(float) + field_coords = field_coords.astype(int) + field_names = [element["label"] for element in page_template_info["fields"]] + field_types = [ + "checkbox" if element["label"].startswith("checkbox") else "word" + for element in page_template_info["fields"] + ] + boxes = np.array(boxes[0]) + print(field_coords) + print(boxes) + print(field_coords.shape, boxes.shape) + area_pct = calc_pct_overlapped_area(field_coords, boxes) + + results = dict() + for row_score, field, _type in zip(area_pct, field_names, field_types): + if _type == "checkbox": + inds = np.where(row_score > threshold)[0] + inds = [i for i in inds if types[i] == "checkbox"] + results[field] = dict() + results[field]["value"] = contents[inds[0]] if len(inds) > 0 else None + results[field]["boxes"] = boxes[inds[0]] if len(inds) > 0 else None + else: + inds = np.where(row_score > threshold)[0] + field_word_boxes = boxes[inds] + sorted_inds = inds[self._reorder_words(field_word_boxes)] + + results[field] = dict() + results[field]["words"] = [contents[i] for i in sorted_inds] + lines = self._get_line_content(boxes[sorted_inds], results[field]["words"]) + results[field]["value"] = '\n'.join(lines).strip() + results[field]["boxes"] = boxes[sorted_inds] + return results + + def _get_line_content(self, boxes, contents): + list_words = [] + for box, text in zip(boxes, contents): + bndbox = [int(j) for j in box] + list_words.append( + Word( + text=text, + bndbox=bndbox, + ) + ) + list_lines, _ = words_to_lines(list_words) + line_texts = [line.text for line in list_lines] + return line_texts + + + def align_image(self, image, template_json, template_image_dir, temp_name): + """Run TemplateMaching main + + Args: + documents (dict): document then document classification + template_json (dict): + example: + { + "pos01": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos01.json", + "pos04": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos04.json", + "pos02": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos02.json", + "pos03": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/pos03_fields_checkbox.json", + "pos08": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos08.json", + "pos05": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos05.json", + "pos06": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/add_fields/pos06.json", + "cccd_front": "/mnt/ssd500/hoanglv/Projects/FWD/template_matching_hoanglv/data/json/cccd_front.json", + } + template_image_dir (str): path to template image dir + + Returns: + dict: content then template matching + """ + + config = self.config.copy() + config["template_info"] = template_json + config["template_im_dir"] = template_image_dir + aligned_image = self._align(config, temp_name, image) + return aligned_image + + def template_based_extractor(self, batch_boxes, texts, doc_page, template_json): + field_data = self._asign_words_to_field( + batch_boxes, + texts, + doc_page["types"], + template_json, + ) + return field_data \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/common.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/common.py new file mode 100755 index 0000000..26e769c --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/common.py @@ -0,0 +1,24 @@ +import os +import json + + +def get_doc_id_with_page(img_path, doc_id): + if "_0.jpg" in img_path: + doc_page = "{}_page_1".format(doc_id) + elif "_1.jpg" in img_path: + doc_page = "{}_page_2".format(doc_id) + else: + idx = int(os.path.splitext(os.path.basename(img_path))[0].split("_")[-1]) + # idx = int(float(img_path.split(".jpg")[0].split("_")[-1])) + if idx % 2 == 0: + doc_page = "{}_page_1".format(doc_id) + else: + doc_page = "{}_page_2".format(doc_id) + + return doc_page + + +def read_json(json_path): + with open(json_path, "r", encoding="utf8") as f: + data = json.load(f) + return data diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/image_calib.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/image_calib.py new file mode 100755 index 0000000..2d713bf --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/image_calib.py @@ -0,0 +1,803 @@ +import time, os +import cv2, math +import numpy as np +import pathlib +import math + + +RADIAN_PER_DEGREE = 0.0174532 +debug = False + + +def crop_image(input_img, bbox, bbox_ratio=1.0, offset_x=0, offset_y=0): + left = int(bbox_ratio * bbox[0]) + top = int(bbox_ratio * bbox[1]) + width = int(bbox_ratio * bbox[2]) + height = int(bbox_ratio * bbox[3]) + crop_img = input_img[ + top + offset_y : top + height + offset_y, + left + offset_x : left + width + offset_x, + ] + return crop_img + + +def resize_normalize(img, normalize_width=1654): + w = img.shape[1] + h = img.shape[0] + resize_ratio = normalize_width / w + normalize_height = round(h * resize_ratio) + resize_img = cv2.resize( + img, (normalize_width, normalize_height), interpolation=cv2.INTER_CUBIC + ) + # cv2.imshow('resize img', resize_img) + # cv2.waitKey(0) + return resize_ratio, resize_img + + +def draw_bboxes(img, bboxes, window_name="draw bboxes"): + # e.g: bboxes= [(0,0),(0,5),(5,5),(5,0)] + if len(img.shape) != 3: + img_RGB = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + else: + img_RGB = img + color_red = (0, 0, 255) + for idx, bbox in enumerate(bboxes): + cv2.line(img_RGB, bbox[0], bbox[1], color=color_red, thickness=2) + cv2.line(img_RGB, bbox[1], bbox[2], color=color_red, thickness=2) + cv2.line(img_RGB, bbox[2], bbox[3], color=color_red, thickness=2) + cv2.line(img_RGB, bbox[3], bbox[0], color=color_red, thickness=2) + + font = cv2.FONT_HERSHEY_SIMPLEX + + # org + org = (bbox[0][0], bbox[0][1] - 5) + + # fontScale + fontScale = 1.5 + + # Blue color in BGR + color = (255, 0, 0) + + # Line thickness of 2 px + thickness = 2 + + # Using cv2.putText() method + img_RGB = cv2.putText( + img_RGB, str(idx), org, font, fontScale, color, thickness, cv2.LINE_AA + ) + + img_RGB = cv2.resize( + img_RGB, (int(img_RGB.shape[1] / 2), int(img_RGB.shape[0] / 2)) + ) + cv2.imshow(window_name, img_RGB) + cv2.waitKey(0) + + +class Template_info: + def __init__( + self, + name, + template_path, + field_bboxes, + field_rois_extend=(1.0, 1.0), + field_search_areas=None, + confidence=0.7, + scales=(0.9, 1.1, 0.1), + rotations=(-2, 2, 2), + normalize_width=1654, + ): # 1654 + self.name = name + self.template_img = cv2.imread(template_path, 0) + self.normalize_width = normalize_width + self.resize_ratio, self.template_img = resize_normalize( + self.template_img, normalize_width + ) + self.template_width = self.template_img.shape[1] + self.template_height = self.template_img.shape[0] + self.confidence = confidence + self.field_bboxes = field_bboxes + self.field_rois_extend = field_rois_extend + self.field_search_areas = field_search_areas + self.field_locs = [] + self.list_field_samples = [] + for idx, bbox in enumerate(self.field_bboxes): + bbox = self.resize_bbox(bbox, self.resize_ratio) + + field = dict() + field["name"] = str(idx) + field["loc"] = (bbox[0] + (bbox[2] - 1) / 2, bbox[1] + (bbox[3] - 1) / 2) + self.field_locs.append(field["loc"]) + field["search_area"] = None + if field_search_areas is not None: + field["search_area"] = self.resize_bbox( + field_search_areas[idx], self.resize_ratio + ) + else: + field["data"] = self.crop_image(self.template_img, bbox) + # cv2.imwrite(field['name']+'.jpg', field['data']) + field_w = max(field["data"].shape[1], 50) + field_h = max(field["data"].shape[0], 50) + extend_x = int(self.field_rois_extend[0] * field_w) + extend_y = int(self.field_rois_extend[1] * field_h) + left = max(int(field["loc"][0] - field_w / 2 - extend_x), 0) + top = max(int(field["loc"][1] - field_h / 2 - extend_y), 0) + right = min( + int(field["loc"][0] + field_w / 2 + extend_x), self.template_width + ) + bottom = min( + int(field["loc"][1] + field_h / 2 + extend_y), self.template_height + ) + width = right - left + height = bottom - top + field["search_area"] = [left, top, width, height] + + self.createSamples(field, scales, rotations) + self.list_field_samples.append(field) + + def resize_bbox(self, bbox, resize_ratio): + for i in range(len(bbox)): + bbox[i] = round(bbox[i] * resize_ratio) + return bbox + + def crop_image(self, input_img, bbox, offset_x=0, offset_y=0): + # logger.info('crop') + crop_img = input_img[ + bbox[1] + offset_y : bbox[1] + bbox[3] + offset_y, + bbox[0] + offset_x : bbox[0] + bbox[2] + offset_x, + ] + return crop_img + + def createSamples(self, field, scales, rotations): + # logger.info('Add_template', field['name']) + list_scales = [] + list_rotations = [] + + num_scales = round((scales[1] - scales[0]) / scales[2]) + 1 + num_rotations = round((rotations[1] - rotations[0]) / rotations[2]) + 1 + for i in range(num_scales): + list_scales.append(round(scales[0] + i * scales[2], 4)) + for i in range(num_rotations): + list_rotations.append(round(rotations[0] + i * rotations[2], 4)) + + field["list_samples"] = [] + field_data = field["data"] + w = field_data.shape[1] + h = field_data.shape[0] + bgr_val = int( + ( + int(field_data[0][0]) + + int(field_data[0][w - 1]) + + int(field_data[h - 1][w - 1]) + + int(field_data[h - 1][0]) + ) + / 4 + ) + for rotation in list_rotations: + abs_rotation = abs(rotation) + if w < h: + if abs_rotation <= 45: + sa = math.sin(abs_rotation * RADIAN_PER_DEGREE) + ca = math.cos(abs_rotation * RADIAN_PER_DEGREE) + newHeight = (int)((h - w * sa) / ca) + # newHeight = newHeight - ((h - newHeight) % 2) + szOutput = (w, newHeight) + else: + sa = math.sin((90 - abs_rotation) * RADIAN_PER_DEGREE) + ca = math.cos((90 - abs_rotation) * RADIAN_PER_DEGREE) + newWidth = (int)((h - w * sa) / ca) + # newWidth = newWidth - ((w - newWidth) % 2) + szOutput = (newWidth, w) + else: + if abs_rotation <= 45: + sa = math.sin(abs_rotation * RADIAN_PER_DEGREE) + ca = math.cos(abs_rotation * RADIAN_PER_DEGREE) + newWidth = (int)((w - h * sa) / ca) + # newWidth = newWidth - ((w - newWidth) % 2) + szOutput = (newWidth, h) + else: + sa = math.sin((90 - rotation) * RADIAN_PER_DEGREE) + ca = math.cos((90 - rotation) * RADIAN_PER_DEGREE) + newHeight = (int)((w - h * sa) / ca) + # newHeight = newHeight - ((h - newHeight) % 2) + szOutput = (h, newHeight) + + (h, w) = field_data.shape[:2] + (cX, cY) = (w / 2, h / 2) + M = cv2.getRotationMatrix2D((cX, cY), -rotation, 1.0) + cos = np.abs(M[0, 0]) + sin = np.abs(M[0, 1]) + nW = int((h * sin) + (w * cos)) + nH = int((h * cos) + (w * sin)) + M[0, 2] += (nW / 2) - cX + M[1, 2] += (nH / 2) - cY + rotated = cv2.warpAffine(field_data, M, (nW, nH), borderValue=bgr_val) + + # (h_rot, w_rot) = rotated.shape[:2] + # (cX_rot, cY_rot) = (w_rot // 2, h_rot // 2) + # pt1=(int(cX_rot-3), int(cY_rot-3)) + # pt2=(int(cX_rot+3), int(cY_rot+3)) + # pt3=(int(cX_rot-3), int(cY_rot+3)) + # pt4=(int(cX_rot+3), int(cY_rot-3)) + # cv2.line(rotated,pt1,pt2,color=255) + # cv2.line(rotated,pt3,pt4,color=255) + + offset_X = int((nW - szOutput[0]) / 2) + offset_Y = int((nH - szOutput[1]) / 2) + + crop_rotated = rotated[ + offset_Y : nH - offset_Y - 1, offset_X : nW - offset_X - 1 + ] + crop_w = crop_rotated.shape[1] + crop_h = crop_rotated.shape[0] + # rint('origin size', crop_w, crop_h) + + for scale in list_scales: + temp = dict() + temp["rotation"] = rotation + temp["scale"] = scale + # logger.info('scale', scale, ', rotation', rotation) + crop_rotate_resize = cv2.resize( + crop_rotated, (int(scale * crop_w), int(scale * crop_h)) + ) + # logger.info('resize size', int(scale * crop_w), int(scale * crop_h)) + temp["data"] = crop_rotate_resize + if debug: + cv2.imshow("result", crop_rotated) + cv2.imshow("result_crop", crop_rotate_resize) + ch = cv2.waitKey(0) + if ch == 27: + cv2.imwrite("result.jpg", crop_rotated) + break + field["list_samples"].append(temp) + + def draw_template(self, src_img=None, crop=False, crop_dir=""): + list_bboxes = [] + for idx, bbox in enumerate(self.field_bboxes): + left = bbox[0] + top = bbox[1] + right = bbox[0] + bbox[2] + bottom = bbox[1] + bbox[3] + if crop: + crop_img = crop_image(self.template_img, bbox) + cv2.imwrite( + os.path.join(crop_dir, self.name + "_field_" + str(idx) + ".jpg"), + crop_img, + ) + bboxes = [(left, top), (right, top), (right, bottom), (left, bottom)] + list_bboxes.append(bboxes) + if src_img is None: + draw_bboxes(self.template_img, list_bboxes) + else: + draw_bboxes(src_img, list_bboxes, window_name="new") + + def get_template_img(self): + return self.template_img + + +class MatchingTemplate: + def __init__(self, initTemplate=False): + self.template_dir = "" + self.template_names = [] + self.template_list = [] + self.template_dir = os.path.join( + pathlib.Path(__file__).parent.absolute(), "templates" + ) + if initTemplate: + self.initTemplate() + self.matching_results = [] + self.activate_template = "" + + def initTemplate(self, template_dir=None, list_template_name=[]): + kk = 1 + + def add_template( + self, + template_name, + template_path, + field_bboxes, + field_rois_extend=(1.0, 1.0), + field_search_areas=None, + confidence=0.7, + scales=(0.9, 1.1, 0.1), + rotations=(-2, 2, 2), + normalize_width=1654, + ): + if not os.path.exists(template_path): + print("MatchingTemplate. No template path:", template_path) + return + print("MatchingTemplate. Init template", "[" + str(template_name) + "]") + temp = Template_info( + template_name, + template_path, + field_bboxes, + field_rois_extend, + field_search_areas, + confidence, + scales, + rotations, + normalize_width=normalize_width, + ) + self.template_list.append(temp) + + def clear_template(self): + self.template_list.clear() + + def check_template(self, template_name): + template_data = None + for template in self.template_list: + if template.name == template_name: + self.activate_template = template_name + template_data = template + break + if template_data is None: + print("MatchingTemplate. No template name", template_name) + # logger.info('Cannot find template', template_name, 'in database') + return template_data + + def draw_template(self, template_name, src_img=None, crop=False, crop_dir=""): + template_data = self.check_template(template_name) + if template_data is None: + return + template_data.draw_template(src_img, crop=crop, crop_dir=crop_dir) + + def get_matching_result(self, final_locx, final_locy, final_sample): + x0 = final_locx + y0 = final_locy + + x1 = x0 - (final_sample["data"].shape[1] / 2) * final_sample["scale"] + y1 = y0 - (final_sample["data"].shape[0] / 2) * final_sample["scale"] + x2 = x0 + (final_sample["data"].shape[1] / 2) * final_sample["scale"] + y2 = y0 + (final_sample["data"].shape[0] / 2) * final_sample["scale"] + + ## + ca = math.cos(final_sample["rotation"] * RADIAN_PER_DEGREE) + sa = math.sin(final_sample["rotation"] * RADIAN_PER_DEGREE) + rx1 = round((x0 + (x1 - x0) * ca - (y1 - y0) * sa)) + ry1 = round((y0 + (x1 - x0) * sa + (y1 - y0) * ca)) + rx2 = round((x0 + (x2 - x0) * ca - (y1 - y0) * sa)) + ry2 = round((y0 + (x2 - x0) * sa + (y1 - y0) * ca)) + rx3 = round((x0 + (x2 - x0) * ca - (y2 - y0) * sa)) + ry3 = round((y0 + (x2 - x0) * sa + (y2 - y0) * ca)) + rx4 = round((x0 + (x1 - x0) * ca - (y2 - y0) * sa)) + ry4 = round((y0 + (x1 - x0) * sa + (y2 - y0) * ca)) + return [(rx1, ry1), (rx2, ry2), (rx3, ry3), (rx4, ry4)] + + def find_field( + self, input_img, field, thres=0.3, fast=True, method="cv2.TM_CCORR_NORMED" + ): + max_conf = 0 + final_locx, final_locy = -1, -1 + final_sample = None + + process_img = input_img.copy() + if len(input_img.shape) == 3: # BGR + process_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2GRAY) + + if fast: + left = field["search_area"][0] + top = field["search_area"][1] + right = field["search_area"][0] + field["search_area"][2] + bottom = field["search_area"][1] + field["search_area"][3] + process_img = process_img[top:bottom, left:right] + try: + if not os.path.exists(os.path.join(self.template_dir, "crop")): + os.makedirs(os.path.join(self.template_dir, "crop")) + # print('MatchingTemplate. find_field. Write process image to', + # os.path.join(self.template_dir, 'crop', self.activate_template + '_' + field['name'] + '.jpg')) + # cv2.imwrite(os.path.join(self.template_dir, 'crop', self.activate_template + '_' + field['name'] + '.jpg'), + # process_img) + except: + print("Except find field : make_dir func") + pass + for sample in field["list_samples"]: + sample_data = sample["data"] + res = cv2.matchTemplate(process_img, sample_data, 5) + min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) + logger.info( + "Score:", + round(max_val, 4), + "Scale:", + sample["scale"], + "Angle:", + sample["rotation"], + max_loc[0] + sample_data.shape[1] / 2, + max_loc[1] + sample_data.shape[0] / 2, + ) + if max_val > max_conf: + max_conf = max_val + final_locx, final_locy = ( + max_loc[0] + sample_data.shape[1] / 2, + max_loc[1] + sample_data.shape[0] / 2, + ) + final_sample = sample + if fast: + final_locx, final_locy = ( + final_locx + field["search_area"][0], + final_locy + field["search_area"][1], + ) + + if max_conf >= thres: + print( + "Score:", + round(max_conf, 4), + "Scale:", + final_sample["scale"], + "Angle:", + final_sample["rotation"], + "Location:", + final_locx, + final_locy, + ) + else: # cannot find field + print( + "MatchingTemplate. find_field. Cannot find field! Max score:", + round(max_conf, 4), + ) + return 0, -1, -1 + + self.matching_results = self.get_matching_result( + final_locx, final_locy, final_sample + ) + # draw_bboxes(input_img, [self.matching_results], field['name']) + + return max_conf, final_locx, final_locy + + def find_template( + self, template_name, src_img, fast=True, threshold=0.7 + ): # src_img is cv2 image + # logger.info('\nCalib template', template_name) + template_data = self.check_template(template_name) + if template_data is None: + return + + resize_ratio, src_img = resize_normalize(src_img, template_data.normalize_width) + gray_img = src_img + if len(src_img.shape) == 3: # BGR + gray_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY) + list_pts = [] + + for idx, field in enumerate(template_data.list_field_samples): + # logger.info(field['name']) + conf, loc_x, loc_y = self.find_field( + gray_img, field, fast=fast, thres=template_data.confidence + ) + if conf > threshold: + list_pts.append((loc_x, loc_y)) + return list_pts + + def calib_template( + self, + template_name, + src_img, + fast=True, + simi_triangle_thres=4, + simi_line_thres=3, + ): # src_img is cv2 image + template_data = self.check_template(template_name) + if template_data is None: + return False, None + print( + "MatchingTemplate. Calib template", + template_name, + ", width", + template_data.template_width, + ", height", + template_data.template_height, + ) + + # src_img = cv2.resize(src_img, (template_data.template_width, template_data.template_height)) + resize_ratio, src_img = resize_normalize(src_img, template_data.normalize_width) + gray_img = src_img + if len(src_img.shape) == 3: # BGR + gray_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY) + list_pts = [] + + for idx, field in enumerate(template_data.list_field_samples): + # logger.info(field['name']) + import time + + begin = time.time() + conf, loc_x, loc_y = self.find_field( + gray_img, field, fast=fast, thres=template_data.confidence + ) + end = time.time() + print("calib_template. find field time:", 1000 * (end - begin), "ms") + list_pts.append((loc_x, loc_y)) + + src_pts = np.asarray(list_pts, dtype=np.float32) + dst_pts = np.asarray(template_data.field_locs, dtype=np.float32) + trans_img = src_img + calib_success = True + if len(src_pts) == 2: # affine transformation with 1 synthetic point + calib_success = check_angle_between_2_lines( + template_data.field_locs, list_pts, diff_thres=simi_line_thres + ) + inter_pts = ( + list_pts[0][0] + list_pts[0][1] - list_pts[1][1], + list_pts[0][1] + list_pts[1][0] - list_pts[0][0], + ) + list_pts.append(inter_pts) + inter_field_pts = [template_data.field_locs[0], template_data.field_locs[1]] + inter_field_pts.append( + ( + template_data.field_locs[0][0] + + template_data.field_locs[0][1] + - template_data.field_locs[1][1], + template_data.field_locs[0][1] + + template_data.field_locs[1][0] + - template_data.field_locs[0][0], + ) + ) + + src_pts = np.asarray(list_pts, dtype=np.float32) + dst_pts = np.asarray(inter_field_pts, dtype=np.float32) + print("dst_pts", dst_pts) + affine_trans = cv2.getAffineTransform(src_pts, dst_pts) + trans_img = cv2.warpAffine( + src_img, + affine_trans, + (template_data.template_width, template_data.template_height), + ) + elif len(src_pts) == 3: # affine transformation + calib_success = check_similar_triangle( + template_data.field_locs, list_pts, diff_thres=simi_triangle_thres + ) + affine_trans = cv2.getAffineTransform(src_pts, dst_pts) + trans_img = cv2.warpAffine( + src_img, + affine_trans, + (template_data.template_width, template_data.template_height), + ) + elif len(src_pts) > 3: # perspective transformation + perspective_trans, status = cv2.findHomography(src_pts, dst_pts) + w, h = template_data.template_width, template_data.template_height + trans_img = cv2.warpPerspective(src_img, perspective_trans, (w, h)) + else: + kk = 1 + return calib_success, trans_img + + def calib_template_2( + self, + template_name, + src_img, + fast=True, + simi_triangle_thres=4, + simi_line_thres=3, + ): # src_img is cv2 image + template_data = self.check_template(template_name) + if template_data is None: + return False, None, None + print( + "MatchingTemplate. Calib template", + template_name, + ", width", + template_data.template_width, + ", height", + template_data.template_height, + ) + + # src_img = cv2.resize(src_img, (template_data.template_width, template_data.template_height)) + resize_ratio, src_img = resize_normalize(src_img, template_data.normalize_width) + gray_img = src_img + if len(src_img.shape) == 3: # BGR + gray_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY) + list_pts = [] + calib_success = True + for idx, field in enumerate(template_data.list_field_samples): + # logger.info(field['name']) + import time + + begin = time.time() + conf, loc_x, loc_y = self.find_field( + gray_img, field, fast=fast, thres=template_data.confidence + ) + end = time.time() + print("calib_template. find field time:", 1000 * (end - begin), "ms") + list_pts.append((loc_x, loc_y)) + + src_pts = np.asarray(list_pts, dtype=np.float32) + dst_pts = np.asarray(template_data.field_locs, dtype=np.float32) + trans_img = src_img + + if len(src_pts) == 2: # affine transformation with 1 synthetic point + calib_success = check_angle_between_2_lines( + template_data.field_locs, list_pts, diff_thres=simi_line_thres + ) + inter_pts = ( + list_pts[0][0] + list_pts[0][1] - list_pts[1][1], + list_pts[0][1] + list_pts[1][0] - list_pts[0][0], + ) + list_pts.append(inter_pts) + inter_field_pts = [template_data.field_locs[0], template_data.field_locs[1]] + inter_field_pts.append( + ( + template_data.field_locs[0][0] + + template_data.field_locs[0][1] + - template_data.field_locs[1][1], + template_data.field_locs[0][1] + + template_data.field_locs[1][0] + - template_data.field_locs[0][0], + ) + ) + + src_pts = np.asarray(list_pts, dtype=np.float32) + dst_pts = np.asarray(inter_field_pts, dtype=np.float32) + # print('dst_pts', dst_pts) + affine_trans = cv2.getAffineTransform(src_pts, dst_pts) + trans_img = cv2.warpAffine( + src_img, + affine_trans, + (template_data.template_width, template_data.template_height), + borderValue=(255, 255, 255), + ) + elif len(src_pts) == 3: # affine transformation + calib_success = check_similar_triangle( + template_data.field_locs, list_pts, diff_thres=simi_triangle_thres + ) + affine_trans = cv2.getAffineTransform(src_pts, dst_pts) + trans_img = cv2.warpAffine( + src_img, + affine_trans, + (template_data.template_width, template_data.template_height), + ) + elif len(src_pts) > 3: # perspective transformation + perspective_trans, status = cv2.findHomography(src_pts, dst_pts) + w, h = template_data.template_width, template_data.template_height + trans_img = cv2.warpPerspective(src_img, perspective_trans, (w, h)) + else: + kk = 1 + return calib_success, trans_img, dst_pts + + def crop_image(self, input_img, bbox, offset_x=0, offset_y=0): + logger.info("crop") + crop_img = input_img[ + bbox[1] + offset_y : bbox[1] + bbox[3] + offset_y, + bbox[0] + offset_x : bbox[0] + bbox[2] + offset_x, + ] + return crop_img + + +def test_calib_multi(template_name, src_img_dir): + list_files = get_list_file_in_folder(src_img_dir) + for idx, f in enumerate(list_files): + print(idx, f) + test_calib(template_name, os.path.join(src_img_dir, f)) + + +def test_calib(template_name, src_img_path): + src_img = cv2.imread(src_img_path) + begin_init = time.time() + + match = MatchingTemplate(initTemplate=True) + # match.add_template(template_name=template_name, + # template_path='C:/Users/titik/Desktop/idcard_2June/test_MireaAsset/contract.JPG', + # field_bboxes=[[184, 1256, 242, 142]], + # field_rois_extend = (10.0,0.3), + # field_search_areas=None, + # # confidence=0.7, + # # scales=(0.95, 1.05, 0.05), + # # rotations=(-1, 1, 1)) + # confidence=0.2, + # scales=(1.0, 1.0, 0.1), + # rotations=(0, 0, 1)) + end_init = time.time() + logger.info("Time init:", end_init - begin_init, "seconds") + # match.draw_template(template_name) + begin = time.time() + calib_success, calib_img = match.calib_template(template_name, src_img, fast=True) + + # base_name = os.path.basename(src_img_path) + # cv2.imwrite(os.path.join(output_dir, base_name.replace('.jpg', '_trans.jpg')), calib_img) + end = time.time() + print("Time:", end - begin, "seconds") + logger.info("Time:", end - begin, "seconds") + + debug = True + if debug: + # src_img_with_box = visualize_boxes('/home/aicr/cuongnd/text_recognition/data/SDV_invoices_mod/006.txt', src_img, + # debug=False, offset_x=-20, offset_y=-20) + # src_img_with_box = cv2.resize(src_img, (int(src_img.shape[1] / 2), int(src_img.shape[0] / 2))) + # cv2.imshow('src with boxes', src_img_with_box) + # trans_img_with_box = visualize_boxes('/home/aicr/cuongnd/text_recognition/data/SDV_invoices_mod/006.txt', + # calib_img, debug=False, offset_x=-20, offset_y=-20) + trans_img_with_box = cv2.resize( + calib_img, (int(calib_img.shape[1] / 2), int(calib_img.shape[0] / 2)) + ) + trans_img_with_box = cv2.resize( + calib_img, (calib_img.shape[1], calib_img.shape[0]) + ) + cv2.imshow("transform_with_boxes", trans_img_with_box) + base_name = os.path.basename(src_img_path) + # cv2.imwrite(src_img_path.replace(base_name, 'transform/' + base_name.replace('.jpg', '_trans.jpg')), + # trans_img_with_box) + cv2.waitKey(0) + return calib_img + + +def get_list_file_in_folder(dir, ext=["jpg", "png", "JPG", "PNG"]): + included_extensions = ext + file_names = [ + fn + for fn in os.listdir(dir) + if any(fn.endswith(ext) for ext in included_extensions) + ] + return file_names + + +def getAngle(a, b, c): + ang = math.fabs( + math.degrees( + math.atan2(c[1] - b[1], c[0] - b[0]) - math.atan2(a[1] - b[1], a[0] - b[0]) + ) + ) + return ang + 360 if ang < 0 else ang + + +def simi_aaa(a1, a2, diff_thres): + a1 = [float(i) for i in a1] + a2 = [float(i) for i in a2] + a1.sort() + a2.sort() + + # Check for AAA + diff_1 = math.fabs(a1[0] - a2[0]) + diff_2 = math.fabs(a1[1] - a2[1]) + diff_3 = math.fabs(a1[2] - a2[2]) + max_diff = max(diff_1, max(diff_2, diff_3)) + if diff_1 < diff_thres and diff_2 < diff_thres and diff_3 < diff_thres: + return max_diff, True + return max_diff, False + + +def check_similar_triangle(list_pts1, list_pts2, diff_thres=4): + list_ang1 = [ + getAngle(list_pts1[0], list_pts1[1], list_pts1[2]), + getAngle(list_pts1[1], list_pts1[2], list_pts1[0]), + getAngle(list_pts1[2], list_pts1[0], list_pts1[1]), + ] + + list_ang2 = [ + getAngle(list_pts2[0], list_pts2[1], list_pts2[2]), + getAngle(list_pts2[1], list_pts2[2], list_pts2[0]), + getAngle(list_pts2[2], list_pts2[0], list_pts2[1]), + ] + + max_diff, is_similar = simi_aaa(list_ang1, list_ang2, diff_thres) + # print('check_similar_triangle. max diff:',max_diff) + return is_similar + + +def dot_product(vA, vB): + return vA[0] * vB[0] + vA[1] * vB[1] + + +def check_angle_between_2_lines(lineA, lineB, diff_thres=2): + # Get nicer vector form + try: + vA = [(lineA[0][0] - lineA[1][0]), (lineA[0][1] - lineA[1][1])] + vB = [(lineB[0][0] - lineB[1][0]), (lineB[0][1] - lineB[1][1])] + # Get dot prod + dot_prod = dot_product(vA, vB) + # Get magnitudes + magA = dot_product(vA, vA) ** 0.5 + magB = dot_product(vB, vB) ** 0.5 + # Get cosine value + cos_ = dot_prod / magA / magB + # Get angle in radians and then convert to degrees + angle = math.acos(dot_prod / magB / magA) + # Basically doing angle <- angle mod 360 + ang_deg = math.degrees(angle) % 360 + + if ang_deg >= 180: + ang_deg = 360 - ang_deg + if ang_deg > 90: + ang_deg = 180 - ang_deg + print("check_angle_between_2_lines. angle:", ang_deg) + + if ang_deg < diff_thres: + return True + else: + return False + except: + print("check_angle_between_2_lines. something wrong") + return False diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/pdf2image.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/pdf2image.py new file mode 100755 index 0000000..1e2c30d --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/pdf2image.py @@ -0,0 +1,45 @@ +import fitz # PyMuPDF, imported as fitz for backward compatibility reasons +import os +import glob +from tqdm import tqdm +import argparse +import cv2 +from PIL import Image + + +def convert_pdf2image(file_path, outdir, img_max_size=None): + if not os.path.exists(outdir): + os.makedirs(outdir) + doc = fitz.open(file_path) # open document + # dpi = 300 # choose desired dpi here + zoom = 2 # zoom factor, standard: 72 dpi + magnify = fitz.Matrix(zoom, zoom) + for idx, page in enumerate(doc): + pix = page.get_pixmap(matrix=magnify) # render page to an image + outpath = os.path.join( + outdir, + os.path.splitext(os.path.basename(file_path))[0] + "_" + str(idx) + ".png", + ) + pix.save(outpath) + + img = Image.open(outpath) + img = img.convert("L") + # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img.save(outpath) + # if status: + # print("OK") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pdf_dir", type=str) + parser.add_argument("--out_dir", type=str) + args = parser.parse_args() + # pdf_dir = "/home/sds/hoanglv/FWD_Raw_Data/Form POS01" + # outdir = "/home/sds/hoanglv/Projects/FWD/assets/test/test_image_transformer/template_aligner/pdf2image" + + pdf_paths = glob.glob(args.pdf_dir + "/*.pdf") + print(pdf_paths[:5]) + + for pdf_path in tqdm(pdf_paths): + convert_pdf2image(pdf_path, args.out_dir) diff --git a/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/visualize.py b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/visualize.py new file mode 100755 index 0000000..73978fd --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/templatebasedextraction/src/utils/visualize.py @@ -0,0 +1,271 @@ +import os +import glob +import math +import json +import random +from sys import prefix + +import cv2 +import numpy as np +import pandas as pd +from PIL import Image, ImageDraw, ImageFont + + +def visualize_ocr_output( + inputs, + image, + vis_dir, + prefix_name="img_visualize", + font_path="./times.ttf", + is_vis_kie=False, +): + """ + Visualize ocr output (box + text) and kie output (optional) + params: + inputs (dict/list[list,list]): keys {ocr, kie} + - ocr value format: list of item (polygon box, label, prob/kie_label) + - kie value format: not implemented + image (np.ndarray): BGR image + vis_dir (str): save directory + name_vis_image (str): prefix name of save image + font_path (str): path of font + is_vis_kie (bool): if True, third item is kie label + return: + + """ + # table_reconstruct_result = ehr_res['table_reconstruct_result'] + # assert 'ocr' in inputs, "not found 'ocr' field in inputs" + + # identity input format + if len(inputs) == 2 and isinstance(inputs[1][0], str): + ocr_result = [ + [box if isinstance(box[0], list) else box2poly(box), text, 1.0] + for box, text in zip(inputs[0], inputs[1]) + ] + else: + ocr_result = inputs["ocr"] + + if not os.path.exists(vis_dir): + print("Creating {} dir".format(vis_dir)) + os.makedirs(vis_dir) + + img_visual = draw_ocr_box_txt( + image=image, + annos=ocr_result, + font_path=font_path, + table_boxes=None, + cell_boxes=None, + para_boxes=None, + is_vis_kie=is_vis_kie, + ) + + paths = sorted( + glob.glob(vis_dir + "/" + prefix_name + "*"), + key=lambda path: int(path.split(".jpg")[0].split("_")[-1]), + ) + if len(paths) == 0: + idx_name = "1" + else: + idx_name = str(int(paths[-1].split(".jpg")[0].split("_")[-1]) + 1) + cv2.imwrite( + os.path.join(vis_dir, prefix_name + "_" + idx_name + ".jpg"), img_visual + ) + + +def export_to_csv(table_reconstruct_text, vis_dir, csv_name="table_text_reconstruct"): + paths = sorted( + glob.glob(vis_dir + "/" + csv_name + "*"), + key=lambda path: int(path.split(".csv")[0].split("_")[-1]), + ) + if len(paths) == 0: + idx_name = "1" + else: + idx_name = str(int(paths[-1].split(".csv")[0].split("_")[-1]) + 1) + df = pd.DataFrame(table_reconstruct_text) + df.to_csv(os.path.join(vis_dir, csv_name + "_" + idx_name + ".csv"), index=False) + + +def save_json(data, vis_dir, json_name="ehr_result"): + """save dictionary to json file + Args: + data (dict): + vis_dir (str): path to save json + json_name (str, optional): json name. Defaults to 'ehr_result'. + """ + paths = sorted( + glob.glob(vis_dir + "/" + json_name + "*"), + key=lambda path: int(path.split(".json")[0].split("_")[-1]), + ) + if len(paths) == 0: + idx_name = "1" + else: + idx_name = str(int(paths[-1].split(".json")[0].split("_")[-1]) + 1) + outpath = os.path.join(vis_dir, json_name + "_" + idx_name + ".json") + with open(outpath, "w", encoding="utf8") as f: + json.dump(data, f, ensure_ascii=False) + + +def draw_ocr_box_txt( + image, + annos, + scores=None, + drop_score=0.5, + font_path="test/fonts/latin.ttf", + table_boxes=None, + cell_boxes=None, + para_boxes=None, + is_vis_kie=False, +): + """ + Args: + image (np.ndarray / PIL): BGR image or PIL image + annos (list): (box, text, label/prob) + scores (list, optional): probality. Defaults to None. + drop_score (float, optional): . Defaults to 0.5. + font_path (str, optional): Path of font. Defaults to "test/fonts/latin.ttf". + Returns: + np.ndarray: BGR image + """ + + if is_vis_kie: + kie_labels = set([item[2] for item in annos]) + colors = { + label: ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + for label in kie_labels + } + + color_vis = { + "table": (255, 192, 70), + "cell": (218, 66, 15), + "paragraph": (0, 187, 148), + } + + random.seed(0) + + if isinstance(image, np.ndarray): + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt, meta_data) in enumerate(annos): + if scores is not None and scores[idx] < drop_score: + continue + + if is_vis_kie: + color = colors[meta_data] + else: + color = ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + draw_left.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + fill=color, + ) + draw_right.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + outline=color, + ) + box_height = math.sqrt( + (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2 + ) + box_width = math.sqrt( + (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2 + ) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + for c in txt: + char_size = font.getsize(c) + draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) + cur_y += char_size[1] + else: + font_size = max(int(box_height * 0.6), 20) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + + if table_boxes is not None: + img_left = draw_rectangle_pil( + img_left, table_boxes, color=color_vis["table"], width=6, label="table" + ) + if cell_boxes is not None: + img_left = draw_rectangle_pil( + img_left, cell_boxes, color=color_vis["cell"], width=5, label="cell" + ) + if para_boxes is not None: + img_left = draw_rectangle_pil( + img_left, para_boxes, color=color_vis["paragraph"], width=2, label="para" + ) + + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + img_show = cv2.cvtColor(np.array(img_show), cv2.COLOR_RGB2BGR) + return img_show + + +def draw_rectangle_pil( + pil_image, boxes, color, width=1, label=None, font_path="test/fonts/latin.ttf" +): + """ + Args: + pil_image ([type]): [description] + boxes (list): list of [xmin, ymim, xmax, ymax] + color (list): list of (R, G, B) + """ + drawer = ImageDraw.Draw(pil_image) + color = tuple((int(color[0]), int(color[1]), int(color[2]))) + for box in boxes: + drawer.rectangle( + [(int(box[0]), int(box[1])), (int(box[2]), int(box[3]))], + outline=color, + width=width, + ) + + if label: + font_size = 35 + font = ImageFont.truetype(font_path, size=32, encoding="utf-8") + drawer.text( + [int(box[0]) + 5, int(box[1]) - font_size - 5], + label, + fill=color, + font=font, + ) + return pil_image + + +def box2poly(box): + """ + Convert box format to polygon format: xyxy to xyxyxyxy + """ + xmin, ymin, xmax, ymax = box + poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + return poly diff --git a/cope2n-ai-fi/modules/TemplateMatching/textdetection/serve_model.py b/cope2n-ai-fi/modules/TemplateMatching/textdetection/serve_model.py new file mode 100755 index 0000000..d9acece --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/textdetection/serve_model.py @@ -0,0 +1,87 @@ +import os +import yaml +from pathlib import Path +from PIL import Image +from io import BytesIO +import numpy as np +import torch +from sdsvtd import StandaloneYOLOXRunner + +from common.utils.word_formation import Word, words_to_lines + +def read_imagefile(file) -> Image.Image: + image = Image.open(BytesIO(file)) + return image + +def sort_bboxes(lbboxes)->tuple[list, list]: + lWords = [Word(bndbox = bbox) for bbox in lbboxes] + list_lines, _ = words_to_lines(lWords) + lbboxes_ = list() + for line in list_lines: + for word_group in line.list_word_groups: + for word in word_group.list_words: + lbboxes_.append(word.boundingbox) + return lbboxes_ + +class Predictor: + def __init__(self, setting_file='./setting.yml'): + with open(setting_file) as f: + # use safe_load instead load + self.setting = yaml.safe_load(f) + + base_path = Path(__file__).parent + model_config_path = os.path.join(base_path, '../' , self.setting['model_config']) + self.mode = self.setting['mode'] + device = self.setting['device'] + + if self.mode == 'trt': + import sys + sys.path.append(self.setting['mmdeploy_path']) + from mmdeploy.utils import get_input_shape, load_config + from mmdeploy.apis.utils import build_task_processor + + deploy_config_path = os.path.join(base_path, '../' , self.setting['deploy_config']) + + class TensorRTInfer: + def __init__(self, deploy_config_path, model_config_path, checkpoint_path, device='cuda:0'): + deploy_cfg, model_cfg = load_config(deploy_config_path, model_config_path) + self.task_processor = build_task_processor(model_cfg, deploy_cfg, device) + self.model = self.task_processor.init_backend_model([checkpoint_path]) + self.input_shape = get_input_shape(deploy_cfg) + + def __call__(self, images): + model_input, _ = self.task_processor.create_input(images, self.input_shape) + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **model_input) + return results + + checkpoint_path = self.setting['checkpoint'] + self.trt_infer = TensorRTInfer(deploy_config_path, model_config_path, checkpoint_path, device=device) + elif self.mode == 'torch': + self.runner = StandaloneYOLOXRunner(version=self.setting['model_config'], device=device) + else: + raise ValueError('No such inference mode') + + def __call__(self, images): + if self.mode == 'torch': + result = [] + for image in images: + result.append(self.runner(image)) + elif self.mode == 'tensorrt': + result = self.trt_infer(images) + + sorted_result = [] + for res, image in zip(result, images): + h, w = image.shape[:2] + res = res[0][:, :4] # leave out confidence score + + # clip inside image range + res[:, 0] = np.clip(res[:, 0], a_min=0, a_max=w) + res[:, 2] = np.clip(res[:, 2], a_min=0, a_max=w) + res[:, 1] = np.clip(res[:, 1], a_min=0, a_max=h) + res[:, 3] = np.clip(res[:, 3], a_min=0, a_max=h) + + res = res.astype(int).tolist() + res = sort_bboxes(res) + sorted_result.append(res) + return sorted_result \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/textdetection/setting.yml b/cope2n-ai-fi/modules/TemplateMatching/textdetection/setting.yml new file mode 100755 index 0000000..ff183e5 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/textdetection/setting.yml @@ -0,0 +1,7 @@ +mode: torch +mmdeploy_path: /home/sds/hoangmd/mmdeploy +deploy_config: configs/detection_custom_tensorrt_dynamic-320x320-1344x1344.py +model_config: yolox-s-general-text-pretrain-20221226 +# checkpoint: /home/sds/hoangmd/mmdeploy/yolox_trt_fp16/end2end.engine +# checkpoint: /home/sds/datnt/mmdetection/logs/textdet-fwd-20221226/best_lite.pth +device: cuda:0 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/textrecognition/configs/satrn_big.py b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/configs/satrn_big.py new file mode 100755 index 0000000..3f3b7d1 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/configs/satrn_big.py @@ -0,0 +1,1115 @@ +checkpoint_config = dict(interval=1) +log_config = dict(interval=50, hooks=[dict(type="TextLoggerHook")]) +dist_params = dict(backend="nccl") +log_level = "INFO" +load_from = None +resume_from = "logs/satrn_big_2022-10-31/last.pth" +workflow = [("train", 1)] +opencv_num_threads = 0 +mp_start_method = "fork" +img_h = 32 +img_w = 128 +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="ResizeOCR", + height=32, + min_width=128, + max_width=128, + keep_aspect_ratio=False, + width_downsample_ratio=0.25, + ), + dict(type="ShearOCR", p=0.5, shear_limit=45), + dict( + type="ColorJitterOCR", + p=0.5, + brightness=0.25, + contrast=0.25, + saturation=0.25, + hue=0.25, + ), + dict(type="GaussianNoiseOCR", p=0.5), + dict(type="GaussianBlurOCR", blur=(3, 5), p=0.5), + dict(type="BlackBoxAttackOCR", p=0.5, box_size=12), + dict(type="DotAttackOCR", p=0.5, dot_size=(1, 3), dot_space=(5, 8)), + dict(type="LineAttackOCR", p=0.5, line_size=(1, 3), line_space=(5, 8)), + dict(type="InvertOCR", p=0.2), + dict(type="ToTensorOCR"), + dict(type="NormalizeOCR", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + dict( + type="Collect", + keys=["img"], + meta_keys=[ + "filename", + "ori_shape", + "img_shape", + "text", + "valid_ratio", + "resize_shape", + ], + ), +] +test_pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="MultiRotateAugOCR", + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type="ResizeOCR", + height=32, + min_width=128, + max_width=128, + keep_aspect_ratio=False, + width_downsample_ratio=0.25, + ), + dict(type="ToTensorOCR"), + dict( + type="NormalizeOCR", + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + dict( + type="Collect", + keys=["img"], + meta_keys=[ + "filename", + "ori_shape", + "img_shape", + "valid_ratio", + "resize_shape", + "img_norm_cfg", + "ori_filename", + ], + ), + ], + ), +] +dataset_type = "OCRDataset" +img_path_prefix = "data/Recognition/Real/" +dataset_list = "data/AnnFiles/current-dirs/2022-10-19/" +default_loader = dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", keys=["filename", "text"], keys_idx=[0, 1], separator=" " + ), +) +default_dataset = dict( + type="OCRDataset", + img_prefix=None, + ann_file=None, + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +handwriten_train = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Handwritten_Train/",), + ann_file="data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Handwritten_Train.txt", + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +printed_train = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Printed_Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Printed_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +handwriten_val = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Handwritten_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Handwritten_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +printed_val = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Printed_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Printed_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +synthetic = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Synthetic/Using/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Synthetic_Using.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +blank_space = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Blank/Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Blank_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +captcha_train = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Captcha_Train/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Train_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +captcha_val = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Captcha_Val/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Val_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +kie_train = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/KIE_Train/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Train_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +kie_val = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/KIE_Val/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Val_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +gplx_train = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/GPLX_Train/train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Train_train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +gplx_val = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/GPLX_Val/val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Val_val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +vietocr = dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/VietOCR_Train/Data/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_VietOCR_Train_Data.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, +) +train_list = [ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Handwritten_Train/",), + ann_file="data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Handwritten_Train.txt", + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Printed_Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Printed_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Synthetic/Using/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Synthetic_Using.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Blank/Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Blank_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Captcha_Train/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Train_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/KIE_Train/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Train_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/GPLX_Train/train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Train_train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/VietOCR_Train/Data/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_VietOCR_Train_Data.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), +] +val_list = [ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Handwritten_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Handwritten_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Printed_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Printed_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Captcha_Val/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Val_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/KIE_Val/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Val_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/GPLX_Val/val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Val_val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), +] +test_list = [ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Handwritten_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Handwritten_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Printed_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Printed_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), +] +fp16 = dict(loss_scale="dynamic") +label_convertor = dict(type="AttnConvertor", dict_type="DICT224", with_unknown=False) +model = dict( + type="SATRN", + backbone=dict(type="ResNetABI", in_channels=3, stem_channels=16, base_channels=16), + encoder=dict( + type="SatrnEncoder", + n_layers=12, + n_head=8, + d_k=32, + d_v=32, + d_model=256, + n_position=100, + d_inner=1024, + dropout=0.1, + ), + decoder=dict( + type="NRTRDecoder", + n_layers=12, + d_embedding=256, + n_head=8, + d_model=256, + d_inner=1024, + d_k=32, + d_v=32, + ), + loss=dict(type="TFLoss"), + label_convertor=dict(type="AttnConvertor", dict_type="DICT224", with_unknown=False), + max_seq_len=25, +) +optimizer = dict(type="Adam", lr=0.001) +optimizer_config = dict(grad_clip=None) +lr_config = dict(policy="poly", power=0.9, min_lr=1e-06, by_epoch=False) +total_epochs = 15 +custom_hooks = [ + dict( + type="ExpMomentumEMAHook", + total_iter=20000, + resume_from=None, + momentum=0.0001, + priority=49, + ) +] +data = dict( + samples_per_gpu=160, + workers_per_gpu=16, + val_dataloader=dict(samples_per_gpu=400), + test_dataloader=dict(samples_per_gpu=400), + train=dict( + type="UniformConcatDataset", + datasets=[ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Handwritten_Train/",), + ann_file="data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Handwritten_Train.txt", + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Printed_Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Train_Printed_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Synthetic/Using/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Synthetic_Using.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Blank/Train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Blank_Train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/Captcha_Train/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Train_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/KIE_Train/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Train_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/GPLX_Train/train/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Train_train.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Train/VietOCR_Train/Data/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_VietOCR_Train_Data.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + ], + pipeline=[ + dict(type="LoadImageFromFile"), + dict( + type="ResizeOCR", + height=32, + min_width=128, + max_width=128, + keep_aspect_ratio=False, + width_downsample_ratio=0.25, + ), + dict(type="ShearOCR", p=0.5, shear_limit=45), + dict( + type="ColorJitterOCR", + p=0.5, + brightness=0.25, + contrast=0.25, + saturation=0.25, + hue=0.25, + ), + dict(type="GaussianNoiseOCR", p=0.5), + dict(type="GaussianBlurOCR", blur=(3, 5), p=0.5), + dict(type="BlackBoxAttackOCR", p=0.5, box_size=12), + dict(type="DotAttackOCR", p=0.5, dot_size=(1, 3), dot_space=(5, 8)), + dict(type="LineAttackOCR", p=0.5, line_size=(1, 3), line_space=(5, 8)), + dict(type="InvertOCR", p=0.2), + dict(type="ToTensorOCR"), + dict( + type="NormalizeOCR", + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + dict( + type="Collect", + keys=["img"], + meta_keys=[ + "filename", + "ori_shape", + "img_shape", + "text", + "valid_ratio", + "resize_shape", + ], + ), + ], + ), + val=dict( + type="UniformConcatDataset", + datasets=[ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Handwritten_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Handwritten_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Printed_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Printed_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Captcha_Val/DONE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_Captcha_Val_DONE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/KIE_Val/KIE/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_KIE_Val_KIE.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/GPLX_Val/val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition_GPLX_Val_val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + ], + pipeline=[ + dict(type="LoadImageFromFile"), + dict( + type="MultiRotateAugOCR", + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type="ResizeOCR", + height=32, + min_width=128, + max_width=128, + keep_aspect_ratio=False, + width_downsample_ratio=0.25, + ), + dict(type="ToTensorOCR"), + dict( + type="NormalizeOCR", + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + dict( + type="Collect", + keys=["img"], + meta_keys=[ + "filename", + "ori_shape", + "img_shape", + "valid_ratio", + "resize_shape", + "img_norm_cfg", + "ori_filename", + ], + ), + ], + ), + ], + ), + test=dict( + type="UniformConcatDataset", + datasets=[ + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Handwritten_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Handwritten_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + dict( + type="OCRDataset", + img_prefix=("data/Recognition/Real/Val/Printed_Val/",), + ann_file=( + "data/AnnFiles/current-dirs/2022-10-19/text_recognition__Val_Printed_Val.txt", + ), + loader=dict( + type="AnnFileLoader", + repeat=1, + parser=dict( + type="LineStrParser", + keys=["filename", "text"], + keys_idx=[0, 1], + separator=" ", + ), + ), + pipeline=None, + test_mode=False, + ), + ], + pipeline=[ + dict(type="LoadImageFromFile"), + dict( + type="MultiRotateAugOCR", + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type="ResizeOCR", + height=32, + min_width=128, + max_width=128, + keep_aspect_ratio=False, + width_downsample_ratio=0.25, + ), + dict(type="ToTensorOCR"), + dict( + type="NormalizeOCR", + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + dict( + type="Collect", + keys=["img"], + meta_keys=[ + "filename", + "ori_shape", + "img_shape", + "valid_ratio", + "resize_shape", + "img_norm_cfg", + "ori_filename", + ], + ), + ], + ), + ], + ), +) +evaluation = dict(interval=1, metric="acc") +work_dir = "logs/satrn_big_2022-10-31/" +gpu_ids = [0] diff --git a/cope2n-ai-fi/modules/TemplateMatching/textrecognition/setting.yml b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/setting.yml new file mode 100755 index 0000000..b10b83e --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/setting.yml @@ -0,0 +1,6 @@ +# config: configs/satrn_big.py +config: /home/sds/datnt/mmocr/logs/satrn_lite_2023-01-08-handwritten/satrn_lite.py +# checkpoint: /home/sds/datnt/mmocr/logs/satrn_big_2022-10-31/best.pth +checkpoint: /home/sds/datnt/mmocr/logs/satrn_lite_2023-01-08-handwritten/best.pth +batch_size: 256 +device: cuda:0 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/TemplateMatching/textrecognition/src/serve_model.py b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/src/serve_model.py new file mode 100755 index 0000000..10321a3 --- /dev/null +++ b/cope2n-ai-fi/modules/TemplateMatching/textrecognition/src/serve_model.py @@ -0,0 +1,20 @@ +# dirty path export +from sdsvtr import StandaloneSATRNRunner +import yaml + +class Predictor: + def __init__(self, setting_file='./setting.yml'): + with open(setting_file) as f: + # use safe_load instead load + self.setting = yaml.safe_load(f) + + self.batch_size = self.setting['batch_size'] + self.runner = StandaloneSATRNRunner(version='satrn-lite-general-pretrain-20230106', + return_confident=True, device=self.setting['device']) + + def __call__(self, images): + results = [] + for i in range(0, len(images), self.batch_size): + result = self.runner(images[i:i+self.batch_size]) + results += result[0] + return results diff --git a/cope2n-ai-fi/modules/__init__.py b/cope2n-ai-fi/modules/__init__.py new file mode 100644 index 0000000..fa2d966 --- /dev/null +++ b/cope2n-ai-fi/modules/__init__.py @@ -0,0 +1 @@ +from modules.sdsvkvu.sdsvkvu import * \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/.gitignore b/cope2n-ai-fi/modules/_sdsvkvu/.gitignore new file mode 100644 index 0000000..17d9c1a --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/.gitignore @@ -0,0 +1,24 @@ +# Model weights +weights/ +microsoft/ +nltk_data/ + +# Visualize +visualize + +# External +sdsvkvu/externals/ocr_engine_deskew/externals/ + +# +__pycache__ +*/__pycache__ +*/*/__pycache__ + +# +.git_temp/ + +# Packages +build/ +dist/ + + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/.gitmodules b/cope2n-ai-fi/modules/_sdsvkvu/.gitmodules new file mode 100644 index 0000000..eb6d053 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/.gitmodules @@ -0,0 +1,4 @@ + +[submodule "sdsvkvu/externals/basic_ocr"] + path = sdsvkvu/externals/basic_ocr + url = https://code.sdsdev.co.kr/tuanlv/IDP-BasicOCR.git diff --git a/cope2n-ai-fi/modules/_sdsvkvu/LICENSE b/cope2n-ai-fi/modules/_sdsvkvu/LICENSE new file mode 100644 index 0000000..21ffbaa --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/LICENSE @@ -0,0 +1,13 @@ +Copyright 2023 tuanlv + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/MANIFEST.in b/cope2n-ai-fi/modules/_sdsvkvu/MANIFEST.in new file mode 100644 index 0000000..e4bd723 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/MANIFEST.in @@ -0,0 +1,2 @@ +include sdsvkvu/weights/*/*.yaml +include sdsvkvu/weights/*/checkpoints/best_model.pth \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/README.md b/cope2n-ai-fi/modules/_sdsvkvu/README.md new file mode 100644 index 0000000..d8b1dd1 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/README.md @@ -0,0 +1,122 @@ +

+

SDSVKVU

+

+ + ***Feature*** + - Extract pairs of key-value in documents: Invoice/Receipt, Forms, Government documents (Id cards, driver license, birth's certificate) + - Language: VI + EN + + ***What's news*** + ### - Ver 0.0.1: + - Support inputs: image, PDF file (single or multi pages) + - Extract all pairs key-value return raw_outputs + + Weights: weights/key_value_understanding-20230716-085549_final + - For VAT invoices : Extract 14 specific fields + + Weights: weights/key_value_understanding-20230627-164536_fi + - For SBT invoices ("sbt" option): Extract table in SBT invoice + + Weights: weights/key_value_understanding-20230812-170826_sbt_2 + ### - Ver 0.0.2: Add more option: "vtb" - Vietin Bank + - For Vietin Bank document ("vtb" option): Extract 6 specific fileds + + Weights: weights/key_value_understanding-20230824-164236_vietin + ### - Ver 0.0.3: Add default option: + - Return all potential pairs of key-value, title, only key, triplet, and table with raw key + ### - Ver 0.0.4: Add more option: "manulife" - Manulife Issurance + - For Manulife Insurance document ("manulife" option): Extract all potential pairs of key-value, title, only key, triplet, and table with raw key + Type of medical documents + + Weights: weights/key_value_understanding-20231024-125646_manulife2 + ### Ver 0.1.0: Modify KVU model for SBT + ### - Ver 0.1.0: Add option: "sbt_v2" - SBT project + - For SBT imei/invoice ("sbt_v2" option): Extract 4 specific fields + + Weights: weights/key_value_understanding_for_sbt-20231108-143935 + + ## I. Setup + ***Dependencies*** + - Python: 3.10 + - Torch: 1.11.3 + - CUDA: 11.6 + - transformers: 4.30.0 + ``` + pip install -v -e . + ``` + + + ## II. Inference + run cmd: python test.py + ``` + import os + from sdsvkvu import load_engine, process_img + os.environ["CUDA_VISIBLE_DEVICES"]="1" + + if __name__ == "__main__": + kwargs = {"device": "cuda:0"} + img_dir = "/mnt/ssd1T/tuanlv/02-KVU/sdsvkvu/visualize/test_img/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg" + save_dir = "/mnt/ssd1T/tuanlv/02-KVU/sdsvkvu/visualize/test2/" + engine = load_engine(kwargs) + # option: "vat" for vat invoice outputs, "sbt": sbt invoice outputs, else for raw outputs + outputs = process_img(img_dir, save_dir, engine, export_all=False, option="vat") + ``` + + # Structure project + . + ├── sdsvkvu + │   ├── main.py + ├── externals + │   │   ├── __init__.py + │   │   ├── basic_ocr + │   │   │   ├── ... + │   │   ├── ocr_engine + │   │   │   ├── ... + │   │   ├── ocr_engine_deskew + │   │   │   ├── ... + │   ├── model + │   │   ├── combined_model.py + │   │   ├── document_kvu_model.py + │   │   ├── __init__.py + │   │   ├── kvu_model.py + │   │   └── relation_extractor.py + │   ├── modules + │   │   ├── __init__.py + │   │   ├── predictor.py + │   │   ├── preprocess.py + │   │   └── run_ocr.py + │   ├── requirements.txt + │   ├── settings.yml + │   ├── sources + │   │   ├── __init__.py + │   │   ├── kvu.py + │   │   └── utils.py + │   ├── utils + │   │   ├── dictionary + │   │   │   ├── __init__.py + │   │   │   ├── sbt.py + │   │   │   └── vat.py + │   │   │   └── vtb.py + │   │   │   ├── manulife.py + │   │   │   ├── sbt_v2.py + │   │   ├── __init__.py + │   │   ├── post_processing.py + │   │   ├── query + │   │   │   ├── __init__.py + │   │   │   ├── sbt.py + │   │   │   └── vat.py + │   │   │   └── vtb.py + │   │   │   ├── all.py + │   │   │   ├── manulife.py + │   │   │   ├── sbt_v2.py + │   │   └── utils.py + ├── weights + │   └── key_value_understanding-20230627-164536_fi + │   ├── key_value_understanding-20230812-170826_sbt_2 + │   └── key_value_understanding-20230716-085549_final + │   └── key_value_understanding-20230824-164236_vietin + │   └── key_value_understanding-20231024-125646_manulife2 + │   └── key_value_understanding_for_sbt-20231108-143935 + ├── LICENSE + ├── MANIFEST.in + ├── pyproject.toml + ├── README.md + ├── scripts + │   └── run.sh + ├── setup.cfg + ├── setup.py + ├── test.py + └── visualize \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/__init__.py new file mode 100644 index 0000000..1621835 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/__init__.py @@ -0,0 +1,8 @@ +import os +import sys +from pathlib import Path +cur_dir = str(Path(__file__).parents[0]) +sys.path.append(cur_dir) +sys.path.append(os.path.join(cur_dir, "sdsvkvu")) + +from sdsvkvu import load_engine, process_img, process_pdf, process_dir \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/draw_img.jpg b/cope2n-ai-fi/modules/_sdsvkvu/draw_img.jpg new file mode 100644 index 0000000..f1581e1 Binary files /dev/null and b/cope2n-ai-fi/modules/_sdsvkvu/draw_img.jpg differ diff --git a/cope2n-ai-fi/modules/_sdsvkvu/pyproject.toml b/cope2n-ai-fi/modules/_sdsvkvu/pyproject.toml new file mode 100644 index 0000000..ed1b0c7 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=65", + "wheel" +] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/requirements.txt b/cope2n-ai-fi/modules/_sdsvkvu/requirements.txt new file mode 100644 index 0000000..8616d6d --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/requirements.txt @@ -0,0 +1,28 @@ +nltk +six +deskew +jdeskew +pdf2image +omegaconf +imagesize +xmltodict +dicttoxml +terminaltables +Pillow>=9.4.0 +nptyping==1.4.2 +opencv-python==4.5.4.60 ## +opencv-python-headless==4.5.4.60 +overrides==4.1.2 +# transformers==4.30.0 +sentencepiece==0.1.99 +seqeval==0.0.12 +tensorboard>=2.2.0 +scipy==1.9.1 +# code-style +isort==5.9.3 +black==21.9b0 +# pytorch +# --find-links https://download.pytorch.org/whl/torch_stable.html +# torch==1.13.1+cu116 +# torchvision==0.14.1+cu116 +tldextract==5.1.1 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/scripts/run.sh b/cope2n-ai-fi/modules/_sdsvkvu/scripts/run.sh new file mode 100644 index 0000000..69c3092 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/scripts/run.sh @@ -0,0 +1,26 @@ +cd /mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu +export CUDA_VISIBLE_DEVICES=0 + +python sdsvkvu/main.py \ + --img_dir /mnt/hdd4T/OCR/tuanlv/00-Datasets/SBT_DATA/invoice_validation \ + --save_dir /mnt/hdd4T/OCR/tuanlv/02-KVU/02-KVU_test/visualize/sbt_invoice \ + --kvu_params "{\"device\":\"cuda:0\"}" \ + --doc_type "sbt_v2" \ + --export_img 1 + + +# python sdsvkvu/main.py \ +# --img_dir /mnt/hdd4T/OCR/tuanlv/00-Datasets/SBT_DATA/imei_validation \ +# --save_dir /mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/visualize/test_sbt_imei \ +# --kvu_params "{\"device\":\"cuda:0\"}" \ +# --doc_type "sbt_v2" \ +# --export_img 1 + + + +# python sdsvkvu/main.py \ +# --img_dir /mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/visualize/test_sbt2 \ +# --save_dir /mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/visualize/test_sbt2 \ +# --kvu_params "{\"device\":\"cuda:0\"}" \ +# --doc_type "sbt_v2" \ +# --export_img 1 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/PKG-INFO b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/PKG-INFO new file mode 100644 index 0000000..d9e424f --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/PKG-INFO @@ -0,0 +1,122 @@ +Metadata-Version: 2.1 +Name: sdsvkvu +Version: 0.0.1 +Summary: SDSV OCR Team: Key-value understanding +Home-page: https://github.com/open-mmlab/mmocr +Author: tuanlv +Author-email: lv.tuan3@samsung.com +License: Apache License 2.0 +Classifier: Development Status :: 4 - Beta +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3.9 +Requires-Python: >=3.9 +Description-Content-Type: text/markdown +License-File: LICENSE + +

+

SDSVKVU

+

+ + ***Feature*** + - Extract pairs of key-value in documents: Invoice/Receipt, Forms, Government documents (Id cards, driver license, birth's certificate) + - Language: VI + EN + + ***What's news*** + ### - Ver 0.0.1: + - Support inputs: image, PDF file (single or multi pages) + - Extract all pairs key-value return raw_outputs + + Weights: sdsvkvu/weights/key_value_understanding-20230716-085549_final + - For VAT invoices : Extract 14 specific fields + + Weights: sdsvkvu/weights/key_value_understanding-20230627-164536_fi + - For SBT invoices ("sbt" option): Extract table in SBT invoice + + Weights: sdsvkvu/weights/key_value_understanding-20230617-162324_sbt + ### - Ver 0.0.2: Add more option: "vtb" - Vietin Bank + - For Vietin Bank document ("vtb" option): Extract 6 specific fileds + + Weights: sdsvkvu/weights/key_value_understanding-20230824-164236_vietin + ### - Ver 0.0.3: Add default option: + - Return all potential pairs of key-value, title, only key, triplet, and table with raw key + + ## I. Setup + ***Dependencies*** + - Python: 3.10 + - Torch: 1.11.3 + - CUDA: 11.6 + - transformers: 4.30.0 + ``` + pip install -v -e . + ``` + + + ## II. Inference + run cmd: python test.py + ``` + import os + from sdsvkvu import load_engine, process_img + os.environ["CUDA_VISIBLE_DEVICES"]="1" + + if __name__ == "__main__": + kwargs = {"device": "cuda:0"} + img_dir = "/mnt/ssd1T/tuanlv/02-KVU/sdsvkvu/visualize/test_img/RedInvoice_WaterPurfier_Feb_PVI_829_0.jpg" + save_dir = "/mnt/ssd1T/tuanlv/02-KVU/sdsvkvu/visualize/test2/" + engine = load_engine(kwargs) + # option: "vat" for vat invoice outputs, "sbt": sbt invoice outputs, else for raw outputs + outputs = process_img(img_dir, save_dir, engine, export_all=False, option="vat") + ``` + + # Structure project + . + ├── sdsvkvu + │   ├── main.py + ├── externals + │   │   ├── __init__.py + │   │   ├── ocr_engine + │   │   │   ├── ... + │   │   ├── ocr_engine_deskew + │   │   │   ├── ... + │   ├── model + │   │   ├── combined_model.py + │   │   ├── document_kvu_model.py + │   │   ├── __init__.py + │   │   ├── kvu_model.py + │   │   └── relation_extractor.py + │   ├── modules + │   │   ├── __init__.py + │   │   ├── predictor.py + │   │   ├── preprocess.py + │   │   └── run_ocr.py + │   ├── requirements.txt + │   ├── settings.yml + │   ├── sources + │   │   ├── __init__.py + │   │   ├── kvu.py + │   │   └── utils.py + │   ├── utils + │   │   ├── dictionary + │   │   │   ├── __init__.py + │   │   │   ├── sbt.py + │   │   │   └── vat.py + │   │   │   └── vtb.py + │   │   ├── __init__.py + │   │   ├── post_processing.py + │   │   ├── query + │   │   │   ├── __init__.py + │   │   │   ├── sbt.py + │   │   │   └── vat.py + │   │   │   └── vtb.py + │   │   └── utils.py + │   └── weights + │   └── key_value_understanding-20230627-164536_fi + │   ├── key_value_understanding-20230617-162324_sbt + │   └── key_value_understanding-20230716-085549_final + │   └── key_value_understanding-20230824-164236_vietin + ├── LICENSE + ├── MANIFEST.in + ├── pyproject.toml + ├── README.md + ├── scripts + │   └── run.sh + ├── setup.cfg + ├── setup.py + ├── test.py + └── visualize diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/SOURCES.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/SOURCES.txt new file mode 100644 index 0000000..fc2dccf --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/SOURCES.txt @@ -0,0 +1,49 @@ +LICENSE +MANIFEST.in +README.md +pyproject.toml +setup.cfg +setup.py +sdsvkvu/__init__.py +sdsvkvu/main.py +sdsvkvu.egg-info/PKG-INFO +sdsvkvu.egg-info/SOURCES.txt +sdsvkvu.egg-info/dependency_links.txt +sdsvkvu.egg-info/not-zip-safe +sdsvkvu.egg-info/requires.txt +sdsvkvu.egg-info/top_level.txt +sdsvkvu/externals/__init__.py +sdsvkvu/externals/basic_ocr/__init__.py +sdsvkvu/externals/basic_ocr/run.py +sdsvkvu/externals/ocr_engine/__init__.py +sdsvkvu/externals/ocr_engine/run.py +sdsvkvu/externals/ocr_engine_deskew/__init__.py +sdsvkvu/externals/ocr_engine_deskew/run.py +sdsvkvu/model/__init__.py +sdsvkvu/model/combined_model.py +sdsvkvu/model/document_kvu_model.py +sdsvkvu/model/kvu_model.py +sdsvkvu/model/relation_extractor.py +sdsvkvu/model/sbt_model.py +sdsvkvu/modules/__init__.py +sdsvkvu/modules/predictor.py +sdsvkvu/modules/preprocess.py +sdsvkvu/modules/run_ocr.py +sdsvkvu/sources/__init__.py +sdsvkvu/sources/kvu.py +sdsvkvu/sources/utils.py +sdsvkvu/utils/__init__.py +sdsvkvu/utils/post_processing.py +sdsvkvu/utils/utils.py +sdsvkvu/utils/word2line.py +sdsvkvu/utils/dictionary/__init__.py +sdsvkvu/utils/dictionary/sbt.py +sdsvkvu/utils/dictionary/sbt_v2.py +sdsvkvu/utils/dictionary/vat.py +sdsvkvu/utils/dictionary/vtb.py +sdsvkvu/utils/query/__init__.py +sdsvkvu/utils/query/all.py +sdsvkvu/utils/query/sbt.py +sdsvkvu/utils/query/sbt_v2.py +sdsvkvu/utils/query/vat.py +sdsvkvu/utils/query/vtb.py \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/dependency_links.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/not-zip-safe b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/not-zip-safe new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/requires.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/requires.txt new file mode 100644 index 0000000..dac7835 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/requires.txt @@ -0,0 +1,21 @@ +nltk +six +deskew +jdeskew +pdf2image +omegaconf +imagesize +xmltodict +dicttoxml +terminaltables +Pillow==9.4.0 +nptyping==1.4.2 +opencv-python==4.5.4.60 +opencv-python-headless==4.5.4.60 +overrides==4.1.2 +sentencepiece==0.1.99 +seqeval==0.0.12 +tensorboard>=2.2.0 +scipy==1.9.1 +isort==5.9.3 +black==21.9b0 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/top_level.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/top_level.txt new file mode 100644 index 0000000..f6f7634 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu.egg-info/top_level.txt @@ -0,0 +1 @@ +sdsvkvu diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/__init__.py new file mode 100644 index 0000000..b90f78f --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/__init__.py @@ -0,0 +1,4 @@ +from .main import load_engine +from .main import process_img +from .main import process_pdf +from .main import process_dir diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/.gitignore b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/.gitignore new file mode 100644 index 0000000..1250e95 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +visualize/ +results/ +*.jpeg +*.jpg +*.png diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/README.md b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/README.md new file mode 100644 index 0000000..ca1b349 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/README.md @@ -0,0 +1,47 @@ +# OCR Engine + +OCR Engine is a Python package that combines text detection and recognition models from [mmdet](https://github.com/open-mmlab/mmdetection) and [mmocr](https://github.com/open-mmlab/mmocr) to perform Optical Character Recognition (OCR) on various inputs. The package currently supports three types of input: a single image, a recursive directory, or a csv file. + +## Installation + +To install OCR Engine, clone the repository and install the required packages: + +```bash +git clone git@github.com:mrlasdt/ocr-engine.git +cd ocr-engine +pip install -r requirements.txt + +``` + + +## Usage + +To use OCR Engine, simply run the `ocr_engine.py` script with the desired input type and input path. For example, to perform OCR on a single image: + +```css +python ocr_engine.py --input_type image --input_path /path/to/image.jpg +``` + +To perform OCR on a recursive directory: + +```css +python ocr_engine.py --input_type directory --input_path /path/to/directory/ + +``` + +To perform OCR on a csv file: + + +``` +python ocr_engine.py --input_type csv --input_path /path/to/file.csv +``` + +OCR Engine will automatically detect and recognize text in the input and output the results in a CSV file named `ocr_results.csv`. + +## Contributing + +If you would like to contribute to OCR Engine, please fork the repository and submit a pull request. We welcome contributions of all types, including bug fixes, new features, and documentation improvements. + +## License + +OCR Engine is released under the [MIT License](https://opensource.org/licenses/MIT). See the LICENSE file for more information. diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/TODO.todo b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/TODO.todo new file mode 100644 index 0000000..34df095 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/TODO.todo @@ -0,0 +1,10 @@ +☐ refactor argument parser of run.py +☐ add timer level, logging level and write_mode to argumments +☐ add paddleocr deskew to the code +☐ fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size +☐ ocr engine import took too long +☐ add word level to write_mode +☐ add word group and line +change max_x_dist from pixel to percentage of box width +☐ visualization: adjust fontsize dynamically + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/__init__.py new file mode 100644 index 0000000..aabc310 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/__init__.py @@ -0,0 +1,11 @@ +# # Define package-level variables +# __version__ = '0.0' + +# Import modules +from .src.ocr import OcrEngine +# from .src.word_formation import words_to_lines +from .src.word_formation import words_to_lines_tesseract as words_to_lines +from .src.utils import ImageReader, read_ocr_result_from_txt +from .src.dto import Word, Line, Page, Document, Box +# Expose package contents +__all__ = ["OcrEngine", "Box", "Word", "Line", "Page", "Document", "words_to_lines", "ImageReader", "read_ocr_result_from_txt"] diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/.gitignore b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/.gitignore new file mode 100644 index 0000000..e56dbb2 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/.gitignore @@ -0,0 +1,9 @@ +output* +*.pyc +*.jpg +check +weights/ +workdirs/ +__pycache__* +test_hungbnt.py +libs* \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/README.md b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/README.md new file mode 100644 index 0000000..d1b9637 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/README.md @@ -0,0 +1,29 @@ +

+

Dewarp

+

+ +***Feature*** +- Align document + + +## I. Setup +***Dependencies*** +- Python: 3.8 +- Torch: 1.10.2 +- CUDA: 11.6 +- transformers: 4.28.1 +### 1. Install PaddlePaddle +``` +python -m pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +### 2. Install sdsv_dewarp +``` +pip install -v -e . +``` + + +## II. Test +``` +python test.py --input samples --out demo/outputs --device 'cuda' +``` diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/cls.yaml b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/cls.yaml new file mode 100644 index 0000000..4c4cfdc --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/cls.yaml @@ -0,0 +1,3 @@ +model_dir: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_ppocr_mobile_v2.0_cls_infer +gpu_mem: 3000 +max_batch_size: 32 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/det.yaml b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/det.yaml new file mode 100644 index 0000000..f218ef1 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/config/det.yaml @@ -0,0 +1,8 @@ +model_dir: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_PP-OCRv3_det_infer +gpu_mem: 3000 +det_limit_side_len: 1560 +det_limit_type: max +det_db_unclip_ratio: 1.85 +det_db_thresh: 0.3 +det_db_box_thresh: 0.5 +det_db_score_mode: fast \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/requirements.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/requirements.txt new file mode 100644 index 0000000..768fde9 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/requirements.txt @@ -0,0 +1,7 @@ + +paddleocr>=2.0.1 +opencv-contrib-python +opencv-python +numpy +gdown==3.13.0 +imgaug==0.4.0 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO new file mode 100644 index 0000000..634708d --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO @@ -0,0 +1,45 @@ +Metadata-Version: 2.1 +Name: sdsv-dewarp +Version: 1.0.0 +Summary: Dewarp document +Home-page: +License: Apache License 2.0 +Classifier: Development Status :: 4 - Beta +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Description-Content-Type: text/markdown + +

+

Dewarp

+

+ +***Feature*** +- Align document + + +## I. Setup +***Dependencies*** +- Python: 3.8 +- Torch: 1.10.2 +- CUDA: 11.6 +- transformers: 4.28.1 +### 1. Install PaddlePaddle +``` +python -m pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +### 2. Install sdsv_dewarp +``` +pip install -v -e . +``` + + +## II. Test +``` +python test.py --input samples --out demo/outputs --device 'cuda' +``` diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt new file mode 100644 index 0000000..953a123 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt @@ -0,0 +1,15 @@ +README.md +setup.py +sdsv_dewarp/__init__.py +sdsv_dewarp/api.py +sdsv_dewarp/config.py +sdsv_dewarp/factory.py +sdsv_dewarp/models.py +sdsv_dewarp/utils.py +sdsv_dewarp/version.py +sdsv_dewarp.egg-info/PKG-INFO +sdsv_dewarp.egg-info/SOURCES.txt +sdsv_dewarp.egg-info/dependency_links.txt +sdsv_dewarp.egg-info/not-zip-safe +sdsv_dewarp.egg-info/requires.txt +sdsv_dewarp.egg-info/top_level.txt \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt new file mode 100644 index 0000000..89816c8 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt @@ -0,0 +1,6 @@ +paddleocr>=2.0.1 +opencv-contrib-python +opencv-python +numpy +gdown==3.13.0 +imgaug==0.4.0 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt new file mode 100644 index 0000000..a5ce4e8 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt @@ -0,0 +1 @@ +sdsv_dewarp diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/api.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/api.py new file mode 100644 index 0000000..d71ddc5 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/api.py @@ -0,0 +1,200 @@ +import math +import numpy as np +from typing import List +import cv2 +import collections +import logging +import imgaug.augmenters as iaa +from imgaug.augmentables.polys import Polygon, PolygonsOnImage + +from sdsv_dewarp.models import PaddleTextClassifier, PaddleTextDetector +from sdsv_dewarp.config import Cfg +from .utils import * + + +MIN_LONG_EDGE = 40**2 +NUMBER_BOX_FOR_ALIGNMENT = 200 +MAX_ANGLE = 180 +MIN_ANGLE = 1 +MIN_NUM_BOX_TEXT = 3 +CROP_SIZE = 3000 + +logging.basicConfig(level=logging.INFO) +LOGGER = logging.getLogger(__name__) + + +class AlignImage: + """Rotate image to 0 degree + Args: + text_detector (deepmodel): Text detection model + text_cls (deepmodel): Text classification model (0 or 180) + + Return: + is_blank (bool): Blank image when haven't boxes text + image_align: Image after alignment + angle_align: Degree of angle alignment + """ + + def __init__(self, text_detector: dict, text_cls: dict, device: str = 'cpu'): + self.text_detector = None + self.text_cls = None + self.use_gpu = True if device != 'cpu' else False + + self._init_model(text_detector, text_cls) + + def _init_model(self, text_detector, text_cls): + det_config = Cfg.load_config_from_file(text_detector['config']) + det_config['model_dir'] = text_detector['weight'] + cls_config = Cfg.load_config_from_file(text_cls['config']) + cls_config['model_dir'] = text_cls['weight'] + + self.text_detector = PaddleTextDetector(config=det_config, use_gpu=self.use_gpu) + self.text_cls = PaddleTextClassifier(config=cls_config, use_gpu=self.use_gpu) + + def _cal_width(self, poly_box): + """Calculate width of a polygon [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]""" + tl, tr, br, bl = poly_box + edge_s, edge_l = distance(tl, tr), distance(tr, br) + + return max(edge_s, edge_l) + + def _get_most_frequent(self, values): + values = np.array(values) + # create the histogram + hist, bins = np.histogram(values, bins=np.arange(0, 181, 10)) + + # get the index of the most frequent angle + index = np.argmax(hist) + + # get the most frequent angle + most_frequent_angle = (bins[index] + bins[index + 1]) / 2 + + return most_frequent_angle + + def _cal_angle(self, poly_box): + """Calculate the angle between two point""" + a = poly_box[0] + b = poly_box[1] + c = poly_box[2] + + # Get the longer edge + if distance(a, b) >= distance(b, c): + x, y = a, b + else: + x, y = b, c + + angle = math.degrees(math.atan2(-(y[1] - x[1]), y[0] - x[0])) + + if angle < 0: + angle = 180 - abs(angle) + + return angle + + def _reject_outliers(self, data, m=5.0): + """Remove noise angle""" + list_index = np.arange(len(data)) + d = np.abs(data - np.median(data)) + mdev = np.median(d) + s = d / (mdev if mdev else 1.0) + + return list_index[s < m], data[s < m] + + def __call__(self, image): + """image (np.ndarray): BGR image""" + + # Crop center image to increase speed of text detection + + image_resized = crop_image(image, crop_size=CROP_SIZE).copy() if max(image.shape) > CROP_SIZE else image.copy() + poly_box_texts = self.text_detector(image_resized) + + # draw_img = vis_ocr( + # image_resized, + # poly_box_texts, + # ) + # cv2.imwrite("draw_img.jpg", draw_img) + + is_blank = False + + # Check image is blank + if len(poly_box_texts) <= MIN_NUM_BOX_TEXT: + is_blank = True + return image, is_blank, 0 + + # # Crop document + # poly_np = np.array(poly_box_texts) + # min_x = poly_box_texts[:, 0].min() + # max_x = poly_box_texts[:, 2].max() + # min_y = poly_box_texts[:, 1].min() + # max_y = poly_box_texts[:, 3].max() + + # Filter small poly + poly_box_areas = [ + [self._cal_width(poly_box), id] + for id, poly_box in enumerate(poly_box_texts) + ] + + poly_box_areas = sorted(poly_box_areas)[-NUMBER_BOX_FOR_ALIGNMENT:] + poly_box_areas = [poly_box_texts[id[1]] for id in poly_box_areas] + + # Calculate angle + list_angle = [self._cal_angle(poly_box) for poly_box in poly_box_areas] + list_angle = [angle if angle >= MIN_ANGLE else 180 for angle in list_angle] + + # LOGGER.info(f"List angle before reject outlier: {list_angle}") + list_angle = np.array(list_angle) + list_index, list_angle = self._reject_outliers(list_angle) + # LOGGER.info(f"List angle after reject outlier: {list_angle}") + + if len(list_angle): + + frequent_angle = self._get_most_frequent(list_angle) + list_angle = [angle for angle in list_angle if abs(angle - frequent_angle) <= 45] + # LOGGER.info(f"List angle after reject angle: {list_angle}") + angle = np.mean(list_angle) + else: + angle = 0 + + # LOGGER.info(f"Avg angle: {angle}") + + # Reuse poly boxes detected by text detection + polys_org = PolygonsOnImage( + [Polygon(poly_box_areas[index]) for index in list_index], + shape=image_resized.shape, + ) + seq_augment = iaa.Sequential([iaa.Rotate(angle, fit_output=True, order=3)]) + + # Rotate image by degree + if angle >= MIN_ANGLE and angle <= MAX_ANGLE: + image_resized, polys_aug = seq_augment( + image=image_resized, polygons=polys_org + ) + else: + angle = 0 + image_resized, polys_aug = image_resized, polys_org + + # cv2.imwrite("image_resized.jpg", image_resized) + + # Classify image 0 or 180 degree + list_poly = [poly.coords for poly in polys_aug] + + image_crop_list = [ + dewarp_by_polygon(image_resized, poly)[0] for poly in list_poly + ] + + cls_res = self.text_cls(image_crop_list) + cls_labels = [cls_[0] for cls_ in cls_res[1]] + # LOGGER.info(f"Angle lines: {cls_labels}") + counter = collections.Counter(cls_labels) + + angle_align = angle + if counter["0"] <= counter["180"]: + aug = iaa.Rotate(angle + 180, fit_output=True, order=3) + angle_align = angle + 180 + else: + aug = iaa.Rotate(angle, fit_output=True, order=3) + + # Rotate the image by degree + image = aug.augment_image(image) + + return image, is_blank, angle_align + # return image diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/config.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/config.py new file mode 100644 index 0000000..204c2c0 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/config.py @@ -0,0 +1,41 @@ +import yaml +import pprint +import os +import json + + +def load_from_yaml(fname): + with open(fname, encoding='utf-8') as f: + base_config = yaml.safe_load(f) + return base_config + +def load_from_json(fname): + with open(fname, "r", encoding='utf-8') as f: + base_config = json.load(f) + return base_config + +class Cfg(dict): + def __init__(self, config_dict): + super(Cfg, self).__init__(**config_dict) + self.__dict__ = self + + @staticmethod + def load_config_from_file(fname, download_base=False): + if not os.path.exists(fname): + raise FileNotFoundError("Not found config at {}".format(fname)) + if fname.endswith(".yaml") or fname.endswith(".yml"): + return Cfg(load_from_yaml(fname)) + elif fname.endswith(".json"): + return Cfg(load_from_json(fname)) + else: + raise Exception(f"{fname} not supported") + + + def save(self, fname): + with open(fname, 'w', encoding='utf-8') as outfile: + yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True) + + # @property + def pretty_text(self): + return pprint.PrettyPrinter().pprint(self) + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/factory.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/factory.py new file mode 100644 index 0000000..65e4bbd --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/factory.py @@ -0,0 +1,75 @@ +import os +import shutil +import hashlib +import warnings + +def sha256sum(filename): + h = hashlib.sha256() + b = bytearray(128*1024) + mv = memoryview(b) + with open(filename, 'rb', buffering=0) as f: + for n in iter(lambda : f.readinto(mv), 0): + h.update(mv[:n]) + return h.hexdigest() + + +online_model_factory = { + 'yolox-s-general-text-pretrain-20221226': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/62j266xm8r.pth', + 'hash': '89bff792685af454d0cfea5d6d673be6914d614e4c2044e786da6eddf36f8b50'}, + 'yolox-s-checkbox-20220726': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/1647d7eys7.pth', + 'hash': '7c1e188b7375dcf0b7b9d317675ebd92a86fdc29363558002249867249ee10f8'}, + 'yolox-s-idcard-5c-20221027': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/jr0egad3ix.pth', + 'hash': '73a7772594c1f6d3f6d6a98b6d6e4097af5026864e3bd50531ad9e635ae795a7'}, + 'yolox-s-handwritten-text-line-20230228': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/rb07rtwmgi.pth', + 'hash': 'a31d1bf8fc880479d2e11463dad0b4081952a13e553a02919109b634a1190ef1'} +} + +__hub_available_versions__ = online_model_factory.keys() + +def _get_from_hub(file_path, version, version_url): + os.system(f'wget -O {file_path} {version_url}') + assert os.path.exists(file_path), \ + 'wget failed while trying to retrieve from hub.' + downloaded_hash = sha256sum(file_path) + if downloaded_hash != online_model_factory[version]['hash']: + os.remove(file_path) + raise ValueError('sha256 hash doesnt match for version retrieved from hub.') + +def _get(version): + use_online = version in __hub_available_versions__ + + if not use_online and not os.path.exists(version): + raise ValueError(f'Model version {version} not found online and not found local.') + + hub_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hub') + if not os.path.exists(hub_path): + os.makedirs(hub_path) + if use_online: + version_url = online_model_factory[version]['url'] + file_path = os.path.join(hub_path, os.path.basename(version_url)) + else: + file_path = os.path.join(hub_path, os.path.basename(version)) + + if not os.path.exists(file_path): + if use_online: + _get_from_hub(file_path, version, version_url) + else: + shutil.copy2(version, file_path) + else: + if use_online: + downloaded_hash = sha256sum(file_path) + if downloaded_hash != online_model_factory[version]['hash']: + os.remove(file_path) + warnings.warn('existing hub version sha256 hash doesnt match, now re-download from hub.') + _get_from_hub(file_path, version, version_url) + else: + if sha256sum(file_path) != sha256sum(version): + os.remove(file_path) + warnings.warn('existing local version sha256 hash doesnt match, now replace with new local version.') + shutil.copy2(version, file_path) + + return \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/models.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/models.py new file mode 100644 index 0000000..64cb88c --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/models.py @@ -0,0 +1,73 @@ + +from paddleocr.tools.infer.predict_det import TextDetector +from paddleocr.tools.infer.predict_cls import TextClassifier +from paddleocr.paddleocr import parse_args +from sdsv_dewarp.config import Cfg + +class PaddleTextDetector(object): + def __init__( + self, + # config_path: str, + config: dict, + use_gpu=False + ): + # config = Cfg.load_config_from_file(config_path) + + self.args = parse_args(mMain=False) + self.args.__dict__.update( + det_model_dir=config['model_dir'], + gpu_mem=config['gpu_mem'], + use_gpu=use_gpu, + use_zero_copy_run=True, + max_batch_size=1, + det_limit_side_len=config['det_limit_side_len'], #960 + det_limit_type=config['det_limit_type'], #'max' + det_db_unclip_ratio=config['det_db_unclip_ratio'], + det_db_thresh=config['det_db_thresh'], + det_db_box_thresh=config['det_db_box_thresh'], + det_db_score_mode=config['det_db_score_mode'], + ) + self.text_detector = TextDetector(self.args) + + def __call__(self, image): + """ + + Args: + image (np.ndarray): BGR images + + Returns: + np.ndarray: numpy array of poly boxes - shape 4x2 + """ + dt_boxes, time_infer = self.text_detector(image) + return dt_boxes + + +class PaddleTextClassifier(object): + def __init__( + self, + # config_path: str, + config: str, + use_gpu=False + ): + # config = Cfg.load_config_from_file(config_path) + + self.args = parse_args(mMain=False) + self.args.__dict__.update( + cls_model_dir=config['model_dir'], + gpu_mem=config['gpu_mem'], + use_gpu=use_gpu, + use_zero_copy_run=True, + cls_batch_num=config['max_batch_size'], + ) + self.text_classifier = TextClassifier(self.args) + + def __call__(self, images): + """ + Args: + images (np.ndarray): list of BGR images + + Returns: + img_list, cls_res, elapse : cls_res format = (label, conf) + """ + out= self.text_classifier(images) + return out \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/utils.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/utils.py new file mode 100644 index 0000000..ae02cc1 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/utils.py @@ -0,0 +1,212 @@ +import math +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import random + + +def distance(p1, p2): + """Calculate Euclid distance""" + x1, y1 = p1 + x2, y2 = p2 + dist = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + + return dist + + +def crop_image(image, crop_size=1280): + """Crop center image""" + h, w = image.shape[:2] + x_center, y_center = w // 2, h // 2 + half_size = crop_size // 2 + + xmin, ymin = x_center - half_size, y_center - half_size + xmax, ymax = x_center + half_size, y_center + half_size + + xmin = max(xmin, 0) + ymin = max(ymin, 0) + xmax = min(xmax, w) + ymax = min(ymax, h) + + return image[ymin:ymax, xmin:xmax] + + +def _closest_point(corners, A): + """Find closest A in corrers point""" + distances = [distance(A, p) for p in corners] + return corners[np.argmin(distances)] + + +def _re_order_corners(image_size, corners) -> list: + """Order by corners by clockwise angle""" + h, w = image_size + tl = _closest_point(corners, (0, 0)) + tr = _closest_point(corners, (w, 0)) + br = _closest_point(corners, (w, h)) + bl = _closest_point(corners, (0, h)) + + return [tl, tr, br, bl] + + +def _validate_corner(corners, ratio_thres=0.5, epsilon=1e-3) -> bool: + """Check corners is valid + Invalid: 3 points, duplicate points, .... + """ + c_tl, c_tr, c_br, c_bl = corners + e_top = distance(c_tl, c_tr) + e_right = distance(c_tr, c_br) + e_bottom = distance(c_br, c_bl) + e_left = distance(c_bl, c_tl) + + min_tb = min(e_top, e_bottom) + max_tb = max(e_top, e_bottom) + min_lr = min(e_left, e_right) + max_lr = max(e_left, e_right) + + # Nếu các điểm trùng nhau thì độ dài các cạnh sẽ bằng 0 + if min(max_tb, max_lr) < epsilon: + return False + + ratio = min(min_tb / max_tb, min_lr / max_lr) + if ratio < ratio_thres: + return False + + return True + + +def dewarp_by_polygon( + image, corners, need_validate=False, need_reorder=True, trace_trans=None +): + """Crop and dewarp from 4 corners of images + + Args: + image (np.array) + corners (list): Ex : [(3347, 512), (3379, 2427), (638, 2524), (647, 495)] + need_validate (bool, optional): validate 4 points. Defaults to False. + need_reorder (bool, optional): validate 4 points. Defaults to True. + + Returns: + dewarped: image after dewarp + corners: location of 4 corners after reorder + """ + h, w = image.shape[:2] + + if need_reorder: + corners = _re_order_corners((h, w), corners) + + dewarped = image + + if need_validate: + validate = _validate_corner(corners) + else: + validate = True + + if validate: + # perform dewarp + target_w = int( + max(distance(corners[0], corners[1]), distance(corners[2], corners[3])) + ) + target_h = int( + max(distance(corners[0], corners[3]), distance(corners[1], corners[2])) + ) + target_corners = [ + [0, 0], + [target_w, 0], + [target_w, target_h], + [0, target_h], + ] + + pts1 = np.float32(corners) + pts2 = np.float32(target_corners) + transform_matrix = cv2.getPerspectiveTransform(pts1, pts2) + + dewarped = cv2.warpPerspective(image, transform_matrix, (target_w, target_h)) + if trace_trans is not None: + trace_trans["dewarp_method"]["polygon"][ + "transform_matrix" + ] = transform_matrix + + return (dewarped, corners, trace_trans) + + +def vis_ocr(image, boxes, txts=[], scores=None, drop_score=0.5): + """ + Args: + image (np.ndarray / PIL): BGR image or PIL image + boxes (list / np.ndarray): list of polygon boxes + txts (list): list of text labels + scores (list, optional): probality. Defaults to None. + drop_score (float, optional): . Defaults to 0.5. + font_path (str, optional): Path of font. Defaults to "test/fonts/latin.ttf". + Returns: + np.ndarray: BGR image + """ + + if len(txts) == 0: + txts = [""] * len(boxes) + + if isinstance(image, np.ndarray): + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + if isinstance(boxes, list): + boxes = np.array(boxes) + + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + fill=color, + ) + draw_right.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + outline=color, + ) + box_height = math.sqrt( + (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2 + ) + box_width = math.sqrt( + (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2 + ) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.load_default() + cur_y = box[0][1] + for c in txt: + char_size = font.getsize(c) + draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) + cur_y += char_size[1] + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.load_default() + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + img_show = cv2.cvtColor(np.array(img_show), cv2.COLOR_RGB2BGR) + return img_show diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/version.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/version.py new file mode 100644 index 0000000..a1570ac --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/sdsv_dewarp/version.py @@ -0,0 +1 @@ +__version__="1.0.0" \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/setup.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/setup.py new file mode 100644 index 0000000..7887e8f --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/setup.py @@ -0,0 +1,187 @@ +import os +import os.path as osp +import shutil +import sys +import warnings +from setuptools import find_packages, setup + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +version_file = 'sdsv_dewarp/version.py' +is_windows = sys.platform == 'win32' + + +def add_mim_extention(): + """Add extra files that are required to support MIM into the package. + + These files will be added by creating a symlink to the originals if the + package is installed in `editable` mode (e.g. pip install -e .), or by + copying from the originals otherwise. + """ + + # parse installment mode + if 'develop' in sys.argv: + # installed by `pip install -e .` + mode = 'symlink' + elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + # installed by `pip install .` + # or create source distribution by `python setup.py sdist` + mode = 'copy' + else: + return + + filenames = ['tools', 'configs', 'model-index.yml'] + repo_path = osp.dirname(__file__) + mim_path = osp.join(repo_path, 'mmocr', '.mim') + os.makedirs(mim_path, exist_ok=True) + + for filename in filenames: + if osp.exists(filename): + src_path = osp.join(repo_path, filename) + tar_path = osp.join(mim_path, filename) + + if osp.isfile(tar_path) or osp.islink(tar_path): + os.remove(tar_path) + elif osp.isdir(tar_path): + shutil.rmtree(tar_path) + + if mode == 'symlink': + src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) + try: + os.symlink(src_relpath, tar_path) + except OSError: + # Creating a symbolic link on windows may raise an + # `OSError: [WinError 1314]` due to privilege. If + # the error happens, the src file will be copied + mode = 'copy' + warnings.warn( + f'Failed to create a symbolic link for {src_relpath}, ' + f'and it will be copied to {tar_path}') + else: + continue + + if mode == 'copy': + if osp.isfile(src_path): + shutil.copyfile(src_path, tar_path) + elif osp.isdir(src_path): + shutil.copytree(src_path, tar_path) + else: + warnings.warn(f'Cannot copy file {src_path}.') + else: + raise ValueError(f'Invalid mode {mode}') + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + import sys + + # return short version for sdist + if 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + return locals()['short_version'] + else: + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strip + specific version information. + + Args: + fname (str): Path to requirements file. + with_version (bool, default=False): If True, include version specs. + Returns: + info (list[str]): List of requirements items. + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +if __name__ == '__main__': + setup( + name='sdsv_dewarp', + version=get_version(), + description='Dewarp document', + long_description=readme(), + long_description_content_type='text/markdown', + packages=find_packages(exclude=('configs', 'tools', 'demo')), + include_package_data=True, + url='', + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + ], + license='Apache License 2.0', + install_requires=parse_requirements('requirements.txt'), + zip_safe=False) diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/test.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/test.py new file mode 100644 index 0000000..4bfdbc7 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/externals/sdsv_dewarp/test.py @@ -0,0 +1,47 @@ +from sdsv_dewarp.api import AlignImage +import cv2 +import glob +import os +import tqdm +import time +import argparse + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input") + parser.add_argument("--out") + parser.add_argument("--device", type=str, default="cuda:1") + + args = parser.parse_args() + model = AlignImage(device=args.device) + + + img_dir = args.input + out_dir = args.out + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + img_paths = glob.glob(img_dir + "/*") + + times = [] + for img_path in tqdm.tqdm(img_paths): + t1 = time.time() + img = cv2.imread(img_path) + if img is None: + print(img_path) + continue + + aligned_img, is_blank, angle_align = model(img) + + times.append(time.time() - t1) + + if not is_blank: + cv2.imwrite(os.path.join(out_dir, os.path.basename(img_path)), aligned_img) + else: + cv2.imwrite(os.path.join(out_dir, os.path.basename(img_path)), img) + + + times = times[1:] + print("Avg time: ", sum(times) / len(times)) \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/requirements.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/requirements.txt new file mode 100644 index 0000000..b247e52 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/requirements.txt @@ -0,0 +1,82 @@ +addict==2.4.0 +asttokens==2.2.1 +autopep8==1.6.0 +backcall==0.2.0 +backports.functools-lru-cache==1.6.4 +brotlipy==0.7.0 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==2.0.4 +click==8.1.3 +colorama==0.4.6 +cryptography==39.0.1 +debugpy==1.5.1 +decorator==5.1.1 +docopt==0.6.2 +entrypoints==0.4 +executing==1.2.0 +flit_core==3.6.0 +idna==3.4 +importlib-metadata==6.0.0 +ipykernel==6.15.0 +ipython==8.11.0 +jedi==0.18.2 +jupyter-client==7.0.6 +jupyter_core==4.12.0 +Markdown==3.4.1 +markdown-it-py==2.2.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mkl-fft==1.3.1 +mkl-random==1.2.2 +mkl-service==2.4.0 +mmcv-full==1.7.1 +model-index==0.1.11 +nest-asyncio==1.5.6 +numpy==1.23.5 +opencv-python==4.7.0.72 +openmim==0.3.6 +ordered-set==4.1.0 +packaging==23.0 +pandas==1.5.3 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.4.0 +pip==22.3.1 +pipdeptree==2.5.2 +prompt-toolkit==3.0.38 +psutil==5.9.0 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycodestyle==2.10.0 +pycparser==2.21 +Pygments==2.14.0 +pyOpenSSL==23.0.0 +PySocks==1.7.1 +python-dateutil==2.8.2 +pytz==2022.7.1 +PyYAML==6.0 +pyzmq==19.0.2 +requests==2.28.1 +rich==13.3.1 +sdsvtd==0.1.1 +sdsvtr==0.0.5 +setuptools==65.6.3 +Shapely==1.8.4 +six==1.16.0 +stack-data==0.6.2 +tabulate==0.9.0 +toml==0.10.2 +torch==1.13.1 +torchvision==0.14.1 +tornado==6.1 +tqdm==4.65.0 +traitlets==5.9.0 +typing_extensions==4.4.0 +urllib3==1.26.14 +wcwidth==0.2.6 +wheel==0.38.4 +yapf==0.32.0 +yarg==0.1.9 +zipp==3.15.0 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/run.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/run.py new file mode 100644 index 0000000..3b19879 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/run.py @@ -0,0 +1,200 @@ +""" +see scripts/run_ocr.sh to run +""" +# from pathlib import Path # add parent path to run debugger +# import sys +# FILE = Path(__file__).absolute() +# sys.path.append(FILE.parents[2].as_posix()) + + +from src.utils import construct_file_path, ImageReader +from src.dto import Line +from src.ocr import OcrEngine +import argparse +import tqdm +import pandas as pd +from pathlib import Path +import json +import os +import numpy as np +from typing import Union, Tuple, List, Optional +from collections import defaultdict + +current_dir = os.getcwd() + + +def get_args(): + parser = argparse.ArgumentParser() + # parser image + parser.add_argument( + "--image", + type=str, + required=True, + help="path to input image/directory/csv file", + ) + parser.add_argument( + "--save_dir", type=str, required=True, help="path to save directory" + ) + parser.add_argument( + "--include", type=str, nargs="+", default=[], help="files/folders to include" + ) + parser.add_argument( + "--exclude", type=str, nargs="+", default=[], help="files/folders to exclude" + ) + parser.add_argument( + "--base_dir", + type=str, + required=False, + default=current_dir, + help="used when --image and --save_dir are relative paths to a base directory, default to current directory", + ) + parser.add_argument( + "--export_csv", + type=str, + required=False, + default="", + help="used when --image is a directory. If set, a csv file contains image_path, ocr_path and label will be exported to save_dir.", + ) + parser.add_argument( + "--export_img", + type=bool, + required=False, + default=False, + help="whether to save the visualize img", + ) + parser.add_argument("--ocr_kwargs", type=str, required=False, default="") + opt = parser.parse_args() + return opt + + +def load_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {} + engine = OcrEngine(**kw) + print("[INFO] Engine loaded") + return engine + + +def convert_relative_path_to_positive_path(tgt_dir: Path, base_dir: Path) -> Path: + return tgt_dir if tgt_dir.is_absolute() else base_dir.joinpath(tgt_dir) + + +def get_paths_from_opt(opt) -> Tuple[Path, Path]: + # BC\ kiem\ tra\ y\ te -> BC kiem tra y te + img_path = opt.image.replace("\\ ", " ").strip() + save_dir = opt.save_dir.replace("\\ ", " ").strip() + base_dir = opt.base_dir.replace("\\ ", " ").strip() + input_image = convert_relative_path_to_positive_path(Path(img_path), Path(base_dir)) + save_dir = convert_relative_path_to_positive_path(Path(save_dir), Path(base_dir)) + if not save_dir.exists(): + save_dir.mkdir() + print("[INFO]: Creating folder ", save_dir) + return input_image, save_dir + + +def process_img( + img: Union[str, np.ndarray], + save_dir_or_path: str, + engine: OcrEngine, + export_img: bool, + save_path_deskew: Optional[str] = None, +) -> None: + save_dir_or_path = Path(save_dir_or_path) + if isinstance(img, np.ndarray): + if save_dir_or_path.is_dir(): + raise ValueError("numpy array input require a save path, not a save dir") + page = engine(img) + save_path = ( + str(save_dir_or_path.joinpath(Path(img).stem + ".txt")) + if save_dir_or_path.is_dir() + else str(save_dir_or_path) + ) + page.write_to_file("word", save_path) + if export_img: + page.save_img( + save_path.replace(".txt", ".jpg"), + is_vnese=True, + save_path_deskew=save_path_deskew, + ) + + +def process_dir( + dir_path: str, + save_dir: str, + engine: OcrEngine, + export_img: bool, + lexcludes: List[str] = [], + lincludes: List[str] = [], + ddata=defaultdict(list), +) -> None: + pdir_path = Path(dir_path) + print(pdir_path) + # save_dir_sub = Path(construct_file_path(save_dir, dir_path, ext="")) + psave_dir = Path(save_dir) + psave_dir.mkdir(exist_ok=True) + for img_path in (pbar := tqdm.tqdm(pdir_path.iterdir())): + pbar.set_description(f"Processing {pdir_path}") + if (lincludes and img_path.name not in lincludes) or ( + img_path.name in lexcludes + ): + continue # only process desired files/foders + if img_path.is_dir(): + psave_dir_sub = psave_dir.joinpath(img_path.stem) + process_dir(img_path, str(psave_dir_sub), engine, ddata) + elif img_path.suffix.lower() in ImageReader.supported_ext: + simg_path = str(img_path) + # try: + img = ( + ImageReader.read(simg_path) + if img_path.suffix != ".pdf" + else ImageReader.read(simg_path)[0] + ) + save_path = str(Path(psave_dir).joinpath(img_path.stem + ".txt")) + save_path_deskew = str( + Path(psave_dir).joinpath(img_path.stem + "_deskewed.jpg") + ) + process_img(img, save_path, engine, export_img, save_path_deskew) + # except Exception as e: + # print('[ERROR]: ', e, ' at ', simg_path) + # continue + ddata["img_path"].append(simg_path) + ddata["ocr_path"].append(save_path) + if Path(save_path_deskew).exists(): + ddata["save_path_deskew"].append(save_path) + ddata["label"].append(pdir_path.stem) + # ddata.update({"img_path": img_path, "save_path": save_path, "label": dir_path.stem}) + return ddata + + +def process_csv(csv_path: str, engine: OcrEngine) -> None: + df = pd.read_csv(csv_path) + if not "image_path" in df.columns or not "ocr_path" in df.columns: + raise AssertionError("Cannot fing image_path in df headers") + for row in df.iterrows(): + process_img(row.image_path, row.ocr_path, engine) + + +if __name__ == "__main__": + opt = get_args() + engine = load_engine(opt) + print("[INFO]: OCR engine settings:", engine.settings) + img, save_dir = get_paths_from_opt(opt) + + lskip_dir = [] + if img.is_dir(): + ddata = process_dir( + img, save_dir, engine, opt.export_img, opt.exclude, opt.include + ) + if opt.export_csv: + pd.DataFrame.from_dict(ddata).to_csv( + Path(save_dir).joinpath(opt.export_csv) + ) + elif img.suffix in ImageReader.supported_ext: + process_img(str(img), save_dir, engine, opt.export_img) + elif img.suffix == ".csv": + print( + "[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used" + ) + process_csv(img, engine) + else: + raise NotImplementedError("[ERROR]: Unsupported file {}".format(img)) diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_deskew.sh b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_deskew.sh new file mode 100644 index 0000000..34ecd8c --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_deskew.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=1 +# export PATH=/usr/local/cuda-11.6/bin${PATH:+:${PATH}} +# export LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64\ {LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +# export CUDA_HOME=/usr/local/cuda-11.6 +# export PATH=/usr/local/cuda-11.6/bin:$PATH +# export CPATH=/usr/local/cuda-11.6/include:$CPATH +# export LIBRARY_PATH=/usr/local/cuda-11.6/lib64:$LIBRARY_PATH +# export LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/extras/CUPTI/lib64:$LD_LIBRARY_PATH +python test/test_deskew_dir.py \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_ocr.sh b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_ocr.sh new file mode 100644 index 0000000..8ade432 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/scripts/run_ocr.sh @@ -0,0 +1,49 @@ + + +#bash scripts/run_ocr.sh -i /mnt/hdd2T/AICR/Projects/2023/FWD/Forms/PDFs/ -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr -e out.csv -k "{\"device\":\"cuda:1\"}" -p True -n Passport 'So\ HK' +#bash scripts/run_ocr.sh -i '/mnt/hdd2T/AICR/Projects/2023/FWD/Forms/PDFs/So\ HK' -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr -e out.csv -k "{\"device\":\"cuda:1\"}" -p True +#-n and -x do not accept multiple argument currently + + +# bash scripts/run_ocr.sh -i /mnt/hdd4T/OCR/hoangdc/End_to_end/ICDAR2013/data/images_receipt_5images/ -o visualize/ -e out.csv -k "{\"device\":\"cuda:1\"}" -p True + +export PYTHONWARNINGS="ignore" + +while getopts i:o:b:e:p:k:n:x: flag +do + case "${flag}" in + i) img=${OPTARG};; + o) out_dir=${OPTARG};; + b) base_dir=${OPTARG};; + e) export_csv=${OPTARG};; + p) export_img=${OPTARG};; + k) ocr_kwargs=${OPTARG};; + n) include=("${OPTARG[@]}");; + x) exclude=("${OPTARG[@]}");; + esac +done + +cmd="python run.py \ + --image $img \ + --save_dir $out_dir \ + --export_csv $export_csv \ + --export_img $export_img \ + --ocr_kwargs $ocr_kwargs" + +if [ ${#include[@]} -gt 0 ]; then + cmd+=" --include" + for item in "${include[@]}"; do + cmd+=" $item" + done +fi + +if [ ${#exclude[@]} -gt 0 ]; then + cmd+=" --exclude" + for item in "${exclude[@]}"; do + cmd+=" $item" + done +fi + + +echo $cmd +exec $cmd diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/settings.yml b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/settings.yml new file mode 100644 index 0000000..9107bb1 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/settings.yml @@ -0,0 +1,35 @@ +device: &device cuda:0 +max_img_size: [1920,1920] #text det default size: 1280x1280 #[] = originla size, TODO: fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size +extend_bbox: [0, 0.0, 0.0, 0.0] # left, top, right, bottom +batch_size: 1 #1 means batch_mode = False +detector: + # version: /mnt/hdd2T/datnt/datnt_from_ssd1T/mmdetection/wild_receipt_finetune_weights_c_lite.pth + version: /mnt/hdd4T/OCR/datnt/mmdetection/logs/textdet-baseline-Oct04-wildreceiptv4-sdsapv1-mcocr-ssreceipt/epoch_100_params.pth + auto_rotate: True + rotator_version: /mnt/hdd2T/datnt/datnt_from_ssd1T/mmdetection/logs/textdet-with-rotate-20230317/best_bbox_mAP_epoch_30_lite.pth + device: *device + +recognizer: + version: satrn-lite-general-pretrain-20230106 + max_seq_len_overwrite: 24 #default = 12 + return_confident: True + device: *device +#extend the bbox to avoid losing accent mark in vietnames, if using ocr for only english, disable it + +deskew: + enable: True + text_detector: + config: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/config/det.yaml + weight: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_PP-OCRv3_det_infer + text_cls: + config: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/config/cls.yaml + weight: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_ppocr_mobile_v2.0_cls_infer + device: *device + + +words_to_lines: + gradient: 0.6 + max_x_dist: 20 + max_running_y_shift_degree: 10 #degrees + y_overlap_threshold: 0.5 + word_formation_mode: line diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/dto.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/dto.py new file mode 100644 index 0000000..8dae901 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/dto.py @@ -0,0 +1,534 @@ +import numpy as np +from typing import Optional, List, Union +import cv2 +from PIL import Image +from pathlib import Path +from .utils import visualize_bbox_and_label + + +class Box: + def __init__( + self, x1: int, y1: int, x2: int, y2: int, conf: float = -1.0, label: str = "" + ): + self._x1 = x1 + self._y1 = y1 + self._x2 = x2 + self._y2 = y2 + self._conf = conf + self._label = label + + def __repr__(self) -> str: + return str(self.bbox) + + def __str__(self) -> str: + return str(self.bbox) + + def get(self, return_confidence=False) -> Union[list[int], list[Union[float, int]]]: + return self.bbox if not return_confidence else self.xyxyc + + def __getitem__(self, key): + return self.bbox[key] + + @property + def width(self): + return max(self._x2 - self._x1, -1) + + @property + def height(self): + return max(self._y2 - self._y1, -1) + + @property + def bbox(self) -> list[int]: + return [self._x1, self._y1, self._x2, self._y2] + + @bbox.setter + def bbox(self, bbox_: list[int]): + self._x1, self._y1, self._x2, self._y2 = bbox_ + + @property + def xyxyc(self) -> list[Union[float, int]]: + return [self._x1, self._y1, self._x2, self._y2, self._conf] + + @staticmethod + def normalize_bbox(bbox: list[int]) -> list[int]: + return [int(b) for b in bbox] + + def to_int(self): + self._x1, self._y1, self._x2, self._y2 = self.normalize_bbox( + [self._x1, self._y1, self._x2, self._y2] + ) + return self + + @staticmethod + def clamp_bbox_by_img_wh(bbox: list, width: int, height: int) -> list[int]: + x1, y1, x2, y2 = bbox + x1 = min(max(0, x1), width) + x2 = min(max(0, x2), width) + y1 = min(max(0, y1), height) + y2 = min(max(0, y2), height) + return [x1, y1, x2, y2] + + def clamp_by_img_wh(self, width: int, height: int): + self._x1, self._y1, self._x2, self._y2 = self.clamp_bbox_by_img_wh( + [self._x1, self._y1, self._x2, self._y2], width, height + ) + return self + + @staticmethod + def extend_bbox(bbox: list, margin: list): # -> Self (python3.11) + margin_l, margin_t, margin_r, margin_b = margin + l, t, r, b = bbox # left, top, right, bottom + t = t - (b - t) * margin_t + b = b + (b - t) * margin_b + l = l - (r - l) * margin_l + r = r + (r - l) * margin_r + return [l, t, r, b] + + def get_extend_bbox(self, margin: list): + extended_bbox = self.extend_bbox(self.bbox, margin) + return Box(*extended_bbox, label=self._label) + + @staticmethod + def bbox_is_valid(bbox: list[int]) -> bool: + if bbox == [-1, -1, -1, -1]: + raise ValueError("Empty bounding box found") + l, t, r, b = bbox # left, top, right, bottom + return True if (b - t) * (r - l) > 0 else False + + def is_valid(self) -> bool: + return self.bbox_is_valid(self.bbox) + + @staticmethod + def crop_img_by_bbox(img: np.ndarray, bbox: list) -> np.ndarray: + l, t, r, b = bbox + return img[t:b, l:r] + + def crop_img(self, img: np.ndarray) -> np.ndarray: + return self.crop_img_by_bbox(img, self.bbox) + + +class Word: + def __init__( + self, + image=None, + text="", + conf_cls=-1.0, + bbox_obj: Box = Box(-1, -1, -1, -1), + conf_detect=-1.0, + kie_label="", + ): + # self.type = "word" + self._text = text + self._image = image + self._conf_det = conf_detect + self._conf_cls = conf_cls + # [left, top,right,bot] coordinate of top-left and bottom-right point + self._bbox_obj = bbox_obj + # self.word_id = 0 # id of word + # self.word_group_id = 0 # id of word_group which instance belongs to + # self.line_id = 0 # id of line which instance belongs to + # self.paragraph_id = 0 # id of line which instance belongs to + self._kie_label = kie_label + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def height(self): + return self._bbox_obj.height + + @property + def width(self): + return self._bbox_obj.width + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + def is_valid(self) -> bool: + return self._bbox_obj.is_valid() + + # def is_special_word(self): + # if not self._text: + # raise ValueError("Cannot validatie size of empty bounding box") + + # # if len(text) > 7: + # # return True + # if len(self._text) >= 7: + # no_digits = sum(c.isdigit() for c in text) + # return no_digits / len(text) >= 0.3 + + # return False + + +class WordGroup: + def __init__( + self, + list_words: List[Word] = list(), + text: str = "", + boundingbox: Box = Box(-1, -1, -1, -1), + conf_cls: float = -1, + conf_det: float = -1, + ): + # self.type = "word_group" + self._list_words = list_words # dict of word instances + # self.word_group_id = 0 # word group id + # self.line_id = 0 # id of line which instance belongs to + # self.paragraph_id = 0 # id of paragraph which instance belongs to + self._text = text + self._bbox_obj = boundingbox + self._kie_label = "" + self._conf_cls = conf_cls + self._conf_det = conf_det + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def list_words(self) -> list[Word]: + return self._list_words + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + # def add_word(self, word: Word): # add a word instance to the word_group + # if word._text != "✪": + # for w in self._list_words: + # if word.word_id == w.word_id: + # print("Word id collision") + # return False + # word.word_group_id = self.word_group_id # + # word.line_id = self.line_id + # word.paragraph_id = self.paragraph_id + # self._list_words.append(word) + # self._text += " " + word._text + # if self.bbox_obj == [-1, -1, -1, -1]: + # self.bbox_obj = word._bbox_obj + # else: + # self.bbox_obj = [ + # min(self.bbox_obj[0], word._bbox_obj[0]), + # min(self.bbox_obj[1], word._bbox_obj[1]), + # max(self.bbox_obj[2], word._bbox_obj[2]), + # max(self.bbox_obj[3], word._bbox_obj[3]), + # ] + # return True + # else: + # return False + + # def update_word_group_id(self, new_word_group_id): + # self.word_group_id = new_word_group_id + # for i in range(len(self._list_words)): + # self._list_words[i].word_group_id = new_word_group_id + + # def update_kie_label(self): + # list_kie_label = [word._kie_label for word in self._list_words] + # dict_kie = dict() + # for label in list_kie_label: + # if label not in dict_kie: + # dict_kie[label] = 1 + # else: + # dict_kie[label] += 1 + # total = len(list(dict_kie.values())) + # max_value = max(list(dict_kie.values())) + # list_keys = list(dict_kie.keys()) + # list_values = list(dict_kie.values()) + # self.kie_label = list_keys[list_values.index(max_value)] + + # def update_text(self): # update text after changing positions of words in list word + # text = "" + # for word in self._list_words: + # text += " " + word._text + # self._text = text + + +class Line: + def __init__( + self, + list_word_groups: List[WordGroup] = [], + text: str = "", + boundingbox: Box = Box(-1, -1, -1, -1), + conf_cls: float = -1, + conf_det: float = -1, + ): + # self.type = "line" + self._list_word_groups = ( + list_word_groups # list of Word_group instances in the line + ) + # self.line_id = 0 # id of line in the paragraph + # self.paragraph_id = 0 # id of paragraph which instance belongs to + self._text = text + self._bbox_obj = boundingbox + self._conf_cls = conf_cls + self._conf_det = conf_det + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def list_word_groups(self) -> List[WordGroup]: + return self._list_word_groups + + @property + def list_words(self) -> list[Word]: + return [ + word + for word_group in self._list_word_groups + for word in word_group.list_words + ] + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + # def add_group(self, word_group: WordGroup): # add a word_group instance + # if word_group._list_words is not None: + # for wg in self.list_word_groups: + # if word_group.word_group_id == wg.word_group_id: + # print("Word_group id collision") + # return False + + # self.list_word_groups.append(word_group) + # self.text += word_group._text + # word_group.paragraph_id = self.paragraph_id + # word_group.line_id = self.line_id + + # for i in range(len(word_group._list_words)): + # word_group._list_words[ + # i + # ].paragraph_id = self.paragraph_id # set paragraph_id for word + # word_group._list_words[i].line_id = self.line_id # set line_id for word + # return True + # return False + + # def update_line_id(self, new_line_id): + # self.line_id = new_line_id + # for i in range(len(self.list_word_groups)): + # self.list_word_groups[i].line_id = new_line_id + # for j in range(len(self.list_word_groups[i]._list_words)): + # self.list_word_groups[i]._list_words[j].line_id = new_line_id + + # def merge_word(self, word): # word can be a Word instance or a Word_group instance + # if word.text != "✪": + # if self.boundingbox == [-1, -1, -1, -1]: + # self.boundingbox = word.boundingbox + # else: + # self.boundingbox = [ + # min(self.boundingbox[0], word.boundingbox[0]), + # min(self.boundingbox[1], word.boundingbox[1]), + # max(self.boundingbox[2], word.boundingbox[2]), + # max(self.boundingbox[3], word.boundingbox[3]), + # ] + # self.list_word_groups.append(word) + # self.text += " " + word.text + # return True + # return False + + # def __cal_ratio(self, top1, bottom1, top2, bottom2): + # sorted_vals = sorted([top1, bottom1, top2, bottom2]) + # intersection = sorted_vals[2] - sorted_vals[1] + # min_height = min(bottom1 - top1, bottom2 - top2) + # if min_height == 0: + # return -1 + # ratio = intersection / min_height + # return ratio + + # def __cal_ratio_height(self, top1, bottom1, top2, bottom2): + + # height1, height2 = top1 - bottom1, top2 - bottom2 + # ratio_height = float(max(height1, height2)) / float(min(height1, height2)) + # return ratio_height + + # def in_same_line(self, input_line, thresh=0.7): + # # calculate iou in vertical direction + # _, top1, _, bottom1 = self.boundingbox + # _, top2, _, bottom2 = input_line.boundingbox + + # ratio = self.__cal_ratio(top1, bottom1, top2, bottom2) + # ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2) + + # if ( + # (top2 <= top1 <= bottom2) or (top1 <= top2 <= bottom1) + # and ratio >= thresh + # and (ratio_height < 2) + # ): + # return True + # return False + + +# class Paragraph: +# def __init__(self, id=0, lines=None): +# self.list_lines = lines if lines is not None else [] # list of all lines in the paragraph +# self.paragraph_id = id # index of paragraph in the ist of paragraph +# self.text = "" +# self.boundingbox = [-1, -1, -1, -1] + +# @property +# def bbox(self): +# return self.boundingbox + +# def __repr__(self) -> str: +# return self.text + +# def __str__(self) -> str: +# return self.text + +# def add_line(self, line: Line): # add a line instance +# if line.list_word_groups is not None: +# for l in self.list_lines: +# if line.line_id == l.line_id: +# print("Line id collision") +# return False +# for i in range(len(line.list_word_groups)): +# line.list_word_groups[ +# i +# ].paragraph_id = ( +# self.paragraph_id +# ) # set paragraph id for every word group in line +# for j in range(len(line.list_word_groups[i]._list_words)): +# line.list_word_groups[i]._list_words[ +# j +# ].paragraph_id = ( +# self.paragraph_id +# ) # set paragraph id for every word in word groups +# line.paragraph_id = self.paragraph_id # set paragraph id for line +# self.list_lines.append(line) # add line to paragraph +# self.text += " " + line.text +# return True +# else: +# return False + +# def update_paragraph_id( +# self, new_paragraph_id +# ): # update new paragraph_id for all lines, word_groups, words inside paragraph +# self.paragraph_id = new_paragraph_id +# for i in range(len(self.list_lines)): +# self.list_lines[ +# i +# ].paragraph_id = new_paragraph_id # set new paragraph_id for line +# for j in range(len(self.list_lines[i].list_word_groups)): +# self.list_lines[i].list_word_groups[ +# j +# ].paragraph_id = new_paragraph_id # set new paragraph_id for word_group +# for k in range(len(self.list_lines[i].list_word_groups[j].list_words)): +# self.list_lines[i].list_word_groups[j].list_words[ +# k +# ].paragraph_id = new_paragraph_id # set new paragraph id for word +# return True + + +class Page: + def __init__( + self, + word_segments: Union[List[WordGroup], List[Line]], + image: np.ndarray, + deskewed_image: Optional[np.ndarray] = None, + ) -> None: + self._word_segments = word_segments + self._image = image + self._deskewed_image = deskewed_image + self._drawed_image: Optional[np.ndarray] = None + + @property + def word_segments(self): + return self._word_segments + + @property + def list_words(self) -> list[Word]: + return [ + word + for word_segment in self._word_segments + for word in word_segment.list_words + ] + + @property + def image(self): + return self._image + + @property + def PIL_image(self): + return Image.fromarray(self._image) + + @property + def drawed_image(self): + return self._drawed_image + + @property + def deskewed_image(self): + return self._deskewed_image + + def visualize_bbox_and_label(self, **kwargs: dict): + if self._drawed_image is not None: + return self._drawed_image + bboxes = list() + texts = list() + for word in self.list_words: + bboxes.append([int(float(b)) for b in word.bbox]) + texts.append(word._text) + img = visualize_bbox_and_label( + self._deskewed_image if self._deskewed_image is not None else self._image, + bboxes, + texts, + **kwargs + ) + self._drawed_image = img + return self._drawed_image + + def save_img(self, save_path: str, **kwargs: dict) -> None: + save_path_deskew = kwargs.pop("save_path_deskew", Path(save_path).with_stem(Path(save_path).stem + "_deskewed").as_posix()) + if self._deskewed_image is not None: + # save_path_deskew: str = kwargs.pop("save_path_deskew", Path(save_path).with_stem(Path(save_path).stem + "_deskewed").as_posix()) # type: ignore + cv2.imwrite(save_path_deskew, self._deskewed_image) + + img = self.visualize_bbox_and_label(**kwargs) + cv2.imwrite(save_path, img) + + + def write_to_file(self, mode: str, save_path: str) -> None: + f = open(save_path, "w+", encoding="utf-8") + for word_segment in self._word_segments: + if mode == "segment": + xmin, ymin, xmax, ymax = word_segment.bbox + f.write( + "{}\t{}\t{}\t{}\t{}\n".format( + xmin, ymin, xmax, ymax, word_segment._text + ) + ) + elif mode == "word": + for word in word_segment.list_words: + # xmin, ymin, xmax, ymax = word.bbox + xmin, ymin, xmax, ymax = [int(float(b)) for b in word.bbox] + f.write( + "{}\t{}\t{}\t{}\t{}\n".format( + xmin, ymin, xmax, ymax, word._text + ) + ) + else: + raise NotImplementedError("Unknown mode: {}".format(mode)) + f.close() + + +class Document: + def __init__(self, lpages: List[Page]) -> None: + self.lpages = lpages diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/ocr.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/ocr.py new file mode 100644 index 0000000..280d0b2 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/ocr.py @@ -0,0 +1,258 @@ +from typing import Union, overload, List, Optional, Tuple +from PIL import Image +import torch +import numpy as np +import yaml +from pathlib import Path +import mmcv +from sdsvtd import StandaloneYOLOXRunner +from sdsvtr import StandaloneSATRNRunner +from sdsv_dewarp.api import AlignImage + +from .utils import ImageReader, chunks, Timer, post_process_recog # rotate_bbox + +# from .utils import jdeskew as deskew +# from externals.deskew.sdsv_dewarp import pdeskew as deskew +# from .utils import deskew +from .dto import Word, Line, Page, Document, Box, WordGroup + +# from .word_formation import words_to_lines as words_to_lines +# from .word_formation import wo rds_to_lines_mmocr as words_to_lines +from .word_formation import words_formation_mmocr_tesseract as word_formation + +DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml" + + +class OcrEngine: + def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs): + """Warper of text detection and text recognition + :param settings_file: path to default setting file + :param kwargs: keyword arguments to overwrite the default settings file + """ + with open(settings_file) as f: + # use safe_load instead load + self._settings = yaml.safe_load(f) + self._update_configs(kwargs) + + self._ensure_device() + self._detector = StandaloneYOLOXRunner(**self._settings["detector"]) + self._recognizer = StandaloneSATRNRunner(**self._settings["recognizer"]) + self._deskewer = self._load_deskewer() + + def _update_configs(self, params): + for key, para in params.items(): # overwrite default settings by keyword arguments + if key not in self._settings: + raise ValueError("Invalid setting found in OcrEngine: ", k) + if key == "device": + self._settings[key] = para + self._settings["detector"][key] = para + self._settings["recognizer"][key] = para + self._settings["deskew"][key] = para + else: + for k, v in para.items(): + if isinstance(v, dict): + for sub_key, sub_value in v.items(): + self._settings[key][k][sub_key] = sub_value + else: + self._settings[key][k] = v + + def _load_deskewer(self) -> Optional[AlignImage]: + if self._settings["deskew"]["enable"]: + deskewer = AlignImage( + **{k: v for k, v in self._settings["deskew"].items() if k != "enable"} + ) + print( + "[WARNING]: Deskew is enabled. The bounding boxes prediction may not be aligned with the original image. In case of using these predictions for pseudo-label, turn on save_deskewed option and use the saved deskewed images instead for further proceed." + ) + return deskewer + return None + + def _ensure_device(self): + if "cuda" in self._settings["device"]: + if not torch.cuda.is_available(): + print("[WARNING]: CUDA is not available, running with cpu instead") + self._settings["device"] = "cpu" + + @property + def version(self): + return { + "detector": self._settings["detector"], + "recognizer": self._settings["recognizer"], + } + + @property + def settings(self): + return self._settings + + # @staticmethod + # def xyxyc_to_xyxy_c(xyxyc: np.ndarray) -> Tuple[List[list], list]: + # ''' + # convert sdsvtd yoloX detection output to list of bboxes and list of confidences + # @param xyxyc: array of shape (n, 5) + # ''' + # xyxy = xyxyc[:, :4].tolist() + # confs = xyxyc[:, 4].tolist() + # return xyxy, confs + # -> Tuple[np.ndarray, List[Box]]: + + def preprocess(self, img: np.ndarray) -> tuple[np.ndarray, bool, float]: + img_ = img.copy() + if self._settings["max_img_size"]: + img_ = mmcv.imrescale( + img, + tuple(self._settings["max_img_size"]), + return_scale=False, + interpolation="bilinear", + backend="cv2", + ) + is_blank = False + if self._deskewer: + with Timer("deskew"): + img_, is_blank, angle = self._deskewer(img_) + return img, is_blank, angle # replace img_ to img + # for i, bbox in enumerate(bboxes): + # rotated_bbox = rotate_bbox(bbox, angle, img.shape[:2]) + # bboxes[i].bbox = rotated_bbox + return img, is_blank, 0 + + def run_detect( + self, img: np.ndarray, return_raw: bool = False + ) -> Tuple[np.ndarray, Union[List[Box], List[list]]]: + """ + run text detection and return list of xyxyc if return_confidence is True, otherwise return a list of xyxy + """ + pred_det = self._detector(img) + if self._settings["detector"]["auto_rotate"]: + img, pred_det = pred_det + pred_det = pred_det[0] # only image at a time + return ( + (img, pred_det.tolist()) + if return_raw + else (img, [Box(*xyxyc) for xyxyc in pred_det.tolist()]) + ) + + def run_recog( + self, imgs: List[np.ndarray] + ) -> Union[List[str], List[Tuple[str, float]]]: + if len(imgs) == 0: + return list() + pred_rec = self._recognizer(imgs) + return [ + (post_process_recog(word), conf) + for word, conf in zip(pred_rec[0], pred_rec[1]) + ] + + def read_img(self, img: str) -> np.ndarray: + return ImageReader.read(img) + + def get_cropped_imgs( + self, img: np.ndarray, bboxes: Union[List[Box], List[list]] + ) -> Tuple[List[np.ndarray], List[bool]]: + """ + img: np image + bboxes: list of xyxy + """ + lcropped_imgs = list() + mask = list() + for bbox in bboxes: + bbox = Box(*bbox) if isinstance(bbox, list) else bbox + bbox = bbox.get_extend_bbox(self._settings["extend_bbox"]) + + bbox.clamp_by_img_wh(img.shape[1], img.shape[0]) + bbox.to_int() + if not bbox.is_valid(): + mask.append(False) + continue + cropped_img = bbox.crop_img(img) + lcropped_imgs.append(cropped_img) + mask.append(True) + return lcropped_imgs, mask + + def read_page( + self, img: np.ndarray, bboxes: Union[List[Box], List[list]] + ) -> Union[List[WordGroup], List[Line]]: + if len(bboxes) == 0: # no bbox found + return list() + with Timer("cropped imgs"): + lcropped_imgs, mask = self.get_cropped_imgs(img, bboxes) + with Timer("recog"): + # batch_mode for efficiency + pred_recs = self.run_recog(lcropped_imgs) + with Timer("construct words"): + lwords = list() + for i in range(len(pred_recs)): + if not mask[i]: + continue + text, conf_rec = pred_recs[i][0], pred_recs[i][1] + bbox = Box(*bboxes[i]) if isinstance(bboxes[i], list) else bboxes[i] + lwords.append( + Word( + image=img, + text=text, + conf_cls=conf_rec, + bbox_obj=bbox, + conf_detect=bbox._conf, + ) + ) + with Timer("word formation"): + return word_formation( + lwords, img.shape[1], **self._settings["words_to_lines"] + )[0] + + # https://stackoverflow.com/questions/48127642/incompatible-types-in-assignment-on-union + + @overload + def __call__(self, img: Union[str, np.ndarray, Image.Image]) -> Page: + ... + + @overload + def __call__(self, img: List[Union[str, np.ndarray, Image.Image]]) -> Document: + ... + + def __call__(self, img): # type: ignore #ignoring type before implementing batch_mode + """ + Accept an image or list of them, return ocr result as a page or document + """ + with Timer("read image"): + img = ImageReader.read(img) + if self._settings["batch_size"] == 1: + if isinstance(img, list): + if len(img) == 1: + img = img[0] # in case input type is a 1 page pdf + else: + raise AssertionError( + "list input can only be used with batch_mode enabled" + ) + img_deskewed, is_blank, angle = self.preprocess(img) + + if is_blank: + print( + "[WARNING]: Blank image detected" + ) # TODO: should we stop the execution here? + with Timer("detect"): + img_deskewed, bboxes = self.run_detect(img_deskewed) + with Timer("read_page"): + lsegments = self.read_page(img_deskewed, bboxes) + return Page(lsegments, img, img_deskewed if angle != 0 else None) + else: + # lpages = [] + # # chunks to reduce memory footprint + # for imgs in chunks(img, self._batch_size): + # # pred_dets = self._detector(imgs) + # # TEMP: use list comprehension because sdsvtd do not support batch mode of text detection + # img = self.preprocess(img) + # img, bboxes = self.run_detect(img) + # for img_, bboxes_ in zip(imgs, bboxes): + # llines = self.read_page(img, bboxes_) + # page = Page(llines, img) + # lpages.append(page) + # return Document(lpages) + raise NotImplementedError("Batch mode is currently not supported") + + +if __name__ == "__main__": + img_path = "/mnt/ssd1T/hungbnt/Cello/data/PH/Sea7/Sea_7_1.jpg" + engine = OcrEngine(device="cuda:0") + # https://stackoverflow.com/questions/66435480/overload-following-optional-argument + page = engine(img_path) # type: ignore + print(page._word_segments) diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/utils.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/utils.py new file mode 100644 index 0000000..3c4332e --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/utils.py @@ -0,0 +1,369 @@ +from PIL import ImageFont, ImageDraw, Image, ImageOps +# import matplotlib.pyplot as plt +import numpy as np +import cv2 +import os +import time +from typing import Generator, Union, List, overload, Tuple, Callable +import glob +import math +from pathlib import Path +from pdf2image import convert_from_path +# from deskew import determine_skew +# from jdeskew.estimator import get_angle +# from jdeskew.utility import rotate as jrotate + + +def post_process_recog(text: str) -> str: + text = text.replace("✪", " ") + return text + + +def find_maximum_without_outliers(lst: list[int], threshold: float = 1.): + ''' + To find the maximum number in a list while excluding its outlier values, you can follow these steps: + Determine the range within which you consider values as outliers. This can be based on a specific threshold or a statistical measure such as the interquartile range (IQR). + Iterate through the list and filter out the outlier values based on the defined range. Keep track of the non-outlier values. + Find the maximum value among the non-outlier values. + ''' + # Calculate the lower and upper boundaries for outliers + q1 = np.percentile(lst, 25) + q3 = np.percentile(lst, 75) + iqr = q3 - q1 + lower_bound = q1 - threshold * iqr + upper_bound = q3 + threshold * iqr + + # Filter out outlier values + non_outliers = [x for x in lst if lower_bound <= x <= upper_bound] + + # Find the maximum value among non-outliers + max_value = max(non_outliers) + + return max_value + + +class Timer: + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, func: Callable, *args): + self.end_time = time.perf_counter() + self.elapsed_time = self.end_time - self.start_time + print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds") + + +# def rotate( +# image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]] +# ) -> np.ndarray: +# old_width, old_height = image.shape[:2] +# angle_radian = math.radians(angle) +# width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width) +# height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height) +# image_center = tuple(np.array(image.shape[1::-1]) / 2) +# rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) +# rot_mat[1, 2] += (width - old_width) / 2 +# rot_mat[0, 2] += (height - old_height) / 2 +# return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background) + + +# def rotate_bbox(bbox: list, angle: float) -> list: +# # Compute the center point of the bounding box +# cx = bbox[0] + bbox[2] / 2 +# cy = bbox[1] + bbox[3] / 2 + +# # Define the scale factor for the rotated bounding box +# scale = 1.0 # following the deskew and jdeskew function +# angle_radian = math.radians(angle) + +# # Obtain the rotation matrix using cv2.getRotationMatrix2D() +# M = cv2.getRotationMatrix2D((cx, cy), angle_radian, scale) + +# # Apply the rotation matrix to the four corners of the bounding box +# corners = np.array([[bbox[0], bbox[1]], +# [bbox[0] + bbox[2], bbox[1]], +# [bbox[0] + bbox[2], bbox[1] + bbox[3]], +# [bbox[0], bbox[1] + bbox[3]]], dtype=np.float32) +# rotated_corners = cv2.transform(np.array([corners]), M)[0] + +# # Compute the bounding box of the rotated corners +# x = int(np.min(rotated_corners[:, 0])) +# y = int(np.min(rotated_corners[:, 1])) +# w = int(np.max(rotated_corners[:, 0]) - np.min(rotated_corners[:, 0])) +# h = int(np.max(rotated_corners[:, 1]) - np.min(rotated_corners[:, 1])) +# rotated_bbox = [x, y, w, h] + +# return rotated_bbox + +# def rotate_bbox(bbox: List[int], angle: float, old_shape: Tuple[int, int]) -> List[int]: +# # https://medium.com/@pokomaru/image-and-bounding-box-rotation-using-opencv-python-2def6c39453 +# bbox_ = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]] +# h, w = old_shape +# cx, cy = (int(w / 2), int(h / 2)) + +# bbox_tuple = [ +# (bbox_[0], bbox_[1]), +# (bbox_[2], bbox_[3]), +# (bbox_[4], bbox_[5]), +# (bbox_[6], bbox_[7]), +# ] # put x and y coordinates in tuples, we will iterate through the tuples and perform rotation + +# rotated_bbox = [] + +# for i, coord in enumerate(bbox_tuple): +# M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) +# cos, sin = abs(M[0, 0]), abs(M[0, 1]) +# newW = int((h * sin) + (w * cos)) +# newH = int((h * cos) + (w * sin)) +# M[0, 2] += (newW / 2) - cx +# M[1, 2] += (newH / 2) - cy +# v = [coord[0], coord[1], 1] +# adjusted_coord = np.dot(M, v) +# rotated_bbox.insert(i, (adjusted_coord[0], adjusted_coord[1])) +# result = [int(x) for t in rotated_bbox for x in t] +# return [result[i] for i in [0, 1, 2, -1]] # reformat to xyxy + + +# def deskew(image: np.ndarray) -> Tuple[np.ndarray, float]: +# grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) +# angle = 0. +# try: +# angle = determine_skew(grayscale) +# except Exception: +# pass +# rotated = rotate(image, angle, (0, 0, 0)) if angle else image +# return rotated, angle + + +# def jdeskew(image: np.ndarray) -> Tuple[np.ndarray, float]: +# angle = 0. +# try: +# angle = get_angle(image) +# except Exception: +# pass +# # TODO: change resize = True and scale the bounding box +# rotated = jrotate(image, angle, resize=False) if angle else image +# return rotated, angle +# def deskew() + +class ImageReader: + """ + accept anything, return numpy array image + """ + supported_ext = [".png", ".jpg", ".jpeg", ".pdf", ".gif"] + + @staticmethod + def validate_img_path(img_path: str) -> None: + if not os.path.exists(img_path): + raise FileNotFoundError(img_path) + if os.path.isdir(img_path): + raise IsADirectoryError(img_path) + if not Path(img_path).suffix.lower() in ImageReader.supported_ext: + raise NotImplementedError("Not supported extension at {}".format(img_path)) + + @overload + @staticmethod + def read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: ... + + @overload + @staticmethod + def read(img: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: ... + + @overload + @staticmethod + def read(img: str) -> List[np.ndarray]: ... # for pdf or directory + + @staticmethod + def read(img): + if isinstance(img, list): + return ImageReader.from_list(img) + elif isinstance(img, str) and os.path.isdir(img): + return ImageReader.from_dir(img) + elif isinstance(img, str) and img.endswith(".pdf"): + return ImageReader.from_pdf(img) + else: + return ImageReader._read(img) + + @staticmethod + def from_dir(dir_path: str) -> List[np.ndarray]: + if os.path.isdir(dir_path): + image_files = glob.glob(os.path.join(dir_path, "*")) + return ImageReader.from_list(image_files) + else: + raise NotADirectoryError(dir_path) + + @staticmethod + def from_str(img_path: str) -> np.ndarray: + ImageReader.validate_img_path(img_path) + return ImageReader.from_PIL(Image.open(img_path)) + + @staticmethod + def from_np(img_array: np.ndarray) -> np.ndarray: + return img_array + + @staticmethod + def from_PIL(img_pil: Image.Image, transpose=True) -> np.ndarray: + # if img_pil.is_animated: + # raise NotImplementedError("Only static images are supported, animated image found") + if transpose: + img_pil = ImageOps.exif_transpose(img_pil) + if img_pil.mode != "RGB": + img_pil = img_pil.convert("RGB") + + return np.array(img_pil) + + @staticmethod + def from_list(img_list: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: + limgs = list() + for img_path in img_list: + try: + if isinstance(img_path, str): + ImageReader.validate_img_path(img_path) + limgs.append(ImageReader._read(img_path)) + except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e: + print("[ERROR]: ", e) + print("[INFO]: Skipping image {}".format(img_path)) + return limgs + + @staticmethod + def from_pdf(pdf_path: str, start_page: int = 0, end_page: int = 0) -> List[np.ndarray]: + pdf_file = convert_from_path(pdf_path) + if end_page is not None: + end_page = min(len(pdf_file), end_page + 1) + limgs = [np.array(pdf_page) for pdf_page in pdf_file[start_page:end_page]] + return limgs + + @staticmethod + def _read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: + if isinstance(img, str): + return ImageReader.from_str(img) + elif isinstance(img, Image.Image): + return ImageReader.from_PIL(img) + elif isinstance(img, np.ndarray): + return ImageReader.from_np(img) + else: + raise ValueError("Invalid img argument type: ", type(img)) + + +def get_name(file_path, ext: bool = True): + file_path_ = os.path.basename(file_path) + return file_path_ if ext else os.path.splitext(file_path_)[0] + + +def construct_file_path(dir, file_path, ext=''): + ''' + args: + dir: /path/to/dir + file_path /example_path/to/file.txt + ext = '.json' + return + /path/to/dir/file.json + ''' + return os.path.join( + dir, get_name(file_path, + True)) if ext == '' else os.path.join( + dir, get_name(file_path, + False)) + ext + + +def chunks(lst: list, n: int) -> Generator: + """ + Yield successive n-sized chunks from lst. + https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equally-sized-chunks + """ + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +def read_ocr_result_from_txt(file_path: str) -> Tuple[list, list]: + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + boxes, words = [], [] + for line in lines: + if line == "": + continue + x1, y1, x2, y2, text = line.split("\t") + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + + +def get_xyxywh_base_on_format(bbox, format): + if format == "xywh": + x1, y1, w, h = bbox[0], bbox[1], bbox[2], bbox[3] + x2, y2 = x1 + w, y1 + h + elif format == "xyxy": + x1, y1, x2, y2 = bbox + w, h = x2 - x1, y2 - y1 + else: + raise NotImplementedError("Invalid format {}".format(format)) + return (x1, y1, x2, y2, w, h) + + +def get_dynamic_params_for_bbox_of_label(text, x1, y1, w, h, img_h, img_w, font, font_scale_offset=1): + font_scale_factor = img_h / (img_w + img_h) * font_scale_offset + font_scale = w / (w + h) * font_scale_factor # adjust font scale by width height + thickness = int(font_scale_factor) + 1 + (text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0] + text_offset_x = x1 + text_offset_y = y1 - thickness + box_coords = ((text_offset_x, text_offset_y + 1), (text_offset_x + text_width - 2, text_offset_y - text_height - 2)) + return (font_scale, thickness, text_height, box_coords) + + +def visualize_bbox_and_label( + img, bboxes, texts, bbox_color=(200, 180, 60), + text_color=(0, 0, 0), + format="xyxy", is_vnese=False, draw_text=True): + ori_img_type = type(img) + if is_vnese: + img = Image.fromarray(img) if ori_img_type is np.ndarray else img + draw = ImageDraw.Draw(img) + img_w, img_h = img.size + font_pil_str = "fonts/arial.ttf" + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + else: + img_h, img_w = img.shape[0], img.shape[1] + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(bboxes)): + text = texts[i] # text = "{}: {:.0f}%".format(LABELS[classIDs[i]], confidences[i]*100) + x1, y1, x2, y2, w, h = get_xyxywh_base_on_format(bboxes[i], format) + font_scale, thickness, text_height, box_coords = get_dynamic_params_for_bbox_of_label( + text, x1, y1, w, h, img_h, img_w, font=font_cv2) + if is_vnese: + font_pil = ImageFont.truetype(font_pil_str, size=text_height) # type: ignore + fdraw_text = draw.text # type: ignore + fdraw_bbox = draw.rectangle # type: ignore + # Pil use different coordinate => y = y+thickness = y-thickness + 2*thickness + arg_text = ((box_coords[0][0], box_coords[1][1]), text) + kwarg_text = {"font": font_pil, "fill": text_color, "width": thickness} + arg_rec = ((x1, y1, x2, y2),) + kwarg_rec = {"outline": bbox_color, "width": thickness} + arg_rec_text = ((box_coords[0], box_coords[1]),) + kwarg_rec_text = {"fill": bbox_color, "width": thickness} + else: + # cv2.rectangle(img, box_coords[0], box_coords[1], color, cv2.FILLED) + # cv2.putText(img, text, (text_offset_x, text_offset_y), font, fontScale=font_scale, color=(50, 0,0), thickness=thickness) + # cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) + fdraw_text = cv2.putText + fdraw_bbox = cv2.rectangle + arg_text = (img, text, box_coords[0]) + kwarg_text = {"fontFace": font_cv2, "fontScale": font_scale, "color": text_color, "thickness": thickness} + arg_rec = (img, (x1, y1), (x2, y2)) + kwarg_rec = {"color": bbox_color, "thickness": thickness} + arg_rec_text = (img, box_coords[0], box_coords[1]) + kwarg_rec_text = {"color": bbox_color, "thickness": cv2.FILLED} + # draw a bounding box rectangle and label on the img + fdraw_bbox(*arg_rec, **kwarg_rec) # type: ignore + if draw_text: + fdraw_bbox(*arg_rec_text, **kwarg_rec_text) # type: ignore + fdraw_text(*arg_text, **kwarg_text) # type: ignore # text have to put in front of rec_text + return np.array(img) if ori_img_type is np.ndarray and is_vnese else img diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/word_formation.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/word_formation.py new file mode 100644 index 0000000..3e64b97 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/externals/basic_ocr/src/word_formation.py @@ -0,0 +1,903 @@ +from builtins import dict +from .dto import Word, Line, WordGroup, Box +from .utils import find_maximum_without_outliers +import numpy as np +from typing import Optional, List, Tuple, Union + +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ +### WORDS TO LINES ALGORITHMS FROM MMOCR AND TESSERACT ############################################################################################################################################################################### +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ + +DEGREE_TO_RADIAN_COEF = np.pi / 180 +MAX_INT = int(2e10 + 9) +MIN_INT = -MAX_INT + + +def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8): + """Check if two boxes are on the same line by their y-axis coordinates. + + Two boxes are on the same line if they overlap vertically, and the length + of the overlapping line segment is greater than min_y_overlap_ratio * the + height of either of the boxes. + + Args: + box_a (list), box_b (list): Two bounding boxes to be checked + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for boxes in the same line + + Returns: + The bool flag indicating if they are on the same line + """ + a_y_min = np.min(box_a[1::2]) + b_y_min = np.min(box_b[1::2]) + a_y_max = np.max(box_a[1::2]) + b_y_max = np.max(box_b[1::2]) + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def merge_bboxes_to_group(bboxes_group, x_sorted_boxes): + merged_bboxes = [] + for box_group in bboxes_group: + merged_box = {} + merged_box['text'] = ' '.join( + [x_sorted_boxes[idx]['text'] for idx in box_group]) + x_min, y_min = float('inf'), float('inf') + x_max, y_max = float('-inf'), float('-inf') + for idx in box_group: + x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max) + x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min) + y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max) + y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min) + merged_box['box'] = [ + x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max + ] + merged_box['list_words'] = [x_sorted_boxes[idx]['word'] + for idx in box_group] + merged_bboxes.append(merged_box) + return merged_bboxes + + +def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.3): + """Stitch fragmented boxes of words into lines. + + Note: part of its logic is inspired by @Johndirr + (https://github.com/faustomorales/keras-ocr/issues/22) + + Args: + boxes (list): List of ocr results to be stitched + max_x_dist (int): The maximum horizontal distance between the closest + edges of neighboring boxes in the same line + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for any pairs of neighboring boxes in the same line + + Returns: + merged_boxes(List[dict]): List of merged boxes and texts + """ + + if len(boxes) <= 1: + if len(boxes) == 1: + boxes[0]["list_words"] = [boxes[0]["word"]] + return boxes + + # merged_groups = [] + merged_lines = [] + + # sort groups based on the x_min coordinate of boxes + x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2])) + # store indexes of boxes which are already parts of other lines + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(x_sorted_boxes)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(x_sorted_boxes)): + if j in skip_idxs: + continue + if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'], + x_sorted_boxes[j]['box'], min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + # split line into lines if the distance between two neighboring + # sub-lines' is greater than max_x_dist + # groups = [] + # line_idx = 0 + # groups.append([line[0]]) + # for k in range(1, len(line)): + # curr_box = x_sorted_boxes[line[k]] + # prev_box = x_sorted_boxes[line[k - 1]] + # dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2]) + # if dist > max_x_dist: + # line_idx += 1 + # groups.append([]) + # groups[line_idx].append(line[k]) + + # # Get merged boxes + merged_line = merge_bboxes_to_group([line], x_sorted_boxes) + merged_lines.extend(merged_line) + # merged_group = merge_bboxes_to_group(groups,x_sorted_boxes) + # merged_groups.extend(merged_group) + + merged_lines = sorted(merged_lines, key=lambda x: np.min(x['box'][1::2])) + # merged_groups = sorted(merged_groups, key=lambda x: np.min(x['box'][1::2])) + return merged_lines # , merged_groups + +# REFERENCE +# https://vigneshgig.medium.com/bounding-box-sorting-algorithm-for-text-detection-and-object-detection-from-left-to-right-and-top-cf2c523c8a85 +# https://huggingface.co/spaces/tomofi/MMOCR/blame/main/mmocr/utils/box_util.py + + +def words_to_lines_mmocr(words: List[Word], *args) -> Tuple[List[Line], Optional[int]]: + bboxes = [{"box": [w.bbox[0], w.bbox[1], w.bbox[2], w.bbox[1], w.bbox[2], w.bbox[3], w.bbox[0], w.bbox[3]], + "text":w._text, "word":w} for w in words] + merged_lines = stitch_boxes_into_lines(bboxes) + merged_groups = merged_lines # TODO: fix code to return both word group and line + lwords_groups = [WordGroup(list_words=merged_box["list_words"], + text=merged_box["text"], + boundingbox=[merged_box["box"][i] for i in [0, 1, 2, -1]]) + for merged_box in merged_groups] + + llines = [Line(text=word_group._text, list_word_groups=[word_group], boundingbox=word_group._bbox_obj) + for word_group in lwords_groups] + + return llines, None # same format with the origin words_to_lines + # lines = [Line() for merged] + + +# def most_overlapping_row(rows, top, bottom, y_shift): +# max_overlap = -1 +# max_overlap_idx = -1 +# for i, row in enumerate(rows): +# row_top, row_bottom = row +# overlap = min(top + y_shift, row_top) - max(bottom + y_shift, row_bottom) +# if overlap > max_overlap: +# max_overlap = overlap +# max_overlap_idx = i +# return max_overlap_idx +def most_overlapping_row(rows, row_words, bottom, top, y_shift, max_row_size, y_overlap_threshold=0.5): + max_overlap = -1 + max_overlap_idx = -1 + overlapping_rows = [] + + for i, row in enumerate(rows): + row_bottom, row_top = row + overlap = min(bottom - y_shift[i], row_bottom) - \ + max(top - y_shift[i], row_top) + + if overlap > max_overlap: + max_overlap = overlap + max_overlap_idx = i + + # if at least overlap 1 pixel and not (overlap too much and overlap too little) + if (row_top <= bottom and row_bottom >= top) and not (bottom - top - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold): + overlapping_rows.append(i) + + # Merge overlapping rows if necessary + if len(overlapping_rows) > 1: + merge_bottom = max(rows[i][0] for i in overlapping_rows) + merge_top = min(rows[i][1] for i in overlapping_rows) + + if merge_bottom - merge_top <= max_row_size: + # Merge rows + merged_row = (merge_bottom, merge_top) + merged_words = [] + # Remove other overlapping rows + + for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2 + merged_words.extend(row_words[row_idx]) + del rows[row_idx] + del row_words[row_idx] + + rows[overlapping_rows[0]] = merged_row + row_words[overlapping_rows[0]].extend(merged_words[::-1]) + max_overlap_idx = overlapping_rows[0] + + if bottom - top - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold: + max_overlap_idx = -1 + return max_overlap_idx + + +def stitch_boxes_into_lines_tesseract(words: list[Word], max_running_y_shift: int, + gradient: float, y_overlap_threshold: float) -> Tuple[list[list[Word]], float]: + sorted_words = sorted(words, key=lambda x: x.bbox[0]) + rows = [] + row_words = [] + max_row_size = find_maximum_without_outliers([word.height for word in sorted_words]) + running_y_shift = [] + for _i, word in enumerate(sorted_words): + bbox, _text = word.bbox, word._text + _x1, y1, _x2, y2 = bbox + bottom, top = y2, y1 + max_row_size = max(max_row_size, bottom - top) + overlap_row_idx = most_overlapping_row( + rows, row_words, bottom, top, running_y_shift, max_row_size, y_overlap_threshold) + + if overlap_row_idx == -1: # No overlapping row found + new_row = (bottom, top) + rows.append(new_row) + row_words.append([word]) + running_y_shift.append(0) + else: # Overlapping row found + row_bottom, row_top = rows[overlap_row_idx] + new_bottom = max(row_bottom, bottom) + new_top = min(row_top, top) + rows[overlap_row_idx] = (new_bottom, new_top) + row_words[overlap_row_idx].append(word) + new_shift = (top + bottom) / 2 - (row_top + row_bottom) / 2 + running_y_shift[overlap_row_idx] = min( + gradient * running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift, max_running_y_shift) # update and clamp + + # Sort rows and row_texts based on the top y-coordinate + sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][1]) + _sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data) + # /_|<- the perpendicular line of the horizontal line and the skew line of the page + page_skew_dist = sum(running_y_shift) / len(running_y_shift) + return sorted_row_words, page_skew_dist + + +def construct_word_groups_tesseract(sorted_row_words: list[list[Word]], + max_x_dist: int, page_skew_dist: float) -> list[list[list[Word]]]: + # approximate page_skew_angle by page_skew_dist + corrected_max_x_dist = max_x_dist * abs(np.cos(page_skew_dist * DEGREE_TO_RADIAN_COEF)) + constructed_row_word_groups = [] + for row_words in sorted_row_words: + lword_groups = [] + line_idx = 0 + lword_groups.append([row_words[0]]) + for k in range(1, len(row_words)): + curr_box = row_words[k].bbox + prev_box = row_words[k - 1].bbox + dist = curr_box[0] - prev_box[2] + if dist > corrected_max_x_dist: + line_idx += 1 + lword_groups.append([]) + lword_groups[line_idx].append(row_words[k]) + constructed_row_word_groups.append(lword_groups) + return constructed_row_word_groups + + +def group_bbox_and_text(lwords: Union[list[Word], list[WordGroup]]) -> tuple[Box, tuple[str, float]]: + text = ' '.join([word._text for word in lwords]) + x_min, y_min = MAX_INT, MAX_INT + x_max, y_max = MIN_INT, MIN_INT + conf_det = 0 + conf_cls = 0 + for word in lwords: + x_max = int(max(np.max(word.bbox[::2]), x_max)) + x_min = int(min(np.min(word.bbox[::2]), x_min)) + y_max = int(max(np.max(word.bbox[1::2]), y_max)) + y_min = int(min(np.min(word.bbox[1::2]), y_min)) + conf_det += word._conf_det + conf_cls += word._conf_cls + bbox = Box(x_min, y_min, x_max, y_max, conf=conf_det / len(lwords)) + return bbox, (text, conf_cls / len(lwords)) + + +def words_to_lines_tesseract(words: List[Word], + page_width: int, max_running_y_shift_degree: int, gradient: float, max_x_dist: int, + y_overlap_threshold: float) -> Tuple[List[Line], + Optional[float]]: + max_running_y_shift = page_width * np.tan(max_running_y_shift_degree * DEGREE_TO_RADIAN_COEF) + sorted_row_words, page_skew_dist = stitch_boxes_into_lines_tesseract( + words, max_running_y_shift, gradient, y_overlap_threshold) + constructed_row_word_groups = construct_word_groups_tesseract( + sorted_row_words, max_x_dist, page_skew_dist) + llines = [] + for row in constructed_row_word_groups: + lwords_row = [] + lword_groups = [] + for word_group in row: + bbox_word_group, text_word_group = group_bbox_and_text(word_group) + lwords_row.extend(word_group) + lword_groups.append( + WordGroup( + list_words=word_group, text=text_word_group[0], + conf_cls=text_word_group[1], + boundingbox=bbox_word_group)) + bbox_line, text_line = group_bbox_and_text(lwords_row) + llines.append( + Line( + list_word_groups=lword_groups, text=text_line[0], + boundingbox=bbox_line, conf_cls=text_line[1])) + return llines, page_skew_dist + + + + +### WORDS TO WORDGROUPS ######################################################################################################################################################################################################################### + + +def merge_overlapping_word_groups( + rows: list[list[int]], + row_words: list[list[Word]], + overlapping_rows: list[int], + max_row_size: int) -> bool: + # Merge found overlapping rows if necessary + merge_top = max(rows[i][1] for i in overlapping_rows) + merge_bottom = min(rows[i][3] for i in overlapping_rows) + merge_left = min(rows[i][0] for i in overlapping_rows) + merge_right = max(rows[i][2] for i in overlapping_rows) + + if merge_top - merge_bottom <= max_row_size: + # Merge rows + merged_row = [merge_left, merge_top, merge_right, merge_bottom] + merged_words = [] + # Remove other overlapping rows + + for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2 + merged_words.extend(row_words[row_idx]) + del rows[row_idx] + del row_words[row_idx] + + rows[overlapping_rows[0]] = merged_row + row_words[overlapping_rows[0]].extend(merged_words[::-1]) + return True + return False + + +def most_overlapping_word_groups( + rows, row_words, curr_word_bbox, y_shift, max_row_size, y_overlap_threshold, max_x_dist): + max_overlap = -1 + max_overlap_idx = -1 + overlapping_rows = [] + left, top, right, bottom = curr_word_bbox + for i, row in enumerate(rows): + row_left, row_top, row_right, row_bottom = row + top_shift = top - y_shift[i] + bottom_shift = bottom - y_shift[i] + + # find the most overlapping row + overlap = min(bottom_shift, row_bottom) - max(top_shift, row_top) + if overlap > max_overlap and min(right - row_left, left - row_right) < max_x_dist: + max_overlap = overlap + max_overlap_idx = i + + # exclusive process to handle cases where there are multiple satisfying overlapping rows. For example some rows are not initially overlapping but as the appended words constantly get skewer, there is a change that the end of 1 row would reạch the beginning other row + # if (row_top <= bottom and row_bottom >= top) and not (bottom - top - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold): + if (row_top <= bottom_shift and row_bottom >= top_shift) \ + and min(right - row_left, left - row_right) < max_x_dist \ + and not (bottom - top - overlap > max_row_size * y_overlap_threshold) \ + and not (overlap < max_row_size * y_overlap_threshold): + # explain: + # (row_top <= bottom_shift and row_bottom >= top_shift) -> overlap at least 1 pixel + # not (bottom - top - overlap > max_row_size * y_overlap_threshold) -> curr_word is not too big too overlap (to exclude figures containing words) + # not (overlap < max_row_size * y_overlap_threshold) -> overlap too little should not be merged + # min(right - row_left, row_right - left) < max_x_dist -> either the curr_word is close enough to left or right of the curr_row + overlapping_rows.append(i) + + if len(overlapping_rows) > 1 and merge_overlapping_word_groups(rows, row_words, overlapping_rows, max_row_size): + max_overlap_idx = overlapping_rows[0] + if bottom - top - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold: + max_overlap_idx = -1 + return max_overlap_idx + + +def update_overlapping_word_group_bbox(rows: list[list[int]], overlap_row_idx: int, curr_word_bbox: list[int]) -> None: + left, top, right, bottom = curr_word_bbox + row_left, row_top, row_right, row_bottom = rows[overlap_row_idx] + new_bottom = max(row_bottom, bottom) + new_top = min(row_top, top) + new_left = min(row_left, left) + new_right = max(row_right, right) + rows[overlap_row_idx] = [new_left, new_top, new_right, new_bottom] + + +def update_word_group_running_y_shift( + running_y_shift: list[float], + overlap_row_idx: int, curr_row_bbox: list[int], + curr_word_bbox: list[int], + gradient: float, max_running_y_shift: float) -> None: + _, top, _, bottom = curr_word_bbox + _, row_top, _, row_bottom = curr_row_bbox + new_shift = (top + bottom) / 2 - (row_top + row_bottom) / 2 + running_y_shift[overlap_row_idx] = min( + gradient * running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift, max_running_y_shift) # update and clamp + + +def stitch_boxes_into_word_groups_tesseract(words: list[Word], + max_running_y_shift: int, gradient: float, y_overlap_threshold: float, + max_x_dist: int) -> Tuple[list[WordGroup], float]: + sorted_words = sorted(words, key=lambda x: x.bbox[0]) + rows = [] + row_words = [] + max_row_size = sorted_words[0].height + running_y_shift = [] + for word in sorted_words: + bbox: list[int] = word.bbox + max_row_size = max(max_row_size, bbox[3] - bbox[1]) + if bbox[-1] < 200 and word.text == "Nguyễn": + print("DEBUGING") + overlap_row_idx = most_overlapping_word_groups( + rows, row_words, bbox, running_y_shift, max_row_size, y_overlap_threshold, max_x_dist) + if overlap_row_idx == -1: # No overlapping row found + rows.append(bbox) # new row + row_words.append([word]) # new row_word + running_y_shift.append(0) + else: # Overlapping row found + # row_bottom, row_top = rows[overlap_row_idx] + update_overlapping_word_group_bbox(rows, overlap_row_idx, bbox) + row_words[overlap_row_idx].append(word) # update row_words + update_word_group_running_y_shift( + running_y_shift, overlap_row_idx, rows[overlap_row_idx], + bbox, gradient, max_running_y_shift) + + # Sort rows and row_texts based on the top y-coordinate + sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][1]) + _sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data) + lword_groups = [] + for word_group in sorted_row_words: + bbox_word_group, text_word_group = group_bbox_and_text(word_group) + lword_groups.append( + WordGroup( + list_words=word_group, text=text_word_group[0], + conf_cls=text_word_group[1], + boundingbox=bbox_word_group)) + # /_|<- the perpendicular line of the horizontal line and the skew line of the page + page_skew_dist = sum(running_y_shift) / len(running_y_shift) + + return lword_groups, page_skew_dist + + +def is_on_same_line_mmocr_tesseract(box_a: list[int], box_b: list[int], min_y_overlap_ratio: float) -> bool: + a_y_min = box_a[1] + b_y_min = box_b[1] + a_y_max = box_a[3] + b_y_max = box_b[3] + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def stitch_word_groups_into_lines_mmocr_tesseract( + lword_groups: list[WordGroup], + min_y_overlap_ratio: float) -> list[Line]: + merged_lines = [] + + # sort groups based on the x_min coordinate of boxes + # store indexes of boxes which are already parts of other lines + sorted_word_groups = sorted(lword_groups, key=lambda x: x.bbox[0]) + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(sorted_word_groups)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(sorted_word_groups)): + if j in skip_idxs: + continue + if is_on_same_line_mmocr_tesseract(sorted_word_groups[rightmost_box_idx].bbox, + sorted_word_groups[j].bbox, min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + lword_groups_in_line = [sorted_word_groups[k] for k in line] + bbox_line, text_line = group_bbox_and_text(lword_groups_in_line) + merged_lines.append( + Line( + list_word_groups=lword_groups_in_line, text=text_line[0], + conf_cls=text_line[1], + boundingbox=bbox_line)) + merged_lines = sorted(merged_lines, key=lambda x: x.bbox[1]) + return merged_lines + + +def words_formation_mmocr_tesseract(words: List[Word], page_width: int, word_formation_mode: str, max_running_y_shift_degree: int, gradient: float, + max_x_dist: int, y_overlap_threshold: float) -> Tuple[Union[List[WordGroup], list[Line]], + Optional[float]]: + if len(words) == 0: + return [], 0 + max_running_y_shift = page_width * np.tan(max_running_y_shift_degree * DEGREE_TO_RADIAN_COEF) + lword_groups, page_skew_dist = stitch_boxes_into_word_groups_tesseract( + words, max_running_y_shift, gradient, y_overlap_threshold, max_x_dist) + if word_formation_mode == "word_group": + return lword_groups, page_skew_dist + elif word_formation_mode == "line": + llines = stitch_word_groups_into_lines_mmocr_tesseract(lword_groups, y_overlap_threshold) + return llines, page_skew_dist + else: + raise NotImplementedError("Word formation mode not supported: {}".format(word_formation_mode)) + +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ +### END WORDS TO LINES ALGORITHMS FROM MMOCR AND TESSERACT ############################################################################################################################################################################### +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ + +# MIN_IOU_HEIGHT = 0.7 +# MIN_WIDTH_LINE_RATIO = 0.05 + + +# def resize_to_original( +# boundingbox, scale +# ): # resize coordinates to match size of original image +# left, top, right, bottom = boundingbox +# left *= scale[1] +# right *= scale[1] +# top *= scale[0] +# bottom *= scale[0] +# return [left, top, right, bottom] + + +# def check_iomin(word: Word, word_group: Word_group): +# min_height = min( +# word.boundingbox[3] - word.boundingbox[1], +# word_group.boundingbox[3] - word_group.boundingbox[1], +# ) +# intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max( +# word.boundingbox[1], word_group.boundingbox[1] +# ) +# if intersect / min_height > 0.7: +# return True +# return False + + +# def prepare_line(words): +# lines = [] +# visited = [False] * len(words) +# for id_word, word in enumerate(words): +# if word.invalid_size() == 0: +# continue +# new_line = True +# for i in range(len(lines)): +# if ( +# lines[i].in_same_line(word) and not visited[id_word] +# ): # check if word is in the same line with lines[i] +# lines[i].merge_word(word) +# new_line = False +# visited[id_word] = True + +# if new_line == True: +# new_line = Line() +# new_line.merge_word(word) +# lines.append(new_line) + +# # print(len(lines)) +# # sort line from top to bottom according top coordinate +# lines.sort(key=lambda x: x.boundingbox[1]) +# return lines + + +# def __create_word_group(word, word_group_id): +# new_word_group_ = Word_group() +# new_word_group_.list_words = list() +# new_word_group_.word_group_id = word_group_id +# new_word_group_.add_word(word) + +# return new_word_group_ + + +# def __sort_line(line): +# line.list_word_groups.sort( +# key=lambda x: x.boundingbox[0] +# ) # sort word in lines from left to right + +# return line + + +# def __merge_text_for_line(line): +# line.text = "" +# for word in line.list_word_groups: +# line.text += " " + word.text + +# return line + + +# def __update_list_word_groups(line, word_group_id, word_id, line_width): + +# old_list_word_group = line.list_word_groups +# list_word_groups = [] + +# inital_word_group = __create_word_group( +# old_list_word_group[0], word_group_id) +# old_list_word_group[0].word_id = word_id +# list_word_groups.append(inital_word_group) +# word_group_id += 1 +# word_id += 1 + +# for word in old_list_word_group[1:]: +# check_word_group = True +# word.word_id = word_id +# word_id += 1 + +# if ( +# (not list_word_groups[-1].text.endswith(":")) +# and ( +# (word.boundingbox[0] - list_word_groups[-1].boundingbox[2]) +# / line_width +# < MIN_WIDTH_LINE_RATIO +# ) +# and check_iomin(word, list_word_groups[-1]) +# ): +# list_word_groups[-1].add_word(word) +# check_word_group = False + +# if check_word_group: +# new_word_group = __create_word_group(word, word_group_id) +# list_word_groups.append(new_word_group) +# word_group_id += 1 +# line.list_word_groups = list_word_groups +# return line, word_group_id, word_id + + +# def construct_word_groups_in_each_line(lines): +# line_id = 0 +# word_group_id = 0 +# word_id = 0 +# for i in range(len(lines)): +# if len(lines[i].list_word_groups) == 0: +# continue + +# # left, top ,right, bottom +# line_width = lines[i].boundingbox[2] - \ +# lines[i].boundingbox[0] # right - left +# line_width = 1 # TODO: to remove +# lines[i] = __sort_line(lines[i]) + +# # update text for lines after sorting +# lines[i] = __merge_text_for_line(lines[i]) + +# lines[i], word_group_id, word_id = __update_list_word_groups( +# lines[i], +# word_group_id, +# word_id, +# line_width) +# lines[i].update_line_id(line_id) +# line_id += 1 +# return lines + + +# def words_to_lines(words, check_special_lines=True): # words is list of Word instance +# # sort word by top +# words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0])) +# # words.sort(key=lambda x: (sum(x.bbox))) +# number_of_word = len(words) +# # print(number_of_word) +# # sort list words to list lines, which have not contained word_group yet +# lines = prepare_line(words) + +# # construct word_groups in each line +# lines = construct_word_groups_in_each_line(lines) +# return lines, number_of_word + + +# def near(word_group1: Word_group, word_group2: Word_group): +# min_height = min( +# word_group1.boundingbox[3] - word_group1.boundingbox[1], +# word_group2.boundingbox[3] - word_group2.boundingbox[1], +# ) +# overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max( +# word_group1.boundingbox[1], word_group2.boundingbox[1] +# ) + +# if overlap > 0: +# return True +# if abs(overlap / min_height) < 1.5: +# print("near enough", abs(overlap / min_height), overlap, min_height) +# return True +# return False + + +# def calculate_iou_and_near(wg1: Word_group, wg2: Word_group): +# min_height = min( +# wg1.boundingbox[3] - +# wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1] +# ) +# overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max( +# wg1.boundingbox[1], wg2.boundingbox[1] +# ) +# iou = overlap / min_height +# distance = min( +# abs(wg1.boundingbox[0] - wg2.boundingbox[2]), +# abs(wg1.boundingbox[2] - wg2.boundingbox[0]), +# ) +# if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]): +# return True +# return False + + +# def construct_word_groups_to_kie_label(list_word_groups: list): +# kie_dict = dict() +# for wg in list_word_groups: +# if wg.kie_label == "other": +# continue +# if wg.kie_label not in kie_dict: +# kie_dict[wg.kie_label] = [wg] +# else: +# kie_dict[wg.kie_label].append(wg) + +# new_dict = dict() +# for key, value in kie_dict.items(): +# if len(value) == 1: +# new_dict[key] = value +# continue + +# value.sort(key=lambda x: x.boundingbox[1]) +# new_dict[key] = value +# return new_dict + + +# def invoice_construct_word_groups_to_kie_label(list_word_groups: list): +# kie_dict = dict() + +# for wg in list_word_groups: +# if wg.kie_label == "other": +# continue +# if wg.kie_label not in kie_dict: +# kie_dict[wg.kie_label] = [wg] +# else: +# kie_dict[wg.kie_label].append(wg) + +# return kie_dict + + +# def postprocess_total_value(kie_dict): +# if "total_in_words_value" not in kie_dict: +# return kie_dict + +# for k, value in kie_dict.items(): +# if k == "total_in_words_value": +# continue +# l = [] +# for v in value: +# if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]: +# l.append(v) + +# if len(l) != 0: +# kie_dict[k] = l + +# return kie_dict + + +# def postprocess_tax_code_value(kie_dict): +# if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict: +# return kie_dict + +# kie_dict["buyer_tax_code_value"] = [] +# for v in kie_dict["seller_tax_code_value"]: +# if "buyer_name_key" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_key"][0]) +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# continue + +# if "buyer_name_value" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_value"][0]) +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# continue + +# if "buyer_address_value" in kie_dict and near( +# kie_dict["buyer_address_value"][0], v +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# return kie_dict + + +# def postprocess_tax_code_key(kie_dict): +# if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict: +# return kie_dict +# kie_dict["buyer_tax_code_key"] = [] +# for v in kie_dict["seller_tax_code_key"]: +# if "buyer_name_key" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_key"][0]) +# ): +# kie_dict["buyer_tax_code_key"].append(v) +# continue + +# if "buyer_name_value" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_value"][0]) +# ): +# kie_dict["buyer_tax_code_key"].append(v) +# continue + +# if "buyer_address_value" in kie_dict and near( +# kie_dict["buyer_address_value"][0], v +# ): +# kie_dict["buyer_tax_code_key"].append(v) + +# return kie_dict + + +# def invoice_postprocess(kie_dict: dict): +# # all keys or values which are below total_in_words_value will be thrown away +# kie_dict = postprocess_total_value(kie_dict) +# kie_dict = postprocess_tax_code_value(kie_dict) +# kie_dict = postprocess_tax_code_key(kie_dict) +# return kie_dict + + +# def throw_overlapping_words(list_words): +# new_list = [list_words[0]] +# for word in list_words: +# overlap = False +# area = (word.boundingbox[2] - word.boundingbox[0]) * ( +# word.boundingbox[3] - word.boundingbox[1] +# ) +# for word2 in new_list: +# area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * ( +# word2.boundingbox[3] - word2.boundingbox[1] +# ) +# xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0]) +# xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2]) +# ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1]) +# ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3]) +# if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: +# continue + +# area_intersect = (xmax_intersect - xmin_intersect) * ( +# ymax_intersect - ymin_intersect +# ) +# if area_intersect / area > 0.7 or area_intersect / area2 > 0.7: +# overlap = True +# if overlap == False: +# new_list.append(word) +# return new_list + + +# def check_iou(box1: Word, box2: Box, threshold=0.9): +# area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * ( +# box1.boundingbox[3] - box1.boundingbox[1] +# ) +# area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) +# xmin_intersect = max(box1.boundingbox[0], box2.xmin) +# ymin_intersect = max(box1.boundingbox[1], box2.ymin) +# xmax_intersect = min(box1.boundingbox[2], box2.xmax) +# ymax_intersect = min(box1.boundingbox[3], box2.ymax) +# if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: +# area_intersect = 0 +# else: +# area_intersect = (xmax_intersect - xmin_intersect) * ( +# ymax_intersect - ymin_intersect +# ) +# union = area1 + area2 - area_intersect +# iou = area_intersect / union +# if iou > threshold: +# return True +# return False diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/main.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/main.py new file mode 100644 index 0000000..db5f3b6 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/main.py @@ -0,0 +1,147 @@ +import os +import glob +import cv2 +import json +import argparse +import numpy as np +from tqdm import tqdm +from PIL import Image +from datetime import datetime +from sdsvkvu.sources.kvu import KVUEngine +from sdsvkvu.sources.utils import export_kvu_outputs, export_sbt_outputs, draw_kvu_outputs +from sdsvkvu.utils.utils import create_dir, write_to_json, pdf2img +from sdsvkvu.utils.query.vat import export_kvu_for_VAT_invoice, merged_kvu_for_VAT_invoice_for_multi_pages +from sdsvkvu.utils.query.sbt import export_kvu_for_SDSAP, merged_kvu_for_SDSAP_for_multi_pages +from sdsvkvu.utils.query.vtb import export_kvu_for_vietin, merged_kvu_for_vietin_for_multi_pages +from sdsvkvu.utils.query.all import export_kvu_for_all, merged_kvu_for_all_for_multi_pages +from sdsvkvu.utils.query.manulife import export_kvu_for_manulife, merged_kvu_for_manulife_for_multi_pages +from sdsvkvu.utils.query.sbt_v2 import export_kvu_for_SBT, merged_kvu_for_SBT_for_multi_pages + + +def get_args(): + args = argparse.ArgumentParser(description='Main file') + args.add_argument('--img_dir', type=str, required=True, + help='path to input image/directory file') + args.add_argument('--save_dir', type=str, required=True, + help='path to save directory') + args.add_argument('--doc_type', type=str, default="vat", + help='type of document') + args.add_argument('--export_img', type=bool, default=False, + help='export image of output visualization') + args.add_argument('--kvu_params', type=str, required=False, default="") + return args.parse_args() + + +def load_engine(kwargs) -> KVUEngine: + print('[INFO] Loading Key-Value Understanding model ...') + if not isinstance(kwargs, dict): + kwargs = json.loads(kwargs) if kwargs else {} + engine = KVUEngine(**kwargs) + print("[INFO] Loaded model") + print("[INFO] KVU engine settings: \n", engine._settings) + return engine + + +def process_img(img_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str) -> dict: + assert (engine._settings.mode == 4 and option == "sbt_v2") \ + or (engine._settings.mode != 4 and option != "sbt_v2"), \ + "[ERROR] Mode (4) has just supported option \"sbt_v2\"" + + print("="*5, os.path.basename(img_path)) + create_dir(save_dir) + fname, img_ext = os.path.splitext(os.path.basename(img_path)) + out_ext = ".json" + image, lbbox, lwords, pr_class_words, pr_relations = engine.predict(img_path) + + if len(lbbox) != 1: + raise ValueError( + f"Not support to predict each separated window: {len(lbbox)}" + ) + + for i in range(len(lbbox)): + if engine._settings.mode in range(4): + raw_outputs = export_kvu_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names) + elif engine._settings.mode == 4: + raw_outputs = export_sbt_outputs(lwords[i], lbbox[i], pr_class_words[i], pr_relations[i], engine._settings.class_names) + + if export_all: + save_path = os.path.join(save_dir, 'kvu_results') + create_dir(save_path) + write_to_json(os.path.join(save_path, fname + out_ext), raw_outputs) + # image = Image.open(img_path) + image = np.array(image) + image = draw_kvu_outputs(image, lbbox[i], pr_class_words[i], pr_relations[i], class_names=engine._settings.class_names) + cv2.imwrite(os.path.join(save_path, fname + img_ext), image) + + + if option == "vat": + outputs = export_kvu_for_VAT_invoice(raw_outputs) + elif option == "sbt": + outputs = export_kvu_for_SDSAP(raw_outputs) + elif option == "vtb": + outputs = export_kvu_for_vietin(raw_outputs) + elif option == "manulife": + outputs = export_kvu_for_manulife(raw_outputs) + elif option == "sbt_v2": + outputs = export_kvu_for_SBT(raw_outputs) + else: + outputs = export_kvu_for_all(raw_outputs) + write_to_json(os.path.join(save_dir, fname + out_ext), outputs) + return outputs + + +def process_pdf(pdf_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, n_pages: int = -1) -> dict: + out_ext = ".json" + fname, pdf_ext = os.path.splitext(os.path.basename(pdf_path)) + img_dirname = '_'.join([os.path.basename(os.path.dirname(pdf_path)), fname]) + img_save_dir = os.path.join(save_dir, img_dirname) + create_dir(img_save_dir) + list_img_files = pdf2img(pdf_path, img_save_dir, n_pages=n_pages, return_fname=True) + outputs = [] + for img_path in list_img_files: + print("=====", os.path.basename(img_path)) + _outputs = process_img(img_path, img_save_dir, engine, export_all=export_all, option=option) + outputs.append(_outputs) + if option == "vat": + outputs = merged_kvu_for_VAT_invoice_for_multi_pages(outputs) + elif option == "sbt": + outputs = merged_kvu_for_SDSAP_for_multi_pages(outputs) + elif option == "vtb": + outputs = merged_kvu_for_vietin_for_multi_pages(outputs) + elif option == "manulife": + outputs = merged_kvu_for_manulife_for_multi_pages(outputs) + elif option == "sbt_v2": + outputs = merged_kvu_for_SBT_for_multi_pages(outputs) + else: + outputs = merged_kvu_for_all_for_multi_pages(outputs) + write_to_json(os.path.join(save_dir, fname + out_ext), outputs) + return outputs + + +def process_dir(dir_path: str, save_dir: str, engine: KVUEngine, export_all: bool, option: str, dir_level: int = 0) -> None: + list_images = [] + for ext in ['JPG', 'PNG', 'jpeg', 'jpg', 'png', 'pdf']: + list_images += glob.glob(os.path.join(dir_path, f"{'*/'*dir_level}*.{ext}")) + print('No. images:', len(list_images)) + for file_path in tqdm(list_images): + if os.path.splitext(file_path)[1] == ".pdf": + outputs = process_pdf(file_path, save_dir, engine, export_all=export_all, option=option, n_pages=-1) + else: + outputs = process_img(file_path, save_dir, engine, export_all=export_all, option=option) + + +def Predictor_KVU(img: str, save_dir: str, engine: KVUEngine) -> dict: + curr_datetime = datetime.now().strftime('%Y-%m-%d %H-%M-%S') + image_path = "/home/thucpd/thucpd/PV2-2023/tmp_image/{}.jpg".format(curr_datetime) + cv2.imwrite(image_path, img) + vat_outputs = process_img(image_path, save_dir, engine, export_all=False, option="vat") + return vat_outputs + + +if __name__ == "__main__": + args = get_args() + engine = load_engine(args.kvu_params) + # vat_outputs = process_img(args.img_dir, args.save_dir, engine, export_all=True, option="vat") + # vat_outputs = process_pdf(args.img_dir, args.save_dir, engine, export_all=True, option="vat") + process_dir(args.img_dir, args.save_dir, engine, export_all=args.export_img, option=args.doc_type) + print('[INFO] Done') diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/__init__.py new file mode 100644 index 0000000..d4e48aa --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/__init__.py @@ -0,0 +1,45 @@ + +import os +import torch +from sdsvkvu.model.kvu_model import KVUModel +from sdsvkvu.model.combined_model import ComKVUModel +from sdsvkvu.model.document_kvu_model import DocKVUModel +from sdsvkvu.model.sbt_model import SBTModel + +def get_model(cfg): + if cfg.mode == 0 or cfg.mode == 1: + model = ComKVUModel(cfg=cfg) + elif cfg.mode == 2: + model = KVUModel(cfg=cfg) + elif cfg.mode == 3: + model = DocKVUModel(cfg=cfg) + elif cfg.mode == 4: + model = SBTModel(cfg=cfg) + else: + raise ValueError(f'[ERROR] Model mode of {cfg.mode} is not supported') + return model + +def load_checkpoint(ckpt_path, model, key_include): + assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!" + state_dict = torch.load(ckpt_path, 'cpu')['state_dict'] + for key in list(state_dict.keys()): + if f'.{key_include}.' not in key: + del state_dict[key] + else: + state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something. + del state_dict[key] + model.load_state_dict(state_dict, strict=True) + print(f"Load checkpoint at {ckpt_path}") + return model + +def load_model_weight(net, pretrained_model_file): + pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[ + "state_dict" + ] + new_state_dict = {} + for k, v in pretrained_model_state_dict.items(): + new_k = k + if new_k.startswith("net."): + new_k = new_k[len("net.") :] + new_state_dict[new_k] = v + net.load_state_dict(new_state_dict) \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/combined_model.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/combined_model.py new file mode 100644 index 0000000..5c32797 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/combined_model.py @@ -0,0 +1,71 @@ +import os +import torch +from torch import nn + +from sdsvkvu.model.kvu_model import KVUModel +# from model import load_checkpoint + + +class ComKVUModel(KVUModel): + def __init__(self, cfg): + super().__init__(cfg) + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + # if os.path.exists(self.model_cfg.ckpt_model_file): + # self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + # self.itc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.itc_layer, 'itc_layer') + # self.stc_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.stc_layer, 'stc_layer') + # self.relation_layer = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer, 'relation_layer') + # self.relation_layer_from_key = load_checkpoint(self.model_cfg.ckpt_model_file, self.relation_layer_from_key, 'relation_layer_from_key') + + self.loss_func = nn.CrossEntropyLoss() + + # if self.freeze: + # for name, param in self.named_parameters(): + # if 'backbone' in name: + # param.requires_grad = False + # if self.finetune_only == 'EE': + # for name, param in self.named_parameters(): + # if 'itc_layer' not in name and 'stc_layer' not in name: + # param.requires_grad = False + # if self.finetune_only == 'EL': + # for name, param in self.named_parameters(): + # if 'relation_layer' not in name or 'relation_layer_from_key' in name: + # param.requires_grad = False + # if self.finetune_only == 'ELK': + # for name, param in self.named_parameters(): + # if 'relation_layer_from_key' not in name: + # param.requires_grad = False + + + def forward(self, batch): + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, input_ids=input_ids_layoutxlm, bbox=bbox, attention_mask=attention_mask_layoutxlm) + + last_hidden_states = backbone_outputs_layoutxlm.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + head_outputs = {"itc_outputs": itc_outputs, "stc_outputs": stc_outputs, + "el_outputs": el_outputs, "el_outputs_from_key": el_outputs_from_key} + + loss = 0.0 + if any(['labels' in key for key in batch.keys()]): + loss = self._get_loss(head_outputs, batch) + + return head_outputs, loss + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/document_kvu_model.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/document_kvu_model.py new file mode 100644 index 0000000..b278fb5 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/document_kvu_model.py @@ -0,0 +1,162 @@ +import torch +from torch import nn +from sdsvkvu.model.relation_extractor import RelationExtractor +from sdsvkvu.model.kvu_model import KVUModel +# from model import load_checkpoint + + +class DocKVUModel(KVUModel): + def __init__(self, cfg): + super().__init__(cfg) + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.train_cfg = cfg.train + self.n_classes = len(self.model_cfg.class_names) + + self._get_backbones(self.model_cfg.backbone) + + self._create_head() + + self.loss_func = nn.CrossEntropyLoss() + + def _create_head(self): + self.backbone_hidden_size = self.backbone_config.hidden_size + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + # self.n_classes = self.model_cfg.n_classes + 1 + self.repr_hiddent_size = self.backbone_hidden_size + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # Classfication Layer for whole document + # (1) Initial token classification + self.itc_layer_document = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (3) Linking token classification + self.relation_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + # (4) Linking token classification + self.relation_layer_from_key_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + self.relation_layer_from_key.apply(self._init_weight) + + self.itc_layer_document.apply(self._init_weight) + self.stc_layer_document.apply(self._init_weight) + self.relation_layer_document.apply(self._init_weight) + self.relation_layer_from_key_document.apply(self._init_weight) + + + def forward(self, batches): + head_outputs_list = [] + loss = 0.0 + for batch in batches["windows"]: + image = batch["image"] + input_ids = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask = batch["attention_mask_layoutxlm"] + + if self.freeze: + for param in self.backbone.parameters(): + param.requires_grad = False + + if self.model_cfg.backbone == 'layoutxlm': + backbone_outputs = self.backbone( + image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask + ) + else: + backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) + + last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key(last_hidden_states, last_hidden_states).squeeze(0) + + window_repr = last_hidden_states.transpose(0, 1).contiguous() + + head_outputs = {"window_repr": window_repr, + "itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + head_outputs_list.append(head_outputs) + + batch = batches["documents"] + + document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) + document_repr = document_repr.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) + el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key_document(document_repr, document_repr).squeeze(0) + + head_outputs = {"itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + return head_outputs, loss + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/kvu_model.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/kvu_model.py new file mode 100644 index 0000000..0277d1b --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/kvu_model.py @@ -0,0 +1,300 @@ +import os +import torch +from torch import nn +from pathlib import Path +from transformers import ( + LayoutLMConfig, + LayoutLMModel, + LayoutLMTokenizer, +) +from transformers import ( + LayoutLMv2Config, + LayoutLMv2Model, + LayoutLMv2FeatureExtractor, + LayoutXLMTokenizer, +) +from transformers import ( + XLMRobertaConfig, + AutoTokenizer, + XLMRobertaModel +) + +# from model import load_checkpoint +from sdsvkvu.sources.utils import merged_token_embeddings +from sdsvkvu.model.relation_extractor import RelationExtractor + + +class KVUModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.finetune_only = cfg.train.finetune_only + self.n_classes = len(self.model_cfg.class_names) + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + # if (cfg.stage == 2) and (os.path.exists(self.model_cfg.ckpt_model_file)): + # self.backbone_layoutxlm = load_checkpoint(self.model_cfg.ckpt_model_file, self.backbone_layoutxlm, 'backbone_layoutxlm') + + self._create_head() + self.loss_func = nn.CrossEntropyLoss() + + if self.freeze: + for name, param in self.named_parameters(): + if "backbone" in name: + param.requires_grad = False + + def _create_head(self): + self.backbone_hidden_size = 768 + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + # self.n_classes = self.model_cfg.n_classes + 1 + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, # 1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, # 1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (4) Linking token classification + self.relation_layer_from_key = RelationExtractor( + n_relations=1, # 1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + + # def _get_backbones(self, config_type): + # self.tokenizer_layoutxlm = LayoutXLMTokenizer.from_pretrained('microsoft/layoutxlm-base') + # self.feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False) + # self.backbone_layoutxlm = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base') + + def _get_backbones(self, config_type): + configs = { + "layoutlm": { + "config": LayoutLMConfig, + "tokenizer": LayoutLMTokenizer, + "backbone": LayoutLMModel, + "feature_extrator": LayoutLMv2FeatureExtractor, + }, + "layoutxlm": { + "config": LayoutLMv2Config, + "tokenizer": LayoutXLMTokenizer, + "backbone": LayoutLMv2Model, + "feature_extrator": LayoutLMv2FeatureExtractor, + }, + "xlm-roberta": { + "config": XLMRobertaConfig, + "tokenizer": AutoTokenizer, + "backbone": XLMRobertaModel, + "feature_extrator": LayoutLMv2FeatureExtractor, + }, + } + + self.backbone_config = configs[config_type]["config"].from_pretrained( + self.model_cfg.pretrained_model_path + ) + if config_type != "xlm-roberta": + self.tokenizer = configs[config_type]["tokenizer"].from_pretrained( + self.model_cfg.pretrained_model_path + ) + else: + self.tokenizer = configs[config_type]["tokenizer"].from_pretrained( + self.model_cfg.pretrained_model_path, use_fast=False + ) + self.feature_extractor = configs[config_type]["feature_extrator"]( + apply_ocr=False + ) + self.backbone = configs[config_type]["backbone"].from_pretrained( + self.model_cfg.pretrained_model_path + ) + + @staticmethod + def _init_weight(module): + init_std = 0.02 + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.LayerNorm): + nn.init.normal_(module.weight, 1.0, init_std) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + + def forward(self, lbatches): + windows = lbatches["windows"] + token_embeddings_windows = [] + lvalids = [] + loverlaps = [] + + for i, batch in enumerate(windows): + batch = { + k: v.cuda() for k, v in batch.items() if k not in ("img_path", "words") + } + image = batch["image"] + input_ids_layoutxlm = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask_layoutxlm = batch["attention_mask_layoutxlm"] + + backbone_outputs_layoutxlm = self.backbone_layoutxlm( + image=image, + input_ids=input_ids_layoutxlm, + bbox=bbox, + attention_mask=attention_mask_layoutxlm, + ) + + last_hidden_states_layoutxlm = backbone_outputs_layoutxlm.last_hidden_state[ + :, :512, : + ] + + lvalids.append(batch["len_valid_tokens"]) + loverlaps.append(batch["len_overlap_tokens"]) + token_embeddings_windows.append(last_hidden_states_layoutxlm) + + token_embeddings = merged_token_embeddings( + token_embeddings_windows, loverlaps, lvalids, average=False + ) + + token_embeddings = token_embeddings.transpose(0, 1).contiguous().cuda() + itc_outputs = self.itc_layer(token_embeddings).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs = self.relation_layer(token_embeddings, token_embeddings).squeeze(0) + el_outputs_from_key = self.relation_layer_from_key( + token_embeddings, token_embeddings + ).squeeze(0) + head_outputs = { + "itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs, + "el_outputs_from_key": el_outputs_from_key, + "embedding_tokens": token_embeddings.transpose(0, 1) + .contiguous() + .detach() + .cpu() + .numpy(), + } + + loss = 0.0 + if any(["labels" in key for key in lbatches.keys()]): + labels = { + k: v.cuda() + for k, v in lbatches["documents"].items() + if k not in ("img_path") + } + loss = self._get_loss(head_outputs, labels) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch) + el_loss_from_key = self._get_el_loss(el_outputs_from_key, batch, from_key=True) + + loss = itc_loss + stc_loss + el_loss + el_loss_from_key + + return loss + + def _get_itc_loss(self, itc_outputs, batch): + itc_mask = batch["are_box_first_tokens"].view(-1) + + itc_logits = itc_outputs.view(-1, self.model_cfg.n_classes + 1) + itc_logits = itc_logits[itc_mask] + + itc_labels = batch["itc_labels"].view(-1) + itc_labels = itc_labels[itc_mask] + + itc_loss = self.loss_func(itc_logits, itc_labels) + + return itc_loss + + def _get_stc_loss(self, stc_outputs, batch): + inv_attention_mask = 1 - batch["attention_mask_layoutxlm"] + + bsz, max_seq_length = inv_attention_mask.shape + device = inv_attention_mask.device + + invalid_token_mask = torch.cat( + [inv_attention_mask, torch.zeros([bsz, 1]).to(device)], axis=1 + ).bool() + + stc_outputs.masked_fill_(invalid_token_mask[:, None, :], -10000.0) + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + stc_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + stc_mask = batch["attention_mask_layoutxlm"].view(-1).bool() + stc_logits = stc_outputs.view(-1, max_seq_length + 1) + stc_logits = stc_logits[stc_mask] + + stc_labels = batch["stc_labels"].view(-1) + stc_labels = stc_labels[stc_mask] + + stc_loss = self.loss_func(stc_logits, stc_labels) + + return stc_loss + + def _get_el_loss(self, el_outputs, batch, from_key=False): + bsz, max_seq_length = batch["attention_mask_layoutxlm"].shape + + device = batch["attention_mask_layoutxlm"].device + + self_token_mask = ( + torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + ) + + box_first_token_mask = torch.cat( + [ + (batch["are_box_first_tokens"] == False), + torch.zeros([bsz, 1], dtype=torch.bool).to(device), + ], + axis=1, + ) + el_outputs.masked_fill_(box_first_token_mask[:, None, :], -10000.0) + el_outputs.masked_fill_(self_token_mask[None, :, :], -10000.0) + + mask = batch["are_box_first_tokens"].view(-1) + + logits = el_outputs.view(-1, max_seq_length + 1) + logits = logits[mask] + + if from_key: + el_labels = batch["el_labels_from_key"] + else: + el_labels = batch["el_labels"] + labels = el_labels.view(-1) + labels = labels[mask] + + loss = self.loss_func(logits, labels) + return loss diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/relation_extractor.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/relation_extractor.py new file mode 100644 index 0000000..40a169e --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/relation_extractor.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + + +class RelationExtractor(nn.Module): + def __init__( + self, + n_relations, + backbone_hidden_size, + head_hidden_size, + head_p_dropout=0.1, + ): + super().__init__() + + self.n_relations = n_relations + self.backbone_hidden_size = backbone_hidden_size + self.head_hidden_size = head_hidden_size + self.head_p_dropout = head_p_dropout + + self.drop = nn.Dropout(head_p_dropout) + self.q_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.k_net = nn.Linear( + self.backbone_hidden_size, self.n_relations * self.head_hidden_size + ) + + self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size)) + nn.init.normal_(self.dummy_node) + + def forward(self, h_q, h_k): + h_q = self.q_net(self.drop(h_q)) + + dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1) + h_k = torch.cat([h_k, dummy_vec], axis=0) + h_k = self.k_net(self.drop(h_k)) + + head_q = h_q.view( + h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size + ) + head_k = h_k.view( + h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size + ) + + relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k)) + + return relation_score diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/sbt_model.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/sbt_model.py new file mode 100644 index 0000000..736d223 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/sbt_model.py @@ -0,0 +1,156 @@ +import torch +from torch import nn +from transformers import LayoutLMConfig, LayoutLMModel, LayoutLMTokenizer, LayoutLMv2FeatureExtractor +from transformers import LayoutLMv2Config, LayoutLMv2Model +from sdsvkvu.model.relation_extractor import RelationExtractor +from sdsvkvu.model.kvu_model import KVUModel +# from utils import load_checkpoint + + +class SBTModel(KVUModel): + def __init__(self, cfg): + super().__init__(cfg=cfg) + + self.model_cfg = cfg.model + self.freeze = cfg.train.freeze + self.train_cfg = cfg.train + self.n_classes = len(self.model_cfg.class_names) + + self._get_backbones(self.model_cfg.backbone) + self._create_head() + + self.loss_func = nn.CrossEntropyLoss() + + + def _create_head(self): + self.backbone_hidden_size = self.backbone_config.hidden_size + self.head_hidden_size = self.model_cfg.head_hidden_size + self.head_p_dropout = self.model_cfg.head_p_dropout + # self.n_classes = self.model_cfg.n_classes + 1 + # self.relations = self.model_cfg.n_relations + self.repr_hiddent_size = self.backbone_hidden_size + + # (1) Initial token classification + self.itc_layer = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.backbone_hidden_size, self.n_classes), + ) + # (2) Subsequent token classification + self.stc_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # (3) Linking token classification + self.relation_layer = RelationExtractor( + n_relations=1, #1 + backbone_hidden_size=self.backbone_hidden_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + # Classfication Layer for whole document + self.itc_layer_document = nn.Sequential( + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.repr_hiddent_size), + nn.Dropout(self.head_p_dropout), + nn.Linear(self.repr_hiddent_size, self.n_classes), + ) + + self.stc_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.relation_layer_document = RelationExtractor( + n_relations=1, + backbone_hidden_size=self.repr_hiddent_size, + head_hidden_size=self.head_hidden_size, + head_p_dropout=self.head_p_dropout, + ) + + self.itc_layer.apply(self._init_weight) + self.stc_layer.apply(self._init_weight) + self.relation_layer.apply(self._init_weight) + + self.itc_layer_document.apply(self._init_weight) + self.stc_layer_document.apply(self._init_weight) + self.relation_layer_document.apply(self._init_weight) + + + + def forward(self, batches): + head_outputs_list = [] + loss = 0. + for batch in batches["windows"]: + image = batch["image"] + input_ids = batch["input_ids_layoutxlm"] + bbox = batch["bbox"] + attention_mask = batch["attention_mask_layoutxlm"] + + if self.freeze: + for param in self.backbone.parameters(): + param.requires_grad = False + + if self.model_cfg.backbone == 'layoutxlm': + backbone_outputs = self.backbone( + image=image, input_ids=input_ids, bbox=bbox, attention_mask=attention_mask + ) + else: + backbone_outputs = self.backbone(input_ids, attention_mask=attention_mask) + + last_hidden_states = backbone_outputs.last_hidden_state[:, :512, :] + last_hidden_states = last_hidden_states.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer(last_hidden_states).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer(last_hidden_states, last_hidden_states).squeeze(0) + el_outputs = self.relation_layer(last_hidden_states, last_hidden_states).squeeze(0) + + window_repr = last_hidden_states.transpose(0, 1).contiguous() + + head_outputs = {"window_repr": window_repr, + "itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs,} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + head_outputs_list.append(head_outputs) + + batch = batches["documents"] + + document_repr = torch.cat([w['window_repr'] for w in head_outputs_list], dim=1) + document_repr = document_repr.transpose(0, 1).contiguous() + + itc_outputs = self.itc_layer_document(document_repr).transpose(0, 1).contiguous() + stc_outputs = self.stc_layer_document(document_repr, document_repr).squeeze(0) + el_outputs = self.relation_layer_document(document_repr, document_repr).squeeze(0) + + head_outputs = {"itc_outputs": itc_outputs, + "stc_outputs": stc_outputs, + "el_outputs": el_outputs} + + if any(['labels' in key for key in batch.keys()]): + loss += self._get_loss(head_outputs, batch) + + return head_outputs, loss + + def _get_loss(self, head_outputs, batch): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + + itc_loss = self._get_itc_loss(itc_outputs, batch) + stc_loss = self._get_stc_loss(stc_outputs, batch) + el_loss = self._get_el_loss(el_outputs, batch, from_key=False) + + loss = itc_loss + stc_loss + el_loss + + return loss diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/predictor.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/predictor.py new file mode 100644 index 0000000..2861a31 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/predictor.py @@ -0,0 +1,225 @@ +import torch +from pathlib import Path +from omegaconf import OmegaConf + +import os +from sdsvkvu.sources.utils import parse_initial_words, parse_subsequent_words, parse_relations +from sdsvkvu.model import get_model, load_model_weight + + +class KVUPredictor: + def __init__(self, configs): + self.mode = configs.mode + self.device = configs.device + self.pretrained_model_path = configs.model.pretrained_model_path + net, cfg = self._load_model(configs.model.config, + configs.model.checkpoint) + + self.model = net + self.class_names = cfg.model.class_names + self.max_seq_length = cfg.train.max_seq_length + self.backbone_type = cfg.model.backbone + + if self.mode in (3, 4): + self.slice_interval = 0 + self.window_size = cfg.train.window_size + self.max_window_count = cfg.train.max_window_count + self.dummy_idx = self.max_seq_length * self.max_window_count + + else: + self.slice_interval = cfg.train.slice_interval + self.window_size = cfg.train.max_num_words + self.max_window_count = 1 + if self.mode == 2: + self.dummy_idx = 0 # dynamic dummy + else: + self.dummy_idx = self.max_seq_length # 512 + + + def get_process_configs(self): + _settings = { + # "tokenizer_layoutxlm": self.model.tokenizer_layoutxlm, + # "feature_extractor": self.model.feature_extractor, + "class_names": self.class_names, + "backbone_type": self.backbone_type, + "window_size": self.window_size, + "slice_interval": self.slice_interval, + "max_window_count": self.max_window_count, + "max_seq_length": self.max_seq_length, + "device": self.device, + "mode": self.mode + } + + feature_extractor = self.model.feature_extractor + if self.mode in (3, 4): + tokenizer_layoutxlm = self.model.tokenizer + else: + tokenizer_layoutxlm = self.model.tokenizer_layoutxlm + + return OmegaConf.create(_settings), tokenizer_layoutxlm, feature_extractor + + + def _load_model(self, cfg_path, ckpt_path): + cfg = OmegaConf.load(cfg_path) + + if self.pretrained_model_path is not None and os.path.exists(self.pretrained_model_path): + cfg.model.pretrained_model_path = self.pretrained_model_path + print("[INFO] Load pretrained backbone at:", cfg.model.pretrained_model_path) + + cfg.mode = self.mode + net = get_model(cfg) + load_model_weight(net, ckpt_path) + net.to(self.device) + net.eval() + return net, cfg + + def predict(self, input_sample): + if self.mode == 0: # Normal + bbox, lwords, pr_class_words, pr_relations = self.com_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 1: # Full - tokens + bbox, lwords, pr_class_words, pr_relations = self.cat_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 2: # Sliding + bbox, lwords, pr_class_words, pr_relations = [], [], [], [] + for window in input_sample['windows']: + _bbox, _lwords, _pr_class_words, _pr_relations = self.com_predict(window) + bbox.append(_bbox) + lwords.append(_lwords) + pr_class_words.append(_pr_class_words) + pr_relations.append(_pr_relations) + return bbox, lwords, pr_class_words, pr_relations + + elif self.mode == 3: # Document + bbox, lwords, pr_class_words, pr_relations = self.doc_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + + elif self.mode == 4: # SBT + bbox, lwords, pr_class_words, pr_relations = self.sbt_predict(input_sample) + return [bbox], [lwords], [pr_class_words], [pr_relations] + else: + raise ValueError(f"Not supported mode: {self.mode }") + + def doc_predict(self, input_sample): + lwords = input_sample['documents']['words'] + for idx, window in enumerate(input_sample['windows']): + input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')} + + with torch.no_grad(): + head_outputs, _ = self.model(input_sample) + + input_sample = input_sample['documents'] + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + # input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} + + bbox = input_sample['bbox'].squeeze(0) + pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs) + + return bbox, lwords, pr_class_words, pr_relations + + + def com_predict(self, input_sample): + lwords = input_sample['words'] + input_sample = {k: v.unsqueeze(0) for k, v in input_sample.items() if k not in ('words', 'img_path')} + input_sample = {k: v.to(self.device) for k, v in input_sample.items()} + + with torch.no_grad(): + head_outputs, _ = self.model(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} + + + bbox = input_sample['bbox'].squeeze(0) + pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs) + + return bbox, lwords, pr_class_words, pr_relations + + + def cat_predict(self, input_sample): + lwords = input_sample['documents']['words'] + inputs = [] + for window in input_sample['windows']: + inputs.append({k: v.unsqueeze(0).cuda() for k, v in window.items() if k not in ('words', 'img_path')}) + input_sample['windows'] = inputs + + with torch.no_grad(): + head_outputs, _ = self.model(input_sample) + + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items() if k not in ('embedding_tokens')} + + + input_sample = {k: v.unsqueeze(0) for k, v in input_sample["documents"].items()} + + bbox = input_sample['bbox'].squeeze(0) + self.dummy_idx = bbox.shape[0] + pr_class_words, pr_relations = self.kvu_parser(input_sample, head_outputs) + return bbox, lwords, pr_class_words, pr_relations + + + def kvu_parser(self, input_sample, head_outputs): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + el_outputs_from_key = head_outputs["el_outputs_from_key"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + pr_el_from_key = torch.argmax(el_outputs_from_key, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations_from_header = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + pr_relations_from_key = parse_relations(pr_el_from_key, box_first_token_mask, self.dummy_idx) + pr_relations = pr_relations_from_header | pr_relations_from_key + + return pr_class_words, pr_relations + + + def sbt_predict(self, input_sample): + lwords = input_sample['documents']['words'] + for idx, window in enumerate(input_sample['windows']): + input_sample['windows'][idx] = {k: v.unsqueeze(0).to(self.device) for k, v in window.items() if k not in ('words', 'n_empty_windows')} + + with torch.no_grad(): + head_outputs, _ = self.model(input_sample) + + input_sample = input_sample['documents'] + head_outputs = {k: v.detach().cpu() for k, v in head_outputs.items()} + # input_sample = {k: v.detach().cpu() for k, v in input_sample.items()} + + bbox = input_sample['bbox'].squeeze(0) + pr_class_words, pr_relations = self.sbt_parser(input_sample, head_outputs) + + return bbox, lwords, pr_class_words, pr_relations + + + def sbt_parser(self, input_sample, head_outputs): + itc_outputs = head_outputs["itc_outputs"] + stc_outputs = head_outputs["stc_outputs"] + el_outputs = head_outputs["el_outputs"] + + pr_itc_label = torch.argmax(itc_outputs, -1).squeeze(0) + pr_stc_label = torch.argmax(stc_outputs, -1).squeeze(0) + pr_el_label = torch.argmax(el_outputs, -1).squeeze(0) + + box_first_token_mask = input_sample['are_box_first_tokens'].squeeze(0) + attention_mask = input_sample['attention_mask_layoutxlm'].squeeze(0) + + pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, self.class_names) + pr_class_words = parse_subsequent_words( + pr_stc_label, attention_mask, pr_init_words, self.dummy_idx + ) + + pr_relations = parse_relations(pr_el_label, box_first_token_mask, self.dummy_idx) + + return pr_class_words, pr_relations \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/preprocess.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/preprocess.py new file mode 100644 index 0000000..c499a61 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/preprocess.py @@ -0,0 +1,479 @@ +import torch +import itertools +import numpy as np + +from sdsvkvu.sources.utils import sliding_windows + + +class KVUProcessor: + def __init__( + self, + tokenizer_layoutxlm, + feature_extractor, + backbone_type, + class_names, + slice_interval, + window_size, + max_seq_length, + mode, + **kwargs, + ): + self.mode = mode + self.class_names = class_names + self.backbone_type = backbone_type + + self.window_size = window_size + self.slice_interval = slice_interval + self.max_seq_length = max_seq_length + + self.tokenizer_layoutxlm = tokenizer_layoutxlm + self.feature_extractor = feature_extractor + + self.pad_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids( + tokenizer_layoutxlm._pad_token + ) + self.cls_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids( + tokenizer_layoutxlm._cls_token + ) + self.sep_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids( + tokenizer_layoutxlm._sep_token + ) + self.unk_token_id_layoutxlm = tokenizer_layoutxlm.convert_tokens_to_ids( + self.tokenizer_layoutxlm._unk_token + ) + + self.class_idx_dic = dict( + [(class_name, idx) for idx, class_name in enumerate(self.class_names)] + ) + + def __call__(self, lbboxes: list, lwords: list, image, width, height) -> dict: + image = torch.from_numpy( + self.feature_extractor(image)["pixel_values"][0].copy() + ) + + bbox_windows = sliding_windows(lbboxes, self.window_size, self.slice_interval) + word_windows = sliding_windows(lwords, self.window_size, self.slice_interval) + assert len(bbox_windows) == len(word_windows), \ + f"Shape of lbboxes and lwords after sliding window is not the same {len(bbox_windows)} # {len(word_windows)}" + + if self.mode == 0: # First 512 tokens + output = self.preprocess_window( + bounding_boxes=lbboxes, + words=lwords, + image_features={"image": image, "width": width, "height": height}, + max_seq_length=self.max_seq_length, + ) + elif self.mode == 1: # Get full tokens + output = {} + windows = [] + for i in range(len(bbox_windows)): + windows.append( + self.preprocess_window( + bounding_boxes=bbox_windows[i], + words=word_windows[i], + image_features={"image": image, "width": width, "height": height}, + max_seq_length=self.max_seq_length, + ) + ) + + output["windows"] = windows + elif self.mode == 2: # Sliding window + output = {} + windows = [] + output["doduments"] = self.preprocess_window( + bounding_boxes=lbboxes, + words=lwords, + image_features={"image": image, "width": width, "height": height}, + max_seq_length=2048, + ) + for i in range(len(bbox_windows)): + windows.append( + self.preprocess( + bounding_boxes=bbox_windows[i], + words=word_windows[i], + image_features={"image": image, "width": width, "height": height}, + max_seq_length=self.max_seq_length, + ) + ) + + output["windows"] = windows + else: + raise ValueError(f"Not supported mode: {self.mode }") + return output + + def preprocess_window(self, bounding_boxes, words, image_features, max_seq_length): + list_word_objects = [] + for bb, text in zip(bounding_boxes, words): + boundingBox = [ + [bb[0], bb[1]], + [bb[2], bb[1]], + [bb[2], bb[3]], + [bb[0], bb[3]], + ] + tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids( + self.tokenizer_layoutxlm.tokenize(text) + ) + list_word_objects.append( + {"layoutxlm_tokens": tokens, "boundingBox": boundingBox, "text": text} + ) + + ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_tokens, + ) = self.parser_words( + list_word_objects, + self.max_seq_length, + image_features["width"], + image_features["height"], + ) + + assert len_list_tokens == len_valid_tokens + 2 + len_overlap_tokens = len_valid_tokens - len_non_overlap_tokens + + ntokens = max_seq_length if max_seq_length == 512 else len_valid_tokens + 2 + + input_ids = input_ids[:ntokens] + attention_mask = attention_mask[:ntokens] + bbox = bbox[:ntokens] + are_box_first_tokens = are_box_first_tokens[:ntokens] + + input_ids = torch.from_numpy(input_ids) + attention_mask = torch.from_numpy(attention_mask) + bbox = torch.from_numpy(bbox) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + len_valid_tokens = torch.tensor(len_valid_tokens) + len_overlap_tokens = torch.tensor(len_overlap_tokens) + return_dict = { + "words": lwords, + "len_overlap_tokens": len_overlap_tokens, + "len_valid_tokens": len_valid_tokens, + "image": image_features["image"], + "input_ids_layoutxlm": input_ids, + "attention_mask_layoutxlm": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + "bbox": bbox, + } + return return_dict + + def parser_words(self, words, max_seq_length, width, height): + list_bbs = [] + list_words = [] + list_tokens = [] + cls_bbs = [0.0] * 8 + box2token_span_map = [] + box_to_token_indices = [] + lwords = [""] * max_seq_length + + cum_token_idx = 0 + len_valid_tokens = 0 + len_non_overlap_tokens = 0 + + input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id_layoutxlm + bbox = np.zeros((max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) + + for word_idx, word in enumerate(words): + this_box_token_indices = [] + + tokens = word["layoutxlm_tokens"] + bb = word["boundingBox"] + text = word["text"] + + len_valid_tokens += len(tokens) + if word_idx < self.slice_interval: + len_non_overlap_tokens += len(tokens) + + if len(tokens) == 0: + tokens.append(self.unk_token_id) + + if len(list_tokens) + len(tokens) > max_seq_length - 2: + break + + box2token_span_map.append( + [len(list_tokens) + 1, len(list_tokens) + len(tokens) + 1] + ) # including st_idx + list_tokens += tokens + + # min, max clipping + for coord_idx in range(4): + bb[coord_idx][0] = max(0.0, min(bb[coord_idx][0], width)) + bb[coord_idx][1] = max(0.0, min(bb[coord_idx][1], height)) + + bb = list(itertools.chain(*bb)) + bbs = [bb for _ in range(len(tokens))] + texts = [text for _ in range(len(tokens))] + + for _ in tokens: + cum_token_idx += 1 + this_box_token_indices.append(cum_token_idx) + + list_bbs.extend(bbs) + list_words.extend(texts) #### + box_to_token_indices.append(this_box_token_indices) + + sep_bbs = [width, height] * 4 + + # For [CLS] and [SEP] + list_tokens = ( + [self.cls_token_id_layoutxlm] + + list_tokens[: max_seq_length - 2] + + [self.sep_token_id_layoutxlm] + ) + if len(list_bbs) == 0: + # When len(json_obj["words"]) == 0 (no OCR result) + list_bbs = [cls_bbs] + [sep_bbs] + else: # len(list_bbs) > 0 + list_bbs = [cls_bbs] + list_bbs[: max_seq_length - 2] + [sep_bbs] + # list_words = ['CLS'] + list_words[: max_seq_length - 2] + ['SEP'] ### + # if len(list_words) < 510: + # list_words.extend(['

' for _ in range(510 - len(list_words))]) + list_words = ( + [self.tokenizer_layoutxlm._cls_token] + + list_words[: max_seq_length - 2] + + [self.tokenizer_layoutxlm._sep_token] + ) + + len_list_tokens = len(list_tokens) + input_ids[:len_list_tokens] = list_tokens + attention_mask[:len_list_tokens] = 1 + + bbox[:len_list_tokens, :] = list_bbs + lwords[:len_list_tokens] = list_words + + # Normalize bbox -> 0 ~ 1 + bbox[:, [0, 2, 4, 6]] = bbox[:, [0, 2, 4, 6]] / width + bbox[:, [1, 3, 5, 7]] = bbox[:, [1, 3, 5, 7]] / height + + if self.backbone_type in ("layoutlm", "layoutxlm"): + bbox = bbox[:, [0, 1, 4, 5]] + bbox = bbox * 1000 + bbox = bbox.astype(int) + else: + assert False + + st_indices = [ + indices[0] + for indices in box_to_token_indices + if indices[0] < max_seq_length + ] + are_box_first_tokens[st_indices] = True + + return ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_tokens, + ) + + +class DocKVUProcessor(KVUProcessor): + def __init__( + self, + tokenizer_layoutxlm, + feature_extractor, + backbone_type, + class_names, + max_window_count, + slice_interval, + window_size, + max_seq_length, + mode, + **kwargs, + ): + super().__init__( + tokenizer_layoutxlm=tokenizer_layoutxlm, + feature_extractor=feature_extractor, + backbone_type=backbone_type, + class_names=class_names, + slice_interval=slice_interval, + window_size=window_size, + max_seq_length=max_seq_length, + mode=mode, + ) + + self.max_window_count = max_window_count + self.pad_token_id = self.pad_token_id_layoutxlm + self.cls_token_id = self.cls_token_id_layoutxlm + self.sep_token_id = self.sep_token_id_layoutxlm + self.unk_token_id = self.unk_token_id_layoutxlm + self.tokenizer = self.tokenizer_layoutxlm + + def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict: + image_features = torch.from_numpy( + self.feature_extractor(images)["pixel_values"][0].copy() + ) + output = self.preprocess_document( + bounding_boxes=lbboxes, + words=lwords, + image_features={"image": image_features, "width": width, "height": height}, + max_seq_length=self.max_seq_length, + ) + return output + + def preprocess_document(self, bounding_boxes, words, image_features, max_seq_length): + n_words = len(words) + output_dicts = {"windows": [], "documents": []} + n_empty_windows = 0 + + for i in range(self.max_window_count): + input_ids = np.ones(max_seq_length, dtype=int) * self.pad_token_id + bbox = np.zeros((max_seq_length, 8), dtype=np.float32) + attention_mask = np.zeros(max_seq_length, dtype=int) + are_box_first_tokens = np.zeros(max_seq_length, dtype=np.bool_) + + if n_words == 0: + n_empty_windows += 1 + output_dicts["windows"].append( + { + "image": image_features["image"], + "input_ids_layoutxlm": torch.from_numpy(input_ids), + "bbox": torch.from_numpy(bbox), + "words": [], + "attention_mask_layoutxlm": torch.from_numpy(attention_mask), + "are_box_first_tokens": torch.from_numpy(are_box_first_tokens), + } + ) + continue + + start_word_idx = i * self.window_size + stop_word_idx = min(n_words, (i + 1) * self.window_size) + + if start_word_idx >= stop_word_idx: + n_empty_windows += 1 + output_dicts["windows"].append(output_dicts["windows"][-1]) + continue + + list_word_objects = [] + for bb, text in zip( + bounding_boxes[start_word_idx:stop_word_idx], + words[start_word_idx:stop_word_idx], + ): + boundingBox = [ + [bb[0], bb[1]], + [bb[2], bb[1]], + [bb[2], bb[3]], + [bb[0], bb[3]], + ] + tokens = self.tokenizer_layoutxlm.convert_tokens_to_ids( + self.tokenizer_layoutxlm.tokenize(text) + ) + list_word_objects.append( + { + "layoutxlm_tokens": tokens, + "boundingBox": boundingBox, + "text": text, + } + ) + + ( + bbox, + input_ids, + attention_mask, + are_box_first_tokens, + box_to_token_indices, + box2token_span_map, + lwords, + len_valid_tokens, + len_non_overlap_tokens, + len_list_layoutxlm_tokens, + ) = self.parser_words( + list_word_objects, + self.max_seq_length, + image_features["width"], + image_features["height"], + ) + + input_ids = torch.from_numpy(input_ids) + bbox = torch.from_numpy(bbox) + attention_mask = torch.from_numpy(attention_mask) + are_box_first_tokens = torch.from_numpy(are_box_first_tokens) + + return_dict = { + "bbox": bbox, + "words": lwords, + "image": image_features["image"], + "input_ids_layoutxlm": input_ids, + "attention_mask_layoutxlm": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + } + output_dicts["windows"].append(return_dict) + + attention_mask = torch.cat( + [o["attention_mask_layoutxlm"] for o in output_dicts["windows"]] + ) + are_box_first_tokens = torch.cat( + [o["are_box_first_tokens"] for o in output_dicts["windows"]] + ) + if n_empty_windows > 0: + attention_mask[ + self.max_seq_length * (self.max_window_count - n_empty_windows) : + ] = torch.from_numpy( + np.zeros(self.max_seq_length * n_empty_windows, dtype=int) + ) + are_box_first_tokens[ + self.max_seq_length * (self.max_window_count - n_empty_windows) : + ] = torch.from_numpy( + np.zeros(self.max_seq_length * n_empty_windows, dtype=np.bool_) + ) + bbox = torch.cat([o["bbox"] for o in output_dicts["windows"]]) + words = [] + for o in output_dicts["windows"]: + words.extend(o["words"]) + + return_dict = { + "bbox": bbox, + "words": words, + "attention_mask_layoutxlm": attention_mask, + "are_box_first_tokens": are_box_first_tokens, + "n_empty_windows": n_empty_windows, + } + output_dicts["documents"] = return_dict + + return output_dicts + + +class SBTProcessor(DocKVUProcessor): + def __init__( + self, + tokenizer_layoutxlm, + feature_extractor, + backbone_type, + class_names, + max_window_count, + slice_interval, + window_size, + max_seq_length, + mode, + **kwargs, + ): + super().__init__( + tokenizer_layoutxlm, + feature_extractor, + backbone_type, + class_names, + max_window_count, + slice_interval, + window_size, + max_seq_length, + mode, + **kwargs, + ) + + def __call__(self, lbboxes: list, lwords: list, images, width, height) -> dict: + return super().__call__(lbboxes, lwords, images, width, height) \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/run_ocr.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/run_ocr.py new file mode 100644 index 0000000..8bce812 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/modules/run_ocr.py @@ -0,0 +1,25 @@ +import numpy as np +from pathlib import Path +from typing import Union, Tuple, List +from sdsvkvu.externals.basic_ocr.src.ocr import OcrEngine + + +def load_ocr_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + engine = OcrEngine(**opt) + print("[INFO] Engine loaded") + return engine + + +def process_img(img: Union[str, np.ndarray], engine: OcrEngine) -> List: # For OCR integrated deskew using paddle + page = engine(img) + bboxes = [] + texts = [] + for word_segment in page._word_segments: + for word in word_segment.list_words: + bboxes.append(word.bbox[:]) + texts.append(word.text) + + image = page.deskewed_image if page.deskewed_image is not None else page.image + return bboxes, texts, image + \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/settings.yml b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/settings.yml new file mode 100644 index 0000000..685f9d3 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/settings.yml @@ -0,0 +1,21 @@ +device: "cuda:0" # "cuda:0" +mode: 4 # best option to infer +model: + pretrained_model_path: /mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/microsoft/layoutxlm-base # default: "" + config: /home/sds/tuanlv/02-KVU/03-KVU_sbt/experiments/key_value_understanding_for_sbt-20231121-085847/base.yaml + checkpoint: /home/sds/tuanlv/02-KVU/03-KVU_sbt/experiments/key_value_understanding_for_sbt-20231121-085847/checkpoints/best_model.pth +ocr_engine: + detector: + # version: /home/sds/datnt/mmdetection/wild_receipt_finetune_weights_c_lite.pth + version: /mnt/hdd4T/OCR/datnt/mmdetection/logs/textdet-baseline-Nov3-wildreceiptv4-sdsapv1-mcocr-ssreceipt1_Imei/epoch_100_params.pth + rotator_version: /home/sds/datnt/mmdetection/logs/textdet-with-rotate-20230317/best_bbox_mAP_epoch_30_lite.pth + recognizer: + version: satrn-lite-general-pretrain-20230106 + deskew: + enable: True + text_detector: + config: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/config/det.yaml + weight: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_PP-OCRv3_det_infer + text_cls: + config: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/config/cls.yaml + weight: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_ppocr_mobile_v2.0_cls_infer diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/kvu.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/kvu.py new file mode 100644 index 0000000..1b560b0 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/kvu.py @@ -0,0 +1,73 @@ +import imagesize +from PIL import Image +from pathlib import Path +from omegaconf import OmegaConf + +from sdsvkvu.modules.predictor import KVUPredictor +from sdsvkvu.modules.preprocess import KVUProcessor, DocKVUProcessor, SBTProcessor +from sdsvkvu.modules.run_ocr import load_ocr_engine, process_img +from sdsvkvu.utils.utils import post_process_basic_ocr +from sdsvkvu.sources.utils import revert_scale_bbox, Timer + +DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml" + + +class KVUEngine: + def __init__(self, setting_file: str = DEFAULT_SETTING_PATH, ocr_engine=None, **kwargs) -> None: + configs = OmegaConf.load(setting_file) + for key, param in kwargs.items(): # overwrite default settings by keyword arguments + if key not in configs: + raise ValueError("Invalid setting found in KVUEngine: ", key) + if isinstance(param, dict): + for k, v in param.items(): + if k not in configs[key]: + raise ValueError("Invalid setting found in KVUEngine: ", key, k) + configs[key][k] = v + else: + configs[key] = param + + self.predictor = KVUPredictor(configs) + self._settings, tokenizer_layoutxlm, feature_extractor = self.predictor.get_process_configs() + mode = self._settings.mode + if mode in (0, 1, 2): + self.processor = KVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm, + feature_extractor=feature_extractor, + **self._settings) + elif mode == 3: + self.processor = DocKVUProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm, + feature_extractor=feature_extractor, + **self._settings) + elif mode == 4: + self.processor = SBTProcessor(tokenizer_layoutxlm=tokenizer_layoutxlm, + feature_extractor=feature_extractor, + **self._settings) + else: + raise ValueError(f'[ERROR] Inferencing mode of {mode} is not supported') + + if ocr_engine is None: + print("[INFO] Load internal OCR Engine") + configs.ocr_engine.device = configs.device + self.ocr_engine = load_ocr_engine(configs.ocr_engine) + else: + print("[INFO] Load external OCR Engine") + self.ocr_engine = ocr_engine + + def predict(self, img_path): + lbboxes, lwords, image = process_img(img_path, self.ocr_engine) + lwords = post_process_basic_ocr(lwords) + + if len(lbboxes) == 0: + print("[WARNING] Empty document") + return image, [[]], [[]], [[]], [[]] + + height, width, _ = image.shape + image = Image.fromarray(image) + + inputs = self.processor(lbboxes, lwords, image, width=width, height=height) + + with Timer("kvu"): + lbbox, lwords, pr_class_words, pr_relations = self.predictor.predict(inputs) + + for i in range(len(lbbox)): + lbbox[i] = [revert_scale_bbox(bb, width=width, height=height) for bb in lbbox[i]] + return image, lbbox, lwords, pr_class_words, pr_relations \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/utils.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/utils.py new file mode 100644 index 0000000..53bc12e --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/sources/utils.py @@ -0,0 +1,610 @@ +import os +import cv2 +import copy +import time +import torch +import math +import numpy as np +from typing import Callable +from sdsvkvu.utils.post_processing import get_string, get_string_by_deduplicate_bbox, get_string_with_word2line + +# def get_colormap(): +# return { +# 'others': (0, 0, 255), # others: red +# 'title': (0, 255, 255), # title: yellow +# 'key': (255, 0, 0), # key: blue +# 'value': (0, 255, 0), # value: green +# 'header': (233, 197, 15), # header +# 'group': (0, 128, 128), # group +# 'relation': (0, 0, 255)# (128, 128, 128), # relation +# } + + +class Timer: + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, func: Callable, *args): + self.end_time = time.perf_counter() + self.elapsed_time = self.end_time - self.start_time + print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds") + +def get_colormap(): + return { + "others": (0, 0, 255), # others: red + "title": (0, 255, 255), # title: yellow + "key": (255, 0, 0), # key: blue + "value": (0, 255, 0), # value: green + "header": (233, 197, 15), # header + "group": (0, 128, 128), # group + "relation": (0, 0, 255), # (128, 128, 128), # relation + + # "others": (187, 125, 250), # pink + "seller": (183, 50, 255), # bold pink + "date_key": (128, 51, 115), # orange + "date_value": (55, 250, 250), # yellow + "product_name": (245, 61, 61), # blue + "product_code": (233, 197, 17), # header + "quantity": (102, 255, 102), # green + "sn_key": (179, 134, 89), + "sn_value": (51, 153, 204), + "invoice_number_key": (40, 90, 144), + "invoice_number_value": (162, 239, 204), + "sold_key": (74, 180, 150), + "sold_value": (14, 184, 53), + "voucher": (39, 86, 103), + "website": (207, 19, 85), + "hotline": (153, 224, 56), + # "group": (0, 128, 128), # brown + # "relation": (0, 0, 255), # (128, 128, 128), # red + } + +def convert_image(image): + exif = image._getexif() + orientation = None + if exif is not None: + orientation = exif.get(0x0112) + # Convert the PIL image to OpenCV format + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + # Rotate the image in OpenCV if necessary + if orientation == 3: + image = cv2.rotate(image, cv2.ROTATE_180) + elif orientation == 6: + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + elif orientation == 8: + image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) + else: + image = np.asarray(image) + + if len(image.shape) == 2: + image = np.repeat(image[:, :, np.newaxis], 3, axis=2) + assert len(image.shape) == 3 + + return image, orientation + +def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1): + # image, orientation = convert_image(image) + + # if orientation is not None and orientation == 6: + # width, height, _ = image.shape + # else: + # height, width, _ = image.shape + + if len(pr_class_words) > 0: + id2label = {k: labels[k] for k in range(len(labels))} + for lb, groups in enumerate(pr_class_words): + if lb == 0: + continue + for group_id, group in enumerate(groups): + for i, word_id in enumerate(group): + # x0, y0, x1, y1 = revert_scale_bbox(bbox[word_id], width, height) + x0, y0, x1, y1 = bbox[word_id] + cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness) + + if i == 0: + x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2) + else: + x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2) + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness) + x_center0, y_center0 = x_center1, y_center1 + + if len(pr_relations) > 0: + for pair in pr_relations: + # xyxy0 = revert_scale_bbox(bbox[pair[0]], width, height) + # xyxy1 = revert_scale_bbox(bbox[pair[1]], width, height) + xyxy0 = bbox[pair[0]] + xyxy1 = bbox[pair[1]] + + x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2) + x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2) + + cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness) + + return image + +def revert_scale_bbox(box, width, height): + return [ + int((box[0] / 1000) * width), + int((box[1] / 1000) * height), + int((box[2] / 1000) * width), + int((box[3] / 1000) * height) + ] + + +def draw_kvu_outputs(image: np.ndarray, bbox: list, pr_class_words: list, pr_relations: list, class_names: list = ['others', 'title', 'key', 'value', 'header'], thickness: int = 1): + color_map = get_colormap() + image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness) + if (image.shape[2] == 2): + image = cv2.cvtColor(image, cv2.COLOR_BGR5652BGR) + return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + +def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list: + element_windows = [] + + if len(elements) > window_size: + max_step = math.ceil((len(elements) - window_size)/slice_interval) + + for i in range(0, max_step + 1): + if (i*slice_interval+window_size) >= len(elements): + _window = copy.deepcopy(elements[i*slice_interval:]) + else: + _window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size]) + element_windows.append(_window) + return element_windows + else: + return [elements] + + +def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: + start_pos = 1 + end_pos = start_pos + lvalids[0] + embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) + cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) + sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) + + for i in range(1, len(lpatches)): + start_pos = 1 + end_pos = start_pos + lvalids[i] + + overlap_gap = copy.deepcopy(loverlaps[i-1]) + window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...]) + + if overlap_gap != 0: + prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) + curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) + assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" + + if average: + avg_overlap = ( + prev_overlap + curr_overlap + ) / 2. + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 + ) + else: + embedding_tokens = torch.cat( + [embedding_tokens, window], dim=1 + ) + return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) + + +def parse_initial_words(itc_label, box_first_token_mask, class_names): + itc_label_np = itc_label.cpu().numpy() + box_first_token_mask_np = box_first_token_mask.cpu().numpy() + + outputs = [[] for _ in range(len(class_names))] + + for token_idx, label in enumerate(itc_label_np): + if box_first_token_mask_np[token_idx] and label != 0: + outputs[label].append(token_idx) + + return outputs + + +def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx): + max_connections = 50 + + valid_stc_label = stc_label * attention_mask.bool() + valid_stc_label = valid_stc_label.cpu().numpy() + stc_label_np = stc_label.cpu().numpy() + + valid_token_indices = np.where( + (valid_stc_label != dummy_idx) * (valid_stc_label != 0) + ) + + next_token_idx_dict = {} + for token_idx in valid_token_indices[0]: + next_token_idx_dict[stc_label_np[token_idx]] = token_idx + + outputs = [] + for init_token_indices in init_words: + sub_outputs = [] + for init_token_idx in init_token_indices: + cur_token_indices = [init_token_idx] + for _ in range(max_connections): + if cur_token_indices[-1] in next_token_idx_dict: + if ( + next_token_idx_dict[cur_token_indices[-1]] + not in init_token_indices + ): + cur_token_indices.append( + next_token_idx_dict[cur_token_indices[-1]] + ) + else: + break + else: + break + sub_outputs.append(tuple(cur_token_indices)) + + outputs.append(sub_outputs) + + return outputs + +def parse_relations(el_label, box_first_token_mask, dummy_idx): + valid_el_labels = el_label * box_first_token_mask + valid_el_labels = valid_el_labels.cpu().numpy() + el_label_np = el_label.cpu().numpy() + + max_token = box_first_token_mask.shape[0] - 1 + + valid_token_indices = np.where( + ((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ### + ) + + link_map_tuples = [] + for token_idx in valid_token_indices[0]: + link_map_tuples.append((el_label_np[token_idx], token_idx)) + + return set(link_map_tuples) + + +def get_pairs(json: list, rel_from: str, rel_to: str) -> dict: + outputs = {} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] in (rel_from, rel_to): + is_rel[element['class']]['status'] = 1 + is_rel[element['class']]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']] + return outputs + +def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict: + list_keys = list(header_key_pairs.keys()) + relations = {k: [] for k in list_keys} + for pair in json: + is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}} + for element in pair: + if element['class'] == rel_from and element['group_id'] in list_keys: + is_rel[rel_from]['status'] = 1 + is_rel[rel_from]['value'] = element + if element['class'] == rel_to: + is_rel[rel_to]['status'] = 1 + is_rel[rel_to]['value'] = element + if all([v['status'] == 1 for _, v in is_rel.items()]): + relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id']) + return relations + +def get_key2values_relations(key_value_pairs: dict): + triple_linkings = {} + for value_group_id, key_value_pair in key_value_pairs.items(): + key_group_id = key_value_pair[0] + if key_group_id not in list(triple_linkings.keys()): + triple_linkings[key_group_id] = [] + triple_linkings[key_group_id].append(value_group_id) + return triple_linkings + + +def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list: + points = [lbbox[i] for i in lword_ids] + x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1] + x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3] + return [x_min, y_min, x_max, y_max] + + +def merged_token_to_wordgroup(class_words: list, lwords: list, lbbox: list, labels: list) -> dict: + word_groups = {} + id2class = {i: labels[i] for i in range(len(labels))} + for class_id, lwgroups_in_class in enumerate(class_words): + for ltokens_in_wgroup in lwgroups_in_class: + group_id = ltokens_in_wgroup[0] + ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup] + ltokens_to_lbboxes = [lbbox[token] for token in ltokens_in_wgroup] + # text_string = get_string(ltokens_to_ltexts) + text_string = get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes) + # text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes) + group_bbox = get_wordgroup_bbox(lbbox, ltokens_in_wgroup) + word_groups[group_id] = { + 'group_id': group_id, + 'text': text_string, + 'class': id2class[class_id], + 'tokens': ltokens_in_wgroup, + 'bbox': group_bbox + } + return word_groups + +def verify_linking_id(word_groups: dict, linking_id: int) -> int: + if linking_id not in list(word_groups): + for wg_id, _word_group in word_groups.items(): + if linking_id in _word_group['tokens']: + return wg_id + return linking_id + +def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list: + outputs = [] + for pair in lrelations: + wg_from = verify_linking_id(word_groups, pair[0]) + wg_to = verify_linking_id(word_groups, pair[1]) + try: + outputs.append([word_groups[wg_from], word_groups[wg_to]]) + except Exception as e: + print('Not valid pair:', wg_from, wg_to) + return outputs + + +def get_single_entity(word_groups: dict, lrelations: list, labels: list) -> list: + # single_entity = {'title': [], 'key': [], 'value': [], 'header': []} + single_entity = {lb: [] for lb in labels} + list_linked_ids = [] + for pair in lrelations: + list_linked_ids.extend(pair) + + for word_group_id, word_group in word_groups.items(): + if word_group_id not in list_linked_ids: + single_entity[word_group['class']].append(word_group) + return single_entity + + +def export_kvu_outputs(lwords, lbbox, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']): + word_groups = merged_token_to_wordgroup(class_words, lwords, lbbox, labels) + linking_pairs = matched_wordgroup_relations(word_groups, lrelations) + + header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]} + header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]} + key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]} + + single_entity = get_single_entity(word_groups, lrelations, labels=labels) + + # table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]} + + triplet_pairs = [] + single_pairs = [] + table = [] + # print('key2values_relations', key2values_relations) + for key_group_id, list_value_group_ids in key2values_relations.items(): + if len(list_value_group_ids) == 0: continue + elif len(list_value_group_ids) == 1: + value_group_id = list_value_group_ids[0] + single_pairs.append({word_groups[key_group_id]['text']: { + 'id': value_group_id, + 'class': "value", + 'text': word_groups[value_group_id]['text'], + 'bbox': word_groups[value_group_id]['bbox'], + "key_bbox": word_groups[key_group_id]["bbox"], + }}) + else: + item = [] + for value_group_id in list_value_group_ids: + if value_group_id not in header_value.keys(): + header_name_for_value = "non-header" + else: + header_group_id = header_value[value_group_id][0] + header_name_for_value = word_groups[header_group_id]['text'] + item.append({ + 'id': value_group_id, + 'class': 'value', + 'header': header_name_for_value, + 'text': word_groups[value_group_id]['text'], + 'bbox': word_groups[value_group_id]['bbox'], + "key_bbox": word_groups[key_group_id]["bbox"], + "header_bbox": word_groups[header_group_id]["bbox"] + if header_group_id != -1 else [0, 0, 0, 0], + }) + if key_group_id not in list(header_key.keys()): + triplet_pairs.append({ + word_groups[key_group_id]['text']: item + }) + else: + header_group_id = header_key[key_group_id][0] + header_name_for_key = word_groups[header_group_id]['text'] + item.append({ + 'id': key_group_id, + 'class': 'key', + 'header': header_name_for_key, + 'text': word_groups[key_group_id]['text'], + 'bbox': word_groups[key_group_id]['bbox'], + "key_bbox": word_groups[key_group_id]["bbox"], + "header_bbox": word_groups[header_group_id]["bbox"], + }) + table.append({key_group_id: item}) + + + # Add entity without linking + single_entity_dict = {} + for class_name, single_items in single_entity.items(): + single_entity_dict[class_name] = [] + for single_item in single_items: + single_entity_dict[class_name].append({ + 'text': single_item['text'], + 'id': single_item['group_id'], + 'class': class_name, + 'bbox': single_item['bbox'] + }) + + + if len(table) > 0: + table = sorted(table, key=lambda x: list(x.keys())[0]) + table = [v for item in table for k, v in item.items()] + + outputs = {} + outputs['title'] = sorted( + single_entity_dict["title"], key=lambda x: x["id"] + ) + outputs['key'] = sorted( + single_entity_dict["key"], key=lambda x: x["id"] + ) + outputs['value'] = sorted( + single_entity_dict["value"], key=lambda x: x["id"] + ) + outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id']))) + outputs['triplet'] = triplet_pairs + outputs['table'] = table + return outputs + + + +def export_sbt_outputs( + lwords, + lbboxes, + class_words, + lrelations, + labels, +): + word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels) + linking_pairs = matched_wordgroup_relations(word_groups, lrelations) + + date_key_value_pairs = get_pairs( + linking_pairs, rel_from="date_key", rel_to="date_value" + ) # => {date_value_group_id: [date_key_group_id, date_value_group_id]} + # product_name_code_pairs = get_pairs( + # linking_pairs, rel_to="product_name", rel_from="product_code" + # ) # => {product_name_group_id: [product_code_group_id, product_name_group_id]} + # product_name_quantity_pairs = get_pairs( + # linking_pairs, rel_to="product_name", rel_from="quantity" + # ) # => {product_name_group_id: [quantity_group_id, product_name_group_id]} + serial_key_value_pairs = get_pairs( + linking_pairs, rel_from="sn_key", rel_to="sn_value" + ) # => {sn_value_group_id: [sn_key_group_id, sn_value_group_id]} + + sold_key_value_pairs = get_pairs( + linking_pairs, rel_from="sold_key", rel_to="sold_value" + ) # => {sold_value_group_id: [sold_key_group_id, sold_value_group_id]} + + single_entity = get_single_entity(word_groups, lrelations, labels=labels) + + date_value = [] + sold_value = [] + serial_imei = [] + table = [] + # print('key2values_relations', key2values_relations) + date_relations = get_key2values_relations(date_key_value_pairs) + for key_group_id, list_value_group_id in date_relations.items(): + for value_group_id in list_value_group_id: + date_value.append( + { + "text": word_groups[value_group_id]["text"], + "id": value_group_id, + "class": "date_value", + "bbox": word_groups[value_group_id]["bbox"], + "key_bbox": word_groups[key_group_id]["bbox"], + "raw_key_name": word_groups[key_group_id]["text"], + } + ) + + sold_relations = get_key2values_relations(sold_key_value_pairs) + for key_group_id, list_value_group_id in sold_relations.items(): + for value_group_id in list_value_group_id: + sold_value.append( + { + "text": word_groups[value_group_id]["text"], + "id": value_group_id, + "class": "sold_value", + "bbox": word_groups[value_group_id]["bbox"], + "key_bbox": word_groups[key_group_id]["bbox"], + "raw_key_name": word_groups[key_group_id]["text"], + } + ) + + + serial_relations = get_key2values_relations(serial_key_value_pairs) + for key_group_id, list_value_group_id in serial_relations.items(): + for value_group_id in list_value_group_id: + serial_imei.append( + { + "text": word_groups[value_group_id]["text"], + "id": value_group_id, + "class": "sn_value", + "bbox": word_groups[value_group_id]["bbox"], + "key_bbox": word_groups[key_group_id]["bbox"], + "raw_key_name": word_groups[key_group_id]["text"], + } + ) + + + single_entity_dict = {} + for class_name, single_items in single_entity.items(): + single_entity_dict[class_name] = [] + for single_item in single_items: + single_entity_dict[class_name].append( + { + "text": single_item["text"], + "id": single_item["group_id"], + "class": class_name, + "bbox": single_item["bbox"], + } + ) + + # list_product_name_group_ids = set( + # list(product_name_code_pairs.keys()) + # + list(product_name_quantity_pairs.keys()) + # + [x["id"] for x in single_entity_dict["product_name"]] + # ) + # for product_name_group_id in list_product_name_group_ids: + # item = {"productname": [], "modelnumber": [], "qty": []} + # item["productname"].append( + # { + # "text": word_groups[product_name_group_id]["text"], + # "id": product_name_group_id, + # "class": "product_name", + # "bbox": word_groups[product_name_group_id]["bbox"], + # } + # ) + # if product_name_group_id in product_name_code_pairs: + # product_code_group_id = product_name_code_pairs[product_name_group_id][0] + # item["modelnumber"].append( + # { + # "text": word_groups[product_code_group_id]["text"], + # "id": product_code_group_id, + # "class": "product_code", + # "bbox": word_groups[product_code_group_id]["bbox"], + # } + # ) + # if product_name_group_id in product_name_quantity_pairs: + # quantity_group_id = product_name_quantity_pairs[product_name_group_id][0] + # item["qty"].append( + # { + # "text": word_groups[quantity_group_id]["text"], + # "id": quantity_group_id, + # "class": "quantity", + # "bbox": word_groups[quantity_group_id]["bbox"], + # } + # ) + # table.append(item) + + # if len(table) > 0: + # table = sorted(table, key=lambda x: x["productname"][0]["id"]) + + if len(serial_imei) > 0: + serial_imei = sorted(serial_imei, key=lambda x: x["id"]) + + outputs = {} + outputs["seller"] = single_entity_dict["seller"] + outputs["voucher"] = single_entity_dict["voucher"] + outputs["website"] = single_entity_dict["website"] + outputs["hotline"] = single_entity_dict["hotline"] + outputs["sold_value"] = sold_value + single_entity_dict["sold_key"] + single_entity_dict["sold_value"] + outputs["date_value"] = date_value + single_entity_dict["date_value"] + single_entity_dict["date_key"] + outputs["serial_imei"] = serial_imei + single_entity_dict["sn_value"] + single_entity_dict["sn_key"] + # outputs["table"] = table + return outputs diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/list_retailers.txt b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/list_retailers.txt new file mode 100644 index 0000000..dbce48a --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/list_retailers.txt @@ -0,0 +1,167 @@ +1 FUSION TELECOM +3 Mobile +A STAR MOBILE TRADING +A2Z +ACES MOBILE +ARS DIGITAL WORLD +AV Intelligence +AZ Telecommunication +Active Electronics +Addon Systems +Alpha Telecom +Amazon.sg +Arrow Communication +Audio Ace Electronics +Audio House +BEST DIGITAL LIFESTYLE +BLAZING HANDPHONE SHOP +Best Denki +Best Denki - Great World +Best Denki - Ngee Ann +Best Denki - Online shop +Best Denki - Vivo +C. K. Tang +CALLMOBILE +Hachi +Challenger +Challenger Corporate +Circles.Life +Courts +Courts Heeren +Courts Megastore +Cyber Jip +DV Tech +Easytone +Everjoint Electrical +GIANT MOBILE +GRAPES COMMUNICATION +Gadget Affair +Gain City +Gain City Best-Electric +Gain City - Sungei Kadut +Gain City - Marina Square +Gain City - Sungei Kadut +Garphil Enterprise +Goh Ah Bee +HI MOBILE +Han's Communication +Handphone Shop +Harvey Norman +Harvey Norman - Millenia Walk +Harvey Norman - Pertama Merchandising +Harvey Norman - Millenia Walk +Harvey Norman - Northpoint +Hi-Life +I.COMM MOBILE PLUS +ING Mobile +IT Mobile  +IT TALENT TRADING +Ingram Micro +Isetan +Ivan Mobile & Jewelleries +J2 MOBILE STORE +KASIA Mobile +Kong Tai +Kris Shop +Lazada +Lazada - Samsung Brand Store +Lion City Company +Lucky Store +M1 Exclusive Partners +M1 Shop +MAGNA MOBILE +MEGA TELESHOP +MELA SHOPPE +MG MOBILE COMMUNICATION +MOBILE SQUARE +MOBILE X +MOBILEHUB SERVICE +MOBILERELATION 1 +MOBY +MOHAMED MUSTAFA & SAMSUDDIN +MY MOBILE HOUSE +Magnify +Mega Discount Store +My Mobile +My Mobile House +MyRepublic +NAIN INTERNATIONAL TRADING +NARANJAN ELECTRONICS +NTUC +Naranjan Int Mobile +New Sound Electrical Dept +One Dream Telecom +One2Free Mobile +Onephone Online +PHONEVIBES +POD CONTACT TELECOMMUNICATIONS +PROVIDER STORES +Parisilk Electronics & Computers +Planet Telecoms +Pod Contact +Poorvika (TV) +Popular Book +Provider Stores +RED WHITE MOBILE +REMO COMM +Red White Mobile +Rigel Telecom +SK Mobile +SMART PLAY TRADING +SMART TECH MOBILE +SOLULAR PLUS +SONIC CONNECTION ENTERPRISES +SPRINT - CASS (HQ) +SUMMER TELECOM +Lazada Samsung Brand Store +Shopee Samsung Brand Store +Lazada Samsung Certified Store +Shopee Samsung Certified Store +Samsung Brand Store +Samsung Customer Service Plaza Singapura +Samsung EDU Store +Samsung EPP Store +Samsung Experience Store +Samsung Experience Store - 313 Somerset +Samsung Experience Store - Bedok Mall +Samsung Experience Store - Bugis Junction +Samsung Experience Store - Causeway Point +Samsung Experience Store - ION +Samsung Experience Store - Jurong Point +Samsung Experience Store - Nex +Samsung Experience Store - Northpoint City +Samsung Experience Store - Tampines Mall +Samsung Experience Store - VivoCity +Samsung Experience Store - Westgate +Samsung Experience Store - delivery +Samsung Experience Store - pickup +Samsung Online Shop +Samsung Online Shop +Singtel Exclusive Retailers +Singtel Shop +Sprint-Cass +StarHub Exclusive Partners +StarHub Shop +T2 Electronics +TANGS +Takashimaya +Takashimaya Singapore +Telemobile +Telestation Infocomm +Trends n Trendies +U MOBILE SHOP +U-First Mobile +UNITED MOBILE SERVICES +UWIN - CHEERS COMMUNICATIONS +UWIN COMMUNICATION +Univercell +Urban Republic +VMCS Pte Ltd +VV MOBILE +Vi Mobile +WEN MOBILE TRADING +YOSHI MOBILE +ZYM Official Store +Zalora +i.Comm Mobile +iShopChangi diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/manulife.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/manulife.py new file mode 100644 index 0000000..6d172ae --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/manulife.py @@ -0,0 +1,69 @@ +key_dict = { + "Document type": ["documenttype", "loaichungtu"], + "Document name": ["documentname", "tenchungtu"], + "Patient Name": ["patientname", "tenbenhnhan"], + "Date of Birth/Year of birth": [ + "dateofbirth", + "yearofbirth", + "ngaythangnamsinh", + "namsinh", + ], + "Age": ["age", "tuoi"], + "Gender": ["gender", "gioitinh"], + "Social insurance card No.": ["socialinsurancecardno", "sothebhyt"], + "Medical service provider name": ["medicalserviceprovidername", "tencosoyte"], + "Department name": ["departmentname", "tenkhoadieutri"], + "Diagnosis description": ["diagnosisdescription", "motachandoan"], + "Diagnosis code": ["diagnosiscode", "machandoan"], + "Admission date": ["admissiondate", "ngaynhapvien"], + "Discharge date": ["dischargedate", "ngayxuatvien"], + "Treatment method": ["treatmentmethod", "phuongphapdieutri"], + "Treatment date": ["treatmentdate", "ngaydieutri", "ngaykham"], + "Follow up treatment date": ["followuptreatmentdate", "ngaytaikham"], + # "Name of precribed medicine": [], + # "Quantity of prescribed medicine": [], + # "Dosage for each medicine": [] + "Medical expense": ["Medical expense", "chiphiyte"], + "Invoice No.": ["invoiceno", "sohoadon"], +} + +title_dict = { + "Chứng từ y tế": [ + "giayravien", + "giaychungnhanphauthuat", + "cachthucphauthuat", + "phauthuat", + "tomtathosobenhan", + "donthuoc", + "toathuoc", + "donbosung", + "ketquaconghuongtu" + "ketqua", + "phieuchidinh", + "phieudangkykham", + "giayhenkhamlai", + "phieukhambenh", + "phieukhambenhvaovien", + "phieuxetnghiem", + "phieuketquaxetnghiem", + "phieuchidinhxetnghiem", + "ketquasieuam", + "phieuchidinhxetnghiem" + ], + "Chứng từ thanh toán": [ + "hoadon", + "hoadongiatrigiatang", + "hoadongiatrigiatangchuyendoituhoadondientu", + "bangkechiphibaohiem", + "bienlaithutien", + "bangkechiphidieutrinoitru" + ], +} + +def get_dict(type: str): + if type == "key": + return key_dict + elif type == "title": + return title_dict + else: + raise ValueError(f"[ERROR] Dictionary type of {type} is not supported") \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt.py new file mode 100644 index 0000000..edb33f8 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt.py @@ -0,0 +1,32 @@ +header_dict = { + 'productname': ['description', 'productdescription', 'articledescription', 'descriptionofgood', 'itemdescription', + 'brandmodel', 'itemdepartment', 'departmentbrand', 'department', 'certificateno', + 'product', 'modelname', 'paticulars', 'device', 'items', 'itemno'], + 'modelnumber': ['serialno', 'serial', 'articles', 'simimeiserial', 'article', 'articlenumber', 'articleidmaterialcode', + 'itemcode', 'code', 'mcode', 'productcode', 'model', 'product', 'imeiccid', 'transaction'], + 'qty': ['quantity', 'invoicequantity'] +} + +key_dict = { + 'purchase_date': ['date', 'purchasedate', 'datetime', 'orderdate', 'orderdatetime', 'invoicedate', 'dateredeemed', 'issuedate', 'billingdocdate'], + 'retailername': ['retailer', 'retailername', 'ownedoperatedby'], + 'serial_number': ['serialnumber', 'serialno'], + 'imei_number': ['imeiesim', 'imeislot1', 'imeislot2', 'imei', 'imei1', 'imei2'] +} + +extra_dict = { + 'serial_number': ['sn'], + 'imei_number': ['imel', 'imed'], + 'modelnumber': ['sku', 'sn', 'imei'], + 'qty': ['qty'] +} + +def get_dict(type: str): + if type == "key": + return key_dict + elif type == "header": + return header_dict + elif type == "extra": + return extra_dict + else: + raise ValueError(f'[ERROR] Dictionary type of {type} is not supported') \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt_v2.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt_v2.py new file mode 100644 index 0000000..ad3141e --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/sbt_v2.py @@ -0,0 +1,116 @@ +import os +from pathlib import Path + +cur_dir = str(Path(__file__).parents[0]) +from sdsvkvu.utils.utils import read_txt +from sdsvkvu.utils.post_processing import preprocessing + + +header_dict = { + "productname": [ + "description", + "productdescription", + "articledescription", + "descriptionofgood", + "itemdescription", + "brandmodel", + "itemdepartment", + "departmentbrand", + "department", + "certificateno", + "product", + "modelname", + "paticulars", + "device", + "items", + "itemno", + ], + "modelnumber": [ + "serialno", + "serial", + "articles", + "simimeiserial", + "article", + "articlenumber", + "articleidmaterialcode", + "itemcode", + "code", + "mcode", + "productcode", + "model", + "product", + "imeiccid", + "transaction", + ], + "qty": ["quantity", "invoicequantity"], +} + +date_dict = { + "purchase_date": [ + "date", + "purchasedate", + "datetime", + "orderdate", + "orderdatetime", + "invoicedate", + "dateredeemed", + "issuedate", + "billingdocdate", + "placedon", + "transactiondatetime", + "creationdate", + "ordertime", + "dateofissue", + ] +} + +imei_dict = { + "serial_number": ["serialnumber", "serialno"], + "imei_number": ["imeiesim", "imeislot1", "imeislot2", "imei", "imei1", "imei2"], +} + +sold_dict = { + "sold_by_party": ["soldtoparty"], + "sold_by": ["soldby"] +} + +extra_dict = { + "serial_number": ["sn"], + "imei_number": ["imel", "imed"], + "modelnumber": ["sku", "sn", "imei"], + "qty": ["qty"], +} + +seller_mapping = { + "Samsung Experience Store": ["G-FORCE NETWORK PTE LTD", "eSmart Mobile", "eSmart Mobile Pte Ltd"], + "Samsung Online Store": ["SAMSUNG ELECTRONICS SINGAPORE PTE LTD"], + "Samsung Brand Store": ["SAMSUNG OFFICIAL STORE"], + "Harvey Norman": ["PERTAMA MERCHANDISING PTE LTD"], + "Shopee": ["shopee mall"], + "Lazada": ["laz mall", "lazmall"], + "LTD": ["limited"], + "PTE": ["private"], +} + + +seller_list = read_txt(os.path.join(cur_dir, "list_retailers.txt")) +seller_dict = {seller.upper(): [preprocessing(seller)] for seller in seller_list} + + +def get_dict(type: str): + if type == "date": + return date_dict + elif type == "imei": + return imei_dict + elif type == "sold_by": + return sold_dict + elif type == "header": + return header_dict + elif type == "extra": + return extra_dict + elif type == "seller": + return seller_dict + elif type == "seller_mapping": + return seller_mapping + else: + raise ValueError(f"[ERROR] Dictionary type of {type} is not supported") diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vat.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vat.py new file mode 100644 index 0000000..8945200 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vat.py @@ -0,0 +1,69 @@ +DKVU2XML = { + "Ký hiệu mẫu hóa đơn": "form_no", + "Ký hiệu hóa đơn": "serial_no", + "Số hóa đơn": "invoice_no", + "Ngày, tháng, năm lập hóa đơn": "issue_date", + "Tên người bán": "seller_name", + "Mã số thuế người bán": "seller_tax_code", + "Thuế suất": "tax_rate", + "Thuế GTGT đủ điều kiện khấu trừ thuế": "VAT_input_amount", + "Mặt hàng": "item", + "Đơn vị tính": "unit", + "Số lượng": "quantity", + "Đơn giá": "unit_price", + "Doanh số mua chưa có thuế": "amount" +} + +key_dict = { + 'Ký hiệu mẫu hóa đơn': ['mausoformno', 'mauso'], + 'Ký hiệu hóa đơn': ['kyhieuserialno', 'kyhieuserial', 'kyhieu'], + 'Số hóa đơn': ['soinvoiceno', 'invoiceno'], + 'Ngày, tháng, năm lập hóa đơn': [], + 'Tên người bán': ['donvibanseller', 'donvibanhangsalesunit', 'donvibanhangseller', 'kyboisignedby'], + 'Mã số thuế người bán': ['masothuetaxcode', 'maxsothuetaxcodenumber', 'masothue'], + 'Thuế suất': ['thuesuatgtgttaxrate', 'thuesuatgtgt'], + 'Thuế GTGT đủ điều kiện khấu trừ thuế': ['tienthuegtgtvatamount', 'tienthuegtgt'], + # 'Ghi chú': [], + # 'Ngày': ['ngayday', 'ngay', 'day'], + # 'Tháng': ['thangmonth', 'thang', 'month'], + # 'Năm': ['namyear', 'nam', 'year'] +} + +header_dict = { + 'Mặt hàng': ['tenhanghoa,dichvu', 'danhmuc,dichvu', 'dichvusudung', 'tenquycachhanghoa','description', 'descriptionofgood', 'itemdescription'], + 'Đơn vị tính': ['donvitinh', 'dvtunit'], + 'Số lượng': ['soluong', 'quantity', 'invoicequantity', 'soluongquantity'], + 'Đơn giá': ['dongia', 'dongiaprice'], + 'Doanh số mua chưa có thuế': ['thanhtien', 'thanhtientruocthuegtgt', 'tienchuathue'], +} + +extra_dict = { + 'Số hóa đơn': ['sono', 'so'], + 'Mã số thuế người bán': ['mst'], + 'Tên người bán': ['kyboi'], + 'Ngày, tháng, năm lập hóa đơn': ['kyngay', 'kyngaydate'], + 'Mặt hàng': ['tenhanghoa','sanpham'], + 'Số lượng': ['sl', 'qty'], + 'Đơn vị tính': ['dvt'], +} + +date_dict = { + 'day': ['ngayday', 'ngaydate', 'ngay', 'day'], + 'month': ['thangmonth', 'thang', 'month'], + 'year': ['namyear', 'nam', 'year'] +} + +def get_dict(type: str): + if type == "key": + return key_dict + elif type == "header": + return header_dict + elif type == "extra": + return extra_dict + elif type == "date": + return date_dict + elif type == "kvu2xml": + return DKVU2XML + else: + raise ValueError(f'[ERROR] Dictionary type of {type} is not supported') + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vtb.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vtb.py new file mode 100644 index 0000000..a4dd15b --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/dictionary/vtb.py @@ -0,0 +1,33 @@ +key_dict = { + "number": ["kemtheoquyetdinhso", "quyetdinhso"], + "title": [], + "date": [], + "signee": ['botruong', 'thutruong', 'giamdoc', 'phogiamdoc', 'chunhiem', 'phochunhiem', + 'hieutruong', 'vientruong', 'thuky', 'chutich', 'phochutich', 'bithu', 'chutoa', + 'daidien', 'truongban', 'tongcuctruong', 'photongcuctruong', 'cuctruong', 'cucpho', + 'thuky', 'chanhthanhtra', 'thutruongdonvi', 'thutuong', + 'kiemtoanvien', 'canbokekhai'], + "sender": ['kinhgui'], + "receiver": ['noinhan', 'noigui'] + } + +extra_dict = { + "number": ['so'], + # "sender": ['dien'] +} + +date_dict = { + "day": ['ngayday', 'ngaydate', 'ngay', 'day'], + "month": ['thangmonth', 'thang', 'month'], + "year": ['namyear', 'nam', 'year'] +} + +def get_dict(type: str): + if type == "key": + return key_dict + elif type == "extra": + return extra_dict + elif type == "date": + return date_dict + else: + raise ValueError(f'[ERROR] Dictionary type of {type} is not supported') \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/post_processing.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/post_processing.py new file mode 100644 index 0000000..14b6bec --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/post_processing.py @@ -0,0 +1,362 @@ +import re +import nltk +import string +import tldextract +from dateutil import parser +from datetime import datetime +# nltk.download('words') +try: + nltk.data.find("corpora/words") +except LookupError: + nltk.download('words') +words = set(nltk.corpus.words.words()) + +from sdsvkvu.utils.word2line import Word, words_to_lines + +s1 = u'ÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝàáâãèéêìíòóôõùúýĂăĐđĨĩŨũƠơƯưẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ' +s0 = u'AAAAEEEIIOOOOUUYaaaaeeeiioooouuyAaDdIiUuOoUuAaAaAaAaAaAaAaAaAaAaAaAaEeEeEeEeEeEeEeEeIiIiOoOoOoOoOoOoOoOoOoOoOoOoUuUuUuUuUuUuUuYyYyYyYy' + +# def clean_text(text): +# return re.sub(r"[^A-Za-z(),!?\'\`]", " ", text) + +def get_string(lwords: list): + unique_list = [] + for item in lwords: + if item.isdigit() and len(item) == 1: + unique_list.append(item) + elif item not in unique_list: + unique_list.append(item) + return ' '.join(unique_list) + +def get_string_by_deduplicate_bbox(lwords: list, lbboxes: list): + unique_list = [] + prev_bbox = [-1, -1, -1, -1] + for word, bbox in zip(lwords, lbboxes): + if bbox != prev_bbox: + unique_list.append(word) + prev_bbox = bbox + return ' '.join(unique_list) + +def get_string_with_word2line(lwords: list, lbboxes: list): + list_words = [] + unique_list = [] + list_sorted_words = [] + + prev_bbox = [-1, -1, -1, -1] + for word, bbox in zip(lwords, lbboxes): + if bbox != prev_bbox: + prev_bbox = bbox + list_words.append(Word(image=None, text=word, conf_cls=-1, bndbox=bbox, conf_detect=-1)) + unique_list.append(word) + + llines = words_to_lines(list_words)[0] + + for line in llines: + for _word_group in line.list_word_groups: + for _word in _word_group.list_words: + list_sorted_words.append(_word.text) + + string_from_model = ' '.join(unique_list) + string_after_word2line = ' '.join(list_sorted_words) + + # if string_from_model != string_after_word2line: + # print("[Warning] Word group from model is different with word2line module") + # print("Model: ", ' '.join(unique_list)) + # print("Word2line: ", ' '.join(list_sorted_words)) + + return string_after_word2line + +def date_regexing(inp_str): + patterns = { + 'ngay': r"ngày\d+", + 'thang': r"tháng\d+", + 'nam': r"năm\d+" + } + inp_str = inp_str.replace(" ", "").lower() + outputs = {k: '' for k in patterns} + for key, pattern in patterns.items(): + matches = re.findall(pattern, inp_str) + if len(matches) > 0: + element = set([match[len(key):] for match in matches]) + outputs[key] = list(element)[0] + return outputs['ngay'], outputs['thang'], outputs['nam'] + + +def parse_date1(date_str): + # remove space + date_str = re.sub(r"[\[\]\{\}\(\)\.\,]", " ", date_str) + date_str = re.sub(r"/\s+", "/", date_str) + date_str = re.sub(r"-\s+", "-", date_str) + + is_parser_error = False + try: + date_obj = parser.parse(date_str, fuzzy=True) + year_str = str(date_obj.year) + day_str = str(date_obj.day) + # date_formated = date_obj.strftime("%d/%m/%Y") + date_formated = date_obj.strftime("%Y-%m-%d") + except Exception as err: + # date_str = sorted(date_str.split(" "), key=lambda x: len(x), reverse=True)[0] + # date_str, is_match = date_regexing(date_str) + is_match = False + if is_match: + date_formated = date_str + is_parser_error = False + return date_formated, is_parser_error + else: + print(f"Error parse date: err = {err}, date = {date_str}") + date_formated = date_str + is_parser_error = True + return date_formated, is_parser_error + + if len(normalize_number(date_str)) == 6: + year_str = year_str[-2:] + try: + year_index = date_str.index(str(year_str)) + day_index = date_str.index(str(day_str)) + if year_index > day_index: + date_obj = parser.parse(date_str, fuzzy=True, dayfirst=True) + + # date_formated = date_obj.strftime("%d/%m/%Y") + date_formated = date_obj.strftime("%Y-%m-%d") + except Exception as err: + print(f"Error check dayfirst: err = {err}, date = {date_str}") + + return date_formated, is_parser_error + + +def parse_date(date_str): + # remove space + date_str = re.sub(r"[\[\]\{\}\(\)\.\,]", " ", date_str) + date_str = re.sub(r"/\s+", "/", date_str) + date_str = re.sub(r"-\s+", "-", date_str) + date_str = re.sub(r"\-+", "-", date_str) + date_str = date_str.lower().replace("0ct", "oct") + + is_parser_error = False + try: + date_obj = parser.parse(date_str, fuzzy=True) + except Exception as err: + print(f"1.Error parse date: err = {err}, date = {date_str}") + try: + date_str = sorted(date_str.split(" "), key=lambda x: len(x), reverse=True)[0] + date_obj = parser.parse(date_str, fuzzy=True) + except Exception as err: + print(f"2.Error parse date: err = {err}, date = {date_str}") + is_parser_error = True + return [date_str], is_parser_error + + year_str = int(date_obj.year) + month_str = int(date_obj.month) + day_str = int(date_obj.day) + + current_year = int(datetime.now().year) + if year_str > current_year or year_str < 2010: # invalid year + date_obj = date_obj.replace(year=current_year) + + formated_date = date_obj.strftime("%Y-%m-%d") + revert_formated_date = date_obj.strftime("%Y-%d-%m") + + if any(txt in date_str for txt in ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec']): + return [formated_date], is_parser_error + if month_str <= 12 and day_str <= 12: + return [formated_date, revert_formated_date], is_parser_error + return [formated_date], is_parser_error + + + +def normalize_imei(imei): + imei = imei.replace(" ", "") + imei = imei.split("/")[0] + return imei + +def normalize_seller(seller): + # if isinstance(seller, str): + # seller = seller + return seller + +def normalize_website(website): + if isinstance(website, str): + # website = website.lower().replace("www.", "").replace("ww.", "").replace(".com", "") + website = website.lower() + website = website.split(".com")[0] + website = tldextract.extract(website).domain + return website + +def normalize_hotline(hotline): + if isinstance(hotline, str): + hotline = hotline.lower().replace("hotline", "") + return hotline + +def normalize_voucher(voucher): + if isinstance(voucher, str): + voucher = voucher.lower().replace("voucher", "") + return voucher + + + +def normalize_number( + text_str: str, reserve_dot=False, reserve_plus=False, reserve_minus=False +): + """ + Normalize a string of numbers by removing non-numeric characters + + """ + assert isinstance(text_str, str), "input must be str" + reserver_chars = "" + if reserve_dot: + reserver_chars += ".," + if reserve_plus: + reserver_chars += "+" + if reserve_minus: + reserver_chars += "-" + regex_fomula = "[^0-9{}]".format(reserver_chars) + normalized_text_str = re.sub(r"{}".format(regex_fomula), "", text_str) + return normalized_text_str + +def remove_bullet_points_and_punctuation(text): + # Remove bullet points (e.g., • or -) + text = re.sub(r'^\s*[\•\-\*]\s*', '', text, flags=re.MULTILINE) + text = text.strip() + # # Remove end-of-sentence punctuation (e.g., ., !, ?) + # text = re.sub(r'[.!?]', '', text) + if len(text) > 0 and text[0] in (',', '.', ':', ';', '?', '!'): + text = text[1:] + if len(text) > 0 and text[-1] in (',', '.', ':', ';', '?', '!'): + text = text[:-1] + return text.strip() + +def split_key_value_by_colon(key: str, value: str) -> list: + key_string = key if key is not None else "" + value_string = value if value is not None else "" + text_string = key_string + " " + value_string + elements = text_string.split(":") + if len(elements) > 1 and not bool(re.search(r'\d', elements[0])): + return elements[0], text_string[len(elements[0])+1 :].strip() + return key, value + + +def is_string_in_range(s): + try: + num = int(s) + return 0 <= num <= 9 + except ValueError: + return False + +def remove_english_words(text): + _word = [w.lower() for w in nltk.wordpunct_tokenize(text) if w.lower() not in words] + return ' '.join(_word) + +def remove_punctuation(text): + return text.translate(str.maketrans(" ", " ", string.punctuation)) + +def remove_accents(input_str, s0, s1): + s = '' + # print input_str.encode('utf-8') + for c in input_str: + if c in s1: + s += s0[s1.index(c)] + else: + s += c + return s + +def remove_spaces(text): + return text.replace(' ', '') + +def preprocessing(text: str): + # text = remove_english_words(text) if table else text + text = remove_punctuation(text) + text = remove_accents(text, s0, s1) + text = remove_spaces(text) + return text.lower() + +def longestCommonSubsequence(text1: str, text2: str) -> int: + # https://leetcode.com/problems/longest-common-subsequence/discuss/351689/JavaPython-3-Two-DP-codes-of-O(mn)-and-O(min(m-n))-spaces-w-picture-and-analysis + dp = [[0] * (len(text2) + 1) for _ in range(len(text1) + 1)] + for i, c in enumerate(text1): + for j, d in enumerate(text2): + dp[i + 1][j + 1] = 1 + \ + dp[i][j] if c == d else max(dp[i][j + 1], dp[i + 1][j]) + return dp[-1][-1] + + +def longest_common_subsequence_with_idx(X, Y): + """ + This implementation uses dynamic programming to calculate the length of the LCS, and uses a path array to keep track of the characters in the LCS. + The longest_common_subsequence function takes two strings as input, and returns a tuple with three values: + the length of the LCS, + the index of the first character of the LCS in the first string, + and the index of the last character of the LCS in the first string. + """ + m, n = len(X), len(Y) + L = [[0 for i in range(n + 1)] for j in range(m + 1)] + + # Following steps build L[m+1][n+1] in bottom up fashion. Note + # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] + right_idx = 0 + max_lcs = 0 + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + if L[i][j] > max_lcs: + max_lcs = L[i][j] + right_idx = i + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + # Create a string variable to store the lcs string + lcs = L[i][j] + # Start from the right-most-bottom-most corner and + # one by one store characters in lcs[] + i = m + j = n + # right_idx = 0 + while i > 0 and j > 0: + # If current character in X[] and Y are same, then + # current character is part of LCS + if X[i - 1] == Y[j - 1]: + + i -= 1 + j -= 1 + # If not same, then find the larger of two and + # go in the direction of larger value + elif L[i - 1][j] > L[i][j - 1]: + # right_idx = i if not right_idx else right_idx #the first change in L should be the right index of the lcs + i -= 1 + else: + j -= 1 + return lcs, i, max(i + lcs, right_idx) + + + +def longest_common_substring(X, Y): + m = len(X) + n = len(Y) + + # Create a 2D array to store the lengths of common substrings + dp = [[0] * (n + 1) for _ in range(m + 1)] + + # Variables to store the length of the longest common substring + max_length = 0 + end_index = 0 + + # Build the dp array bottom-up + for i in range(1, m + 1): + for j in range(1, n + 1): + if X[i - 1] == Y[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + + # Update the length and ending index of the common substring + if dp[i][j] > max_length: + max_length = dp[i][j] + end_index = i - 1 + else: + dp[i][j] = 0 + + # The longest common substring is X[end_index - max_length + 1:end_index + 1] + + return len(X[end_index - max_length + 1: end_index + 1]) + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/__init__.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/all.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/all.py new file mode 100644 index 0000000..c1a909a --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/all.py @@ -0,0 +1,75 @@ +from sdsvkvu.utils.post_processing import split_key_value_by_colon, remove_bullet_points_and_punctuation + + +def normalize_kvu_output(raw_outputs: dict) -> dict: + outputs = {} + for key, values in raw_outputs.items(): + if key == "table": + table = [] + for row in values: + item = {} + for k, v in row.items(): + k = remove_bullet_points_and_punctuation(k) + if v is not None and len(v) > 0: + v = remove_bullet_points_and_punctuation(v) + item[k] = v + table.append(item) + outputs[key] = table + else: + key = remove_bullet_points_and_punctuation(key) + if isinstance(values, list): + values = [remove_bullet_points_and_punctuation(v) for v in values] + elif values is not None and len(values) > 0: + values = remove_bullet_points_and_punctuation(values) + outputs[key] = values + return outputs + + +def export_kvu_for_all(raw_outputs: dict) -> dict: + outputs = {} + # Title + outputs["title"] = ( + raw_outputs["title"][0]["text"] if len(raw_outputs["title"]) > 0 else None + ) + + # Pairs of key-value + for pair in raw_outputs["single"]: + for key, values in pair.items(): + # outputs[key] = values["text"] + elements = split_key_value_by_colon(key, values["text"]) + outputs[elements[0]] = elements[1] + + # Only key fields + for key in raw_outputs["key"]: + # outputs[key["text"]] = None + elements = split_key_value_by_colon(key["text"], None) + outputs[elements[0]] = elements[1] + + # Triplet data + for triplet in raw_outputs["triplet"]: + for key, list_value in triplet.items(): + outputs[key] = [value["text"] for value in list_value] + + # Table data + table = [] + for row in raw_outputs["table"]: + item = {} + for cell in row: + item[cell["header"]] = cell["text"] + table.append(item) + outputs["table"] = table + outputs = normalize_kvu_output(outputs) + return outputs + + +def merged_kvu_for_all_for_multi_pages(loutputs: list) -> dict: + merged_outputs = {} + table = [] + for outputs in loutputs: + for key, value in outputs.items(): + if key == "table": + table.append(value) + else: + merged_outputs[key] = value + merged_outputs['table'] = table + return merged_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/manulife.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/manulife.py new file mode 100644 index 0000000..543b89b --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/manulife.py @@ -0,0 +1,133 @@ +from sdsvkvu.utils.dictionary.manulife import get_dict +from sdsvkvu.utils.post_processing import ( + split_key_value_by_colon, + remove_bullet_points_and_punctuation, + longestCommonSubsequence, + preprocessing + ) + + + +def manulife_key_matching(text: str, threshold: float, dict_type: str): + dictionary = get_dict(type=dict_type) + processed_text = preprocessing(text) + + for key, candidates in dictionary.items(): + + if any([txt == processed_text for txt in candidates]): + return True, key, 5 * (1 + len(processed_text)), processed_text + + if any([txt in processed_text for txt in candidates]): + return True, key, 5, processed_text + + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + if len(v) == 0: + continue + scores[k] = max( + [ + longestCommonSubsequence(processed_text, key) / len(key) + for key in dictionary[k] + ] + ) + key, score = max(scores.items(), key=lambda x: x[1]) + return score > threshold, key if score > threshold else text, score, processed_text + +def normalize_kvu_output_for_manulife(raw_outputs: dict) -> dict: + outputs = {} + for key, values in raw_outputs.items(): + if key == "tables" and len(values) > 0: + table_list = [] + for table in values: + headers, data = [], [] + headers = [ + remove_bullet_points_and_punctuation(header).upper() + for header in table["headers"] + ] + for row in table["data"]: + item = [] + for k, v in row.items(): + if v is not None and len(v) > 0: + item.append(remove_bullet_points_and_punctuation(v)) + else: + item.append(v) + data.append(item) + table_list.append({"headers": headers, "data": data}) + outputs[key] = table_list + else: + key = remove_bullet_points_and_punctuation(key).capitalize() + if isinstance(values, list): + values = [remove_bullet_points_and_punctuation(v) for v in values] + elif values is not None and len(values) > 0: + values = remove_bullet_points_and_punctuation(values) + outputs[key] = values + return outputs + +def export_kvu_for_manulife(raw_outputs: dict) -> dict: + outputs = {} + # Title + title_list = [] + for title in raw_outputs["title"]: + is_match, title_name, score, proceessed_text = manulife_key_matching(title["text"], threshold=0.6, dict_type="title") + title_list.append({ + 'documment_type': title_name if is_match else "", + 'content': title['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': title['id'] + }) + + if len(title_list) > 0: + selected_element = max(title_list, key=lambda x: x['lcs_score']) + outputs["title"] = f"({selected_element['documment_type']}) {selected_element['content']}" + else: + outputs["title"] = None + + # Pairs of key-value + for pair in raw_outputs["single"]: + for key, values in pair.items(): + # outputs[key] = values["text"] + elements = split_key_value_by_colon(key, values["text"]) + outputs[elements[0]] = elements[1] + + # Only key fields + for key in raw_outputs["key"]: + # outputs[key["text"]] = None + elements = split_key_value_by_colon(key["text"], None) + outputs[elements[0]] = elements[1] + + # Triplet data + for triplet in raw_outputs["triplet"]: + for key, list_value in triplet.items(): + outputs[key] = [value["text"] for value in list_value] + + # Table data + table = [] + header_list = {cell['header']: cell['header_bbox'] for row in raw_outputs['table'] for cell in row} + if header_list: + header_list = dict(sorted(header_list.items(), key=lambda x: int(x[1][0]))) + # print("Header_list:", header_list.keys()) + + for row in raw_outputs["table"]: + item = {header: None for header in list(header_list.keys())} + for cell in row: + item[cell["header"]] = cell["text"] + table.append(item) + outputs["tables"] = [{"headers": list(header_list.keys()), "data": table}] + else: + outputs["tables"] = [] + outputs = normalize_kvu_output_for_manulife(outputs) + return outputs + + +def merged_kvu_for_manulife_for_multi_pages(loutputs: list) -> dict: + merged_outputs = {} + table = [] + for outputs in loutputs: + for key, value in outputs.items(): + if key == "tables": + table.append(value) + else: + merged_outputs[key] = value + merged_outputs['tables'] = table + return merged_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt.py new file mode 100644 index 0000000..bc7a5e4 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt.py @@ -0,0 +1,186 @@ +from sdsvkvu.utils.post_processing import longestCommonSubsequence, preprocessing, is_string_in_range +from sdsvkvu.utils.dictionary.sbt import get_dict + +# For SBT project +def sbt_key_matching(text: str, threshold: float, dict_type: str): + dictionary = get_dict(type=dict_type) + processed_text = preprocessing(text) + + # Step 1: Exactly matching + extra_dict = get_dict("extra") + for key, candidates in dictionary.items(): + candidates = candidates + extra_dict[key] if key in extra_dict.keys() else candidates + + if any([processed_text == txt for txt in candidates]): + return key, 10, processed_text + + # Step 2: LCS score + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score >= threshold else text, score, processed_text + +def get_sbt_table_info(outputs): + table = [] + for single_item in outputs['table']: + item = {k: [] for k in get_dict("header").keys()} + for cell in single_item: + header_name, score, proceessed_text = sbt_key_matching(cell['header'], threshold=0.8, dict_type="header") + # print(f"{cell['header']} ==> {proceessed_text} ==> {header_name} : {score} - {cell['text']}") + is_header_valid = False + if header_name in list(item.keys()): + if header_name != "productname": + is_header_valid = True + elif cell['class'] == 'key': # Header with name of itemno as productname only when key + is_header_valid = True + _, _, proceessed_text = sbt_key_matching(cell['text'], threshold=0.8, dict_type="header") + if any([txt in proceessed_text for txt in ["originalreceipt", "homeclubvoucher", "ippuob"]]): + # print(proceessed_text) + is_header_valid = False + else: + is_header_valid = False + + if is_header_valid: + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + table.append(item) + return table + +def get_sbt_triplet_info(outputs): + triplet_pairs = [] + for single_item in outputs['triplet']: + item = {k: [] for k in get_dict("header").keys()} + is_item_valid = 0 + for key_name, list_value in single_item.items(): + for value in list_value: + if value['header'] == "non-header": + continue + header_name, score, proceessed_text = sbt_key_matching(value['header'], threshold=0.8, dict_type="header") + if header_name in list(item.keys()): + is_item_valid = 1 + item[header_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + if is_item_valid == 1: + for header_name, value in item.items(): + if len(value) == 0: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item['productname'] = key_name + # triplet_pairs.append({key_name: new_item}) + triplet_pairs.append(item) + + # else: ## Triplet => key as productname + # item['productname'] = key_name + # for value in list_value: + # # print(value) + # if is_string_in_range(value['text']): + # item['qty'] = value['text'] + # triplet_pairs.append(item) + return triplet_pairs + + +def get_sbt_info(outputs): + single_pairs = {k: [] for k in get_dict("key").keys()} + for pair in outputs['single']: + for key_name, value in pair.items(): + key_name, score, proceessed_text = sbt_key_matching(key_name, threshold=0.8, dict_type="key") + # print(f"{key} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'] + }) + + ### Get single_pair of serial_number if it predict as a table (Product Information) + is_product_info = False + for table_row in outputs['table']: + pair = {"key": None, 'value': None} + for cell in table_row: + _, _, proceessed_text = sbt_key_matching(cell['header'], threshold=0.8, dict_type="key") + if any(txt in proceessed_text for txt in ['product', 'information', 'productinformation']): + is_product_info = True + if cell['class'] in pair: + pair[cell['class']] = cell + + if all(v is not None for k, v in pair.items()) and is_product_info == True: + key_name, score, proceessed_text = sbt_key_matching(pair['key']['text'], threshold=0.8, dict_type="key") + # print(f"{pair['key']['text']} ==> {proceessed_text} ==> {key_name} : {score} - {pair['value']['text']}") + + if key_name in list(single_pairs): + single_pairs[key_name].append({ + 'content': pair['value']['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['value']['id'] + }) + ### block_end + + ap_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if len(list_potential_value) == 0: continue + if key_name == "imei_number": + # print('list_potential_value', list_potential_value) + ap_outputs[key_name] = [] + for v in list_potential_value: + imei = v['content'].replace(' ', '') + if imei.isdigit() and len(imei) > 5: # imei is number and have more 5 digits + ap_outputs[key_name].append(imei) + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + ap_outputs[key_name] = selected_value['content'] + + return ap_outputs + +def export_kvu_for_SDSAP(outputs): + # List of items in table + table = get_sbt_table_info(outputs) + triplet_pairs = get_sbt_triplet_info(outputs) + table = table + triplet_pairs + + ap_outputs = get_sbt_info(outputs) + + ap_outputs['table'] = table + return ap_outputs + +def merged_kvu_for_SDSAP_for_multi_pages(lvat_outputs: list): + merged_outputs = {k: [] for k in get_dict("key").keys()} + merged_outputs['table'] = [] + for outputs in lvat_outputs: + for key_name, value in outputs.items(): + if key_name == "table": + merged_outputs[key_name].extend(value) + else: + merged_outputs[key_name].append(value) + + for key, value in merged_outputs.items(): + if key == "table": + continue + if len(value) == 0: + merged_outputs[key] = None + else: + merged_outputs[key] = value[0] + + return merged_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt_v2.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt_v2.py new file mode 100644 index 0000000..31d3002 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/sbt_v2.py @@ -0,0 +1,320 @@ +import re +from sdsvkvu.utils.post_processing import ( + longestCommonSubsequence, + longest_common_substring, + preprocessing, + parse_date, + split_key_value_by_colon, + normalize_imei, + normalize_website, + normalize_hotline, + normalize_seller, + normalize_voucher +) +from sdsvkvu.utils.dictionary.sbt_v2 import get_dict + +def post_process_date(list_dates): + if len(list_dates) == 0: + return None + selected_value = max(list_dates, key=lambda x: x["lcs_score"]) # Get max lsc score + if not isinstance(selected_value["content"], str): + is_parser_error = True + date_formated = None + else: + date_formated, is_parser_error = parse_date(selected_value["content"]) + return date_formated + + + +def post_process_serial(list_serials): + if len(list_serials) == 0: + return None + selected_value = max( + list_serials, key=lambda x: x["lcs_score"] + ) # Get max lsc score + return selected_value["content"].strip() + + +def post_process_imei(list_imeis): + imeis = [] + for v in list_imeis: + if not isinstance(v["content"], str): + continue + imei = v["content"].replace(" ", "") + if imei.isdigit() and len(imei) > 5: # imei is number and have more 5 digits + imeis.append({ + "content": imei, + "token_id": v['token_id'] + }) + + if len(imeis) > 0: + return sorted(imeis, key=lambda x: int(x["token_id"]))[0]['content'].strip() + return None + + +def post_process_qty(inp_str: str) -> str: + pattern = r"\d" + match = re.search(pattern, inp_str) + if match: + return match.group() + return inp_str + + +def post_process_seller(list_sellers): + seller_mapping = get_dict(type="seller_mapping") + vote_list = {} + for seller in list_sellers: + seller_name = seller['content'] + if seller_name not in vote_list: + vote_list[seller_name] = 0 + + vote_list[seller_name] += seller['lcs_score'] + + if len(vote_list) > 0: + selected_value = max( + vote_list, key=lambda x: vote_list[x] + ) # Get major voting + + for norm_seller, candidates in seller_mapping.items(): + if any(preprocessing(txt) == preprocessing(selected_value) for txt in candidates): + selected_value = norm_seller + break + + selected_value = selected_value.lower() + for txt in candidates: + txt = txt.lower() + if txt in selected_value: + selected_value = selected_value.replace(txt, norm_seller) + + return selected_value.strip().title() + return None + +def post_process_subsidiary(list_subsidiaries): + if len(list_subsidiaries) > 0: + selected_value = max( + list_subsidiaries, key=lambda x: x["lcs_score"] + ) # Get max lsc score + return selected_value["content"] + return None + +def sbt_key_matching(text: str, threshold: float, dict_type: str): + dictionary = get_dict(type=dict_type) + processed_text = preprocessing(text) + + scores = {k: 0.0 for k in dictionary} + # Step 1: LCS score + for k, v in dictionary.items(): + score1 = max([ + longestCommonSubsequence(processed_text, key) / + max(len(key), len(processed_text)) + for key in dictionary[k]]) + + score2 = max([ + longest_common_substring(processed_text, key) / + max(len(key), len(processed_text)) + for key in dictionary[k]]) + + scores[k] = score1 if score1 > score2 else score2 + + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score >= threshold else text, score, processed_text + + +def get_date_value(list_dates): + date_outputs = [] + for date_obj in list_dates: + if "raw_key_name" in date_obj: + date_key, date_value = split_key_value_by_colon(date_obj['raw_key_name'], date_obj['text']) + else: + date_key, date_value = split_key_value_by_colon( + date_obj['text'] if date_obj['class'] == "date_key" else None, + date_obj['text'] if date_obj['class'] == "date_value" else None + ) + # print(f"======{date_key} : {date_value}") + + if date_key is None and date_obj['class'] == "date_value": + date_value = date_obj['text'] + proceessed_text, score = "", len(date_value) if isinstance(date_value, str) else 0 + else: + key_name, score, proceessed_text = sbt_key_matching( + date_key, threshold=0.8, dict_type="date" + ) + # print(f"{date_key} ==> {proceessed_text} ==> {key_name} : {score} - {date_value}") + date_outputs.append( + { + "content": date_value, + "processed_key_name": proceessed_text, + "lcs_score": score, + "token_id": date_obj["id"], + } + ) + return date_outputs + +def get_serial_imei(list_sn): + sn_outputs = {"serial_number": [], "imei_number": []} + for sn_obj in list_sn: + if "raw_key_name" in sn_obj: + sn_key, sn_value = split_key_value_by_colon(sn_obj['raw_key_name'], sn_obj['text']) + else: + sn_key, sn_value = split_key_value_by_colon( + sn_obj['text'] if sn_obj['class'] == "sn_key" else None, + sn_obj['text'] if sn_obj['class'] == "sn_value" else None + ) + # print(f"====== {sn_key} : {sn_value}") + + if sn_key is None and sn_obj['class'] == "sn_value": + sn_value = sn_obj['text'] + key_name, proceessed_text, score = None, "", 0.8 + else: + key_name, score, proceessed_text = sbt_key_matching( + sn_key, threshold=0.8, dict_type="imei" + ) + # print(f"{sn_key} ==> {proceessed_text} ==> {key_name} : {score} - {sn_value}") + + value = { + "content": sn_value, + "processed_key_name": proceessed_text, + "lcs_score": score, + "token_id": sn_obj["id"], + } + + if key_name is None: + if normalize_imei(sn_value).isdigit(): + sn_outputs['imei_number'].append(value) + else: + sn_outputs['serial_number'].append(value) + elif key_name in ['imei_number', 'serial_number']: + sn_outputs[key_name].append(value) + return sn_outputs + + +def get_product_info(list_items): + table = [] + for row in list_items: + item = {} + for key, value in row.items(): + item[key] = None + if len(value) > 0: + if key == "qty": + item[key] = post_process_qty(value[0]["text"]) + else: + item[key] = value[0]["text"] + table.append(item) + return table + + + +def get_seller(outputs): # Post processing to combine seller and extra information (voucher, hotline, website) + seller_outputs = [] + voucher_info = [] + + for key_field in ["seller", "website", "hotline", "voucher", "sold_by"]: + threshold = 0.7 + func_name = f"normalize_{key_field}" + for potential_seller in outputs[key_field]: + seller_name, score, processed_text = sbt_key_matching(eval(func_name)(potential_seller["text"]), threshold=threshold, dict_type="seller") + print(f"{potential_seller['text']} ==> {processed_text} ==> {seller_name} : {score}") + + if key_field in ("voucher"): + voucher_info.append(potential_seller['text']) + + seller_outputs.append( + { + "content": seller_name, + "raw_seller_name": potential_seller["text"], + "processed_seller_name": processed_text, + "lcs_score": score, + "info": key_field + } + ) + + for voucher in voucher_info: + for i in range(len(seller_outputs)): + if voucher.lower() not in seller_outputs[i]['content'].lower(): + seller_outputs[i]['content'] = f"{voucher} {seller_outputs[i]['content']}" + + return seller_outputs + + +def get_subsidiary(list_subsidiaries): + subsidiary_outputs = [] + sold_by_info = [] + for sold_obj in list_subsidiaries: + if "raw_key_name" in sold_obj: + sold_key, sold_value = split_key_value_by_colon(sold_obj['raw_key_name'], sold_obj['text']) + else: + sold_key, sold_value = split_key_value_by_colon( + sold_obj['text'] if sold_obj['class'] == "sold_key" else None, + sold_obj['text'] if sold_obj['class'] == "sold_value" else None + ) + # print(f"======{sold_key} : {sold_value}") + + + if sold_key is None and sold_obj['class'] == "sold_value": + sold_value = sold_obj['text'] + key_name, proceessed_text, score = "unknown", "", 0.8 + else: + key_name, score, proceessed_text = sbt_key_matching( + sold_key, threshold=0.8, dict_type="sold_by" + ) + # print(f"{sold_key} ==> {proceessed_text} ==> {key_name} : {score} - {sold_value}") + + if key_name == "sold_by": + sold_by_info.append(sold_obj) + else: + subsidiary_outputs.append( + { + "content": sold_value, + "processed_key_name": proceessed_text, + "lcs_score": score, + "token_id": sold_obj["id"], + } + ) + return subsidiary_outputs, sold_by_info + + +def export_kvu_for_SBT(outputs): + # sold to party + list_subsidiaries, sold_by_info = get_subsidiary(outputs['sold_value']) + # seller + outputs['sold_by'] = sold_by_info + list_sellers = get_seller(outputs) + # date + list_dates = get_date_value(outputs["date_value"]) + # serial_number or imei + list_serial_imei = get_serial_imei(outputs["serial_imei"]) + + serial_number = post_process_serial(list_serial_imei["serial_number"]) + imei_number = post_process_imei(list_serial_imei["imei_number"]) + # table + # list_items = get_product_info(outputs["table"]) + + ap_outputs = {} + ap_outputs["retailername"] = post_process_seller(list_sellers) + ap_outputs["sold_to_party"] = post_process_subsidiary(list_subsidiaries) + ap_outputs["purchase_date"] = post_process_date(list_dates) + ap_outputs["imei_number"] = imei_number if imei_number is not None else serial_number + # ap_outputs["table"] = list_items + + return ap_outputs + + +def merged_kvu_for_SBT_for_multi_pages(lvat_outputs: list): + merged_outputs = {k: [] for k in get_dict("key").keys()} + merged_outputs['table'] = [] + for outputs in lvat_outputs: + for key_name, value in outputs.items(): + if key_name == "table": + merged_outputs[key_name].extend(value) + else: + merged_outputs[key_name].append(value) + + for key, value in merged_outputs.items(): + if key == "table": + continue + if len(value) == 0: + merged_outputs[key] = None + else: + merged_outputs[key] = value[0] + + return merged_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vat.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vat.py new file mode 100644 index 0000000..424ec35 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vat.py @@ -0,0 +1,237 @@ +import re +from sdsvkvu.utils.post_processing import longestCommonSubsequence, preprocessing +from sdsvkvu.utils.dictionary.vat import get_dict + + +# For FI-VAT project +def vat_key_replacing(vat_outputs: dict) -> dict: + outputs = {} + DKVU2XML = get_dict("kvu2xml") + for key, value in vat_outputs.items(): + if key != "table": + outputs[DKVU2XML[key]] = value + else: + list_items = [] + for item in value: + list_items.append({ + DKVU2XML[item_key]: item_value for item_key, item_value in item.items() + }) + outputs['table'] = list_items + return outputs + + +def vat_key_matching(text: str, threshold: float, dict_type: str): + dictionary = get_dict(dict_type) + processed_text = preprocessing(text) + + # Step 1: Exactly matching + date_dict = get_dict("date") + for time_key, candidates in date_dict.items(): + if any([processed_text == txt for txt in candidates]): + return "Ngày, tháng, năm lập hóa đơn", 5, time_key + + extra_dict = get_dict("extra") + for key, candidates in dictionary.items(): + candidates = candidates + extra_dict[key] if key in extra_dict.keys() else candidates + + if key == 'Tên người bán' and processed_text == "kyboi": + return key, 8, processed_text + + if any([processed_text == txt for txt in candidates]): + return key, 10, processed_text + + # Step 2: LCS score + scores = {k: 0.0 for k in dictionary} + for k, v in dictionary.items(): + if k in ("Ngày, tháng, năm lập hóa đơn"): continue + scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score > threshold else text, score, processed_text + + +def normalize_number_format(s: str) -> float: + s = s.replace(' ', '').replace('O', '0').replace('o', '0') + if s.endswith(",00") or s.endswith(".00"): + s = s[:-3] + if all([delimiter in s for delimiter in [',', '.']]): + s = s.replace('.', '').split(',') + remain_value = s[1].split('0')[0] + return int(s[0]) + int(remain_value) * 1 / (10**len(remain_value)) + else: + s = s.replace(',', '').replace('.', '') + return int(s) + + +def post_process_item(item: dict) -> dict: + check_keys = ['Số lượng', 'Đơn giá', 'Doanh số mua chưa có thuế'] + mis_key = [] + for key in check_keys: + if item[key] in (None, '0'): + mis_key.append(key) + if len(mis_key) == 1: + try: + if mis_key[0] == check_keys[0] and normalize_number_format(item[check_keys[1]]) != 0: + item[mis_key[0]] = round(normalize_number_format(item[check_keys[2]]) / normalize_number_format(item[check_keys[1]])).__str__() + elif mis_key[0] == check_keys[1] and normalize_number_format(item[check_keys[0]]) != 0: + item[mis_key[0]] = (normalize_number_format(item[check_keys[2]]) / normalize_number_format(item[check_keys[0]])).__str__() + elif mis_key[0] == check_keys[2]: + item[mis_key[0]] = (normalize_number_format(item[check_keys[0]]) * normalize_number_format(item[check_keys[1]])).__str__() + except Exception as e: + print("Cannot post process this item with error:", e) + return item + + +def get_vat_table_info(outputs): + table = [] + for single_item in outputs['table']: + item = {k: [] for k in get_dict("header").keys()} + for cell in single_item: + header_name, score, proceessed_text = vat_key_matching(cell['header'], threshold=0.75, dict_type="header") + if header_name in list(item.keys()): + # item[header_name] = value['text'] + item[header_name].append({ + 'content': cell['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': cell['id'] + }) + + for header_name, value in item.items(): + if len(value) == 0: + if header_name in ("Số lượng", "Doanh số mua chưa có thuế"): + item[header_name] = '0' + else: + item[header_name] = None + continue + item[header_name] = max(value, key=lambda x: x['lcs_score'])['content'] # Get max lsc score + + item = post_process_item(item) + + if item["Mặt hàng"] == None: + continue + table.append(item) + return table + +def get_vat_info(outputs): + # VAT Information + single_pairs = {k: [] for k in get_dict("key").keys()} + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + key_name, score, proceessed_text = vat_key_matching(raw_key_name, threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value['id'], + }) + + for triplet in outputs['triplet']: + for key, value_list in triplet.items(): + if len(value_list) == 1: + key_name, score, proceessed_text = vat_key_matching(key, threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value_list[0]['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': value_list[0]['id'] + }) + + for pair in value_list: + key_name, score, proceessed_text = vat_key_matching(pair['header'], threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + for table_row in outputs['table']: + for pair in table_row: + key_name, score, proceessed_text = vat_key_matching(pair['header'], threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': pair['text'], + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'] + }) + + return single_pairs + + +def post_process_tax_code(tax_code_raw: str): + if len(tax_code_raw.replace(' ', '')) not in (10, 13): # to remove the first/last number dupicated + tax_code_raw = tax_code_raw.split(' ') + tax_code_raw = sorted(tax_code_raw, key=lambda x: len(x), reverse=True)[0] + return tax_code_raw.replace(' ', '') + + +def merge_vat_info(single_pairs): + vat_outputs = {k: None for k in list(single_pairs)} + for key_name, list_potential_value in single_pairs.items(): + if key_name in ("Ngày, tháng, năm lập hóa đơn"): + if len(list_potential_value) == 1: + vat_outputs[key_name] = list_potential_value[0]['content'] + else: + date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'} + for value in list_potential_value: + date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content']) + vat_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}" + else: + if len(list_potential_value) == 0: continue + if key_name in ("Mã số thuế người bán"): + selected_value = min(list_potential_value, key=lambda x: x['token_id']) # Get first tax code + vat_outputs[key_name] = post_process_tax_code(selected_value['content']) + + else: + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + vat_outputs[key_name] = selected_value['content'] + return vat_outputs + + +def export_kvu_for_VAT_invoice(outputs): + vat_outputs = {} + # List of items in table + table = get_vat_table_info(outputs) + # VAT Information + single_pairs = get_vat_info(outputs) + vat_outputs = merge_vat_info(single_pairs) + # Combine VAT information and table + vat_outputs['table'] = table + return vat_outputs + + +def merged_kvu_for_VAT_invoice_for_multi_pages(lvat_outputs: list): + merged_outputs = {k: [] for k in get_dict("key").keys()} + merged_outputs['table'] = [] + for outputs in lvat_outputs: + for key_name, value in outputs.items(): + if key_name == "table": + merged_outputs[key_name].extend(value) + else: + if value == None or value == "dd/mm/yyyy": + # print(key_name, value) + continue + merged_outputs[key_name].append(value) + + for key, value in merged_outputs.items(): + if key == "table": + continue + if len(value) == 0: + merged_outputs[key] = None + else: + merged_outputs[key] = value[0] + + return merged_outputs + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vtb.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vtb.py new file mode 100644 index 0000000..14694af --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/query/vtb.py @@ -0,0 +1,153 @@ +import re +from sdsvkvu.utils.post_processing import preprocessing, date_regexing, remove_bullet_points_and_punctuation +from sdsvkvu.utils.dictionary.vtb import get_dict + +# For Vietin Bank project +def vietin_key_matching(text: str, threshold: float, dict_type: str): + dictionary = get_dict(type=dict_type) + processed_text = preprocessing(text) + + # Step 1: Exactly matching + date_dict = get_dict("date") + for time_key, candidates in date_dict.items(): + if any([txt in processed_text for txt in candidates]): + return "date", 5, time_key + + extra_dict = get_dict("extra") + for key, candidates in dictionary.items(): + candidates = candidates + extra_dict[key] if key in extra_dict.keys() else candidates + + if processed_text[-4:] == "dien": # EX: Bộ trưởng Bộ GTVT điện: A, B, C + return "sender", 15, processed_text + + if any([txt in processed_text for txt in candidates]): + return key, 10, processed_text + + # Step 2: LCS score + scores = {k: 0.0 for k in dictionary} + ## Disable temporarily + # for k, v in dictionary.items(): + # if k in ("date", "title", "number", 'signee', 'sender', 'receiver'): continue + # scores[k] = max([longestCommonSubsequence(processed_text, key)/len(key) for key in dictionary[k]]) + key, score = max(scores.items(), key=lambda x: x[1]) + return key if score > threshold else text, score, processed_text + + +def get_vietin_info(outputs): + # Vietin Information + single_pairs = {k: [] for k in get_dict(type="key").keys()} + for pair in outputs['single']: + for raw_key_name, value in pair.items(): + key_name, score, proceessed_text = vietin_key_matching(raw_key_name, threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value['text']}") + + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': value['text'], + 'processed_key_name': proceessed_text, + 'raw_key_name': raw_key_name, + 'lcs_score': score, + 'token_id': value['id'], + 'single_entity': False + }) + + + for single_item in outputs['key'] + outputs['value']: + key_name, score, proceessed_text = vietin_key_matching(single_item['text'], threshold=0.8, dict_type="key") + # print(f"{single_item['text']} ==> {proceessed_text} ==> {key_name} : {score} - {single_item['text']}") + + # if key_name not in ('number', 'date'): continue + if key_name in list(single_pairs.keys()): + single_pairs[key_name].append({ + 'content': single_item['text'], + 'processed_key_name': proceessed_text, + 'raw_key_name': single_item['text'], + 'lcs_score': score, + 'token_id': single_item['id'], + 'single_entity': True + }) + + + # Sender and receiver are usually in triplet + for triplet in outputs['triplet']: + for raw_key_name, value_list in triplet.items(): + key_name, score, proceessed_text = vietin_key_matching(raw_key_name, threshold=0.8, dict_type="key") + # print(f"{raw_key_name} ==> {proceessed_text} ==> {key_name} : {score} - {value_list[0]['text']}") + + if key_name in list(single_pairs.keys()): + for pair in value_list: + single_pairs[key_name].append({ + 'content': pair['text'], + 'raw_key_name': raw_key_name, + 'processed_key_name': proceessed_text, + 'lcs_score': score, + 'token_id': pair['id'], + 'single_entity': False + }) + return single_pairs + +def post_process_vietin_info(single_pairs): + vietin_outputs = {k: None for k in get_dict(type="key").keys()} + for key_name, list_potential_value in single_pairs.items(): + if key_name in ("date"): + if len(list_potential_value) == 1: + check_string = list_potential_value[0]['content'].replace(" ", "") + if check_string.replace('/', '').isdigit(): + vietin_outputs[key_name] = check_string + else: + # date_time = {'day': 'dd', 'month': 'mm', 'year': 'yyyy'} + # if len(list_potential_value) == 3: + # for value in list_potential_value: + # date_time[value['processed_key_name']] = re.sub("[^0-9]", "", value['content']) + # vietin_outputs[key_name] = f"{date_time['day']}/{date_time['month']}/{date_time['year']}" + # else: + list_potential_value = sorted(list_potential_value, key=lambda x: x['token_id'], reverse=False) + full_string = ' '.join([v['raw_key_name'] + v['content'] for v in list_potential_value]) + d, m, y = date_regexing(full_string) + vietin_outputs[key_name] = f"{d}/{m}/{y}" + # print(full_string) + # print(d, m, y) + elif key_name in ("receiver", "sender"): + list_potential_value = sorted(list_potential_value, key=lambda x: x['token_id'], reverse=False) + vietin_outputs[key_name] = [remove_bullet_points_and_punctuation(value['content']) for value in list_potential_value] + elif key_name in ("signee"): + list_potential_value = sorted(list_potential_value, key=lambda x: x['token_id'], reverse=False) + vietin_outputs[key_name] = [f"{value['content']} - {value['raw_key_name']}" for value in list_potential_value if value['single_entity'] == False] + else: + if len(list_potential_value) == 0: continue + selected_value = max(list_potential_value, key=lambda x: x['lcs_score']) # Get max lsc score + vietin_outputs[key_name] = selected_value['content'] + if key_name in ("number"): + number = re.sub("[^0-9]", "", selected_value['raw_key_name']) + start_idx = selected_value['content'].find(number) + if start_idx != -1: + vietin_outputs[key_name] = selected_value['content'].replace(" ", "")[start_idx:] + else: + vietin_outputs[key_name] = number + selected_value['content'].replace(" ", "") + + return vietin_outputs + +def export_kvu_for_vietin(outputs): + single_pairs = get_vietin_info(outputs) + vietin_outputs = post_process_vietin_info(single_pairs) + vietin_outputs['title'] = [title['text'] for title in outputs["title"]] + return vietin_outputs + +def merged_kvu_for_vietin_for_multi_pages(lvietin_outputs: list): + merged_outputs = {k: [] for k in get_dict("key").keys()} + for outputs in lvietin_outputs: + for key_name, value in outputs.items(): + if value == None or value == "dd/mm/yyyy": + # print(key_name, value) + continue + merged_outputs[key_name].append(value) + + for key, value in merged_outputs.items(): + if len(value) == 0: + merged_outputs[key] = None + elif key == "receiver": + merged_outputs[key] = value[-1] + elif key in ("number", "title", "date", "signee", "sender"): + merged_outputs[key] = value[0] + + return merged_outputs \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/utils.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/utils.py new file mode 100644 index 0000000..5019adb --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/utils.py @@ -0,0 +1,129 @@ +import os +import json +import glob +from tqdm import tqdm +from pdf2image import convert_from_path +from dicttoxml import dicttoxml + + +def create_dir(save_dir=''): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + # else: + # print("DIR already existed.") + # print('Save dir : {}'.format(save_dir)) + +def convert_pdf2img(pdf_dir, save_dir): + pdf_files = glob.glob(f'{pdf_dir}/*.pdf') + print('No. pdf files:', len(pdf_files)) + print(pdf_files) + + for file in tqdm(pdf_files): + pdf2img(file, save_dir, n_pages=-1, return_fname=False) + # pages = convert_from_path(file, 500) + # for i, page in enumerate(pages): + # page.save(os.path.join(save_dir, os.path.basename(file).replace('.pdf', f'_{i}.jpg')), 'JPEG') + print('Done!!!') + +def pdf2img(pdf_path, save_dir, n_pages=-1, return_fname=False): + file_names = [] + pages = convert_from_path(pdf_path) + if n_pages != -1: + pages = pages[:n_pages] + for i, page in enumerate(pages): + _save_path = os.path.join(save_dir, os.path.basename(pdf_path).replace('.pdf', f'_{i}.jpg')) + page.save(_save_path, 'JPEG') + file_names.append(_save_path) + if return_fname: + return file_names + +def xyxy2xywh(bbox): + return [ + float(bbox[0]), + float(bbox[1]), + float(bbox[2]) - float(bbox[0]), + float(bbox[3]) - float(bbox[1]), + ] + +def write_to_json(file_path, content): + with open(file_path, mode='w', encoding='utf8') as f: + json.dump(content, f, ensure_ascii=False) + + +def read_json(file_path): + with open(file_path, 'r') as f: + return json.load(f) + +def read_xml(file_path): + with open(file_path, 'r') as xml_file: + return xml_file.read() + +def write_to_xml(file_path, content): + with open(file_path, mode="w", encoding='utf8') as f: + f.write(content) + +def write_to_xml_from_dict(file_path, content): + xml = dicttoxml(content) + xml = content + xml_decode = xml.decode() + + with open(file_path, mode="w") as f: + f.write(xml_decode) + +def read_txt(ocr_path): + with open(ocr_path, "r") as f: + lines = f.read().splitlines() + return lines + +def load_ocr_result(ocr_path): + with open(ocr_path, 'r') as f: + lines = f.read().splitlines() + + preds = [] + for line in lines: + preds.append(line.split('\t')) + return preds + +def post_process_basic_ocr(lwords: list) -> list: + pp_lwords = [] + for word in lwords: + pp_lwords.append(word.replace("✪", " ")) + return pp_lwords + +def read_ocr_result_from_txt(file_path: str): + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + + boxes, words = [], [] + for line in lines: + if line == "": + continue + word_info = line.split("\t") + if len(word_info) == 6: + x1, y1, x2, y2, text, _ = word_info + elif len(word_info) == 5: + x1, y1, x2, y2, text = word_info + + x1, y1, x2, y2 = int(float(x1)), int(float(y1)), int(float(x2)), int(float(y2)) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + + + + + + + + + + + + + + + diff --git a/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/word2line.py b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/word2line.py new file mode 100644 index 0000000..d8380ef --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/utils/word2line.py @@ -0,0 +1,226 @@ +class Word(): + def __init__(self, text="",image=None, conf_detect=0.0, conf_cls=0.0, bndbox = [-1,-1,-1,-1], kie_label =""): + self.type = "word" + self.text =text + self.image = image + self.conf_detect = conf_detect + self.conf_cls = conf_cls + self.boundingbox = bndbox # [left, top,right,bot] coordinate of top-left and bottom-right point + self.word_id = 0 # id of word + self.word_group_id = 0 # id of word_group which instance belongs to + self.line_id = 0 #id of line which instance belongs to + self.paragraph_id = 0 #id of line which instance belongs to + self.kie_label = kie_label + def invalid_size(self): + return (self.boundingbox[2] - self.boundingbox[0]) * (self.boundingbox[3] - self.boundingbox[1]) > 0 + def is_special_word(self): + left, top, right, bottom = self.boundingbox + width, height = right - left, bottom - top + text = self.text + + if text is None: + return True + + # if len(text) > 7: + # return True + if len(text) >= 7: + no_digits = sum(c.isdigit() for c in text) + return no_digits / len(text) >= 0.3 + + return False + +class Word_group(): + def __init__(self): + self.type = "word_group" + self.list_words = [] # dict of word instances + self.word_group_id = 0 # word group id + self.line_id = 0 #id of line which instance belongs to + self.paragraph_id = 0# id of paragraph which instance belongs to + self.text ="" + self.boundingbox = [-1,-1,-1,-1] + self.kie_label ="" + def add_word(self, word:Word): #add a word instance to the word_group + if word.text != "✪": + for w in self.list_words: + if word.word_id == w.word_id: + print("Word id collision") + return False + word.word_group_id = self.word_group_id # + word.line_id = self.line_id + word.paragraph_id = self.paragraph_id + self.list_words.append(word) + self.text += ' '+ word.text + if self.boundingbox == [-1,-1,-1,-1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3])] + return True + else: + return False + + def update_word_group_id(self, new_word_group_id): + self.word_group_id = new_word_group_id + for i in range(len(self.list_words)): + self.list_words[i].word_group_id = new_word_group_id + + def update_kie_label(self): + list_kie_label = [word.kie_label for word in self.list_words] + dict_kie = dict() + for label in list_kie_label: + if label not in dict_kie: + dict_kie[label]=1 + else: + dict_kie[label]+=1 + total = len(list(dict_kie.values())) + max_value = max(list(dict_kie.values())) + list_keys = list(dict_kie.keys()) + list_values = list(dict_kie.values()) + self.kie_label = list_keys[list_values.index(max_value)] + +class Line(): + def __init__(self): + self.type = "line" + self.list_word_groups = [] # list of Word_group instances in the line + self.line_id = 0 #id of line in the paragraph + self.paragraph_id = 0 # id of paragraph which instance belongs to + self.text = "" + self.boundingbox = [-1,-1,-1,-1] + def add_group(self, word_group:Word_group): # add a word_group instance + if word_group.list_words is not None: + for wg in self.list_word_groups: + if word_group.word_group_id == wg.word_group_id: + print("Word_group id collision") + return False + + self.list_word_groups.append(word_group) + self.text += word_group.text + word_group.paragraph_id = self.paragraph_id + word_group.line_id = self.line_id + + for i in range(len(word_group.list_words)): + word_group.list_words[i].paragraph_id = self.paragraph_id #set paragraph_id for word + word_group.list_words[i].line_id = self.line_id #set line_id for word + return True + return False + def update_line_id(self, new_line_id): + self.line_id = new_line_id + for i in range(len(self.list_word_groups)): + self.list_word_groups[i].line_id = new_line_id + for j in range(len(self.list_word_groups[i].list_words)): + self.list_word_groups[i].list_words[j].line_id = new_line_id + + + def merge_word(self, word): # word can be a Word instance or a Word_group instance + if word.text != "✪": + if self.boundingbox == [-1,-1,-1,-1]: + self.boundingbox = word.boundingbox + else: + self.boundingbox = [min(self.boundingbox[0], word.boundingbox[0]), + min(self.boundingbox[1], word.boundingbox[1]), + max(self.boundingbox[2], word.boundingbox[2]), + max(self.boundingbox[3], word.boundingbox[3])] + self.list_word_groups.append(word) + self.text += ' ' + word.text + return True + return False + + + def in_same_line(self, input_line, thresh=0.7): + # calculate iou in vertical direction + left1, top1, right1, bottom1 = self.boundingbox + left2, top2, right2, bottom2 = input_line.boundingbox + + sorted_vals = sorted([top1, bottom1, top2, bottom2]) + intersection = sorted_vals[2] - sorted_vals[1] + union = sorted_vals[3]-sorted_vals[0] + min_height = min(bottom1-top1, bottom2-top2) + if min_height==0: + return False + ratio = intersection / min_height + height1, height2 = top1-bottom1, top2-bottom2 + ratio_height = float(max(height1, height2))/float(min(height1, height2)) + # height_diff = (float(top1-bottom1))/(float(top2-bottom2)) + + + if (top1 in range(top2, bottom2) or top2 in range(top1, bottom1)) and ratio >= thresh and (ratio_height<2): + return True + return False + +def check_iomin(word:Word, word_group:Word_group): + min_height = min(word.boundingbox[3]-word.boundingbox[1],word_group.boundingbox[3]-word_group.boundingbox[1]) + intersect = min(word.boundingbox[3],word_group.boundingbox[3]) - max(word.boundingbox[1],word_group.boundingbox[1]) + if intersect/min_height > 0.7: + return True + return False + +def words_to_lines(words, check_special_lines=True): #words is list of Word instance + #sort word by top + words.sort(key = lambda x: (x.boundingbox[1], x.boundingbox[0])) + number_of_word = len(words) + #sort list words to list lines, which have not contained word_group yet + lines = [] + for i, word in enumerate(words): + if word.invalid_size()==0: + continue + new_line = True + for i in range(len(lines)): + if lines[i].in_same_line(word): #check if word is in the same line with lines[i] + lines[i].merge_word(word) + new_line = False + + if new_line ==True: + new_line = Line() + new_line.merge_word(word) + lines.append(new_line) + + # print(len(lines)) + #sort line from top to bottom according top coordinate + lines.sort(key = lambda x: x.boundingbox[1]) + + #construct word_groups in each line + line_id = 0 + word_group_id =0 + word_id = 0 + for i in range(len(lines)): + if len(lines[i].list_word_groups)==0: + continue + #left, top ,right, bottom + line_width = lines[i].boundingbox[2] - lines[i].boundingbox[0] # right - left + # print("line_width",line_width) + lines[i].list_word_groups.sort(key = lambda x: x.boundingbox[0]) #sort word in lines from left to right + + #update text for lines after sorting + lines[i].text ="" + for word in lines[i].list_word_groups: + lines[i].text += " "+word.text + + list_word_groups=[] + inital_word_group = Word_group() + inital_word_group.word_group_id= word_group_id + word_group_id +=1 + lines[i].list_word_groups[0].word_id=word_id + inital_word_group.add_word(lines[i].list_word_groups[0]) + word_id+=1 + list_word_groups.append(inital_word_group) + for word in lines[i].list_word_groups[1:]: #iterate through each word object in list_word_groups (has not been construted to word_group yet) + check_word_group= True + #set id for each word in each line + word.word_id = word_id + word_id+=1 + if (not list_word_groups[-1].text.endswith(':')) and ((word.boundingbox[0]-list_word_groups[-1].boundingbox[2])/line_width <0.05) and check_iomin(word, list_word_groups[-1]): + list_word_groups[-1].add_word(word) + check_word_group=False + if check_word_group ==True: + new_word_group = Word_group() + new_word_group.word_group_id= word_group_id + word_group_id +=1 + new_word_group.add_word(word) + list_word_groups.append(new_word_group) + lines[i].list_word_groups = list_word_groups + # set id for lines from top to bottom + lines[i].update_line_id(line_id) + line_id +=1 + return lines, number_of_word \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/setup.cfg b/cope2n-ai-fi/modules/_sdsvkvu/setup.cfg new file mode 100644 index 0000000..ce66352 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/setup.cfg @@ -0,0 +1,49 @@ +[tool:pytest] +norecursedirs = + .git + dist + build +addopts = + --strict + --doctest-modules + --durations=0 + +[coverage:report] +exclude_lines = + pragma: no-cover + pass + +[flake8] +max-line-length = 120 +exclude = .tox,*.egg,build,temp +select = E,W,F +doctests = True +verbose = 2 +# https://pep8.readthedocs.io/en/latest/intro.html#error-codes +format = pylint +# see: https://www.flake8rules.com/ +ignore = + E731 # Do not assign a lambda expression, use a def + W504 # Line break occurred after a binary operator + F401 # Module imported but unused + F841 # Local variable name is assigned to but never used + W605 # Invalid escape sequence 'x' + +# setup.cfg or tox.ini +[check-manifest] +ignore = + *.yml + .github + .github/* + +[metadata] +license_file = LICENSE +description-file = README.md +author = tuanlv +author_email = lv.tuan3@samsung.com +# long_description = file:README.md +# long_description_content_type = text/markdown +[options] +packages = find: +python_requires = >=3.9 +include_package_data = True \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/setup.py b/cope2n-ai-fi/modules/_sdsvkvu/setup.py new file mode 100644 index 0000000..8fde133 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/setup.py @@ -0,0 +1,181 @@ +import os +import os.path as osp +import shutil +import sys +import warnings +from setuptools import find_packages, setup + + +version_file = "sdsvkvu/utils/version.py" +is_windows = sys.platform == 'win32' + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + +def add_mim_extention(): + """Add extra files that are required to support MIM into the package. + + These files will be added by creating a symlink to the originals if the + package is installed in `editable` mode (e.g. pip install -e .), or by + copying from the originals otherwise. + """ + + # parse installment mode + if 'develop' in sys.argv: + # installed by `pip install -e .` + mode = 'symlink' + elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + # installed by `pip install .` + # or create source distribution by `python setup.py sdist` + mode = 'copy' + else: + return + + filenames = ['tools', 'configs', 'model-index.yml'] + repo_path = osp.dirname(__file__) + mim_path = osp.join(repo_path, 'mmocr', '.mim') + os.makedirs(mim_path, exist_ok=True) + + for filename in filenames: + if osp.exists(filename): + src_path = osp.join(repo_path, filename) + tar_path = osp.join(mim_path, filename) + + if osp.isfile(tar_path) or osp.islink(tar_path): + os.remove(tar_path) + elif osp.isdir(tar_path): + shutil.rmtree(tar_path) + + if mode == 'symlink': + src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) + try: + os.symlink(src_relpath, tar_path) + except OSError: + # Creating a symbolic link on windows may raise an + # `OSError: [WinError 1314]` due to privilege. If + # the error happens, the src file will be copied + mode = 'copy' + warnings.warn( + f'Failed to create a symbolic link for {src_relpath}, ' + f'and it will be copied to {tar_path}') + else: + continue + + if mode == 'copy': + if osp.isfile(src_path): + shutil.copyfile(src_path, tar_path) + elif osp.isdir(src_path): + shutil.copytree(src_path, tar_path) + else: + warnings.warn(f'Cannot copy file {src_path}.') + else: + raise ValueError(f'Invalid mode {mode}') + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + import sys + + # return short version for sdist + if 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + return locals()['short_version'] + else: + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strip + specific version information. + + Args: + fname (str): Path to requirements file. + with_version (bool, default=False): If True, include version specs. + Returns: + info (list[str]): List of requirements items. + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + +if __name__ == '__main__': + setup( + name='sdsvkvu', + # version=get_version(), + version="0.0.1", + description='SDSV OCR Team: Key-value understanding', + long_description=readme(), + long_description_content_type='text/markdown', + packages=find_packages(), # exclude=('configs', 'tools', 'demo') + include_package_data=True, + url='https://github.com/open-mmlab/mmocr', + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3.9', + ], + license='Apache License 2.0', + install_requires=parse_requirements('requirements.txt'), + zip_safe=False) \ No newline at end of file diff --git a/cope2n-ai-fi/modules/_sdsvkvu/test.py b/cope2n-ai-fi/modules/_sdsvkvu/test.py new file mode 100644 index 0000000..5be2845 --- /dev/null +++ b/cope2n-ai-fi/modules/_sdsvkvu/test.py @@ -0,0 +1,18 @@ +import os +from sdsvkvu import load_engine, process_img, process_pdf, process_dir +from sdsvkvu.modules.run_ocr import load_ocr_engine +os.environ["CUDA_VISIBLE_DEVICES"]="1" +# os.environ["NLTK_DATA"]="/mnt/ssd1T/tuanlv/02-KVU/sdsvkvu/nltk_data" + +if __name__ == "__main__": + # ocr_engine = load_ocr_engine({"device": "cuda:0"}) + kwargs = {"device": "cuda:0", "ocr_engine": None} + img_dir = "/mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/visualize/test_manulife" + save_dir = "/mnt/hdd4T/OCR/tuanlv/02-KVU/sdsvkvu/visualize/test_manulife" + engine = load_engine(kwargs) + # option: "vat" for vat invoice outputs, "sbt": sbt invoice outputs, else for raw outputs + # outputs = process_img(img_dir, save_dir, engine, export_all=False, option="vat") + # outputs = process_pdf(img_dir, save_dir, engine, export_all=True, option="vat") + process_dir(img_dir, save_dir, engine, export_all=True, option="manulife") + # process_dir(img_dir, save_dir, engine, export_all=True, option="") + diff --git a/cope2n-ai-fi/modules/ocr_engine/.gitignore b/cope2n-ai-fi/modules/ocr_engine/.gitignore new file mode 100644 index 0000000..1250e95 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +visualize/ +results/ +*.jpeg +*.jpg +*.png diff --git a/cope2n-ai-fi/modules/ocr_engine/README.md b/cope2n-ai-fi/modules/ocr_engine/README.md new file mode 100644 index 0000000..ca1b349 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/README.md @@ -0,0 +1,47 @@ +# OCR Engine + +OCR Engine is a Python package that combines text detection and recognition models from [mmdet](https://github.com/open-mmlab/mmdetection) and [mmocr](https://github.com/open-mmlab/mmocr) to perform Optical Character Recognition (OCR) on various inputs. The package currently supports three types of input: a single image, a recursive directory, or a csv file. + +## Installation + +To install OCR Engine, clone the repository and install the required packages: + +```bash +git clone git@github.com:mrlasdt/ocr-engine.git +cd ocr-engine +pip install -r requirements.txt + +``` + + +## Usage + +To use OCR Engine, simply run the `ocr_engine.py` script with the desired input type and input path. For example, to perform OCR on a single image: + +```css +python ocr_engine.py --input_type image --input_path /path/to/image.jpg +``` + +To perform OCR on a recursive directory: + +```css +python ocr_engine.py --input_type directory --input_path /path/to/directory/ + +``` + +To perform OCR on a csv file: + + +``` +python ocr_engine.py --input_type csv --input_path /path/to/file.csv +``` + +OCR Engine will automatically detect and recognize text in the input and output the results in a CSV file named `ocr_results.csv`. + +## Contributing + +If you would like to contribute to OCR Engine, please fork the repository and submit a pull request. We welcome contributions of all types, including bug fixes, new features, and documentation improvements. + +## License + +OCR Engine is released under the [MIT License](https://opensource.org/licenses/MIT). See the LICENSE file for more information. diff --git a/cope2n-ai-fi/modules/ocr_engine/TODO.todo b/cope2n-ai-fi/modules/ocr_engine/TODO.todo new file mode 100644 index 0000000..34df095 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/TODO.todo @@ -0,0 +1,10 @@ +☐ refactor argument parser of run.py +☐ add timer level, logging level and write_mode to argumments +☐ add paddleocr deskew to the code +☐ fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size +☐ ocr engine import took too long +☐ add word level to write_mode +☐ add word group and line +change max_x_dist from pixel to percentage of box width +☐ visualization: adjust fontsize dynamically + diff --git a/cope2n-ai-fi/modules/ocr_engine/__init__.py b/cope2n-ai-fi/modules/ocr_engine/__init__.py new file mode 100644 index 0000000..433d5ff --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/__init__.py @@ -0,0 +1,19 @@ +# # Define package-level variables +# __version__ = '0.0' + +import os +import sys +from pathlib import Path +cur_dir = str(Path(__file__).parents[0]) +sys.path.append(cur_dir) +sys.path.append(os.path.join(cur_dir, "externals")) + +# Import modules +from .run import load_engine +from .src.ocr import OcrEngine +# from .src.word_formation import words_to_lines +from .src.word_formation import words_to_lines_tesseract as words_to_lines +from .src.utils import ImageReader, read_ocr_result_from_txt +from .src.dto import Word, Line, Page, Document, Box +# Expose package contents +__all__ = ["OcrEngine", "Box", "Word", "Line", "Page", "Document", "words_to_lines", "ImageReader", "read_ocr_result_from_txt"] diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/.gitignore b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/.gitignore new file mode 100644 index 0000000..e56dbb2 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/.gitignore @@ -0,0 +1,9 @@ +output* +*.pyc +*.jpg +check +weights/ +workdirs/ +__pycache__* +test_hungbnt.py +libs* \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/README.md b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/README.md new file mode 100644 index 0000000..d1b9637 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/README.md @@ -0,0 +1,29 @@ +

+

Dewarp

+

+ +***Feature*** +- Align document + + +## I. Setup +***Dependencies*** +- Python: 3.8 +- Torch: 1.10.2 +- CUDA: 11.6 +- transformers: 4.28.1 +### 1. Install PaddlePaddle +``` +python -m pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +### 2. Install sdsv_dewarp +``` +pip install -v -e . +``` + + +## II. Test +``` +python test.py --input samples --out demo/outputs --device 'cuda' +``` diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml new file mode 100644 index 0000000..4c4cfdc --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml @@ -0,0 +1,3 @@ +model_dir: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_ppocr_mobile_v2.0_cls_infer +gpu_mem: 3000 +max_batch_size: 32 \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml new file mode 100644 index 0000000..f218ef1 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml @@ -0,0 +1,8 @@ +model_dir: /mnt/hdd4T/OCR/tuanlv/01-BasicOCR/ocr-engine-deskew/externals/sdsv_dewarp/weights/ch_PP-OCRv3_det_infer +gpu_mem: 3000 +det_limit_side_len: 1560 +det_limit_type: max +det_db_unclip_ratio: 1.85 +det_db_thresh: 0.3 +det_db_box_thresh: 0.5 +det_db_score_mode: fast \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/requirements.txt b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/requirements.txt new file mode 100644 index 0000000..768fde9 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/requirements.txt @@ -0,0 +1,7 @@ + +paddleocr>=2.0.1 +opencv-contrib-python +opencv-python +numpy +gdown==3.13.0 +imgaug==0.4.0 diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO new file mode 100644 index 0000000..634708d --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/PKG-INFO @@ -0,0 +1,45 @@ +Metadata-Version: 2.1 +Name: sdsv-dewarp +Version: 1.0.0 +Summary: Dewarp document +Home-page: +License: Apache License 2.0 +Classifier: Development Status :: 4 - Beta +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Description-Content-Type: text/markdown + +

+

Dewarp

+

+ +***Feature*** +- Align document + + +## I. Setup +***Dependencies*** +- Python: 3.8 +- Torch: 1.10.2 +- CUDA: 11.6 +- transformers: 4.28.1 +### 1. Install PaddlePaddle +``` +python -m pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +### 2. Install sdsv_dewarp +``` +pip install -v -e . +``` + + +## II. Test +``` +python test.py --input samples --out demo/outputs --device 'cuda' +``` diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt new file mode 100644 index 0000000..953a123 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/SOURCES.txt @@ -0,0 +1,15 @@ +README.md +setup.py +sdsv_dewarp/__init__.py +sdsv_dewarp/api.py +sdsv_dewarp/config.py +sdsv_dewarp/factory.py +sdsv_dewarp/models.py +sdsv_dewarp/utils.py +sdsv_dewarp/version.py +sdsv_dewarp.egg-info/PKG-INFO +sdsv_dewarp.egg-info/SOURCES.txt +sdsv_dewarp.egg-info/dependency_links.txt +sdsv_dewarp.egg-info/not-zip-safe +sdsv_dewarp.egg-info/requires.txt +sdsv_dewarp.egg-info/top_level.txt \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt new file mode 100644 index 0000000..89816c8 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/requires.txt @@ -0,0 +1,6 @@ +paddleocr>=2.0.1 +opencv-contrib-python +opencv-python +numpy +gdown==3.13.0 +imgaug==0.4.0 diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt new file mode 100644 index 0000000..a5ce4e8 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp.egg-info/top_level.txt @@ -0,0 +1 @@ +sdsv_dewarp diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/__init__.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/api.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/api.py new file mode 100644 index 0000000..d71ddc5 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/api.py @@ -0,0 +1,200 @@ +import math +import numpy as np +from typing import List +import cv2 +import collections +import logging +import imgaug.augmenters as iaa +from imgaug.augmentables.polys import Polygon, PolygonsOnImage + +from sdsv_dewarp.models import PaddleTextClassifier, PaddleTextDetector +from sdsv_dewarp.config import Cfg +from .utils import * + + +MIN_LONG_EDGE = 40**2 +NUMBER_BOX_FOR_ALIGNMENT = 200 +MAX_ANGLE = 180 +MIN_ANGLE = 1 +MIN_NUM_BOX_TEXT = 3 +CROP_SIZE = 3000 + +logging.basicConfig(level=logging.INFO) +LOGGER = logging.getLogger(__name__) + + +class AlignImage: + """Rotate image to 0 degree + Args: + text_detector (deepmodel): Text detection model + text_cls (deepmodel): Text classification model (0 or 180) + + Return: + is_blank (bool): Blank image when haven't boxes text + image_align: Image after alignment + angle_align: Degree of angle alignment + """ + + def __init__(self, text_detector: dict, text_cls: dict, device: str = 'cpu'): + self.text_detector = None + self.text_cls = None + self.use_gpu = True if device != 'cpu' else False + + self._init_model(text_detector, text_cls) + + def _init_model(self, text_detector, text_cls): + det_config = Cfg.load_config_from_file(text_detector['config']) + det_config['model_dir'] = text_detector['weight'] + cls_config = Cfg.load_config_from_file(text_cls['config']) + cls_config['model_dir'] = text_cls['weight'] + + self.text_detector = PaddleTextDetector(config=det_config, use_gpu=self.use_gpu) + self.text_cls = PaddleTextClassifier(config=cls_config, use_gpu=self.use_gpu) + + def _cal_width(self, poly_box): + """Calculate width of a polygon [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]""" + tl, tr, br, bl = poly_box + edge_s, edge_l = distance(tl, tr), distance(tr, br) + + return max(edge_s, edge_l) + + def _get_most_frequent(self, values): + values = np.array(values) + # create the histogram + hist, bins = np.histogram(values, bins=np.arange(0, 181, 10)) + + # get the index of the most frequent angle + index = np.argmax(hist) + + # get the most frequent angle + most_frequent_angle = (bins[index] + bins[index + 1]) / 2 + + return most_frequent_angle + + def _cal_angle(self, poly_box): + """Calculate the angle between two point""" + a = poly_box[0] + b = poly_box[1] + c = poly_box[2] + + # Get the longer edge + if distance(a, b) >= distance(b, c): + x, y = a, b + else: + x, y = b, c + + angle = math.degrees(math.atan2(-(y[1] - x[1]), y[0] - x[0])) + + if angle < 0: + angle = 180 - abs(angle) + + return angle + + def _reject_outliers(self, data, m=5.0): + """Remove noise angle""" + list_index = np.arange(len(data)) + d = np.abs(data - np.median(data)) + mdev = np.median(d) + s = d / (mdev if mdev else 1.0) + + return list_index[s < m], data[s < m] + + def __call__(self, image): + """image (np.ndarray): BGR image""" + + # Crop center image to increase speed of text detection + + image_resized = crop_image(image, crop_size=CROP_SIZE).copy() if max(image.shape) > CROP_SIZE else image.copy() + poly_box_texts = self.text_detector(image_resized) + + # draw_img = vis_ocr( + # image_resized, + # poly_box_texts, + # ) + # cv2.imwrite("draw_img.jpg", draw_img) + + is_blank = False + + # Check image is blank + if len(poly_box_texts) <= MIN_NUM_BOX_TEXT: + is_blank = True + return image, is_blank, 0 + + # # Crop document + # poly_np = np.array(poly_box_texts) + # min_x = poly_box_texts[:, 0].min() + # max_x = poly_box_texts[:, 2].max() + # min_y = poly_box_texts[:, 1].min() + # max_y = poly_box_texts[:, 3].max() + + # Filter small poly + poly_box_areas = [ + [self._cal_width(poly_box), id] + for id, poly_box in enumerate(poly_box_texts) + ] + + poly_box_areas = sorted(poly_box_areas)[-NUMBER_BOX_FOR_ALIGNMENT:] + poly_box_areas = [poly_box_texts[id[1]] for id in poly_box_areas] + + # Calculate angle + list_angle = [self._cal_angle(poly_box) for poly_box in poly_box_areas] + list_angle = [angle if angle >= MIN_ANGLE else 180 for angle in list_angle] + + # LOGGER.info(f"List angle before reject outlier: {list_angle}") + list_angle = np.array(list_angle) + list_index, list_angle = self._reject_outliers(list_angle) + # LOGGER.info(f"List angle after reject outlier: {list_angle}") + + if len(list_angle): + + frequent_angle = self._get_most_frequent(list_angle) + list_angle = [angle for angle in list_angle if abs(angle - frequent_angle) <= 45] + # LOGGER.info(f"List angle after reject angle: {list_angle}") + angle = np.mean(list_angle) + else: + angle = 0 + + # LOGGER.info(f"Avg angle: {angle}") + + # Reuse poly boxes detected by text detection + polys_org = PolygonsOnImage( + [Polygon(poly_box_areas[index]) for index in list_index], + shape=image_resized.shape, + ) + seq_augment = iaa.Sequential([iaa.Rotate(angle, fit_output=True, order=3)]) + + # Rotate image by degree + if angle >= MIN_ANGLE and angle <= MAX_ANGLE: + image_resized, polys_aug = seq_augment( + image=image_resized, polygons=polys_org + ) + else: + angle = 0 + image_resized, polys_aug = image_resized, polys_org + + # cv2.imwrite("image_resized.jpg", image_resized) + + # Classify image 0 or 180 degree + list_poly = [poly.coords for poly in polys_aug] + + image_crop_list = [ + dewarp_by_polygon(image_resized, poly)[0] for poly in list_poly + ] + + cls_res = self.text_cls(image_crop_list) + cls_labels = [cls_[0] for cls_ in cls_res[1]] + # LOGGER.info(f"Angle lines: {cls_labels}") + counter = collections.Counter(cls_labels) + + angle_align = angle + if counter["0"] <= counter["180"]: + aug = iaa.Rotate(angle + 180, fit_output=True, order=3) + angle_align = angle + 180 + else: + aug = iaa.Rotate(angle, fit_output=True, order=3) + + # Rotate the image by degree + image = aug.augment_image(image) + + return image, is_blank, angle_align + # return image diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/config.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/config.py new file mode 100644 index 0000000..204c2c0 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/config.py @@ -0,0 +1,41 @@ +import yaml +import pprint +import os +import json + + +def load_from_yaml(fname): + with open(fname, encoding='utf-8') as f: + base_config = yaml.safe_load(f) + return base_config + +def load_from_json(fname): + with open(fname, "r", encoding='utf-8') as f: + base_config = json.load(f) + return base_config + +class Cfg(dict): + def __init__(self, config_dict): + super(Cfg, self).__init__(**config_dict) + self.__dict__ = self + + @staticmethod + def load_config_from_file(fname, download_base=False): + if not os.path.exists(fname): + raise FileNotFoundError("Not found config at {}".format(fname)) + if fname.endswith(".yaml") or fname.endswith(".yml"): + return Cfg(load_from_yaml(fname)) + elif fname.endswith(".json"): + return Cfg(load_from_json(fname)) + else: + raise Exception(f"{fname} not supported") + + + def save(self, fname): + with open(fname, 'w', encoding='utf-8') as outfile: + yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True) + + # @property + def pretty_text(self): + return pprint.PrettyPrinter().pprint(self) + diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/factory.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/factory.py new file mode 100644 index 0000000..65e4bbd --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/factory.py @@ -0,0 +1,75 @@ +import os +import shutil +import hashlib +import warnings + +def sha256sum(filename): + h = hashlib.sha256() + b = bytearray(128*1024) + mv = memoryview(b) + with open(filename, 'rb', buffering=0) as f: + for n in iter(lambda : f.readinto(mv), 0): + h.update(mv[:n]) + return h.hexdigest() + + +online_model_factory = { + 'yolox-s-general-text-pretrain-20221226': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/62j266xm8r.pth', + 'hash': '89bff792685af454d0cfea5d6d673be6914d614e4c2044e786da6eddf36f8b50'}, + 'yolox-s-checkbox-20220726': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/1647d7eys7.pth', + 'hash': '7c1e188b7375dcf0b7b9d317675ebd92a86fdc29363558002249867249ee10f8'}, + 'yolox-s-idcard-5c-20221027': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/jr0egad3ix.pth', + 'hash': '73a7772594c1f6d3f6d6a98b6d6e4097af5026864e3bd50531ad9e635ae795a7'}, + 'yolox-s-handwritten-text-line-20230228': { + 'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/rb07rtwmgi.pth', + 'hash': 'a31d1bf8fc880479d2e11463dad0b4081952a13e553a02919109b634a1190ef1'} +} + +__hub_available_versions__ = online_model_factory.keys() + +def _get_from_hub(file_path, version, version_url): + os.system(f'wget -O {file_path} {version_url}') + assert os.path.exists(file_path), \ + 'wget failed while trying to retrieve from hub.' + downloaded_hash = sha256sum(file_path) + if downloaded_hash != online_model_factory[version]['hash']: + os.remove(file_path) + raise ValueError('sha256 hash doesnt match for version retrieved from hub.') + +def _get(version): + use_online = version in __hub_available_versions__ + + if not use_online and not os.path.exists(version): + raise ValueError(f'Model version {version} not found online and not found local.') + + hub_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hub') + if not os.path.exists(hub_path): + os.makedirs(hub_path) + if use_online: + version_url = online_model_factory[version]['url'] + file_path = os.path.join(hub_path, os.path.basename(version_url)) + else: + file_path = os.path.join(hub_path, os.path.basename(version)) + + if not os.path.exists(file_path): + if use_online: + _get_from_hub(file_path, version, version_url) + else: + shutil.copy2(version, file_path) + else: + if use_online: + downloaded_hash = sha256sum(file_path) + if downloaded_hash != online_model_factory[version]['hash']: + os.remove(file_path) + warnings.warn('existing hub version sha256 hash doesnt match, now re-download from hub.') + _get_from_hub(file_path, version, version_url) + else: + if sha256sum(file_path) != sha256sum(version): + os.remove(file_path) + warnings.warn('existing local version sha256 hash doesnt match, now replace with new local version.') + shutil.copy2(version, file_path) + + return \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/models.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/models.py new file mode 100644 index 0000000..64cb88c --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/models.py @@ -0,0 +1,73 @@ + +from paddleocr.tools.infer.predict_det import TextDetector +from paddleocr.tools.infer.predict_cls import TextClassifier +from paddleocr.paddleocr import parse_args +from sdsv_dewarp.config import Cfg + +class PaddleTextDetector(object): + def __init__( + self, + # config_path: str, + config: dict, + use_gpu=False + ): + # config = Cfg.load_config_from_file(config_path) + + self.args = parse_args(mMain=False) + self.args.__dict__.update( + det_model_dir=config['model_dir'], + gpu_mem=config['gpu_mem'], + use_gpu=use_gpu, + use_zero_copy_run=True, + max_batch_size=1, + det_limit_side_len=config['det_limit_side_len'], #960 + det_limit_type=config['det_limit_type'], #'max' + det_db_unclip_ratio=config['det_db_unclip_ratio'], + det_db_thresh=config['det_db_thresh'], + det_db_box_thresh=config['det_db_box_thresh'], + det_db_score_mode=config['det_db_score_mode'], + ) + self.text_detector = TextDetector(self.args) + + def __call__(self, image): + """ + + Args: + image (np.ndarray): BGR images + + Returns: + np.ndarray: numpy array of poly boxes - shape 4x2 + """ + dt_boxes, time_infer = self.text_detector(image) + return dt_boxes + + +class PaddleTextClassifier(object): + def __init__( + self, + # config_path: str, + config: str, + use_gpu=False + ): + # config = Cfg.load_config_from_file(config_path) + + self.args = parse_args(mMain=False) + self.args.__dict__.update( + cls_model_dir=config['model_dir'], + gpu_mem=config['gpu_mem'], + use_gpu=use_gpu, + use_zero_copy_run=True, + cls_batch_num=config['max_batch_size'], + ) + self.text_classifier = TextClassifier(self.args) + + def __call__(self, images): + """ + Args: + images (np.ndarray): list of BGR images + + Returns: + img_list, cls_res, elapse : cls_res format = (label, conf) + """ + out= self.text_classifier(images) + return out \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/utils.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/utils.py new file mode 100644 index 0000000..ae02cc1 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/utils.py @@ -0,0 +1,212 @@ +import math +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import random + + +def distance(p1, p2): + """Calculate Euclid distance""" + x1, y1 = p1 + x2, y2 = p2 + dist = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + + return dist + + +def crop_image(image, crop_size=1280): + """Crop center image""" + h, w = image.shape[:2] + x_center, y_center = w // 2, h // 2 + half_size = crop_size // 2 + + xmin, ymin = x_center - half_size, y_center - half_size + xmax, ymax = x_center + half_size, y_center + half_size + + xmin = max(xmin, 0) + ymin = max(ymin, 0) + xmax = min(xmax, w) + ymax = min(ymax, h) + + return image[ymin:ymax, xmin:xmax] + + +def _closest_point(corners, A): + """Find closest A in corrers point""" + distances = [distance(A, p) for p in corners] + return corners[np.argmin(distances)] + + +def _re_order_corners(image_size, corners) -> list: + """Order by corners by clockwise angle""" + h, w = image_size + tl = _closest_point(corners, (0, 0)) + tr = _closest_point(corners, (w, 0)) + br = _closest_point(corners, (w, h)) + bl = _closest_point(corners, (0, h)) + + return [tl, tr, br, bl] + + +def _validate_corner(corners, ratio_thres=0.5, epsilon=1e-3) -> bool: + """Check corners is valid + Invalid: 3 points, duplicate points, .... + """ + c_tl, c_tr, c_br, c_bl = corners + e_top = distance(c_tl, c_tr) + e_right = distance(c_tr, c_br) + e_bottom = distance(c_br, c_bl) + e_left = distance(c_bl, c_tl) + + min_tb = min(e_top, e_bottom) + max_tb = max(e_top, e_bottom) + min_lr = min(e_left, e_right) + max_lr = max(e_left, e_right) + + # Nếu các điểm trùng nhau thì độ dài các cạnh sẽ bằng 0 + if min(max_tb, max_lr) < epsilon: + return False + + ratio = min(min_tb / max_tb, min_lr / max_lr) + if ratio < ratio_thres: + return False + + return True + + +def dewarp_by_polygon( + image, corners, need_validate=False, need_reorder=True, trace_trans=None +): + """Crop and dewarp from 4 corners of images + + Args: + image (np.array) + corners (list): Ex : [(3347, 512), (3379, 2427), (638, 2524), (647, 495)] + need_validate (bool, optional): validate 4 points. Defaults to False. + need_reorder (bool, optional): validate 4 points. Defaults to True. + + Returns: + dewarped: image after dewarp + corners: location of 4 corners after reorder + """ + h, w = image.shape[:2] + + if need_reorder: + corners = _re_order_corners((h, w), corners) + + dewarped = image + + if need_validate: + validate = _validate_corner(corners) + else: + validate = True + + if validate: + # perform dewarp + target_w = int( + max(distance(corners[0], corners[1]), distance(corners[2], corners[3])) + ) + target_h = int( + max(distance(corners[0], corners[3]), distance(corners[1], corners[2])) + ) + target_corners = [ + [0, 0], + [target_w, 0], + [target_w, target_h], + [0, target_h], + ] + + pts1 = np.float32(corners) + pts2 = np.float32(target_corners) + transform_matrix = cv2.getPerspectiveTransform(pts1, pts2) + + dewarped = cv2.warpPerspective(image, transform_matrix, (target_w, target_h)) + if trace_trans is not None: + trace_trans["dewarp_method"]["polygon"][ + "transform_matrix" + ] = transform_matrix + + return (dewarped, corners, trace_trans) + + +def vis_ocr(image, boxes, txts=[], scores=None, drop_score=0.5): + """ + Args: + image (np.ndarray / PIL): BGR image or PIL image + boxes (list / np.ndarray): list of polygon boxes + txts (list): list of text labels + scores (list, optional): probality. Defaults to None. + drop_score (float, optional): . Defaults to 0.5. + font_path (str, optional): Path of font. Defaults to "test/fonts/latin.ttf". + Returns: + np.ndarray: BGR image + """ + + if len(txts) == 0: + txts = [""] * len(boxes) + + if isinstance(image, np.ndarray): + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + if isinstance(boxes, list): + boxes = np.array(boxes) + + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + fill=color, + ) + draw_right.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + outline=color, + ) + box_height = math.sqrt( + (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2 + ) + box_width = math.sqrt( + (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2 + ) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.load_default() + cur_y = box[0][1] + for c in txt: + char_size = font.getsize(c) + draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) + cur_y += char_size[1] + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.load_default() + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + img_show = cv2.cvtColor(np.array(img_show), cv2.COLOR_RGB2BGR) + return img_show diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/version.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/version.py new file mode 100644 index 0000000..a1570ac --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/sdsv_dewarp/version.py @@ -0,0 +1 @@ +__version__="1.0.0" \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/setup.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/setup.py new file mode 100644 index 0000000..7887e8f --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/setup.py @@ -0,0 +1,187 @@ +import os +import os.path as osp +import shutil +import sys +import warnings +from setuptools import find_packages, setup + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +version_file = 'sdsv_dewarp/version.py' +is_windows = sys.platform == 'win32' + + +def add_mim_extention(): + """Add extra files that are required to support MIM into the package. + + These files will be added by creating a symlink to the originals if the + package is installed in `editable` mode (e.g. pip install -e .), or by + copying from the originals otherwise. + """ + + # parse installment mode + if 'develop' in sys.argv: + # installed by `pip install -e .` + mode = 'symlink' + elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + # installed by `pip install .` + # or create source distribution by `python setup.py sdist` + mode = 'copy' + else: + return + + filenames = ['tools', 'configs', 'model-index.yml'] + repo_path = osp.dirname(__file__) + mim_path = osp.join(repo_path, 'mmocr', '.mim') + os.makedirs(mim_path, exist_ok=True) + + for filename in filenames: + if osp.exists(filename): + src_path = osp.join(repo_path, filename) + tar_path = osp.join(mim_path, filename) + + if osp.isfile(tar_path) or osp.islink(tar_path): + os.remove(tar_path) + elif osp.isdir(tar_path): + shutil.rmtree(tar_path) + + if mode == 'symlink': + src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) + try: + os.symlink(src_relpath, tar_path) + except OSError: + # Creating a symbolic link on windows may raise an + # `OSError: [WinError 1314]` due to privilege. If + # the error happens, the src file will be copied + mode = 'copy' + warnings.warn( + f'Failed to create a symbolic link for {src_relpath}, ' + f'and it will be copied to {tar_path}') + else: + continue + + if mode == 'copy': + if osp.isfile(src_path): + shutil.copyfile(src_path, tar_path) + elif osp.isdir(src_path): + shutil.copytree(src_path, tar_path) + else: + warnings.warn(f'Cannot copy file {src_path}.') + else: + raise ValueError(f'Invalid mode {mode}') + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + import sys + + # return short version for sdist + if 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + return locals()['short_version'] + else: + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strip + specific version information. + + Args: + fname (str): Path to requirements file. + with_version (bool, default=False): If True, include version specs. + Returns: + info (list[str]): List of requirements items. + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +if __name__ == '__main__': + setup( + name='sdsv_dewarp', + version=get_version(), + description='Dewarp document', + long_description=readme(), + long_description_content_type='text/markdown', + packages=find_packages(exclude=('configs', 'tools', 'demo')), + include_package_data=True, + url='', + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + ], + license='Apache License 2.0', + install_requires=parse_requirements('requirements.txt'), + zip_safe=False) diff --git a/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/test.py b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/test.py new file mode 100644 index 0000000..4bfdbc7 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/test.py @@ -0,0 +1,47 @@ +from sdsv_dewarp.api import AlignImage +import cv2 +import glob +import os +import tqdm +import time +import argparse + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input") + parser.add_argument("--out") + parser.add_argument("--device", type=str, default="cuda:1") + + args = parser.parse_args() + model = AlignImage(device=args.device) + + + img_dir = args.input + out_dir = args.out + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + img_paths = glob.glob(img_dir + "/*") + + times = [] + for img_path in tqdm.tqdm(img_paths): + t1 = time.time() + img = cv2.imread(img_path) + if img is None: + print(img_path) + continue + + aligned_img, is_blank, angle_align = model(img) + + times.append(time.time() - t1) + + if not is_blank: + cv2.imwrite(os.path.join(out_dir, os.path.basename(img_path)), aligned_img) + else: + cv2.imwrite(os.path.join(out_dir, os.path.basename(img_path)), img) + + + times = times[1:] + print("Avg time: ", sum(times) / len(times)) \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/requirements.txt b/cope2n-ai-fi/modules/ocr_engine/requirements.txt new file mode 100644 index 0000000..7bfd9c5 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/requirements.txt @@ -0,0 +1,82 @@ +addict==2.4.0 +asttokens==2.2.1 +autopep8==1.6.0 +backcall==0.2.0 +backports.functools-lru-cache==1.6.4 +brotlipy==0.7.0 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==2.0.4 +click==8.1.3 +colorama==0.4.6 +cryptography==39.0.1 +debugpy==1.5.1 +decorator==5.1.1 +docopt==0.6.2 +entrypoints==0.4 +executing==1.2.0 +flit_core==3.6.0 +idna==3.4 +importlib-metadata==6.0.0 +ipykernel==6.15.0 +ipython==8.11.0 +jedi==0.18.2 +jupyter-client==7.0.6 +jupyter_core==4.12.0 +Markdown==3.4.1 +markdown-it-py==2.2.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mkl-fft==1.3.6 +mkl-random==1.2.2 +mkl-service==2.4.0 +mmcv-full==1.7.1 +model-index==0.1.11 +nest-asyncio==1.5.6 +numpy==1.24.3 +opencv-python==4.7.0.72 +openmim==0.3.6 +ordered-set==4.1.0 +packaging==23.0 +pandas==1.5.3 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.4.0 +pip==22.3.1 +pipdeptree==2.5.2 +prompt-toolkit==3.0.38 +psutil==5.9.0 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycodestyle==2.10.0 +pycparser==2.21 +Pygments==2.14.0 +pyOpenSSL==23.0.0 +PySocks==1.7.1 +python-dateutil==2.8.2 +pytz==2022.7.1 +PyYAML==6.0 +pyzmq==19.0.2 +requests==2.28.1 +rich==13.3.1 +sdsvtd==0.1.1 +sdsvtr==0.0.5 +setuptools==65.6.3 +Shapely==1.8.4 +six==1.16.0 +stack-data==0.6.2 +tabulate==0.9.0 +toml==0.10.2 +torch==1.13.1 +torchvision==0.14.1 +tornado==6.1 +tqdm==4.65.0 +traitlets==5.9.0 +typing_extensions==4.4.0 +urllib3==1.26.14 +wcwidth==0.2.6 +wheel==0.38.4 +yapf==0.32.0 +yarg==0.1.9 +zipp==3.15.0 diff --git a/cope2n-ai-fi/modules/ocr_engine/run.py b/cope2n-ai-fi/modules/ocr_engine/run.py new file mode 100644 index 0000000..3b19879 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/run.py @@ -0,0 +1,200 @@ +""" +see scripts/run_ocr.sh to run +""" +# from pathlib import Path # add parent path to run debugger +# import sys +# FILE = Path(__file__).absolute() +# sys.path.append(FILE.parents[2].as_posix()) + + +from src.utils import construct_file_path, ImageReader +from src.dto import Line +from src.ocr import OcrEngine +import argparse +import tqdm +import pandas as pd +from pathlib import Path +import json +import os +import numpy as np +from typing import Union, Tuple, List, Optional +from collections import defaultdict + +current_dir = os.getcwd() + + +def get_args(): + parser = argparse.ArgumentParser() + # parser image + parser.add_argument( + "--image", + type=str, + required=True, + help="path to input image/directory/csv file", + ) + parser.add_argument( + "--save_dir", type=str, required=True, help="path to save directory" + ) + parser.add_argument( + "--include", type=str, nargs="+", default=[], help="files/folders to include" + ) + parser.add_argument( + "--exclude", type=str, nargs="+", default=[], help="files/folders to exclude" + ) + parser.add_argument( + "--base_dir", + type=str, + required=False, + default=current_dir, + help="used when --image and --save_dir are relative paths to a base directory, default to current directory", + ) + parser.add_argument( + "--export_csv", + type=str, + required=False, + default="", + help="used when --image is a directory. If set, a csv file contains image_path, ocr_path and label will be exported to save_dir.", + ) + parser.add_argument( + "--export_img", + type=bool, + required=False, + default=False, + help="whether to save the visualize img", + ) + parser.add_argument("--ocr_kwargs", type=str, required=False, default="") + opt = parser.parse_args() + return opt + + +def load_engine(opt) -> OcrEngine: + print("[INFO] Loading engine...") + kw = json.loads(opt.ocr_kwargs) if opt.ocr_kwargs else {} + engine = OcrEngine(**kw) + print("[INFO] Engine loaded") + return engine + + +def convert_relative_path_to_positive_path(tgt_dir: Path, base_dir: Path) -> Path: + return tgt_dir if tgt_dir.is_absolute() else base_dir.joinpath(tgt_dir) + + +def get_paths_from_opt(opt) -> Tuple[Path, Path]: + # BC\ kiem\ tra\ y\ te -> BC kiem tra y te + img_path = opt.image.replace("\\ ", " ").strip() + save_dir = opt.save_dir.replace("\\ ", " ").strip() + base_dir = opt.base_dir.replace("\\ ", " ").strip() + input_image = convert_relative_path_to_positive_path(Path(img_path), Path(base_dir)) + save_dir = convert_relative_path_to_positive_path(Path(save_dir), Path(base_dir)) + if not save_dir.exists(): + save_dir.mkdir() + print("[INFO]: Creating folder ", save_dir) + return input_image, save_dir + + +def process_img( + img: Union[str, np.ndarray], + save_dir_or_path: str, + engine: OcrEngine, + export_img: bool, + save_path_deskew: Optional[str] = None, +) -> None: + save_dir_or_path = Path(save_dir_or_path) + if isinstance(img, np.ndarray): + if save_dir_or_path.is_dir(): + raise ValueError("numpy array input require a save path, not a save dir") + page = engine(img) + save_path = ( + str(save_dir_or_path.joinpath(Path(img).stem + ".txt")) + if save_dir_or_path.is_dir() + else str(save_dir_or_path) + ) + page.write_to_file("word", save_path) + if export_img: + page.save_img( + save_path.replace(".txt", ".jpg"), + is_vnese=True, + save_path_deskew=save_path_deskew, + ) + + +def process_dir( + dir_path: str, + save_dir: str, + engine: OcrEngine, + export_img: bool, + lexcludes: List[str] = [], + lincludes: List[str] = [], + ddata=defaultdict(list), +) -> None: + pdir_path = Path(dir_path) + print(pdir_path) + # save_dir_sub = Path(construct_file_path(save_dir, dir_path, ext="")) + psave_dir = Path(save_dir) + psave_dir.mkdir(exist_ok=True) + for img_path in (pbar := tqdm.tqdm(pdir_path.iterdir())): + pbar.set_description(f"Processing {pdir_path}") + if (lincludes and img_path.name not in lincludes) or ( + img_path.name in lexcludes + ): + continue # only process desired files/foders + if img_path.is_dir(): + psave_dir_sub = psave_dir.joinpath(img_path.stem) + process_dir(img_path, str(psave_dir_sub), engine, ddata) + elif img_path.suffix.lower() in ImageReader.supported_ext: + simg_path = str(img_path) + # try: + img = ( + ImageReader.read(simg_path) + if img_path.suffix != ".pdf" + else ImageReader.read(simg_path)[0] + ) + save_path = str(Path(psave_dir).joinpath(img_path.stem + ".txt")) + save_path_deskew = str( + Path(psave_dir).joinpath(img_path.stem + "_deskewed.jpg") + ) + process_img(img, save_path, engine, export_img, save_path_deskew) + # except Exception as e: + # print('[ERROR]: ', e, ' at ', simg_path) + # continue + ddata["img_path"].append(simg_path) + ddata["ocr_path"].append(save_path) + if Path(save_path_deskew).exists(): + ddata["save_path_deskew"].append(save_path) + ddata["label"].append(pdir_path.stem) + # ddata.update({"img_path": img_path, "save_path": save_path, "label": dir_path.stem}) + return ddata + + +def process_csv(csv_path: str, engine: OcrEngine) -> None: + df = pd.read_csv(csv_path) + if not "image_path" in df.columns or not "ocr_path" in df.columns: + raise AssertionError("Cannot fing image_path in df headers") + for row in df.iterrows(): + process_img(row.image_path, row.ocr_path, engine) + + +if __name__ == "__main__": + opt = get_args() + engine = load_engine(opt) + print("[INFO]: OCR engine settings:", engine.settings) + img, save_dir = get_paths_from_opt(opt) + + lskip_dir = [] + if img.is_dir(): + ddata = process_dir( + img, save_dir, engine, opt.export_img, opt.exclude, opt.include + ) + if opt.export_csv: + pd.DataFrame.from_dict(ddata).to_csv( + Path(save_dir).joinpath(opt.export_csv) + ) + elif img.suffix in ImageReader.supported_ext: + process_img(str(img), save_dir, engine, opt.export_img) + elif img.suffix == ".csv": + print( + "[WARNING]: Running with csv file will ignore the save_dir argument. Instead, the ocr_path in the csv would be used" + ) + process_csv(img, engine) + else: + raise NotImplementedError("[ERROR]: Unsupported file {}".format(img)) diff --git a/cope2n-ai-fi/modules/ocr_engine/scripts/run_deskew.sh b/cope2n-ai-fi/modules/ocr_engine/scripts/run_deskew.sh new file mode 100644 index 0000000..34ecd8c --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/scripts/run_deskew.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=1 +# export PATH=/usr/local/cuda-11.6/bin${PATH:+:${PATH}} +# export LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64\ {LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +# export CUDA_HOME=/usr/local/cuda-11.6 +# export PATH=/usr/local/cuda-11.6/bin:$PATH +# export CPATH=/usr/local/cuda-11.6/include:$CPATH +# export LIBRARY_PATH=/usr/local/cuda-11.6/lib64:$LIBRARY_PATH +# export LD_LIBRARY_PATH=/usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/extras/CUPTI/lib64:$LD_LIBRARY_PATH +python test/test_deskew_dir.py \ No newline at end of file diff --git a/cope2n-ai-fi/modules/ocr_engine/scripts/run_ocr.sh b/cope2n-ai-fi/modules/ocr_engine/scripts/run_ocr.sh new file mode 100644 index 0000000..8ade432 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/scripts/run_ocr.sh @@ -0,0 +1,49 @@ + + +#bash scripts/run_ocr.sh -i /mnt/hdd2T/AICR/Projects/2023/FWD/Forms/PDFs/ -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr -e out.csv -k "{\"device\":\"cuda:1\"}" -p True -n Passport 'So\ HK' +#bash scripts/run_ocr.sh -i '/mnt/hdd2T/AICR/Projects/2023/FWD/Forms/PDFs/So\ HK' -o /mnt/ssd1T/hungbnt/DocumentClassification/results/ocr -e out.csv -k "{\"device\":\"cuda:1\"}" -p True +#-n and -x do not accept multiple argument currently + + +# bash scripts/run_ocr.sh -i /mnt/hdd4T/OCR/hoangdc/End_to_end/ICDAR2013/data/images_receipt_5images/ -o visualize/ -e out.csv -k "{\"device\":\"cuda:1\"}" -p True + +export PYTHONWARNINGS="ignore" + +while getopts i:o:b:e:p:k:n:x: flag +do + case "${flag}" in + i) img=${OPTARG};; + o) out_dir=${OPTARG};; + b) base_dir=${OPTARG};; + e) export_csv=${OPTARG};; + p) export_img=${OPTARG};; + k) ocr_kwargs=${OPTARG};; + n) include=("${OPTARG[@]}");; + x) exclude=("${OPTARG[@]}");; + esac +done + +cmd="python run.py \ + --image $img \ + --save_dir $out_dir \ + --export_csv $export_csv \ + --export_img $export_img \ + --ocr_kwargs $ocr_kwargs" + +if [ ${#include[@]} -gt 0 ]; then + cmd+=" --include" + for item in "${include[@]}"; do + cmd+=" $item" + done +fi + +if [ ${#exclude[@]} -gt 0 ]; then + cmd+=" --exclude" + for item in "${exclude[@]}"; do + cmd+=" $item" + done +fi + + +echo $cmd +exec $cmd diff --git a/cope2n-ai-fi/modules/ocr_engine/settings.yml b/cope2n-ai-fi/modules/ocr_engine/settings.yml new file mode 100644 index 0000000..055e16f --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/settings.yml @@ -0,0 +1,36 @@ +device: &device cuda:0 +max_img_size: [1920,1920] #text det default size: 1280x1280 #[] = originla size, TODO: fix the deskew code to resize the image only for detecting the angle, we want to feed the original size image to the text detection pipeline so that the bounding boxes would be mapped back to the original size +extend_bbox: [0, 0.0, 0.0, 0.0] # left, top, right, bottom +batch_size: 1 #1 means batch_mode = False +detector: + # version: /mnt/hdd2T/datnt/datnt_from_ssd1T/mmdetection/wild_receipt_finetune_weights_c_lite.pth + version: /workspace/cope2n-ai-fi/weights/models/sdsap_sbt/ocr_engine/sdsvtd/epoch_100_params.pth + auto_rotate: True + rotator_version: /workspace/cope2n-ai-fi/weights/models/sdsap_sbt/ocr_engine/sdsvtd/best_bbox_mAP_epoch_30_lite.pth + device: *device + +recognizer: + # version: satrn-lite-general-pretrain-20230106 + version: /workspace/cope2n-ai-fi/weights/models/sdsvtr/hub/jxqhbem4to.pth + max_seq_len_overwrite: 24 #default = 12 + return_confident: True + device: *device +#extend the bbox to avoid losing accent mark in vietnames, if using ocr for only english, disable it + +deskew: + enable: True + text_detector: + config: /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/det.yaml + weight: /workspace/cope2n-ai-fi/weights/models/sdsap_sbt/ocr_engine/sdsv_dewarp/ch_PP-OCRv3_det_infer + text_cls: + config: /workspace/cope2n-ai-fi/modules/ocr_engine/externals/sdsv_dewarp/config/cls.yaml + weight: /workspace/cope2n-ai-fi/weights/models/sdsap_sbt/ocr_engine/sdsv_dewarp/ch_ppocr_mobile_v2.0_cls_infer + device: *device + + +words_to_lines: + gradient: 0.6 + max_x_dist: 20 + max_running_y_shift_degree: 10 #degrees + y_overlap_threshold: 0.5 + word_formation_mode: line diff --git a/cope2n-ai-fi/modules/ocr_engine/src/dto.py b/cope2n-ai-fi/modules/ocr_engine/src/dto.py new file mode 100644 index 0000000..8dae901 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/src/dto.py @@ -0,0 +1,534 @@ +import numpy as np +from typing import Optional, List, Union +import cv2 +from PIL import Image +from pathlib import Path +from .utils import visualize_bbox_and_label + + +class Box: + def __init__( + self, x1: int, y1: int, x2: int, y2: int, conf: float = -1.0, label: str = "" + ): + self._x1 = x1 + self._y1 = y1 + self._x2 = x2 + self._y2 = y2 + self._conf = conf + self._label = label + + def __repr__(self) -> str: + return str(self.bbox) + + def __str__(self) -> str: + return str(self.bbox) + + def get(self, return_confidence=False) -> Union[list[int], list[Union[float, int]]]: + return self.bbox if not return_confidence else self.xyxyc + + def __getitem__(self, key): + return self.bbox[key] + + @property + def width(self): + return max(self._x2 - self._x1, -1) + + @property + def height(self): + return max(self._y2 - self._y1, -1) + + @property + def bbox(self) -> list[int]: + return [self._x1, self._y1, self._x2, self._y2] + + @bbox.setter + def bbox(self, bbox_: list[int]): + self._x1, self._y1, self._x2, self._y2 = bbox_ + + @property + def xyxyc(self) -> list[Union[float, int]]: + return [self._x1, self._y1, self._x2, self._y2, self._conf] + + @staticmethod + def normalize_bbox(bbox: list[int]) -> list[int]: + return [int(b) for b in bbox] + + def to_int(self): + self._x1, self._y1, self._x2, self._y2 = self.normalize_bbox( + [self._x1, self._y1, self._x2, self._y2] + ) + return self + + @staticmethod + def clamp_bbox_by_img_wh(bbox: list, width: int, height: int) -> list[int]: + x1, y1, x2, y2 = bbox + x1 = min(max(0, x1), width) + x2 = min(max(0, x2), width) + y1 = min(max(0, y1), height) + y2 = min(max(0, y2), height) + return [x1, y1, x2, y2] + + def clamp_by_img_wh(self, width: int, height: int): + self._x1, self._y1, self._x2, self._y2 = self.clamp_bbox_by_img_wh( + [self._x1, self._y1, self._x2, self._y2], width, height + ) + return self + + @staticmethod + def extend_bbox(bbox: list, margin: list): # -> Self (python3.11) + margin_l, margin_t, margin_r, margin_b = margin + l, t, r, b = bbox # left, top, right, bottom + t = t - (b - t) * margin_t + b = b + (b - t) * margin_b + l = l - (r - l) * margin_l + r = r + (r - l) * margin_r + return [l, t, r, b] + + def get_extend_bbox(self, margin: list): + extended_bbox = self.extend_bbox(self.bbox, margin) + return Box(*extended_bbox, label=self._label) + + @staticmethod + def bbox_is_valid(bbox: list[int]) -> bool: + if bbox == [-1, -1, -1, -1]: + raise ValueError("Empty bounding box found") + l, t, r, b = bbox # left, top, right, bottom + return True if (b - t) * (r - l) > 0 else False + + def is_valid(self) -> bool: + return self.bbox_is_valid(self.bbox) + + @staticmethod + def crop_img_by_bbox(img: np.ndarray, bbox: list) -> np.ndarray: + l, t, r, b = bbox + return img[t:b, l:r] + + def crop_img(self, img: np.ndarray) -> np.ndarray: + return self.crop_img_by_bbox(img, self.bbox) + + +class Word: + def __init__( + self, + image=None, + text="", + conf_cls=-1.0, + bbox_obj: Box = Box(-1, -1, -1, -1), + conf_detect=-1.0, + kie_label="", + ): + # self.type = "word" + self._text = text + self._image = image + self._conf_det = conf_detect + self._conf_cls = conf_cls + # [left, top,right,bot] coordinate of top-left and bottom-right point + self._bbox_obj = bbox_obj + # self.word_id = 0 # id of word + # self.word_group_id = 0 # id of word_group which instance belongs to + # self.line_id = 0 # id of line which instance belongs to + # self.paragraph_id = 0 # id of line which instance belongs to + self._kie_label = kie_label + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def height(self): + return self._bbox_obj.height + + @property + def width(self): + return self._bbox_obj.width + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + def is_valid(self) -> bool: + return self._bbox_obj.is_valid() + + # def is_special_word(self): + # if not self._text: + # raise ValueError("Cannot validatie size of empty bounding box") + + # # if len(text) > 7: + # # return True + # if len(self._text) >= 7: + # no_digits = sum(c.isdigit() for c in text) + # return no_digits / len(text) >= 0.3 + + # return False + + +class WordGroup: + def __init__( + self, + list_words: List[Word] = list(), + text: str = "", + boundingbox: Box = Box(-1, -1, -1, -1), + conf_cls: float = -1, + conf_det: float = -1, + ): + # self.type = "word_group" + self._list_words = list_words # dict of word instances + # self.word_group_id = 0 # word group id + # self.line_id = 0 # id of line which instance belongs to + # self.paragraph_id = 0 # id of paragraph which instance belongs to + self._text = text + self._bbox_obj = boundingbox + self._kie_label = "" + self._conf_cls = conf_cls + self._conf_det = conf_det + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def list_words(self) -> list[Word]: + return self._list_words + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + # def add_word(self, word: Word): # add a word instance to the word_group + # if word._text != "✪": + # for w in self._list_words: + # if word.word_id == w.word_id: + # print("Word id collision") + # return False + # word.word_group_id = self.word_group_id # + # word.line_id = self.line_id + # word.paragraph_id = self.paragraph_id + # self._list_words.append(word) + # self._text += " " + word._text + # if self.bbox_obj == [-1, -1, -1, -1]: + # self.bbox_obj = word._bbox_obj + # else: + # self.bbox_obj = [ + # min(self.bbox_obj[0], word._bbox_obj[0]), + # min(self.bbox_obj[1], word._bbox_obj[1]), + # max(self.bbox_obj[2], word._bbox_obj[2]), + # max(self.bbox_obj[3], word._bbox_obj[3]), + # ] + # return True + # else: + # return False + + # def update_word_group_id(self, new_word_group_id): + # self.word_group_id = new_word_group_id + # for i in range(len(self._list_words)): + # self._list_words[i].word_group_id = new_word_group_id + + # def update_kie_label(self): + # list_kie_label = [word._kie_label for word in self._list_words] + # dict_kie = dict() + # for label in list_kie_label: + # if label not in dict_kie: + # dict_kie[label] = 1 + # else: + # dict_kie[label] += 1 + # total = len(list(dict_kie.values())) + # max_value = max(list(dict_kie.values())) + # list_keys = list(dict_kie.keys()) + # list_values = list(dict_kie.values()) + # self.kie_label = list_keys[list_values.index(max_value)] + + # def update_text(self): # update text after changing positions of words in list word + # text = "" + # for word in self._list_words: + # text += " " + word._text + # self._text = text + + +class Line: + def __init__( + self, + list_word_groups: List[WordGroup] = [], + text: str = "", + boundingbox: Box = Box(-1, -1, -1, -1), + conf_cls: float = -1, + conf_det: float = -1, + ): + # self.type = "line" + self._list_word_groups = ( + list_word_groups # list of Word_group instances in the line + ) + # self.line_id = 0 # id of line in the paragraph + # self.paragraph_id = 0 # id of paragraph which instance belongs to + self._text = text + self._bbox_obj = boundingbox + self._conf_cls = conf_cls + self._conf_det = conf_det + + @property + def bbox(self) -> list[int]: + return self._bbox_obj.bbox + + @property + def text(self) -> str: + return self._text + + @property + def list_word_groups(self) -> List[WordGroup]: + return self._list_word_groups + + @property + def list_words(self) -> list[Word]: + return [ + word + for word_group in self._list_word_groups + for word in word_group.list_words + ] + + def __repr__(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + # def add_group(self, word_group: WordGroup): # add a word_group instance + # if word_group._list_words is not None: + # for wg in self.list_word_groups: + # if word_group.word_group_id == wg.word_group_id: + # print("Word_group id collision") + # return False + + # self.list_word_groups.append(word_group) + # self.text += word_group._text + # word_group.paragraph_id = self.paragraph_id + # word_group.line_id = self.line_id + + # for i in range(len(word_group._list_words)): + # word_group._list_words[ + # i + # ].paragraph_id = self.paragraph_id # set paragraph_id for word + # word_group._list_words[i].line_id = self.line_id # set line_id for word + # return True + # return False + + # def update_line_id(self, new_line_id): + # self.line_id = new_line_id + # for i in range(len(self.list_word_groups)): + # self.list_word_groups[i].line_id = new_line_id + # for j in range(len(self.list_word_groups[i]._list_words)): + # self.list_word_groups[i]._list_words[j].line_id = new_line_id + + # def merge_word(self, word): # word can be a Word instance or a Word_group instance + # if word.text != "✪": + # if self.boundingbox == [-1, -1, -1, -1]: + # self.boundingbox = word.boundingbox + # else: + # self.boundingbox = [ + # min(self.boundingbox[0], word.boundingbox[0]), + # min(self.boundingbox[1], word.boundingbox[1]), + # max(self.boundingbox[2], word.boundingbox[2]), + # max(self.boundingbox[3], word.boundingbox[3]), + # ] + # self.list_word_groups.append(word) + # self.text += " " + word.text + # return True + # return False + + # def __cal_ratio(self, top1, bottom1, top2, bottom2): + # sorted_vals = sorted([top1, bottom1, top2, bottom2]) + # intersection = sorted_vals[2] - sorted_vals[1] + # min_height = min(bottom1 - top1, bottom2 - top2) + # if min_height == 0: + # return -1 + # ratio = intersection / min_height + # return ratio + + # def __cal_ratio_height(self, top1, bottom1, top2, bottom2): + + # height1, height2 = top1 - bottom1, top2 - bottom2 + # ratio_height = float(max(height1, height2)) / float(min(height1, height2)) + # return ratio_height + + # def in_same_line(self, input_line, thresh=0.7): + # # calculate iou in vertical direction + # _, top1, _, bottom1 = self.boundingbox + # _, top2, _, bottom2 = input_line.boundingbox + + # ratio = self.__cal_ratio(top1, bottom1, top2, bottom2) + # ratio_height = self.__cal_ratio_height(top1, bottom1, top2, bottom2) + + # if ( + # (top2 <= top1 <= bottom2) or (top1 <= top2 <= bottom1) + # and ratio >= thresh + # and (ratio_height < 2) + # ): + # return True + # return False + + +# class Paragraph: +# def __init__(self, id=0, lines=None): +# self.list_lines = lines if lines is not None else [] # list of all lines in the paragraph +# self.paragraph_id = id # index of paragraph in the ist of paragraph +# self.text = "" +# self.boundingbox = [-1, -1, -1, -1] + +# @property +# def bbox(self): +# return self.boundingbox + +# def __repr__(self) -> str: +# return self.text + +# def __str__(self) -> str: +# return self.text + +# def add_line(self, line: Line): # add a line instance +# if line.list_word_groups is not None: +# for l in self.list_lines: +# if line.line_id == l.line_id: +# print("Line id collision") +# return False +# for i in range(len(line.list_word_groups)): +# line.list_word_groups[ +# i +# ].paragraph_id = ( +# self.paragraph_id +# ) # set paragraph id for every word group in line +# for j in range(len(line.list_word_groups[i]._list_words)): +# line.list_word_groups[i]._list_words[ +# j +# ].paragraph_id = ( +# self.paragraph_id +# ) # set paragraph id for every word in word groups +# line.paragraph_id = self.paragraph_id # set paragraph id for line +# self.list_lines.append(line) # add line to paragraph +# self.text += " " + line.text +# return True +# else: +# return False + +# def update_paragraph_id( +# self, new_paragraph_id +# ): # update new paragraph_id for all lines, word_groups, words inside paragraph +# self.paragraph_id = new_paragraph_id +# for i in range(len(self.list_lines)): +# self.list_lines[ +# i +# ].paragraph_id = new_paragraph_id # set new paragraph_id for line +# for j in range(len(self.list_lines[i].list_word_groups)): +# self.list_lines[i].list_word_groups[ +# j +# ].paragraph_id = new_paragraph_id # set new paragraph_id for word_group +# for k in range(len(self.list_lines[i].list_word_groups[j].list_words)): +# self.list_lines[i].list_word_groups[j].list_words[ +# k +# ].paragraph_id = new_paragraph_id # set new paragraph id for word +# return True + + +class Page: + def __init__( + self, + word_segments: Union[List[WordGroup], List[Line]], + image: np.ndarray, + deskewed_image: Optional[np.ndarray] = None, + ) -> None: + self._word_segments = word_segments + self._image = image + self._deskewed_image = deskewed_image + self._drawed_image: Optional[np.ndarray] = None + + @property + def word_segments(self): + return self._word_segments + + @property + def list_words(self) -> list[Word]: + return [ + word + for word_segment in self._word_segments + for word in word_segment.list_words + ] + + @property + def image(self): + return self._image + + @property + def PIL_image(self): + return Image.fromarray(self._image) + + @property + def drawed_image(self): + return self._drawed_image + + @property + def deskewed_image(self): + return self._deskewed_image + + def visualize_bbox_and_label(self, **kwargs: dict): + if self._drawed_image is not None: + return self._drawed_image + bboxes = list() + texts = list() + for word in self.list_words: + bboxes.append([int(float(b)) for b in word.bbox]) + texts.append(word._text) + img = visualize_bbox_and_label( + self._deskewed_image if self._deskewed_image is not None else self._image, + bboxes, + texts, + **kwargs + ) + self._drawed_image = img + return self._drawed_image + + def save_img(self, save_path: str, **kwargs: dict) -> None: + save_path_deskew = kwargs.pop("save_path_deskew", Path(save_path).with_stem(Path(save_path).stem + "_deskewed").as_posix()) + if self._deskewed_image is not None: + # save_path_deskew: str = kwargs.pop("save_path_deskew", Path(save_path).with_stem(Path(save_path).stem + "_deskewed").as_posix()) # type: ignore + cv2.imwrite(save_path_deskew, self._deskewed_image) + + img = self.visualize_bbox_and_label(**kwargs) + cv2.imwrite(save_path, img) + + + def write_to_file(self, mode: str, save_path: str) -> None: + f = open(save_path, "w+", encoding="utf-8") + for word_segment in self._word_segments: + if mode == "segment": + xmin, ymin, xmax, ymax = word_segment.bbox + f.write( + "{}\t{}\t{}\t{}\t{}\n".format( + xmin, ymin, xmax, ymax, word_segment._text + ) + ) + elif mode == "word": + for word in word_segment.list_words: + # xmin, ymin, xmax, ymax = word.bbox + xmin, ymin, xmax, ymax = [int(float(b)) for b in word.bbox] + f.write( + "{}\t{}\t{}\t{}\t{}\n".format( + xmin, ymin, xmax, ymax, word._text + ) + ) + else: + raise NotImplementedError("Unknown mode: {}".format(mode)) + f.close() + + +class Document: + def __init__(self, lpages: List[Page]) -> None: + self.lpages = lpages diff --git a/cope2n-ai-fi/modules/ocr_engine/src/ocr.py b/cope2n-ai-fi/modules/ocr_engine/src/ocr.py new file mode 100644 index 0000000..280d0b2 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/src/ocr.py @@ -0,0 +1,258 @@ +from typing import Union, overload, List, Optional, Tuple +from PIL import Image +import torch +import numpy as np +import yaml +from pathlib import Path +import mmcv +from sdsvtd import StandaloneYOLOXRunner +from sdsvtr import StandaloneSATRNRunner +from sdsv_dewarp.api import AlignImage + +from .utils import ImageReader, chunks, Timer, post_process_recog # rotate_bbox + +# from .utils import jdeskew as deskew +# from externals.deskew.sdsv_dewarp import pdeskew as deskew +# from .utils import deskew +from .dto import Word, Line, Page, Document, Box, WordGroup + +# from .word_formation import words_to_lines as words_to_lines +# from .word_formation import wo rds_to_lines_mmocr as words_to_lines +from .word_formation import words_formation_mmocr_tesseract as word_formation + +DEFAULT_SETTING_PATH = str(Path(__file__).parents[1]) + "/settings.yml" + + +class OcrEngine: + def __init__(self, settings_file: str = DEFAULT_SETTING_PATH, **kwargs): + """Warper of text detection and text recognition + :param settings_file: path to default setting file + :param kwargs: keyword arguments to overwrite the default settings file + """ + with open(settings_file) as f: + # use safe_load instead load + self._settings = yaml.safe_load(f) + self._update_configs(kwargs) + + self._ensure_device() + self._detector = StandaloneYOLOXRunner(**self._settings["detector"]) + self._recognizer = StandaloneSATRNRunner(**self._settings["recognizer"]) + self._deskewer = self._load_deskewer() + + def _update_configs(self, params): + for key, para in params.items(): # overwrite default settings by keyword arguments + if key not in self._settings: + raise ValueError("Invalid setting found in OcrEngine: ", k) + if key == "device": + self._settings[key] = para + self._settings["detector"][key] = para + self._settings["recognizer"][key] = para + self._settings["deskew"][key] = para + else: + for k, v in para.items(): + if isinstance(v, dict): + for sub_key, sub_value in v.items(): + self._settings[key][k][sub_key] = sub_value + else: + self._settings[key][k] = v + + def _load_deskewer(self) -> Optional[AlignImage]: + if self._settings["deskew"]["enable"]: + deskewer = AlignImage( + **{k: v for k, v in self._settings["deskew"].items() if k != "enable"} + ) + print( + "[WARNING]: Deskew is enabled. The bounding boxes prediction may not be aligned with the original image. In case of using these predictions for pseudo-label, turn on save_deskewed option and use the saved deskewed images instead for further proceed." + ) + return deskewer + return None + + def _ensure_device(self): + if "cuda" in self._settings["device"]: + if not torch.cuda.is_available(): + print("[WARNING]: CUDA is not available, running with cpu instead") + self._settings["device"] = "cpu" + + @property + def version(self): + return { + "detector": self._settings["detector"], + "recognizer": self._settings["recognizer"], + } + + @property + def settings(self): + return self._settings + + # @staticmethod + # def xyxyc_to_xyxy_c(xyxyc: np.ndarray) -> Tuple[List[list], list]: + # ''' + # convert sdsvtd yoloX detection output to list of bboxes and list of confidences + # @param xyxyc: array of shape (n, 5) + # ''' + # xyxy = xyxyc[:, :4].tolist() + # confs = xyxyc[:, 4].tolist() + # return xyxy, confs + # -> Tuple[np.ndarray, List[Box]]: + + def preprocess(self, img: np.ndarray) -> tuple[np.ndarray, bool, float]: + img_ = img.copy() + if self._settings["max_img_size"]: + img_ = mmcv.imrescale( + img, + tuple(self._settings["max_img_size"]), + return_scale=False, + interpolation="bilinear", + backend="cv2", + ) + is_blank = False + if self._deskewer: + with Timer("deskew"): + img_, is_blank, angle = self._deskewer(img_) + return img, is_blank, angle # replace img_ to img + # for i, bbox in enumerate(bboxes): + # rotated_bbox = rotate_bbox(bbox, angle, img.shape[:2]) + # bboxes[i].bbox = rotated_bbox + return img, is_blank, 0 + + def run_detect( + self, img: np.ndarray, return_raw: bool = False + ) -> Tuple[np.ndarray, Union[List[Box], List[list]]]: + """ + run text detection and return list of xyxyc if return_confidence is True, otherwise return a list of xyxy + """ + pred_det = self._detector(img) + if self._settings["detector"]["auto_rotate"]: + img, pred_det = pred_det + pred_det = pred_det[0] # only image at a time + return ( + (img, pred_det.tolist()) + if return_raw + else (img, [Box(*xyxyc) for xyxyc in pred_det.tolist()]) + ) + + def run_recog( + self, imgs: List[np.ndarray] + ) -> Union[List[str], List[Tuple[str, float]]]: + if len(imgs) == 0: + return list() + pred_rec = self._recognizer(imgs) + return [ + (post_process_recog(word), conf) + for word, conf in zip(pred_rec[0], pred_rec[1]) + ] + + def read_img(self, img: str) -> np.ndarray: + return ImageReader.read(img) + + def get_cropped_imgs( + self, img: np.ndarray, bboxes: Union[List[Box], List[list]] + ) -> Tuple[List[np.ndarray], List[bool]]: + """ + img: np image + bboxes: list of xyxy + """ + lcropped_imgs = list() + mask = list() + for bbox in bboxes: + bbox = Box(*bbox) if isinstance(bbox, list) else bbox + bbox = bbox.get_extend_bbox(self._settings["extend_bbox"]) + + bbox.clamp_by_img_wh(img.shape[1], img.shape[0]) + bbox.to_int() + if not bbox.is_valid(): + mask.append(False) + continue + cropped_img = bbox.crop_img(img) + lcropped_imgs.append(cropped_img) + mask.append(True) + return lcropped_imgs, mask + + def read_page( + self, img: np.ndarray, bboxes: Union[List[Box], List[list]] + ) -> Union[List[WordGroup], List[Line]]: + if len(bboxes) == 0: # no bbox found + return list() + with Timer("cropped imgs"): + lcropped_imgs, mask = self.get_cropped_imgs(img, bboxes) + with Timer("recog"): + # batch_mode for efficiency + pred_recs = self.run_recog(lcropped_imgs) + with Timer("construct words"): + lwords = list() + for i in range(len(pred_recs)): + if not mask[i]: + continue + text, conf_rec = pred_recs[i][0], pred_recs[i][1] + bbox = Box(*bboxes[i]) if isinstance(bboxes[i], list) else bboxes[i] + lwords.append( + Word( + image=img, + text=text, + conf_cls=conf_rec, + bbox_obj=bbox, + conf_detect=bbox._conf, + ) + ) + with Timer("word formation"): + return word_formation( + lwords, img.shape[1], **self._settings["words_to_lines"] + )[0] + + # https://stackoverflow.com/questions/48127642/incompatible-types-in-assignment-on-union + + @overload + def __call__(self, img: Union[str, np.ndarray, Image.Image]) -> Page: + ... + + @overload + def __call__(self, img: List[Union[str, np.ndarray, Image.Image]]) -> Document: + ... + + def __call__(self, img): # type: ignore #ignoring type before implementing batch_mode + """ + Accept an image or list of them, return ocr result as a page or document + """ + with Timer("read image"): + img = ImageReader.read(img) + if self._settings["batch_size"] == 1: + if isinstance(img, list): + if len(img) == 1: + img = img[0] # in case input type is a 1 page pdf + else: + raise AssertionError( + "list input can only be used with batch_mode enabled" + ) + img_deskewed, is_blank, angle = self.preprocess(img) + + if is_blank: + print( + "[WARNING]: Blank image detected" + ) # TODO: should we stop the execution here? + with Timer("detect"): + img_deskewed, bboxes = self.run_detect(img_deskewed) + with Timer("read_page"): + lsegments = self.read_page(img_deskewed, bboxes) + return Page(lsegments, img, img_deskewed if angle != 0 else None) + else: + # lpages = [] + # # chunks to reduce memory footprint + # for imgs in chunks(img, self._batch_size): + # # pred_dets = self._detector(imgs) + # # TEMP: use list comprehension because sdsvtd do not support batch mode of text detection + # img = self.preprocess(img) + # img, bboxes = self.run_detect(img) + # for img_, bboxes_ in zip(imgs, bboxes): + # llines = self.read_page(img, bboxes_) + # page = Page(llines, img) + # lpages.append(page) + # return Document(lpages) + raise NotImplementedError("Batch mode is currently not supported") + + +if __name__ == "__main__": + img_path = "/mnt/ssd1T/hungbnt/Cello/data/PH/Sea7/Sea_7_1.jpg" + engine = OcrEngine(device="cuda:0") + # https://stackoverflow.com/questions/66435480/overload-following-optional-argument + page = engine(img_path) # type: ignore + print(page._word_segments) diff --git a/cope2n-ai-fi/modules/ocr_engine/src/utils.py b/cope2n-ai-fi/modules/ocr_engine/src/utils.py new file mode 100644 index 0000000..3c4332e --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/src/utils.py @@ -0,0 +1,369 @@ +from PIL import ImageFont, ImageDraw, Image, ImageOps +# import matplotlib.pyplot as plt +import numpy as np +import cv2 +import os +import time +from typing import Generator, Union, List, overload, Tuple, Callable +import glob +import math +from pathlib import Path +from pdf2image import convert_from_path +# from deskew import determine_skew +# from jdeskew.estimator import get_angle +# from jdeskew.utility import rotate as jrotate + + +def post_process_recog(text: str) -> str: + text = text.replace("✪", " ") + return text + + +def find_maximum_without_outliers(lst: list[int], threshold: float = 1.): + ''' + To find the maximum number in a list while excluding its outlier values, you can follow these steps: + Determine the range within which you consider values as outliers. This can be based on a specific threshold or a statistical measure such as the interquartile range (IQR). + Iterate through the list and filter out the outlier values based on the defined range. Keep track of the non-outlier values. + Find the maximum value among the non-outlier values. + ''' + # Calculate the lower and upper boundaries for outliers + q1 = np.percentile(lst, 25) + q3 = np.percentile(lst, 75) + iqr = q3 - q1 + lower_bound = q1 - threshold * iqr + upper_bound = q3 + threshold * iqr + + # Filter out outlier values + non_outliers = [x for x in lst if lower_bound <= x <= upper_bound] + + # Find the maximum value among non-outliers + max_value = max(non_outliers) + + return max_value + + +class Timer: + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, func: Callable, *args): + self.end_time = time.perf_counter() + self.elapsed_time = self.end_time - self.start_time + print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds") + + +# def rotate( +# image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]] +# ) -> np.ndarray: +# old_width, old_height = image.shape[:2] +# angle_radian = math.radians(angle) +# width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width) +# height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height) +# image_center = tuple(np.array(image.shape[1::-1]) / 2) +# rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) +# rot_mat[1, 2] += (width - old_width) / 2 +# rot_mat[0, 2] += (height - old_height) / 2 +# return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background) + + +# def rotate_bbox(bbox: list, angle: float) -> list: +# # Compute the center point of the bounding box +# cx = bbox[0] + bbox[2] / 2 +# cy = bbox[1] + bbox[3] / 2 + +# # Define the scale factor for the rotated bounding box +# scale = 1.0 # following the deskew and jdeskew function +# angle_radian = math.radians(angle) + +# # Obtain the rotation matrix using cv2.getRotationMatrix2D() +# M = cv2.getRotationMatrix2D((cx, cy), angle_radian, scale) + +# # Apply the rotation matrix to the four corners of the bounding box +# corners = np.array([[bbox[0], bbox[1]], +# [bbox[0] + bbox[2], bbox[1]], +# [bbox[0] + bbox[2], bbox[1] + bbox[3]], +# [bbox[0], bbox[1] + bbox[3]]], dtype=np.float32) +# rotated_corners = cv2.transform(np.array([corners]), M)[0] + +# # Compute the bounding box of the rotated corners +# x = int(np.min(rotated_corners[:, 0])) +# y = int(np.min(rotated_corners[:, 1])) +# w = int(np.max(rotated_corners[:, 0]) - np.min(rotated_corners[:, 0])) +# h = int(np.max(rotated_corners[:, 1]) - np.min(rotated_corners[:, 1])) +# rotated_bbox = [x, y, w, h] + +# return rotated_bbox + +# def rotate_bbox(bbox: List[int], angle: float, old_shape: Tuple[int, int]) -> List[int]: +# # https://medium.com/@pokomaru/image-and-bounding-box-rotation-using-opencv-python-2def6c39453 +# bbox_ = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]] +# h, w = old_shape +# cx, cy = (int(w / 2), int(h / 2)) + +# bbox_tuple = [ +# (bbox_[0], bbox_[1]), +# (bbox_[2], bbox_[3]), +# (bbox_[4], bbox_[5]), +# (bbox_[6], bbox_[7]), +# ] # put x and y coordinates in tuples, we will iterate through the tuples and perform rotation + +# rotated_bbox = [] + +# for i, coord in enumerate(bbox_tuple): +# M = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) +# cos, sin = abs(M[0, 0]), abs(M[0, 1]) +# newW = int((h * sin) + (w * cos)) +# newH = int((h * cos) + (w * sin)) +# M[0, 2] += (newW / 2) - cx +# M[1, 2] += (newH / 2) - cy +# v = [coord[0], coord[1], 1] +# adjusted_coord = np.dot(M, v) +# rotated_bbox.insert(i, (adjusted_coord[0], adjusted_coord[1])) +# result = [int(x) for t in rotated_bbox for x in t] +# return [result[i] for i in [0, 1, 2, -1]] # reformat to xyxy + + +# def deskew(image: np.ndarray) -> Tuple[np.ndarray, float]: +# grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) +# angle = 0. +# try: +# angle = determine_skew(grayscale) +# except Exception: +# pass +# rotated = rotate(image, angle, (0, 0, 0)) if angle else image +# return rotated, angle + + +# def jdeskew(image: np.ndarray) -> Tuple[np.ndarray, float]: +# angle = 0. +# try: +# angle = get_angle(image) +# except Exception: +# pass +# # TODO: change resize = True and scale the bounding box +# rotated = jrotate(image, angle, resize=False) if angle else image +# return rotated, angle +# def deskew() + +class ImageReader: + """ + accept anything, return numpy array image + """ + supported_ext = [".png", ".jpg", ".jpeg", ".pdf", ".gif"] + + @staticmethod + def validate_img_path(img_path: str) -> None: + if not os.path.exists(img_path): + raise FileNotFoundError(img_path) + if os.path.isdir(img_path): + raise IsADirectoryError(img_path) + if not Path(img_path).suffix.lower() in ImageReader.supported_ext: + raise NotImplementedError("Not supported extension at {}".format(img_path)) + + @overload + @staticmethod + def read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: ... + + @overload + @staticmethod + def read(img: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: ... + + @overload + @staticmethod + def read(img: str) -> List[np.ndarray]: ... # for pdf or directory + + @staticmethod + def read(img): + if isinstance(img, list): + return ImageReader.from_list(img) + elif isinstance(img, str) and os.path.isdir(img): + return ImageReader.from_dir(img) + elif isinstance(img, str) and img.endswith(".pdf"): + return ImageReader.from_pdf(img) + else: + return ImageReader._read(img) + + @staticmethod + def from_dir(dir_path: str) -> List[np.ndarray]: + if os.path.isdir(dir_path): + image_files = glob.glob(os.path.join(dir_path, "*")) + return ImageReader.from_list(image_files) + else: + raise NotADirectoryError(dir_path) + + @staticmethod + def from_str(img_path: str) -> np.ndarray: + ImageReader.validate_img_path(img_path) + return ImageReader.from_PIL(Image.open(img_path)) + + @staticmethod + def from_np(img_array: np.ndarray) -> np.ndarray: + return img_array + + @staticmethod + def from_PIL(img_pil: Image.Image, transpose=True) -> np.ndarray: + # if img_pil.is_animated: + # raise NotImplementedError("Only static images are supported, animated image found") + if transpose: + img_pil = ImageOps.exif_transpose(img_pil) + if img_pil.mode != "RGB": + img_pil = img_pil.convert("RGB") + + return np.array(img_pil) + + @staticmethod + def from_list(img_list: List[Union[str, np.ndarray, Image.Image]]) -> List[np.ndarray]: + limgs = list() + for img_path in img_list: + try: + if isinstance(img_path, str): + ImageReader.validate_img_path(img_path) + limgs.append(ImageReader._read(img_path)) + except (FileNotFoundError, NotImplementedError, IsADirectoryError) as e: + print("[ERROR]: ", e) + print("[INFO]: Skipping image {}".format(img_path)) + return limgs + + @staticmethod + def from_pdf(pdf_path: str, start_page: int = 0, end_page: int = 0) -> List[np.ndarray]: + pdf_file = convert_from_path(pdf_path) + if end_page is not None: + end_page = min(len(pdf_file), end_page + 1) + limgs = [np.array(pdf_page) for pdf_page in pdf_file[start_page:end_page]] + return limgs + + @staticmethod + def _read(img: Union[str, np.ndarray, Image.Image]) -> np.ndarray: + if isinstance(img, str): + return ImageReader.from_str(img) + elif isinstance(img, Image.Image): + return ImageReader.from_PIL(img) + elif isinstance(img, np.ndarray): + return ImageReader.from_np(img) + else: + raise ValueError("Invalid img argument type: ", type(img)) + + +def get_name(file_path, ext: bool = True): + file_path_ = os.path.basename(file_path) + return file_path_ if ext else os.path.splitext(file_path_)[0] + + +def construct_file_path(dir, file_path, ext=''): + ''' + args: + dir: /path/to/dir + file_path /example_path/to/file.txt + ext = '.json' + return + /path/to/dir/file.json + ''' + return os.path.join( + dir, get_name(file_path, + True)) if ext == '' else os.path.join( + dir, get_name(file_path, + False)) + ext + + +def chunks(lst: list, n: int) -> Generator: + """ + Yield successive n-sized chunks from lst. + https://stackoverflow.com/questions/312443/how-do-i-split-a-list-into-equally-sized-chunks + """ + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +def read_ocr_result_from_txt(file_path: str) -> Tuple[list, list]: + ''' + return list of bounding boxes, list of words + ''' + with open(file_path, 'r') as f: + lines = f.read().splitlines() + boxes, words = [], [] + for line in lines: + if line == "": + continue + x1, y1, x2, y2, text = line.split("\t") + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + if text and text != " ": + words.append(text) + boxes.append((x1, y1, x2, y2)) + return boxes, words + + +def get_xyxywh_base_on_format(bbox, format): + if format == "xywh": + x1, y1, w, h = bbox[0], bbox[1], bbox[2], bbox[3] + x2, y2 = x1 + w, y1 + h + elif format == "xyxy": + x1, y1, x2, y2 = bbox + w, h = x2 - x1, y2 - y1 + else: + raise NotImplementedError("Invalid format {}".format(format)) + return (x1, y1, x2, y2, w, h) + + +def get_dynamic_params_for_bbox_of_label(text, x1, y1, w, h, img_h, img_w, font, font_scale_offset=1): + font_scale_factor = img_h / (img_w + img_h) * font_scale_offset + font_scale = w / (w + h) * font_scale_factor # adjust font scale by width height + thickness = int(font_scale_factor) + 1 + (text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0] + text_offset_x = x1 + text_offset_y = y1 - thickness + box_coords = ((text_offset_x, text_offset_y + 1), (text_offset_x + text_width - 2, text_offset_y - text_height - 2)) + return (font_scale, thickness, text_height, box_coords) + + +def visualize_bbox_and_label( + img, bboxes, texts, bbox_color=(200, 180, 60), + text_color=(0, 0, 0), + format="xyxy", is_vnese=False, draw_text=True): + ori_img_type = type(img) + if is_vnese: + img = Image.fromarray(img) if ori_img_type is np.ndarray else img + draw = ImageDraw.Draw(img) + img_w, img_h = img.size + font_pil_str = "fonts/arial.ttf" + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + else: + img_h, img_w = img.shape[0], img.shape[1] + font_cv2 = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(bboxes)): + text = texts[i] # text = "{}: {:.0f}%".format(LABELS[classIDs[i]], confidences[i]*100) + x1, y1, x2, y2, w, h = get_xyxywh_base_on_format(bboxes[i], format) + font_scale, thickness, text_height, box_coords = get_dynamic_params_for_bbox_of_label( + text, x1, y1, w, h, img_h, img_w, font=font_cv2) + if is_vnese: + font_pil = ImageFont.truetype(font_pil_str, size=text_height) # type: ignore + fdraw_text = draw.text # type: ignore + fdraw_bbox = draw.rectangle # type: ignore + # Pil use different coordinate => y = y+thickness = y-thickness + 2*thickness + arg_text = ((box_coords[0][0], box_coords[1][1]), text) + kwarg_text = {"font": font_pil, "fill": text_color, "width": thickness} + arg_rec = ((x1, y1, x2, y2),) + kwarg_rec = {"outline": bbox_color, "width": thickness} + arg_rec_text = ((box_coords[0], box_coords[1]),) + kwarg_rec_text = {"fill": bbox_color, "width": thickness} + else: + # cv2.rectangle(img, box_coords[0], box_coords[1], color, cv2.FILLED) + # cv2.putText(img, text, (text_offset_x, text_offset_y), font, fontScale=font_scale, color=(50, 0,0), thickness=thickness) + # cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) + fdraw_text = cv2.putText + fdraw_bbox = cv2.rectangle + arg_text = (img, text, box_coords[0]) + kwarg_text = {"fontFace": font_cv2, "fontScale": font_scale, "color": text_color, "thickness": thickness} + arg_rec = (img, (x1, y1), (x2, y2)) + kwarg_rec = {"color": bbox_color, "thickness": thickness} + arg_rec_text = (img, box_coords[0], box_coords[1]) + kwarg_rec_text = {"color": bbox_color, "thickness": cv2.FILLED} + # draw a bounding box rectangle and label on the img + fdraw_bbox(*arg_rec, **kwarg_rec) # type: ignore + if draw_text: + fdraw_bbox(*arg_rec_text, **kwarg_rec_text) # type: ignore + fdraw_text(*arg_text, **kwarg_text) # type: ignore # text have to put in front of rec_text + return np.array(img) if ori_img_type is np.ndarray and is_vnese else img diff --git a/cope2n-ai-fi/modules/ocr_engine/src/word_formation.py b/cope2n-ai-fi/modules/ocr_engine/src/word_formation.py new file mode 100644 index 0000000..3e64b97 --- /dev/null +++ b/cope2n-ai-fi/modules/ocr_engine/src/word_formation.py @@ -0,0 +1,903 @@ +from builtins import dict +from .dto import Word, Line, WordGroup, Box +from .utils import find_maximum_without_outliers +import numpy as np +from typing import Optional, List, Tuple, Union + +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ +### WORDS TO LINES ALGORITHMS FROM MMOCR AND TESSERACT ############################################################################################################################################################################### +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ + +DEGREE_TO_RADIAN_COEF = np.pi / 180 +MAX_INT = int(2e10 + 9) +MIN_INT = -MAX_INT + + +def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8): + """Check if two boxes are on the same line by their y-axis coordinates. + + Two boxes are on the same line if they overlap vertically, and the length + of the overlapping line segment is greater than min_y_overlap_ratio * the + height of either of the boxes. + + Args: + box_a (list), box_b (list): Two bounding boxes to be checked + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for boxes in the same line + + Returns: + The bool flag indicating if they are on the same line + """ + a_y_min = np.min(box_a[1::2]) + b_y_min = np.min(box_b[1::2]) + a_y_max = np.max(box_a[1::2]) + b_y_max = np.max(box_b[1::2]) + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def merge_bboxes_to_group(bboxes_group, x_sorted_boxes): + merged_bboxes = [] + for box_group in bboxes_group: + merged_box = {} + merged_box['text'] = ' '.join( + [x_sorted_boxes[idx]['text'] for idx in box_group]) + x_min, y_min = float('inf'), float('inf') + x_max, y_max = float('-inf'), float('-inf') + for idx in box_group: + x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max) + x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min) + y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max) + y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min) + merged_box['box'] = [ + x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max + ] + merged_box['list_words'] = [x_sorted_boxes[idx]['word'] + for idx in box_group] + merged_bboxes.append(merged_box) + return merged_bboxes + + +def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.3): + """Stitch fragmented boxes of words into lines. + + Note: part of its logic is inspired by @Johndirr + (https://github.com/faustomorales/keras-ocr/issues/22) + + Args: + boxes (list): List of ocr results to be stitched + max_x_dist (int): The maximum horizontal distance between the closest + edges of neighboring boxes in the same line + min_y_overlap_ratio (float): The minimum vertical overlapping ratio + allowed for any pairs of neighboring boxes in the same line + + Returns: + merged_boxes(List[dict]): List of merged boxes and texts + """ + + if len(boxes) <= 1: + if len(boxes) == 1: + boxes[0]["list_words"] = [boxes[0]["word"]] + return boxes + + # merged_groups = [] + merged_lines = [] + + # sort groups based on the x_min coordinate of boxes + x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2])) + # store indexes of boxes which are already parts of other lines + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(x_sorted_boxes)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(x_sorted_boxes)): + if j in skip_idxs: + continue + if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'], + x_sorted_boxes[j]['box'], min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + # split line into lines if the distance between two neighboring + # sub-lines' is greater than max_x_dist + # groups = [] + # line_idx = 0 + # groups.append([line[0]]) + # for k in range(1, len(line)): + # curr_box = x_sorted_boxes[line[k]] + # prev_box = x_sorted_boxes[line[k - 1]] + # dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2]) + # if dist > max_x_dist: + # line_idx += 1 + # groups.append([]) + # groups[line_idx].append(line[k]) + + # # Get merged boxes + merged_line = merge_bboxes_to_group([line], x_sorted_boxes) + merged_lines.extend(merged_line) + # merged_group = merge_bboxes_to_group(groups,x_sorted_boxes) + # merged_groups.extend(merged_group) + + merged_lines = sorted(merged_lines, key=lambda x: np.min(x['box'][1::2])) + # merged_groups = sorted(merged_groups, key=lambda x: np.min(x['box'][1::2])) + return merged_lines # , merged_groups + +# REFERENCE +# https://vigneshgig.medium.com/bounding-box-sorting-algorithm-for-text-detection-and-object-detection-from-left-to-right-and-top-cf2c523c8a85 +# https://huggingface.co/spaces/tomofi/MMOCR/blame/main/mmocr/utils/box_util.py + + +def words_to_lines_mmocr(words: List[Word], *args) -> Tuple[List[Line], Optional[int]]: + bboxes = [{"box": [w.bbox[0], w.bbox[1], w.bbox[2], w.bbox[1], w.bbox[2], w.bbox[3], w.bbox[0], w.bbox[3]], + "text":w._text, "word":w} for w in words] + merged_lines = stitch_boxes_into_lines(bboxes) + merged_groups = merged_lines # TODO: fix code to return both word group and line + lwords_groups = [WordGroup(list_words=merged_box["list_words"], + text=merged_box["text"], + boundingbox=[merged_box["box"][i] for i in [0, 1, 2, -1]]) + for merged_box in merged_groups] + + llines = [Line(text=word_group._text, list_word_groups=[word_group], boundingbox=word_group._bbox_obj) + for word_group in lwords_groups] + + return llines, None # same format with the origin words_to_lines + # lines = [Line() for merged] + + +# def most_overlapping_row(rows, top, bottom, y_shift): +# max_overlap = -1 +# max_overlap_idx = -1 +# for i, row in enumerate(rows): +# row_top, row_bottom = row +# overlap = min(top + y_shift, row_top) - max(bottom + y_shift, row_bottom) +# if overlap > max_overlap: +# max_overlap = overlap +# max_overlap_idx = i +# return max_overlap_idx +def most_overlapping_row(rows, row_words, bottom, top, y_shift, max_row_size, y_overlap_threshold=0.5): + max_overlap = -1 + max_overlap_idx = -1 + overlapping_rows = [] + + for i, row in enumerate(rows): + row_bottom, row_top = row + overlap = min(bottom - y_shift[i], row_bottom) - \ + max(top - y_shift[i], row_top) + + if overlap > max_overlap: + max_overlap = overlap + max_overlap_idx = i + + # if at least overlap 1 pixel and not (overlap too much and overlap too little) + if (row_top <= bottom and row_bottom >= top) and not (bottom - top - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold): + overlapping_rows.append(i) + + # Merge overlapping rows if necessary + if len(overlapping_rows) > 1: + merge_bottom = max(rows[i][0] for i in overlapping_rows) + merge_top = min(rows[i][1] for i in overlapping_rows) + + if merge_bottom - merge_top <= max_row_size: + # Merge rows + merged_row = (merge_bottom, merge_top) + merged_words = [] + # Remove other overlapping rows + + for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2 + merged_words.extend(row_words[row_idx]) + del rows[row_idx] + del row_words[row_idx] + + rows[overlapping_rows[0]] = merged_row + row_words[overlapping_rows[0]].extend(merged_words[::-1]) + max_overlap_idx = overlapping_rows[0] + + if bottom - top - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold: + max_overlap_idx = -1 + return max_overlap_idx + + +def stitch_boxes_into_lines_tesseract(words: list[Word], max_running_y_shift: int, + gradient: float, y_overlap_threshold: float) -> Tuple[list[list[Word]], float]: + sorted_words = sorted(words, key=lambda x: x.bbox[0]) + rows = [] + row_words = [] + max_row_size = find_maximum_without_outliers([word.height for word in sorted_words]) + running_y_shift = [] + for _i, word in enumerate(sorted_words): + bbox, _text = word.bbox, word._text + _x1, y1, _x2, y2 = bbox + bottom, top = y2, y1 + max_row_size = max(max_row_size, bottom - top) + overlap_row_idx = most_overlapping_row( + rows, row_words, bottom, top, running_y_shift, max_row_size, y_overlap_threshold) + + if overlap_row_idx == -1: # No overlapping row found + new_row = (bottom, top) + rows.append(new_row) + row_words.append([word]) + running_y_shift.append(0) + else: # Overlapping row found + row_bottom, row_top = rows[overlap_row_idx] + new_bottom = max(row_bottom, bottom) + new_top = min(row_top, top) + rows[overlap_row_idx] = (new_bottom, new_top) + row_words[overlap_row_idx].append(word) + new_shift = (top + bottom) / 2 - (row_top + row_bottom) / 2 + running_y_shift[overlap_row_idx] = min( + gradient * running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift, max_running_y_shift) # update and clamp + + # Sort rows and row_texts based on the top y-coordinate + sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][1]) + _sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data) + # /_|<- the perpendicular line of the horizontal line and the skew line of the page + page_skew_dist = sum(running_y_shift) / len(running_y_shift) + return sorted_row_words, page_skew_dist + + +def construct_word_groups_tesseract(sorted_row_words: list[list[Word]], + max_x_dist: int, page_skew_dist: float) -> list[list[list[Word]]]: + # approximate page_skew_angle by page_skew_dist + corrected_max_x_dist = max_x_dist * abs(np.cos(page_skew_dist * DEGREE_TO_RADIAN_COEF)) + constructed_row_word_groups = [] + for row_words in sorted_row_words: + lword_groups = [] + line_idx = 0 + lword_groups.append([row_words[0]]) + for k in range(1, len(row_words)): + curr_box = row_words[k].bbox + prev_box = row_words[k - 1].bbox + dist = curr_box[0] - prev_box[2] + if dist > corrected_max_x_dist: + line_idx += 1 + lword_groups.append([]) + lword_groups[line_idx].append(row_words[k]) + constructed_row_word_groups.append(lword_groups) + return constructed_row_word_groups + + +def group_bbox_and_text(lwords: Union[list[Word], list[WordGroup]]) -> tuple[Box, tuple[str, float]]: + text = ' '.join([word._text for word in lwords]) + x_min, y_min = MAX_INT, MAX_INT + x_max, y_max = MIN_INT, MIN_INT + conf_det = 0 + conf_cls = 0 + for word in lwords: + x_max = int(max(np.max(word.bbox[::2]), x_max)) + x_min = int(min(np.min(word.bbox[::2]), x_min)) + y_max = int(max(np.max(word.bbox[1::2]), y_max)) + y_min = int(min(np.min(word.bbox[1::2]), y_min)) + conf_det += word._conf_det + conf_cls += word._conf_cls + bbox = Box(x_min, y_min, x_max, y_max, conf=conf_det / len(lwords)) + return bbox, (text, conf_cls / len(lwords)) + + +def words_to_lines_tesseract(words: List[Word], + page_width: int, max_running_y_shift_degree: int, gradient: float, max_x_dist: int, + y_overlap_threshold: float) -> Tuple[List[Line], + Optional[float]]: + max_running_y_shift = page_width * np.tan(max_running_y_shift_degree * DEGREE_TO_RADIAN_COEF) + sorted_row_words, page_skew_dist = stitch_boxes_into_lines_tesseract( + words, max_running_y_shift, gradient, y_overlap_threshold) + constructed_row_word_groups = construct_word_groups_tesseract( + sorted_row_words, max_x_dist, page_skew_dist) + llines = [] + for row in constructed_row_word_groups: + lwords_row = [] + lword_groups = [] + for word_group in row: + bbox_word_group, text_word_group = group_bbox_and_text(word_group) + lwords_row.extend(word_group) + lword_groups.append( + WordGroup( + list_words=word_group, text=text_word_group[0], + conf_cls=text_word_group[1], + boundingbox=bbox_word_group)) + bbox_line, text_line = group_bbox_and_text(lwords_row) + llines.append( + Line( + list_word_groups=lword_groups, text=text_line[0], + boundingbox=bbox_line, conf_cls=text_line[1])) + return llines, page_skew_dist + + + + +### WORDS TO WORDGROUPS ######################################################################################################################################################################################################################### + + +def merge_overlapping_word_groups( + rows: list[list[int]], + row_words: list[list[Word]], + overlapping_rows: list[int], + max_row_size: int) -> bool: + # Merge found overlapping rows if necessary + merge_top = max(rows[i][1] for i in overlapping_rows) + merge_bottom = min(rows[i][3] for i in overlapping_rows) + merge_left = min(rows[i][0] for i in overlapping_rows) + merge_right = max(rows[i][2] for i in overlapping_rows) + + if merge_top - merge_bottom <= max_row_size: + # Merge rows + merged_row = [merge_left, merge_top, merge_right, merge_bottom] + merged_words = [] + # Remove other overlapping rows + + for row_idx in overlapping_rows[:0:-1]: # [1,2,3] -> 3,2 + merged_words.extend(row_words[row_idx]) + del rows[row_idx] + del row_words[row_idx] + + rows[overlapping_rows[0]] = merged_row + row_words[overlapping_rows[0]].extend(merged_words[::-1]) + return True + return False + + +def most_overlapping_word_groups( + rows, row_words, curr_word_bbox, y_shift, max_row_size, y_overlap_threshold, max_x_dist): + max_overlap = -1 + max_overlap_idx = -1 + overlapping_rows = [] + left, top, right, bottom = curr_word_bbox + for i, row in enumerate(rows): + row_left, row_top, row_right, row_bottom = row + top_shift = top - y_shift[i] + bottom_shift = bottom - y_shift[i] + + # find the most overlapping row + overlap = min(bottom_shift, row_bottom) - max(top_shift, row_top) + if overlap > max_overlap and min(right - row_left, left - row_right) < max_x_dist: + max_overlap = overlap + max_overlap_idx = i + + # exclusive process to handle cases where there are multiple satisfying overlapping rows. For example some rows are not initially overlapping but as the appended words constantly get skewer, there is a change that the end of 1 row would reạch the beginning other row + # if (row_top <= bottom and row_bottom >= top) and not (bottom - top - max_overlap > max_row_size * y_overlap_threshold) and not (max_overlap < max_row_size * y_overlap_threshold): + if (row_top <= bottom_shift and row_bottom >= top_shift) \ + and min(right - row_left, left - row_right) < max_x_dist \ + and not (bottom - top - overlap > max_row_size * y_overlap_threshold) \ + and not (overlap < max_row_size * y_overlap_threshold): + # explain: + # (row_top <= bottom_shift and row_bottom >= top_shift) -> overlap at least 1 pixel + # not (bottom - top - overlap > max_row_size * y_overlap_threshold) -> curr_word is not too big too overlap (to exclude figures containing words) + # not (overlap < max_row_size * y_overlap_threshold) -> overlap too little should not be merged + # min(right - row_left, row_right - left) < max_x_dist -> either the curr_word is close enough to left or right of the curr_row + overlapping_rows.append(i) + + if len(overlapping_rows) > 1 and merge_overlapping_word_groups(rows, row_words, overlapping_rows, max_row_size): + max_overlap_idx = overlapping_rows[0] + if bottom - top - max_overlap > max_row_size * y_overlap_threshold and max_overlap < max_row_size * y_overlap_threshold: + max_overlap_idx = -1 + return max_overlap_idx + + +def update_overlapping_word_group_bbox(rows: list[list[int]], overlap_row_idx: int, curr_word_bbox: list[int]) -> None: + left, top, right, bottom = curr_word_bbox + row_left, row_top, row_right, row_bottom = rows[overlap_row_idx] + new_bottom = max(row_bottom, bottom) + new_top = min(row_top, top) + new_left = min(row_left, left) + new_right = max(row_right, right) + rows[overlap_row_idx] = [new_left, new_top, new_right, new_bottom] + + +def update_word_group_running_y_shift( + running_y_shift: list[float], + overlap_row_idx: int, curr_row_bbox: list[int], + curr_word_bbox: list[int], + gradient: float, max_running_y_shift: float) -> None: + _, top, _, bottom = curr_word_bbox + _, row_top, _, row_bottom = curr_row_bbox + new_shift = (top + bottom) / 2 - (row_top + row_bottom) / 2 + running_y_shift[overlap_row_idx] = min( + gradient * running_y_shift[overlap_row_idx] + (1 - gradient) * new_shift, max_running_y_shift) # update and clamp + + +def stitch_boxes_into_word_groups_tesseract(words: list[Word], + max_running_y_shift: int, gradient: float, y_overlap_threshold: float, + max_x_dist: int) -> Tuple[list[WordGroup], float]: + sorted_words = sorted(words, key=lambda x: x.bbox[0]) + rows = [] + row_words = [] + max_row_size = sorted_words[0].height + running_y_shift = [] + for word in sorted_words: + bbox: list[int] = word.bbox + max_row_size = max(max_row_size, bbox[3] - bbox[1]) + if bbox[-1] < 200 and word.text == "Nguyễn": + print("DEBUGING") + overlap_row_idx = most_overlapping_word_groups( + rows, row_words, bbox, running_y_shift, max_row_size, y_overlap_threshold, max_x_dist) + if overlap_row_idx == -1: # No overlapping row found + rows.append(bbox) # new row + row_words.append([word]) # new row_word + running_y_shift.append(0) + else: # Overlapping row found + # row_bottom, row_top = rows[overlap_row_idx] + update_overlapping_word_group_bbox(rows, overlap_row_idx, bbox) + row_words[overlap_row_idx].append(word) # update row_words + update_word_group_running_y_shift( + running_y_shift, overlap_row_idx, rows[overlap_row_idx], + bbox, gradient, max_running_y_shift) + + # Sort rows and row_texts based on the top y-coordinate + sorted_rows_data = sorted(zip(rows, row_words), key=lambda x: x[0][1]) + _sorted_rows_idx, sorted_row_words = zip(*sorted_rows_data) + lword_groups = [] + for word_group in sorted_row_words: + bbox_word_group, text_word_group = group_bbox_and_text(word_group) + lword_groups.append( + WordGroup( + list_words=word_group, text=text_word_group[0], + conf_cls=text_word_group[1], + boundingbox=bbox_word_group)) + # /_|<- the perpendicular line of the horizontal line and the skew line of the page + page_skew_dist = sum(running_y_shift) / len(running_y_shift) + + return lword_groups, page_skew_dist + + +def is_on_same_line_mmocr_tesseract(box_a: list[int], box_b: list[int], min_y_overlap_ratio: float) -> bool: + a_y_min = box_a[1] + b_y_min = box_b[1] + a_y_max = box_a[3] + b_y_max = box_b[3] + + # Make sure that box a is always the box above another + if a_y_min > b_y_min: + a_y_min, b_y_min = b_y_min, a_y_min + a_y_max, b_y_max = b_y_max, a_y_max + + if b_y_min <= a_y_max: + if min_y_overlap_ratio is not None: + sorted_y = sorted([b_y_min, b_y_max, a_y_max]) + overlap = sorted_y[1] - sorted_y[0] + min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio + min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio + return overlap >= min_a_overlap or \ + overlap >= min_b_overlap + else: + return True + return False + + +def stitch_word_groups_into_lines_mmocr_tesseract( + lword_groups: list[WordGroup], + min_y_overlap_ratio: float) -> list[Line]: + merged_lines = [] + + # sort groups based on the x_min coordinate of boxes + # store indexes of boxes which are already parts of other lines + sorted_word_groups = sorted(lword_groups, key=lambda x: x.bbox[0]) + skip_idxs = set() + + i = 0 + # locate lines of boxes starting from the leftmost one + for i in range(len(sorted_word_groups)): + if i in skip_idxs: + continue + # the rightmost box in the current line + rightmost_box_idx = i + line = [rightmost_box_idx] + for j in range(i + 1, len(sorted_word_groups)): + if j in skip_idxs: + continue + if is_on_same_line_mmocr_tesseract(sorted_word_groups[rightmost_box_idx].bbox, + sorted_word_groups[j].bbox, min_y_overlap_ratio): + line.append(j) + skip_idxs.add(j) + rightmost_box_idx = j + + lword_groups_in_line = [sorted_word_groups[k] for k in line] + bbox_line, text_line = group_bbox_and_text(lword_groups_in_line) + merged_lines.append( + Line( + list_word_groups=lword_groups_in_line, text=text_line[0], + conf_cls=text_line[1], + boundingbox=bbox_line)) + merged_lines = sorted(merged_lines, key=lambda x: x.bbox[1]) + return merged_lines + + +def words_formation_mmocr_tesseract(words: List[Word], page_width: int, word_formation_mode: str, max_running_y_shift_degree: int, gradient: float, + max_x_dist: int, y_overlap_threshold: float) -> Tuple[Union[List[WordGroup], list[Line]], + Optional[float]]: + if len(words) == 0: + return [], 0 + max_running_y_shift = page_width * np.tan(max_running_y_shift_degree * DEGREE_TO_RADIAN_COEF) + lword_groups, page_skew_dist = stitch_boxes_into_word_groups_tesseract( + words, max_running_y_shift, gradient, y_overlap_threshold, max_x_dist) + if word_formation_mode == "word_group": + return lword_groups, page_skew_dist + elif word_formation_mode == "line": + llines = stitch_word_groups_into_lines_mmocr_tesseract(lword_groups, y_overlap_threshold) + return llines, page_skew_dist + else: + raise NotImplementedError("Word formation mode not supported: {}".format(word_formation_mode)) + +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ +### END WORDS TO LINES ALGORITHMS FROM MMOCR AND TESSERACT ############################################################################################################################################################################### +############################################################################################################################################################################################################################ +############################################################################################################################################################################################################################ + +# MIN_IOU_HEIGHT = 0.7 +# MIN_WIDTH_LINE_RATIO = 0.05 + + +# def resize_to_original( +# boundingbox, scale +# ): # resize coordinates to match size of original image +# left, top, right, bottom = boundingbox +# left *= scale[1] +# right *= scale[1] +# top *= scale[0] +# bottom *= scale[0] +# return [left, top, right, bottom] + + +# def check_iomin(word: Word, word_group: Word_group): +# min_height = min( +# word.boundingbox[3] - word.boundingbox[1], +# word_group.boundingbox[3] - word_group.boundingbox[1], +# ) +# intersect = min(word.boundingbox[3], word_group.boundingbox[3]) - max( +# word.boundingbox[1], word_group.boundingbox[1] +# ) +# if intersect / min_height > 0.7: +# return True +# return False + + +# def prepare_line(words): +# lines = [] +# visited = [False] * len(words) +# for id_word, word in enumerate(words): +# if word.invalid_size() == 0: +# continue +# new_line = True +# for i in range(len(lines)): +# if ( +# lines[i].in_same_line(word) and not visited[id_word] +# ): # check if word is in the same line with lines[i] +# lines[i].merge_word(word) +# new_line = False +# visited[id_word] = True + +# if new_line == True: +# new_line = Line() +# new_line.merge_word(word) +# lines.append(new_line) + +# # print(len(lines)) +# # sort line from top to bottom according top coordinate +# lines.sort(key=lambda x: x.boundingbox[1]) +# return lines + + +# def __create_word_group(word, word_group_id): +# new_word_group_ = Word_group() +# new_word_group_.list_words = list() +# new_word_group_.word_group_id = word_group_id +# new_word_group_.add_word(word) + +# return new_word_group_ + + +# def __sort_line(line): +# line.list_word_groups.sort( +# key=lambda x: x.boundingbox[0] +# ) # sort word in lines from left to right + +# return line + + +# def __merge_text_for_line(line): +# line.text = "" +# for word in line.list_word_groups: +# line.text += " " + word.text + +# return line + + +# def __update_list_word_groups(line, word_group_id, word_id, line_width): + +# old_list_word_group = line.list_word_groups +# list_word_groups = [] + +# inital_word_group = __create_word_group( +# old_list_word_group[0], word_group_id) +# old_list_word_group[0].word_id = word_id +# list_word_groups.append(inital_word_group) +# word_group_id += 1 +# word_id += 1 + +# for word in old_list_word_group[1:]: +# check_word_group = True +# word.word_id = word_id +# word_id += 1 + +# if ( +# (not list_word_groups[-1].text.endswith(":")) +# and ( +# (word.boundingbox[0] - list_word_groups[-1].boundingbox[2]) +# / line_width +# < MIN_WIDTH_LINE_RATIO +# ) +# and check_iomin(word, list_word_groups[-1]) +# ): +# list_word_groups[-1].add_word(word) +# check_word_group = False + +# if check_word_group: +# new_word_group = __create_word_group(word, word_group_id) +# list_word_groups.append(new_word_group) +# word_group_id += 1 +# line.list_word_groups = list_word_groups +# return line, word_group_id, word_id + + +# def construct_word_groups_in_each_line(lines): +# line_id = 0 +# word_group_id = 0 +# word_id = 0 +# for i in range(len(lines)): +# if len(lines[i].list_word_groups) == 0: +# continue + +# # left, top ,right, bottom +# line_width = lines[i].boundingbox[2] - \ +# lines[i].boundingbox[0] # right - left +# line_width = 1 # TODO: to remove +# lines[i] = __sort_line(lines[i]) + +# # update text for lines after sorting +# lines[i] = __merge_text_for_line(lines[i]) + +# lines[i], word_group_id, word_id = __update_list_word_groups( +# lines[i], +# word_group_id, +# word_id, +# line_width) +# lines[i].update_line_id(line_id) +# line_id += 1 +# return lines + + +# def words_to_lines(words, check_special_lines=True): # words is list of Word instance +# # sort word by top +# words.sort(key=lambda x: (x.boundingbox[1], x.boundingbox[0])) +# # words.sort(key=lambda x: (sum(x.bbox))) +# number_of_word = len(words) +# # print(number_of_word) +# # sort list words to list lines, which have not contained word_group yet +# lines = prepare_line(words) + +# # construct word_groups in each line +# lines = construct_word_groups_in_each_line(lines) +# return lines, number_of_word + + +# def near(word_group1: Word_group, word_group2: Word_group): +# min_height = min( +# word_group1.boundingbox[3] - word_group1.boundingbox[1], +# word_group2.boundingbox[3] - word_group2.boundingbox[1], +# ) +# overlap = min(word_group1.boundingbox[3], word_group2.boundingbox[3]) - max( +# word_group1.boundingbox[1], word_group2.boundingbox[1] +# ) + +# if overlap > 0: +# return True +# if abs(overlap / min_height) < 1.5: +# print("near enough", abs(overlap / min_height), overlap, min_height) +# return True +# return False + + +# def calculate_iou_and_near(wg1: Word_group, wg2: Word_group): +# min_height = min( +# wg1.boundingbox[3] - +# wg1.boundingbox[1], wg2.boundingbox[3] - wg2.boundingbox[1] +# ) +# overlap = min(wg1.boundingbox[3], wg2.boundingbox[3]) - max( +# wg1.boundingbox[1], wg2.boundingbox[1] +# ) +# iou = overlap / min_height +# distance = min( +# abs(wg1.boundingbox[0] - wg2.boundingbox[2]), +# abs(wg1.boundingbox[2] - wg2.boundingbox[0]), +# ) +# if iou > 0.7 and distance < 0.5 * (wg1.boundingboxp[2] - wg1.boundingbox[0]): +# return True +# return False + + +# def construct_word_groups_to_kie_label(list_word_groups: list): +# kie_dict = dict() +# for wg in list_word_groups: +# if wg.kie_label == "other": +# continue +# if wg.kie_label not in kie_dict: +# kie_dict[wg.kie_label] = [wg] +# else: +# kie_dict[wg.kie_label].append(wg) + +# new_dict = dict() +# for key, value in kie_dict.items(): +# if len(value) == 1: +# new_dict[key] = value +# continue + +# value.sort(key=lambda x: x.boundingbox[1]) +# new_dict[key] = value +# return new_dict + + +# def invoice_construct_word_groups_to_kie_label(list_word_groups: list): +# kie_dict = dict() + +# for wg in list_word_groups: +# if wg.kie_label == "other": +# continue +# if wg.kie_label not in kie_dict: +# kie_dict[wg.kie_label] = [wg] +# else: +# kie_dict[wg.kie_label].append(wg) + +# return kie_dict + + +# def postprocess_total_value(kie_dict): +# if "total_in_words_value" not in kie_dict: +# return kie_dict + +# for k, value in kie_dict.items(): +# if k == "total_in_words_value": +# continue +# l = [] +# for v in value: +# if v.boundingbox[3] <= kie_dict["total_in_words_value"][0].boundingbox[3]: +# l.append(v) + +# if len(l) != 0: +# kie_dict[k] = l + +# return kie_dict + + +# def postprocess_tax_code_value(kie_dict): +# if "buyer_tax_code_value" in kie_dict or "seller_tax_code_value" not in kie_dict: +# return kie_dict + +# kie_dict["buyer_tax_code_value"] = [] +# for v in kie_dict["seller_tax_code_value"]: +# if "buyer_name_key" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_key"][0]) +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# continue + +# if "buyer_name_value" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_value"][0]) +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# continue + +# if "buyer_address_value" in kie_dict and near( +# kie_dict["buyer_address_value"][0], v +# ): +# kie_dict["buyer_tax_code_value"].append(v) +# return kie_dict + + +# def postprocess_tax_code_key(kie_dict): +# if "buyer_tax_code_key" in kie_dict or "seller_tax_code_key" not in kie_dict: +# return kie_dict +# kie_dict["buyer_tax_code_key"] = [] +# for v in kie_dict["seller_tax_code_key"]: +# if "buyer_name_key" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_key"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_key"][0]) +# ): +# kie_dict["buyer_tax_code_key"].append(v) +# continue + +# if "buyer_name_value" in kie_dict and ( +# v.boundingbox[3] > kie_dict["buyer_name_value"][0].boundingbox[3] +# or near(v, kie_dict["buyer_name_value"][0]) +# ): +# kie_dict["buyer_tax_code_key"].append(v) +# continue + +# if "buyer_address_value" in kie_dict and near( +# kie_dict["buyer_address_value"][0], v +# ): +# kie_dict["buyer_tax_code_key"].append(v) + +# return kie_dict + + +# def invoice_postprocess(kie_dict: dict): +# # all keys or values which are below total_in_words_value will be thrown away +# kie_dict = postprocess_total_value(kie_dict) +# kie_dict = postprocess_tax_code_value(kie_dict) +# kie_dict = postprocess_tax_code_key(kie_dict) +# return kie_dict + + +# def throw_overlapping_words(list_words): +# new_list = [list_words[0]] +# for word in list_words: +# overlap = False +# area = (word.boundingbox[2] - word.boundingbox[0]) * ( +# word.boundingbox[3] - word.boundingbox[1] +# ) +# for word2 in new_list: +# area2 = (word2.boundingbox[2] - word2.boundingbox[0]) * ( +# word2.boundingbox[3] - word2.boundingbox[1] +# ) +# xmin_intersect = max(word.boundingbox[0], word2.boundingbox[0]) +# xmax_intersect = min(word.boundingbox[2], word2.boundingbox[2]) +# ymin_intersect = max(word.boundingbox[1], word2.boundingbox[1]) +# ymax_intersect = min(word.boundingbox[3], word2.boundingbox[3]) +# if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: +# continue + +# area_intersect = (xmax_intersect - xmin_intersect) * ( +# ymax_intersect - ymin_intersect +# ) +# if area_intersect / area > 0.7 or area_intersect / area2 > 0.7: +# overlap = True +# if overlap == False: +# new_list.append(word) +# return new_list + + +# def check_iou(box1: Word, box2: Box, threshold=0.9): +# area1 = (box1.boundingbox[2] - box1.boundingbox[0]) * ( +# box1.boundingbox[3] - box1.boundingbox[1] +# ) +# area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin) +# xmin_intersect = max(box1.boundingbox[0], box2.xmin) +# ymin_intersect = max(box1.boundingbox[1], box2.ymin) +# xmax_intersect = min(box1.boundingbox[2], box2.xmax) +# ymax_intersect = min(box1.boundingbox[3], box2.ymax) +# if xmax_intersect < xmin_intersect or ymax_intersect < ymin_intersect: +# area_intersect = 0 +# else: +# area_intersect = (xmax_intersect - xmin_intersect) * ( +# ymax_intersect - ymin_intersect +# ) +# union = area1 + area2 - area_intersect +# iou = area_intersect / union +# if iou > threshold: +# return True +# return False diff --git a/cope2n-ai-fi/modules/sdsvkvu b/cope2n-ai-fi/modules/sdsvkvu new file mode 160000 index 0000000..b93fa59 --- /dev/null +++ b/cope2n-ai-fi/modules/sdsvkvu @@ -0,0 +1 @@ +Subproject commit b93fa59908b3329074a475aaf3a6333b937f34e7 diff --git a/cope2n-ai-fi/requirements.txt b/cope2n-ai-fi/requirements.txt new file mode 100755 index 0000000..d62a544 --- /dev/null +++ b/cope2n-ai-fi/requirements.txt @@ -0,0 +1,10 @@ +django-environ + +sdsv_dewarp +sdsvtd +sdsvtr +sdsvkie +sdsvkvu + +pymupdf +easydict \ No newline at end of file diff --git a/cope2n-ai-fi/run.sh b/cope2n-ai-fi/run.sh new file mode 100755 index 0000000..14df2ea --- /dev/null +++ b/cope2n-ai-fi/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# cd /cope2n-ai-fi/sdsvkie +# pip3 install -v -e . +# cd /cope2n-ai-fi/sdsvtd +# pip3 install -v -e . +# cd /cope2n-ai-fi/sdsvtr +# pip3 install -v -e . +# cd /cope2n-ai-fi +bash -c "celery -A celery_worker.worker_fi worker --loglevel=INFO --pool=solo" \ No newline at end of file