611 lines
26 KiB
Python
611 lines
26 KiB
Python
|
import os
|
||
|
import cv2
|
||
|
import copy
|
||
|
import time
|
||
|
import torch
|
||
|
import math
|
||
|
import numpy as np
|
||
|
from typing import Callable
|
||
|
from sdsvkvu.utils.post_processing import get_string, get_string_by_deduplicate_bbox, get_string_with_word2line
|
||
|
|
||
|
# def get_colormap():
|
||
|
# return {
|
||
|
# 'others': (0, 0, 255), # others: red
|
||
|
# 'title': (0, 255, 255), # title: yellow
|
||
|
# 'key': (255, 0, 0), # key: blue
|
||
|
# 'value': (0, 255, 0), # value: green
|
||
|
# 'header': (233, 197, 15), # header
|
||
|
# 'group': (0, 128, 128), # group
|
||
|
# 'relation': (0, 0, 255)# (128, 128, 128), # relation
|
||
|
# }
|
||
|
|
||
|
|
||
|
class Timer:
|
||
|
def __init__(self, name: str) -> None:
|
||
|
self.name = name
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.start_time = time.perf_counter()
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, func: Callable, *args):
|
||
|
self.end_time = time.perf_counter()
|
||
|
self.elapsed_time = self.end_time - self.start_time
|
||
|
print(f"[INFO]: {self.name} took : {self.elapsed_time:.6f} seconds")
|
||
|
|
||
|
def get_colormap():
|
||
|
return {
|
||
|
"others": (0, 0, 255), # others: red
|
||
|
"title": (0, 255, 255), # title: yellow
|
||
|
"key": (255, 0, 0), # key: blue
|
||
|
"value": (0, 255, 0), # value: green
|
||
|
"header": (233, 197, 15), # header
|
||
|
"group": (0, 128, 128), # group
|
||
|
"relation": (0, 0, 255), # (128, 128, 128), # relation
|
||
|
|
||
|
# "others": (187, 125, 250), # pink
|
||
|
"seller": (183, 50, 255), # bold pink
|
||
|
"date_key": (128, 51, 115), # orange
|
||
|
"date_value": (55, 250, 250), # yellow
|
||
|
"product_name": (245, 61, 61), # blue
|
||
|
"product_code": (233, 197, 17), # header
|
||
|
"quantity": (102, 255, 102), # green
|
||
|
"sn_key": (179, 134, 89),
|
||
|
"sn_value": (51, 153, 204),
|
||
|
"invoice_number_key": (40, 90, 144),
|
||
|
"invoice_number_value": (162, 239, 204),
|
||
|
"sold_key": (74, 180, 150),
|
||
|
"sold_value": (14, 184, 53),
|
||
|
"voucher": (39, 86, 103),
|
||
|
"website": (207, 19, 85),
|
||
|
"hotline": (153, 224, 56),
|
||
|
# "group": (0, 128, 128), # brown
|
||
|
# "relation": (0, 0, 255), # (128, 128, 128), # red
|
||
|
}
|
||
|
|
||
|
def convert_image(image):
|
||
|
exif = image._getexif()
|
||
|
orientation = None
|
||
|
if exif is not None:
|
||
|
orientation = exif.get(0x0112)
|
||
|
# Convert the PIL image to OpenCV format
|
||
|
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||
|
# Rotate the image in OpenCV if necessary
|
||
|
if orientation == 3:
|
||
|
image = cv2.rotate(image, cv2.ROTATE_180)
|
||
|
elif orientation == 6:
|
||
|
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
|
||
|
elif orientation == 8:
|
||
|
image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||
|
else:
|
||
|
image = np.asarray(image)
|
||
|
|
||
|
if len(image.shape) == 2:
|
||
|
image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
|
||
|
assert len(image.shape) == 3
|
||
|
|
||
|
return image, orientation
|
||
|
|
||
|
def visualize(image, bbox, pr_class_words, pr_relations, color_map, labels=['others', 'title', 'key', 'value', 'header'], thickness=1):
|
||
|
# image, orientation = convert_image(image)
|
||
|
|
||
|
# if orientation is not None and orientation == 6:
|
||
|
# width, height, _ = image.shape
|
||
|
# else:
|
||
|
# height, width, _ = image.shape
|
||
|
|
||
|
if len(pr_class_words) > 0:
|
||
|
id2label = {k: labels[k] for k in range(len(labels))}
|
||
|
for lb, groups in enumerate(pr_class_words):
|
||
|
if lb == 0:
|
||
|
continue
|
||
|
for group_id, group in enumerate(groups):
|
||
|
for i, word_id in enumerate(group):
|
||
|
# x0, y0, x1, y1 = revert_scale_bbox(bbox[word_id], width, height)
|
||
|
x0, y0, x1, y1 = bbox[word_id]
|
||
|
cv2.rectangle(image, (x0, y0), (x1, y1), color=color_map[id2label[lb]], thickness=thickness)
|
||
|
|
||
|
if i == 0:
|
||
|
x_center0, y_center0 = int((x0+x1)/2), int((y0+y1)/2)
|
||
|
else:
|
||
|
x_center1, y_center1 = int((x0+x1)/2), int((y0+y1)/2)
|
||
|
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['group'], thickness=thickness)
|
||
|
x_center0, y_center0 = x_center1, y_center1
|
||
|
|
||
|
if len(pr_relations) > 0:
|
||
|
for pair in pr_relations:
|
||
|
# xyxy0 = revert_scale_bbox(bbox[pair[0]], width, height)
|
||
|
# xyxy1 = revert_scale_bbox(bbox[pair[1]], width, height)
|
||
|
xyxy0 = bbox[pair[0]]
|
||
|
xyxy1 = bbox[pair[1]]
|
||
|
|
||
|
x_center0, y_center0 = int((xyxy0[0] + xyxy0[2])/2), int((xyxy0[1] + xyxy0[3])/2)
|
||
|
x_center1, y_center1 = int((xyxy1[0] + xyxy1[2])/2), int((xyxy1[1] + xyxy1[3])/2)
|
||
|
|
||
|
cv2.line(image, (x_center0, y_center0), (x_center1, y_center1), color=color_map['relation'], thickness=thickness)
|
||
|
|
||
|
return image
|
||
|
|
||
|
def revert_scale_bbox(box, width, height):
|
||
|
return [
|
||
|
int((box[0] / 1000) * width),
|
||
|
int((box[1] / 1000) * height),
|
||
|
int((box[2] / 1000) * width),
|
||
|
int((box[3] / 1000) * height)
|
||
|
]
|
||
|
|
||
|
|
||
|
def draw_kvu_outputs(image: np.ndarray, bbox: list, pr_class_words: list, pr_relations: list, class_names: list = ['others', 'title', 'key', 'value', 'header'], thickness: int = 1):
|
||
|
color_map = get_colormap()
|
||
|
image = visualize(image, bbox, pr_class_words, pr_relations, color_map, class_names, thickness)
|
||
|
if (image.shape[2] == 2):
|
||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR5652BGR)
|
||
|
return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||
|
|
||
|
|
||
|
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
|
||
|
element_windows = []
|
||
|
|
||
|
if len(elements) > window_size:
|
||
|
max_step = math.ceil((len(elements) - window_size)/slice_interval)
|
||
|
|
||
|
for i in range(0, max_step + 1):
|
||
|
if (i*slice_interval+window_size) >= len(elements):
|
||
|
_window = copy.deepcopy(elements[i*slice_interval:])
|
||
|
else:
|
||
|
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
|
||
|
element_windows.append(_window)
|
||
|
return element_windows
|
||
|
else:
|
||
|
return [elements]
|
||
|
|
||
|
|
||
|
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
||
|
start_pos = 1
|
||
|
end_pos = start_pos + lvalids[0]
|
||
|
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
|
||
|
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
|
||
|
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
|
||
|
|
||
|
for i in range(1, len(lpatches)):
|
||
|
start_pos = 1
|
||
|
end_pos = start_pos + lvalids[i]
|
||
|
|
||
|
overlap_gap = copy.deepcopy(loverlaps[i-1])
|
||
|
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
|
||
|
|
||
|
if overlap_gap != 0:
|
||
|
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
|
||
|
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
|
||
|
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
||
|
|
||
|
if average:
|
||
|
avg_overlap = (
|
||
|
prev_overlap + curr_overlap
|
||
|
) / 2.
|
||
|
embedding_tokens = torch.cat(
|
||
|
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
||
|
)
|
||
|
else:
|
||
|
embedding_tokens = torch.cat(
|
||
|
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
|
||
|
)
|
||
|
else:
|
||
|
embedding_tokens = torch.cat(
|
||
|
[embedding_tokens, window], dim=1
|
||
|
)
|
||
|
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
||
|
|
||
|
|
||
|
def parse_initial_words(itc_label, box_first_token_mask, class_names):
|
||
|
itc_label_np = itc_label.cpu().numpy()
|
||
|
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
|
||
|
|
||
|
outputs = [[] for _ in range(len(class_names))]
|
||
|
|
||
|
for token_idx, label in enumerate(itc_label_np):
|
||
|
if box_first_token_mask_np[token_idx] and label != 0:
|
||
|
outputs[label].append(token_idx)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
|
||
|
max_connections = 50
|
||
|
|
||
|
valid_stc_label = stc_label * attention_mask.bool()
|
||
|
valid_stc_label = valid_stc_label.cpu().numpy()
|
||
|
stc_label_np = stc_label.cpu().numpy()
|
||
|
|
||
|
valid_token_indices = np.where(
|
||
|
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
|
||
|
)
|
||
|
|
||
|
next_token_idx_dict = {}
|
||
|
for token_idx in valid_token_indices[0]:
|
||
|
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
|
||
|
|
||
|
outputs = []
|
||
|
for init_token_indices in init_words:
|
||
|
sub_outputs = []
|
||
|
for init_token_idx in init_token_indices:
|
||
|
cur_token_indices = [init_token_idx]
|
||
|
for _ in range(max_connections):
|
||
|
if cur_token_indices[-1] in next_token_idx_dict:
|
||
|
if (
|
||
|
next_token_idx_dict[cur_token_indices[-1]]
|
||
|
not in init_token_indices
|
||
|
):
|
||
|
cur_token_indices.append(
|
||
|
next_token_idx_dict[cur_token_indices[-1]]
|
||
|
)
|
||
|
else:
|
||
|
break
|
||
|
else:
|
||
|
break
|
||
|
sub_outputs.append(tuple(cur_token_indices))
|
||
|
|
||
|
outputs.append(sub_outputs)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
def parse_relations(el_label, box_first_token_mask, dummy_idx):
|
||
|
valid_el_labels = el_label * box_first_token_mask
|
||
|
valid_el_labels = valid_el_labels.cpu().numpy()
|
||
|
el_label_np = el_label.cpu().numpy()
|
||
|
|
||
|
max_token = box_first_token_mask.shape[0] - 1
|
||
|
|
||
|
valid_token_indices = np.where(
|
||
|
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
|
||
|
)
|
||
|
|
||
|
link_map_tuples = []
|
||
|
for token_idx in valid_token_indices[0]:
|
||
|
link_map_tuples.append((el_label_np[token_idx], token_idx))
|
||
|
|
||
|
return set(link_map_tuples)
|
||
|
|
||
|
|
||
|
def get_pairs(json: list, rel_from: str, rel_to: str) -> dict:
|
||
|
outputs = {}
|
||
|
for pair in json:
|
||
|
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||
|
for element in pair:
|
||
|
if element['class'] in (rel_from, rel_to):
|
||
|
is_rel[element['class']]['status'] = 1
|
||
|
is_rel[element['class']]['value'] = element
|
||
|
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||
|
outputs[is_rel[rel_to]['value']['group_id']] = [is_rel[rel_from]['value']['group_id'], is_rel[rel_to]['value']['group_id']]
|
||
|
return outputs
|
||
|
|
||
|
def get_table_relations(json: list, header_key_pairs: dict, rel_from="key", rel_to="value") -> dict:
|
||
|
list_keys = list(header_key_pairs.keys())
|
||
|
relations = {k: [] for k in list_keys}
|
||
|
for pair in json:
|
||
|
is_rel = {rel_from: {'status': 0}, rel_to: {'status': 0}}
|
||
|
for element in pair:
|
||
|
if element['class'] == rel_from and element['group_id'] in list_keys:
|
||
|
is_rel[rel_from]['status'] = 1
|
||
|
is_rel[rel_from]['value'] = element
|
||
|
if element['class'] == rel_to:
|
||
|
is_rel[rel_to]['status'] = 1
|
||
|
is_rel[rel_to]['value'] = element
|
||
|
if all([v['status'] == 1 for _, v in is_rel.items()]):
|
||
|
relations[is_rel[rel_from]['value']['group_id']].append(is_rel[rel_to]['value']['group_id'])
|
||
|
return relations
|
||
|
|
||
|
def get_key2values_relations(key_value_pairs: dict):
|
||
|
triple_linkings = {}
|
||
|
for value_group_id, key_value_pair in key_value_pairs.items():
|
||
|
key_group_id = key_value_pair[0]
|
||
|
if key_group_id not in list(triple_linkings.keys()):
|
||
|
triple_linkings[key_group_id] = []
|
||
|
triple_linkings[key_group_id].append(value_group_id)
|
||
|
return triple_linkings
|
||
|
|
||
|
|
||
|
def get_wordgroup_bbox(lbbox: list, lword_ids: list) -> list:
|
||
|
points = [lbbox[i] for i in lword_ids]
|
||
|
x_min, y_min = min(points, key=lambda x: x[0])[0], min(points, key=lambda x: x[1])[1]
|
||
|
x_max, y_max = max(points, key=lambda x: x[2])[2], max(points, key=lambda x: x[3])[3]
|
||
|
return [x_min, y_min, x_max, y_max]
|
||
|
|
||
|
|
||
|
def merged_token_to_wordgroup(class_words: list, lwords: list, lbbox: list, labels: list) -> dict:
|
||
|
word_groups = {}
|
||
|
id2class = {i: labels[i] for i in range(len(labels))}
|
||
|
for class_id, lwgroups_in_class in enumerate(class_words):
|
||
|
for ltokens_in_wgroup in lwgroups_in_class:
|
||
|
group_id = ltokens_in_wgroup[0]
|
||
|
ltokens_to_ltexts = [lwords[token] for token in ltokens_in_wgroup]
|
||
|
ltokens_to_lbboxes = [lbbox[token] for token in ltokens_in_wgroup]
|
||
|
# text_string = get_string(ltokens_to_ltexts)
|
||
|
text_string = get_string_by_deduplicate_bbox(ltokens_to_ltexts, ltokens_to_lbboxes)
|
||
|
# text_string = get_string_with_word2line(ltokens_to_ltexts, ltokens_to_lbboxes)
|
||
|
group_bbox = get_wordgroup_bbox(lbbox, ltokens_in_wgroup)
|
||
|
word_groups[group_id] = {
|
||
|
'group_id': group_id,
|
||
|
'text': text_string,
|
||
|
'class': id2class[class_id],
|
||
|
'tokens': ltokens_in_wgroup,
|
||
|
'bbox': group_bbox
|
||
|
}
|
||
|
return word_groups
|
||
|
|
||
|
def verify_linking_id(word_groups: dict, linking_id: int) -> int:
|
||
|
if linking_id not in list(word_groups):
|
||
|
for wg_id, _word_group in word_groups.items():
|
||
|
if linking_id in _word_group['tokens']:
|
||
|
return wg_id
|
||
|
return linking_id
|
||
|
|
||
|
def matched_wordgroup_relations(word_groups:dict, lrelations: list) -> list:
|
||
|
outputs = []
|
||
|
for pair in lrelations:
|
||
|
wg_from = verify_linking_id(word_groups, pair[0])
|
||
|
wg_to = verify_linking_id(word_groups, pair[1])
|
||
|
try:
|
||
|
outputs.append([word_groups[wg_from], word_groups[wg_to]])
|
||
|
except Exception as e:
|
||
|
print('Not valid pair:', wg_from, wg_to)
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
def get_single_entity(word_groups: dict, lrelations: list, labels: list) -> list:
|
||
|
# single_entity = {'title': [], 'key': [], 'value': [], 'header': []}
|
||
|
single_entity = {lb: [] for lb in labels}
|
||
|
list_linked_ids = []
|
||
|
for pair in lrelations:
|
||
|
list_linked_ids.extend(pair)
|
||
|
|
||
|
for word_group_id, word_group in word_groups.items():
|
||
|
if word_group_id not in list_linked_ids:
|
||
|
single_entity[word_group['class']].append(word_group)
|
||
|
return single_entity
|
||
|
|
||
|
|
||
|
def export_kvu_outputs(lwords, lbbox, class_words, lrelations, labels=['others', 'title', 'key', 'value', 'header']):
|
||
|
word_groups = merged_token_to_wordgroup(class_words, lwords, lbbox, labels)
|
||
|
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
|
||
|
|
||
|
header_key = get_pairs(linking_pairs, rel_from='header', rel_to='key') # => {key_group_id: [header_group_id, key_group_id]}
|
||
|
header_value = get_pairs(linking_pairs, rel_from='header', rel_to='value') # => {value_group_id: [header_group_id, value_group_id]}
|
||
|
key_value = get_pairs(linking_pairs, rel_from='key', rel_to='value') # => {value_group_id: [key_group_id, value_group_id]}
|
||
|
|
||
|
single_entity = get_single_entity(word_groups, lrelations, labels=labels)
|
||
|
|
||
|
# table_relations = get_table_relations(linking_pairs, header_key) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||
|
key2values_relations = get_key2values_relations(key_value) # => {key_group_id: [value_group_id1, value_groupid2, ...]}
|
||
|
|
||
|
triplet_pairs = []
|
||
|
single_pairs = []
|
||
|
table = []
|
||
|
# print('key2values_relations', key2values_relations)
|
||
|
for key_group_id, list_value_group_ids in key2values_relations.items():
|
||
|
if len(list_value_group_ids) == 0: continue
|
||
|
elif len(list_value_group_ids) == 1:
|
||
|
value_group_id = list_value_group_ids[0]
|
||
|
single_pairs.append({word_groups[key_group_id]['text']: {
|
||
|
'id': value_group_id,
|
||
|
'class': "value",
|
||
|
'text': word_groups[value_group_id]['text'],
|
||
|
'bbox': word_groups[value_group_id]['bbox'],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
}})
|
||
|
else:
|
||
|
item = []
|
||
|
for value_group_id in list_value_group_ids:
|
||
|
if value_group_id not in header_value.keys():
|
||
|
header_name_for_value = "non-header"
|
||
|
else:
|
||
|
header_group_id = header_value[value_group_id][0]
|
||
|
header_name_for_value = word_groups[header_group_id]['text']
|
||
|
item.append({
|
||
|
'id': value_group_id,
|
||
|
'class': 'value',
|
||
|
'header': header_name_for_value,
|
||
|
'text': word_groups[value_group_id]['text'],
|
||
|
'bbox': word_groups[value_group_id]['bbox'],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
"header_bbox": word_groups[header_group_id]["bbox"]
|
||
|
if header_group_id != -1 else [0, 0, 0, 0],
|
||
|
})
|
||
|
if key_group_id not in list(header_key.keys()):
|
||
|
triplet_pairs.append({
|
||
|
word_groups[key_group_id]['text']: item
|
||
|
})
|
||
|
else:
|
||
|
header_group_id = header_key[key_group_id][0]
|
||
|
header_name_for_key = word_groups[header_group_id]['text']
|
||
|
item.append({
|
||
|
'id': key_group_id,
|
||
|
'class': 'key',
|
||
|
'header': header_name_for_key,
|
||
|
'text': word_groups[key_group_id]['text'],
|
||
|
'bbox': word_groups[key_group_id]['bbox'],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
"header_bbox": word_groups[header_group_id]["bbox"],
|
||
|
})
|
||
|
table.append({key_group_id: item})
|
||
|
|
||
|
|
||
|
# Add entity without linking
|
||
|
single_entity_dict = {}
|
||
|
for class_name, single_items in single_entity.items():
|
||
|
single_entity_dict[class_name] = []
|
||
|
for single_item in single_items:
|
||
|
single_entity_dict[class_name].append({
|
||
|
'text': single_item['text'],
|
||
|
'id': single_item['group_id'],
|
||
|
'class': class_name,
|
||
|
'bbox': single_item['bbox']
|
||
|
})
|
||
|
|
||
|
|
||
|
if len(table) > 0:
|
||
|
table = sorted(table, key=lambda x: list(x.keys())[0])
|
||
|
table = [v for item in table for k, v in item.items()]
|
||
|
|
||
|
outputs = {}
|
||
|
outputs['title'] = sorted(
|
||
|
single_entity_dict["title"], key=lambda x: x["id"]
|
||
|
)
|
||
|
outputs['key'] = sorted(
|
||
|
single_entity_dict["key"], key=lambda x: x["id"]
|
||
|
)
|
||
|
outputs['value'] = sorted(
|
||
|
single_entity_dict["value"], key=lambda x: x["id"]
|
||
|
)
|
||
|
outputs['single'] = sorted(single_pairs, key=lambda x: int(float(list(x.values())[0]['id'])))
|
||
|
outputs['triplet'] = triplet_pairs
|
||
|
outputs['table'] = table
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
|
||
|
def export_sbt_outputs(
|
||
|
lwords,
|
||
|
lbboxes,
|
||
|
class_words,
|
||
|
lrelations,
|
||
|
labels,
|
||
|
):
|
||
|
word_groups = merged_token_to_wordgroup(class_words, lwords, lbboxes, labels)
|
||
|
linking_pairs = matched_wordgroup_relations(word_groups, lrelations)
|
||
|
|
||
|
date_key_value_pairs = get_pairs(
|
||
|
linking_pairs, rel_from="date_key", rel_to="date_value"
|
||
|
) # => {date_value_group_id: [date_key_group_id, date_value_group_id]}
|
||
|
# product_name_code_pairs = get_pairs(
|
||
|
# linking_pairs, rel_to="product_name", rel_from="product_code"
|
||
|
# ) # => {product_name_group_id: [product_code_group_id, product_name_group_id]}
|
||
|
# product_name_quantity_pairs = get_pairs(
|
||
|
# linking_pairs, rel_to="product_name", rel_from="quantity"
|
||
|
# ) # => {product_name_group_id: [quantity_group_id, product_name_group_id]}
|
||
|
serial_key_value_pairs = get_pairs(
|
||
|
linking_pairs, rel_from="sn_key", rel_to="sn_value"
|
||
|
) # => {sn_value_group_id: [sn_key_group_id, sn_value_group_id]}
|
||
|
|
||
|
sold_key_value_pairs = get_pairs(
|
||
|
linking_pairs, rel_from="sold_key", rel_to="sold_value"
|
||
|
) # => {sold_value_group_id: [sold_key_group_id, sold_value_group_id]}
|
||
|
|
||
|
single_entity = get_single_entity(word_groups, lrelations, labels=labels)
|
||
|
|
||
|
date_value = []
|
||
|
sold_value = []
|
||
|
serial_imei = []
|
||
|
table = []
|
||
|
# print('key2values_relations', key2values_relations)
|
||
|
date_relations = get_key2values_relations(date_key_value_pairs)
|
||
|
for key_group_id, list_value_group_id in date_relations.items():
|
||
|
for value_group_id in list_value_group_id:
|
||
|
date_value.append(
|
||
|
{
|
||
|
"text": word_groups[value_group_id]["text"],
|
||
|
"id": value_group_id,
|
||
|
"class": "date_value",
|
||
|
"bbox": word_groups[value_group_id]["bbox"],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
"raw_key_name": word_groups[key_group_id]["text"],
|
||
|
}
|
||
|
)
|
||
|
|
||
|
sold_relations = get_key2values_relations(sold_key_value_pairs)
|
||
|
for key_group_id, list_value_group_id in sold_relations.items():
|
||
|
for value_group_id in list_value_group_id:
|
||
|
sold_value.append(
|
||
|
{
|
||
|
"text": word_groups[value_group_id]["text"],
|
||
|
"id": value_group_id,
|
||
|
"class": "sold_value",
|
||
|
"bbox": word_groups[value_group_id]["bbox"],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
"raw_key_name": word_groups[key_group_id]["text"],
|
||
|
}
|
||
|
)
|
||
|
|
||
|
|
||
|
serial_relations = get_key2values_relations(serial_key_value_pairs)
|
||
|
for key_group_id, list_value_group_id in serial_relations.items():
|
||
|
for value_group_id in list_value_group_id:
|
||
|
serial_imei.append(
|
||
|
{
|
||
|
"text": word_groups[value_group_id]["text"],
|
||
|
"id": value_group_id,
|
||
|
"class": "sn_value",
|
||
|
"bbox": word_groups[value_group_id]["bbox"],
|
||
|
"key_bbox": word_groups[key_group_id]["bbox"],
|
||
|
"raw_key_name": word_groups[key_group_id]["text"],
|
||
|
}
|
||
|
)
|
||
|
|
||
|
|
||
|
single_entity_dict = {}
|
||
|
for class_name, single_items in single_entity.items():
|
||
|
single_entity_dict[class_name] = []
|
||
|
for single_item in single_items:
|
||
|
single_entity_dict[class_name].append(
|
||
|
{
|
||
|
"text": single_item["text"],
|
||
|
"id": single_item["group_id"],
|
||
|
"class": class_name,
|
||
|
"bbox": single_item["bbox"],
|
||
|
}
|
||
|
)
|
||
|
|
||
|
# list_product_name_group_ids = set(
|
||
|
# list(product_name_code_pairs.keys())
|
||
|
# + list(product_name_quantity_pairs.keys())
|
||
|
# + [x["id"] for x in single_entity_dict["product_name"]]
|
||
|
# )
|
||
|
# for product_name_group_id in list_product_name_group_ids:
|
||
|
# item = {"productname": [], "modelnumber": [], "qty": []}
|
||
|
# item["productname"].append(
|
||
|
# {
|
||
|
# "text": word_groups[product_name_group_id]["text"],
|
||
|
# "id": product_name_group_id,
|
||
|
# "class": "product_name",
|
||
|
# "bbox": word_groups[product_name_group_id]["bbox"],
|
||
|
# }
|
||
|
# )
|
||
|
# if product_name_group_id in product_name_code_pairs:
|
||
|
# product_code_group_id = product_name_code_pairs[product_name_group_id][0]
|
||
|
# item["modelnumber"].append(
|
||
|
# {
|
||
|
# "text": word_groups[product_code_group_id]["text"],
|
||
|
# "id": product_code_group_id,
|
||
|
# "class": "product_code",
|
||
|
# "bbox": word_groups[product_code_group_id]["bbox"],
|
||
|
# }
|
||
|
# )
|
||
|
# if product_name_group_id in product_name_quantity_pairs:
|
||
|
# quantity_group_id = product_name_quantity_pairs[product_name_group_id][0]
|
||
|
# item["qty"].append(
|
||
|
# {
|
||
|
# "text": word_groups[quantity_group_id]["text"],
|
||
|
# "id": quantity_group_id,
|
||
|
# "class": "quantity",
|
||
|
# "bbox": word_groups[quantity_group_id]["bbox"],
|
||
|
# }
|
||
|
# )
|
||
|
# table.append(item)
|
||
|
|
||
|
# if len(table) > 0:
|
||
|
# table = sorted(table, key=lambda x: x["productname"][0]["id"])
|
||
|
|
||
|
if len(serial_imei) > 0:
|
||
|
serial_imei = sorted(serial_imei, key=lambda x: x["id"])
|
||
|
|
||
|
outputs = {}
|
||
|
outputs["seller"] = single_entity_dict["seller"]
|
||
|
outputs["voucher"] = single_entity_dict["voucher"]
|
||
|
outputs["website"] = single_entity_dict["website"]
|
||
|
outputs["hotline"] = single_entity_dict["hotline"]
|
||
|
outputs["sold_value"] = sold_value + single_entity_dict["sold_key"] + single_entity_dict["sold_value"]
|
||
|
outputs["date_value"] = date_value + single_entity_dict["date_value"] + single_entity_dict["date_key"]
|
||
|
outputs["serial_imei"] = serial_imei + single_entity_dict["sn_value"] + single_entity_dict["sn_key"]
|
||
|
# outputs["table"] = table
|
||
|
return outputs
|