import numpy as np import torch import torch.utils.data from overrides import overrides from lightning_modules.classifier import ClassifierModule from utils import get_class_names class KVUClassifierModule(ClassifierModule): def __init__(self, cfg): super().__init__(cfg) class_names = get_class_names(self.cfg.dataset_root_path) self.window_size = cfg.train.max_num_words self.slice_interval = cfg.train.slice_interval self.eval_kwargs = { "class_names": class_names, "dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step } self.stage = cfg.stage @overrides def training_step(self, batch, batch_idx, *args): if self.stage == 1: _, loss = self.net(batch['windows']) elif self.stage == 2: _, loss = self.net(batch) else: raise ValueError( f"Not supported stage: {self.stage}" ) log_dict_input = {"train_loss": loss} self.log_dict(log_dict_input, sync_dist=True) return loss @torch.no_grad() @overrides def validation_step(self, batch, batch_idx, *args): if self.stage == 1: step_out_total = { "loss": 0, "ee":{ "n_batch_gt": 0, "n_batch_pr": 0, "n_batch_correct": 0, }, "el":{ "n_batch_gt": 0, "n_batch_pr": 0, "n_batch_correct": 0, }, "el_from_key":{ "n_batch_gt": 0, "n_batch_pr": 0, "n_batch_correct": 0, }} for window in batch['windows']: head_outputs, loss = self.net(window) step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs) for key in step_out_total: if key == 'loss': step_out_total[key] += step_out[key] else: for subkey in step_out_total[key]: step_out_total[key][subkey] += step_out[key][subkey] return step_out_total elif self.stage == 2: head_outputs, loss = self.net(batch) # self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1] # step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs) self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1] step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs) return step_out @torch.no_grad() @overrides def validation_epoch_end(self, validation_step_outputs): scores = do_eval_epoch_end(validation_step_outputs) self.print( f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}" ) self.print( f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}" ) self.print( f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}" ) self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.) tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'], 'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'], 'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \ 'val_f1_el_from_key': scores['el_from_key']['f1'],} return {'log': tensorboard_logs} def do_eval_step(batch, head_outputs, loss, eval_kwargs): class_names = eval_kwargs["class_names"] dummy_idx = eval_kwargs["dummy_idx"] itc_outputs = head_outputs["itc_outputs"] stc_outputs = head_outputs["stc_outputs"] el_outputs = head_outputs["el_outputs"] el_outputs_from_key = head_outputs["el_outputs_from_key"] pr_itc_labels = torch.argmax(itc_outputs, -1) pr_stc_labels = torch.argmax(stc_outputs, -1) pr_el_labels = torch.argmax(el_outputs, -1) pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1) ( n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes, ) = eval_ee_spade_batch( pr_itc_labels, batch["itc_labels"], batch["are_box_first_tokens"], pr_stc_labels, batch["stc_labels"], batch["attention_mask_layoutxlm"], class_names, dummy_idx, ) n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch( pr_el_labels, batch["el_labels"], batch["are_box_first_tokens"], dummy_idx, ) n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch( pr_el_labels_from_key, batch["el_labels_from_key"], batch["are_box_first_tokens"], dummy_idx, ) step_out = { "loss": loss, "ee":{ "n_batch_gt": n_batch_gt_classes, "n_batch_pr": n_batch_pr_classes, "n_batch_correct": n_batch_correct_classes, }, "el":{ "n_batch_gt": n_batch_gt_rel, "n_batch_pr": n_batch_pr_rel, "n_batch_correct": n_batch_correct_rel, }, "el_from_key":{ "n_batch_gt": n_batch_gt_rel_from_key, "n_batch_pr": n_batch_pr_rel_from_key, "n_batch_correct": n_batch_correct_rel_from_key, } } return step_out def eval_ee_spade_batch( pr_itc_labels, gt_itc_labels, are_box_first_tokens, pr_stc_labels, gt_stc_labels, attention_mask, class_names, dummy_idx, ): n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0 bsz = pr_itc_labels.shape[0] for example_idx in range(bsz): n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example( pr_itc_labels[example_idx], gt_itc_labels[example_idx], are_box_first_tokens[example_idx], pr_stc_labels[example_idx], gt_stc_labels[example_idx], attention_mask[example_idx], class_names, dummy_idx, ) n_batch_gt_classes += n_gt_classes n_batch_pr_classes += n_pr_classes n_batch_correct_classes += n_correct_classes return ( n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes, ) def eval_ee_spade_example( pr_itc_label, gt_itc_label, box_first_token_mask, pr_stc_label, gt_stc_label, attention_mask, class_names, dummy_idx, ): gt_first_words = parse_initial_words( gt_itc_label, box_first_token_mask, class_names ) gt_class_words = parse_subsequent_words( gt_stc_label, attention_mask, gt_first_words, dummy_idx ) pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names) pr_class_words = parse_subsequent_words( pr_stc_label, attention_mask, pr_init_words, dummy_idx ) n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0 for class_idx in range(len(class_names)): # Evaluate by ID gt_parse = set(gt_class_words[class_idx]) pr_parse = set(pr_class_words[class_idx]) n_gt_classes += len(gt_parse) n_pr_classes += len(pr_parse) n_correct_classes += len(gt_parse & pr_parse) return n_gt_classes, n_pr_classes, n_correct_classes def parse_initial_words(itc_label, box_first_token_mask, class_names): itc_label_np = itc_label.cpu().numpy() box_first_token_mask_np = box_first_token_mask.cpu().numpy() outputs = [[] for _ in range(len(class_names))] for token_idx, label in enumerate(itc_label_np): if box_first_token_mask_np[token_idx] and label != 0: outputs[label].append(token_idx) return outputs def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx): max_connections = 50 valid_stc_label = stc_label * attention_mask.bool() valid_stc_label = valid_stc_label.cpu().numpy() stc_label_np = stc_label.cpu().numpy() valid_token_indices = np.where( (valid_stc_label != dummy_idx) * (valid_stc_label != 0) ) next_token_idx_dict = {} for token_idx in valid_token_indices[0]: next_token_idx_dict[stc_label_np[token_idx]] = token_idx outputs = [] for init_token_indices in init_words: sub_outputs = [] for init_token_idx in init_token_indices: cur_token_indices = [init_token_idx] for _ in range(max_connections): if cur_token_indices[-1] in next_token_idx_dict: if ( next_token_idx_dict[cur_token_indices[-1]] not in init_token_indices ): cur_token_indices.append( next_token_idx_dict[cur_token_indices[-1]] ) else: break else: break sub_outputs.append(tuple(cur_token_indices)) outputs.append(sub_outputs) return outputs def eval_el_spade_batch( pr_el_labels, gt_el_labels, are_box_first_tokens, dummy_idx, ): n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0 bsz = pr_el_labels.shape[0] for example_idx in range(bsz): n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example( pr_el_labels[example_idx], gt_el_labels[example_idx], are_box_first_tokens[example_idx], dummy_idx, ) n_batch_gt_rel += n_gt_rel n_batch_pr_rel += n_pr_rel n_batch_correct_rel += n_correct_rel return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx): gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx) pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx) gt_relations = set(gt_relations) pr_relations = set(pr_relations) n_gt_rel = len(gt_relations) n_pr_rel = len(pr_relations) n_correct_rel = len(gt_relations & pr_relations) return n_gt_rel, n_pr_rel, n_correct_rel def parse_relations(el_label, box_first_token_mask, dummy_idx): valid_el_labels = el_label * box_first_token_mask valid_el_labels = valid_el_labels.cpu().numpy() el_label_np = el_label.cpu().numpy() max_token = box_first_token_mask.shape[0] - 1 valid_token_indices = np.where( ((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ### ) link_map_tuples = [] for token_idx in valid_token_indices[0]: link_map_tuples.append((el_label_np[token_idx], token_idx)) return set(link_map_tuples) def do_eval_epoch_end(step_outputs): scores = {} for task in ['ee', 'el', 'el_from_key']: n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0 for step_out in step_outputs: n_total_gt_classes += step_out[task]["n_batch_gt"] n_total_pr_classes += step_out[task]["n_batch_pr"] n_total_correct_classes += step_out[task]["n_batch_correct"] precision = ( 0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes ) recall = ( 0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes ) f1 = ( 0.0 if recall * precision == 0 else 2.0 * recall * precision / (recall + precision) ) scores[task] = { "precision": precision, "recall": recall, "f1": f1, } return scores def get_eval_kwargs_spade(dataset_root_path, max_seq_length): class_names = get_class_names(dataset_root_path) dummy_idx = max_seq_length eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx} return eval_kwargs def get_eval_kwargs_spade_rel(max_seq_length): dummy_idx = max_seq_length eval_kwargs = {"dummy_idx": dummy_idx} return eval_kwargs