2023-11-30 11:22:16 +00:00
from omegaconf import OmegaConf
import os
import cv2
import torch
# from functions import get_colormap, visualize
import sys
sys . path . append ( ' /mnt/ssd1T/tuanlv/02.KeyValueUnderstanding/ ' ) # TODO: ???????
from lightning_modules . classifier_module import parse_initial_words , parse_subsequent_words , parse_relations
from model import get_model
from utils import load_model_weight
2024-07-05 13:14:47 +00:00
import logging
import logging . config
from utils . logging . logging import LOGGER_CONFIG
2023-11-30 11:22:16 +00:00
2024-07-05 13:14:47 +00:00
# Load the logging configuration
logging . config . dictConfig ( LOGGER_CONFIG )
# Get the logger
logger = logging . getLogger ( __name__ )
2023-11-30 11:22:16 +00:00
class KVUPredictor :
def __init__ ( self , configs , class_names , dummy_idx , mode = 0 ) :
cfg_path = configs [ ' cfg ' ]
ckpt_path = configs [ ' ckpt ' ]
self . class_names = class_names
self . dummy_idx = dummy_idx
self . mode = mode
2024-07-05 13:14:47 +00:00
logger . info ( ' Loading Key-Value Understanding model ... ' )
2023-11-30 11:22:16 +00:00
self . net , cfg , self . backbone_type = self . _load_model ( cfg_path , ckpt_path )
2024-07-05 13:14:47 +00:00
logger . info ( " Loaded model " )
2023-11-30 11:22:16 +00:00
if mode == 3 :
self . max_window_count = cfg . train . max_window_count
self . window_size = cfg . train . window_size
self . slice_interval = 0
self . dummy_idx = dummy_idx * self . max_window_count
else :
self . slice_interval = cfg . train . slice_interval
self . window_size = cfg . train . max_num_words
self . device = ' cuda '
def _load_model ( self , cfg_path , ckpt_path ) :
cfg = OmegaConf . load ( cfg_path )
cfg . stage = self . mode
backbone_type = cfg . model . backbone
2024-07-05 13:14:47 +00:00
logger . info ( ' Checkpoint: ' , ckpt_path )
2023-11-30 11:22:16 +00:00
net = get_model ( cfg )
load_model_weight ( net , ckpt_path )
net . to ( ' cuda ' )
net . eval ( )
return net , cfg , backbone_type
def predict ( self , input_sample ) :
if self . mode == 0 :
if len ( input_sample [ ' words ' ] ) == 0 :
return [ ] , [ ] , [ ] , [ ]
bbox , lwords , pr_class_words , pr_relations = self . combined_predict ( input_sample )
return [ bbox ] , [ lwords ] , [ pr_class_words ] , [ pr_relations ]
elif self . mode == 1 :
if len ( input_sample [ ' documents ' ] [ ' words ' ] ) == 0 :
return [ ] , [ ] , [ ] , [ ]
bbox , lwords , pr_class_words , pr_relations = self . cat_predict ( input_sample )
return [ bbox ] , [ lwords ] , [ pr_class_words ] , [ pr_relations ]
elif self . mode == 2 :
if len ( input_sample [ ' windows ' ] [ 0 ] [ ' words ' ] ) == 0 :
return [ ] , [ ] , [ ] , [ ]
bbox , lwords , pr_class_words , pr_relations = [ ] , [ ] , [ ] , [ ]
for window in input_sample [ ' windows ' ] :
_bbox , _lwords , _pr_class_words , _pr_relations = self . combined_predict ( window )
bbox . append ( _bbox )
lwords . append ( _lwords )
pr_class_words . append ( _pr_class_words )
pr_relations . append ( _pr_relations )
return bbox , lwords , pr_class_words , pr_relations
elif self . mode == 3 :
if len ( input_sample [ " documents " ] [ ' words ' ] ) == 0 :
return [ ] , [ ] , [ ] , [ ]
bbox , lwords , pr_class_words , pr_relations = self . doc_predict ( input_sample )
return [ bbox ] , [ lwords ] , [ pr_class_words ] , [ pr_relations ]
else :
raise ValueError (
f " Not supported mode: { self . mode } "
)
def doc_predict ( self , input_sample ) :
lwords = input_sample [ ' documents ' ] [ ' words ' ]
for idx , window in enumerate ( input_sample [ ' windows ' ] ) :
input_sample [ ' windows ' ] [ idx ] = { k : v . unsqueeze ( 0 ) . to ( self . device ) for k , v in window . items ( ) if k not in ( ' words ' , ' n_empty_windows ' ) }
# input_sample['documents'] = {k: v.unsqueeze(0).to(self.device) for k, v in input_sample['documents'].items() if k not in ('words', 'n_empty_windows')}
with torch . no_grad ( ) :
head_outputs , _ = self . net ( input_sample )
head_outputs = { k : v . detach ( ) . cpu ( ) for k , v in head_outputs . items ( ) }
input_sample = input_sample [ ' documents ' ]
itc_outputs = head_outputs [ " itc_outputs " ]
stc_outputs = head_outputs [ " stc_outputs " ]
el_outputs = head_outputs [ " el_outputs " ]
el_outputs_from_key = head_outputs [ " el_outputs_from_key " ]
pr_itc_label = torch . argmax ( itc_outputs , - 1 ) . squeeze ( 0 )
pr_stc_label = torch . argmax ( stc_outputs , - 1 ) . squeeze ( 0 )
pr_el_label = torch . argmax ( el_outputs , - 1 ) . squeeze ( 0 )
pr_el_from_key = torch . argmax ( el_outputs_from_key , - 1 ) . squeeze ( 0 )
box_first_token_mask = input_sample [ ' are_box_first_tokens ' ] . squeeze ( 0 )
attention_mask = input_sample [ ' attention_mask ' ] . squeeze ( 0 )
bbox = input_sample [ ' bbox ' ] . squeeze ( 0 )
pr_init_words = parse_initial_words ( pr_itc_label , box_first_token_mask , self . class_names )
pr_class_words = parse_subsequent_words (
pr_stc_label , attention_mask , pr_init_words , self . dummy_idx
)
pr_relations_from_header = parse_relations ( pr_el_label , box_first_token_mask , self . dummy_idx )
pr_relations_from_key = parse_relations ( pr_el_from_key , box_first_token_mask , self . dummy_idx )
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox , lwords , pr_class_words , pr_relations
def combined_predict ( self , input_sample ) :
lwords = input_sample [ ' words ' ]
input_sample = { k : v . unsqueeze ( 0 ) for k , v in input_sample . items ( ) if k not in ( ' words ' , ' img_path ' ) }
input_sample = { k : v . to ( self . device ) for k , v in input_sample . items ( ) }
with torch . no_grad ( ) :
head_outputs , _ = self . net ( input_sample )
head_outputs = { k : v . detach ( ) . cpu ( ) for k , v in head_outputs . items ( ) }
input_sample = { k : v . detach ( ) . cpu ( ) for k , v in input_sample . items ( ) }
itc_outputs = head_outputs [ " itc_outputs " ]
stc_outputs = head_outputs [ " stc_outputs " ]
el_outputs = head_outputs [ " el_outputs " ]
el_outputs_from_key = head_outputs [ " el_outputs_from_key " ]
pr_itc_label = torch . argmax ( itc_outputs , - 1 ) . squeeze ( 0 )
pr_stc_label = torch . argmax ( stc_outputs , - 1 ) . squeeze ( 0 )
pr_el_label = torch . argmax ( el_outputs , - 1 ) . squeeze ( 0 )
pr_el_from_key = torch . argmax ( el_outputs_from_key , - 1 ) . squeeze ( 0 )
box_first_token_mask = input_sample [ ' are_box_first_tokens ' ] . squeeze ( 0 )
attention_mask = input_sample [ ' attention_mask_layoutxlm ' ] . squeeze ( 0 )
bbox = input_sample [ ' bbox ' ] . squeeze ( 0 )
pr_init_words = parse_initial_words ( pr_itc_label , box_first_token_mask , self . class_names )
pr_class_words = parse_subsequent_words (
pr_stc_label , attention_mask , pr_init_words , self . dummy_idx
)
pr_relations_from_header = parse_relations ( pr_el_label , box_first_token_mask , self . dummy_idx )
pr_relations_from_key = parse_relations ( pr_el_from_key , box_first_token_mask , self . dummy_idx )
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox , lwords , pr_class_words , pr_relations
def cat_predict ( self , input_sample ) :
lwords = input_sample [ ' documents ' ] [ ' words ' ]
inputs = [ ]
for window in input_sample [ ' windows ' ] :
inputs . append ( { k : v . unsqueeze ( 0 ) . cuda ( ) for k , v in window . items ( ) if k not in ( ' words ' , ' img_path ' ) } )
input_sample [ ' windows ' ] = inputs
with torch . no_grad ( ) :
head_outputs , _ = self . net ( input_sample )
head_outputs = { k : v . detach ( ) . cpu ( ) for k , v in head_outputs . items ( ) if k not in ( ' embedding_tokens ' ) }
itc_outputs = head_outputs [ " itc_outputs " ]
stc_outputs = head_outputs [ " stc_outputs " ]
el_outputs = head_outputs [ " el_outputs " ]
el_outputs_from_key = head_outputs [ " el_outputs_from_key " ]
pr_itc_label = torch . argmax ( itc_outputs , - 1 ) . squeeze ( 0 )
pr_stc_label = torch . argmax ( stc_outputs , - 1 ) . squeeze ( 0 )
pr_el_label = torch . argmax ( el_outputs , - 1 ) . squeeze ( 0 )
pr_el_from_key = torch . argmax ( el_outputs_from_key , - 1 ) . squeeze ( 0 )
box_first_token_mask = input_sample [ ' documents ' ] [ ' are_box_first_tokens ' ]
attention_mask = input_sample [ ' documents ' ] [ ' attention_mask_layoutxlm ' ]
bbox = input_sample [ ' documents ' ] [ ' bbox ' ]
dummy_idx = input_sample [ ' documents ' ] [ ' bbox ' ] . shape [ 0 ]
pr_init_words = parse_initial_words ( pr_itc_label , box_first_token_mask , self . class_names )
pr_class_words = parse_subsequent_words (
pr_stc_label , attention_mask , pr_init_words , dummy_idx
)
pr_relations_from_header = parse_relations ( pr_el_label , box_first_token_mask , dummy_idx )
pr_relations_from_key = parse_relations ( pr_el_from_key , box_first_token_mask , dummy_idx )
pr_relations = pr_relations_from_header | pr_relations_from_key
return bbox , lwords , pr_class_words , pr_relations
def get_ground_truth_label ( self , ground_truth ) :
# ground_truth = self.preprocessor.load_ground_truth(json_file)
gt_itc_label = ground_truth [ ' itc_labels ' ] . squeeze ( 0 ) # [1, 512] => [512]
gt_stc_label = ground_truth [ ' stc_labels ' ] . squeeze ( 0 ) # [1, 512] => [512]
gt_el_label = ground_truth [ ' el_labels ' ] . squeeze ( 0 )
gt_el_label_from_key = ground_truth [ ' el_labels_from_key ' ] . squeeze ( 0 )
lwords = ground_truth [ " words " ]
box_first_token_mask = ground_truth [ ' are_box_first_tokens ' ] . squeeze ( 0 )
attention_mask = ground_truth [ ' attention_mask ' ] . squeeze ( 0 )
bbox = ground_truth [ ' bbox ' ] . squeeze ( 0 )
gt_first_words = parse_initial_words (
gt_itc_label , box_first_token_mask , self . class_names
)
gt_class_words = parse_subsequent_words (
gt_stc_label , attention_mask , gt_first_words , self . dummy_idx
)
gt_relations_from_header = parse_relations ( gt_el_label , box_first_token_mask , self . dummy_idx )
gt_relations_from_key = parse_relations ( gt_el_label_from_key , box_first_token_mask , self . dummy_idx )
gt_relations = gt_relations_from_header | gt_relations_from_key
return bbox , lwords , gt_class_words , gt_relations