150 lines
4.3 KiB
Python
Executable File
150 lines
4.3 KiB
Python
Executable File
from PIL import Image
|
|
import cv2
|
|
|
|
import numpy as np
|
|
from transformers import (
|
|
LayoutXLMTokenizer,
|
|
LayoutLMv2FeatureExtractor,
|
|
LayoutXLMProcessor,
|
|
LayoutLMv2ForTokenClassification,
|
|
)
|
|
|
|
from common.utils.word_formation import *
|
|
|
|
from common.utils.global_variables import *
|
|
from common.utils.process_label import *
|
|
import ssl
|
|
|
|
ssl._create_default_https_context = ssl._create_unverified_context
|
|
os.environ["CURL_CA_BUNDLE"] = ""
|
|
from PIL import ImageFile
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
|
|
# config
|
|
IGNORE_KIE_LABEL = "others"
|
|
KIE_LABELS = [
|
|
"Number",
|
|
"Name",
|
|
"Birthday",
|
|
"Home Town",
|
|
"Address",
|
|
"Sex",
|
|
"Nationality",
|
|
"Expiry Date",
|
|
"Nation",
|
|
"Religion",
|
|
"Date Range",
|
|
"Issued By",
|
|
IGNORE_KIE_LABEL,
|
|
"Rank"
|
|
]
|
|
DEVICE = "cuda:0"
|
|
|
|
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
|
|
|
|
# tokenizer = LayoutXLMTokenizer.from_pretrained(
|
|
# "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
|
|
# )
|
|
|
|
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
|
|
|
|
model = LayoutLMv2ForTokenClassification.from_pretrained(
|
|
"Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True
|
|
).to(
|
|
DEVICE
|
|
) # TODO FIX this hard code
|
|
|
|
|
|
def load_ocr_labels(list_lines):
|
|
words, boxes, labels = [], [], []
|
|
for line in list_lines:
|
|
for word_group in line.list_word_groups:
|
|
for word in word_group.list_words:
|
|
xmin, ymin, xmax, ymax = (
|
|
word.boundingbox[0],
|
|
word.boundingbox[1],
|
|
word.boundingbox[2],
|
|
word.boundingbox[3],
|
|
)
|
|
text = word.text
|
|
label = "seller_name_value" # TODO ??? fix this
|
|
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
|
|
if text != " ":
|
|
words.append(text)
|
|
boxes.append([x1, y1, x2, y2])
|
|
labels.append(label)
|
|
return words, boxes, labels
|
|
|
|
|
|
def _normalize_box(box, width, height):
|
|
return [
|
|
int(1000 * (box[0] / width)),
|
|
int(1000 * (box[1] / height)),
|
|
int(1000 * (box[2] / width)),
|
|
int(1000 * (box[3] / height)),
|
|
]
|
|
|
|
|
|
def infer_driving_license(image_crop, list_lines, max_n_words, processor):
|
|
# Load inputs
|
|
# image = Image.open(image_path)
|
|
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
|
|
image = Image.fromarray(image)
|
|
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
|
|
batch_preds, batch_true_boxes = [], []
|
|
list_words = []
|
|
for i in range(0, len(batch_words), max_n_words):
|
|
words = batch_words[i : i + max_n_words]
|
|
boxes = batch_boxes[i : i + max_n_words]
|
|
boxes_norm = [
|
|
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
|
|
]
|
|
|
|
# Preprocess
|
|
dummy_word_labels = [0] * len(words)
|
|
encoding = processor(
|
|
image,
|
|
text=words,
|
|
boxes=boxes_norm,
|
|
word_labels=dummy_word_labels,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=512,
|
|
)
|
|
|
|
# Run model
|
|
for k, v in encoding.items():
|
|
encoding[k] = v.to(DEVICE)
|
|
outputs = model(**encoding)
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
|
|
|
# Postprocess
|
|
is_subword = (
|
|
(encoding["labels"] == -100).detach().cpu().numpy()[0]
|
|
) # remove padding
|
|
true_predictions = [
|
|
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
|
|
]
|
|
true_boxes = (
|
|
boxes # TODO check assumption that layourlm do not change box order
|
|
)
|
|
|
|
for i, word in enumerate(words):
|
|
bndbox = [int(j) for j in true_boxes[i]]
|
|
list_words.append(
|
|
Word(
|
|
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
|
|
)
|
|
)
|
|
|
|
batch_preds.extend(true_predictions)
|
|
batch_true_boxes.extend(true_boxes)
|
|
|
|
batch_preds = np.array(batch_preds)
|
|
batch_true_boxes = np.array(batch_true_boxes)
|
|
return batch_words, batch_preds, batch_true_boxes, list_words
|