sbt-idp/cope2n-ai-fi/api/Kie_AHung/prediction.py

150 lines
4.3 KiB
Python
Raw Normal View History

2023-11-30 11:22:16 +00:00
from PIL import Image
import cv2
import numpy as np
from transformers import (
LayoutXLMTokenizer,
LayoutLMv2FeatureExtractor,
LayoutXLMProcessor,
LayoutLMv2ForTokenClassification,
)
from common.utils.word_formation import *
from common.utils.global_variables import *
from common.utils.process_label import *
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["CURL_CA_BUNDLE"] = ""
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# config
IGNORE_KIE_LABEL = "others"
KIE_LABELS = [
"Number",
"Name",
"Birthday",
"Home Town",
"Address",
"Sex",
"Nationality",
"Expiry Date",
"Nation",
"Religion",
"Date Range",
"Issued By",
IGNORE_KIE_LABEL,
"Rank"
]
DEVICE = "cuda:0"
# MAX_SEQ_LENGTH = 512 # TODO Fix this hard code
# tokenizer = LayoutXLMTokenizer.from_pretrained(
# "Kie_AHung/model/pretrained/layoutxlm-base/tokenizer", model_max_length=MAX_SEQ_LENGTH
# )
# feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
# processor = LayoutXLMProcessor(feature_extractor, tokenizer)
model = LayoutLMv2ForTokenClassification.from_pretrained(
"Kie_AHung/model/driver_license", num_labels=len(KIE_LABELS), local_files_only=True
).to(
DEVICE
) # TODO FIX this hard code
def load_ocr_labels(list_lines):
words, boxes, labels = [], [], []
for line in list_lines:
for word_group in line.list_word_groups:
for word in word_group.list_words:
xmin, ymin, xmax, ymax = (
word.boundingbox[0],
word.boundingbox[1],
word.boundingbox[2],
word.boundingbox[3],
)
text = word.text
label = "seller_name_value" # TODO ??? fix this
x1, y1, x2, y2 = float(xmin), float(ymin), float(xmax), float(ymax)
if text != " ":
words.append(text)
boxes.append([x1, y1, x2, y2])
labels.append(label)
return words, boxes, labels
def _normalize_box(box, width, height):
return [
int(1000 * (box[0] / width)),
int(1000 * (box[1] / height)),
int(1000 * (box[2] / width)),
int(1000 * (box[3] / height)),
]
def infer_driving_license(image_crop, list_lines, max_n_words, processor):
# Load inputs
# image = Image.open(image_path)
image = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
batch_words, batch_boxes, _ = load_ocr_labels(list_lines)
batch_preds, batch_true_boxes = [], []
list_words = []
for i in range(0, len(batch_words), max_n_words):
words = batch_words[i : i + max_n_words]
boxes = batch_boxes[i : i + max_n_words]
boxes_norm = [
_normalize_box(bbox, image.size[0], image.size[1]) for bbox in boxes
]
# Preprocess
dummy_word_labels = [0] * len(words)
encoding = processor(
image,
text=words,
boxes=boxes_norm,
word_labels=dummy_word_labels,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
)
# Run model
for k, v in encoding.items():
encoding[k] = v.to(DEVICE)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# Postprocess
is_subword = (
(encoding["labels"] == -100).detach().cpu().numpy()[0]
) # remove padding
true_predictions = [
pred for idx, pred in enumerate(predictions) if not is_subword[idx]
]
true_boxes = (
boxes # TODO check assumption that layourlm do not change box order
)
for i, word in enumerate(words):
bndbox = [int(j) for j in true_boxes[i]]
list_words.append(
Word(
text=word, bndbox=bndbox, kie_label=KIE_LABELS[true_predictions[i]]
)
)
batch_preds.extend(true_predictions)
batch_true_boxes.extend(true_boxes)
batch_preds = np.array(batch_preds)
batch_true_boxes = np.array(batch_true_boxes)
return batch_words, batch_preds, batch_true_boxes, list_words