sbt-idp/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier_module.py
2023-11-30 18:22:16 +07:00

390 lines
13 KiB
Python
Executable File

import numpy as np
import torch
import torch.utils.data
from overrides import overrides
from lightning_modules.classifier import ClassifierModule
from utils import get_class_names
class KVUClassifierModule(ClassifierModule):
def __init__(self, cfg):
super().__init__(cfg)
class_names = get_class_names(self.cfg.dataset_root_path)
self.window_size = cfg.train.max_num_words
self.slice_interval = cfg.train.slice_interval
self.eval_kwargs = {
"class_names": class_names,
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
}
self.stage = cfg.stage
@overrides
def training_step(self, batch, batch_idx, *args):
if self.stage == 1:
_, loss = self.net(batch['windows'])
elif self.stage == 2:
_, loss = self.net(batch)
else:
raise ValueError(
f"Not supported stage: {self.stage}"
)
log_dict_input = {"train_loss": loss}
self.log_dict(log_dict_input, sync_dist=True)
return loss
@torch.no_grad()
@overrides
def validation_step(self, batch, batch_idx, *args):
if self.stage == 1:
step_out_total = {
"loss": 0,
"ee":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
},
"el_from_key":{
"n_batch_gt": 0,
"n_batch_pr": 0,
"n_batch_correct": 0,
}}
for window in batch['windows']:
head_outputs, loss = self.net(window)
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
for key in step_out_total:
if key == 'loss':
step_out_total[key] += step_out[key]
else:
for subkey in step_out_total[key]:
step_out_total[key][subkey] += step_out[key][subkey]
return step_out_total
elif self.stage == 2:
head_outputs, loss = self.net(batch)
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
return step_out
@torch.no_grad()
@overrides
def validation_epoch_end(self, validation_step_outputs):
scores = do_eval_epoch_end(validation_step_outputs)
self.print(
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
)
self.print(
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
)
self.print(
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
)
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
'val_f1_el_from_key': scores['el_from_key']['f1'],}
return {'log': tensorboard_logs}
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
class_names = eval_kwargs["class_names"]
dummy_idx = eval_kwargs["dummy_idx"]
itc_outputs = head_outputs["itc_outputs"]
stc_outputs = head_outputs["stc_outputs"]
el_outputs = head_outputs["el_outputs"]
el_outputs_from_key = head_outputs["el_outputs_from_key"]
pr_itc_labels = torch.argmax(itc_outputs, -1)
pr_stc_labels = torch.argmax(stc_outputs, -1)
pr_el_labels = torch.argmax(el_outputs, -1)
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
(
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
) = eval_ee_spade_batch(
pr_itc_labels,
batch["itc_labels"],
batch["are_box_first_tokens"],
pr_stc_labels,
batch["stc_labels"],
batch["attention_mask_layoutxlm"],
class_names,
dummy_idx,
)
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
pr_el_labels,
batch["el_labels"],
batch["are_box_first_tokens"],
dummy_idx,
)
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
pr_el_labels_from_key,
batch["el_labels_from_key"],
batch["are_box_first_tokens"],
dummy_idx,
)
step_out = {
"loss": loss,
"ee":{
"n_batch_gt": n_batch_gt_classes,
"n_batch_pr": n_batch_pr_classes,
"n_batch_correct": n_batch_correct_classes,
},
"el":{
"n_batch_gt": n_batch_gt_rel,
"n_batch_pr": n_batch_pr_rel,
"n_batch_correct": n_batch_correct_rel,
},
"el_from_key":{
"n_batch_gt": n_batch_gt_rel_from_key,
"n_batch_pr": n_batch_pr_rel_from_key,
"n_batch_correct": n_batch_correct_rel_from_key,
}
}
return step_out
def eval_ee_spade_batch(
pr_itc_labels,
gt_itc_labels,
are_box_first_tokens,
pr_stc_labels,
gt_stc_labels,
attention_mask,
class_names,
dummy_idx,
):
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
bsz = pr_itc_labels.shape[0]
for example_idx in range(bsz):
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
pr_itc_labels[example_idx],
gt_itc_labels[example_idx],
are_box_first_tokens[example_idx],
pr_stc_labels[example_idx],
gt_stc_labels[example_idx],
attention_mask[example_idx],
class_names,
dummy_idx,
)
n_batch_gt_classes += n_gt_classes
n_batch_pr_classes += n_pr_classes
n_batch_correct_classes += n_correct_classes
return (
n_batch_gt_classes,
n_batch_pr_classes,
n_batch_correct_classes,
)
def eval_ee_spade_example(
pr_itc_label,
gt_itc_label,
box_first_token_mask,
pr_stc_label,
gt_stc_label,
attention_mask,
class_names,
dummy_idx,
):
gt_first_words = parse_initial_words(
gt_itc_label, box_first_token_mask, class_names
)
gt_class_words = parse_subsequent_words(
gt_stc_label, attention_mask, gt_first_words, dummy_idx
)
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
pr_class_words = parse_subsequent_words(
pr_stc_label, attention_mask, pr_init_words, dummy_idx
)
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
for class_idx in range(len(class_names)):
# Evaluate by ID
gt_parse = set(gt_class_words[class_idx])
pr_parse = set(pr_class_words[class_idx])
n_gt_classes += len(gt_parse)
n_pr_classes += len(pr_parse)
n_correct_classes += len(gt_parse & pr_parse)
return n_gt_classes, n_pr_classes, n_correct_classes
def parse_initial_words(itc_label, box_first_token_mask, class_names):
itc_label_np = itc_label.cpu().numpy()
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
outputs = [[] for _ in range(len(class_names))]
for token_idx, label in enumerate(itc_label_np):
if box_first_token_mask_np[token_idx] and label != 0:
outputs[label].append(token_idx)
return outputs
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
max_connections = 50
valid_stc_label = stc_label * attention_mask.bool()
valid_stc_label = valid_stc_label.cpu().numpy()
stc_label_np = stc_label.cpu().numpy()
valid_token_indices = np.where(
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
)
next_token_idx_dict = {}
for token_idx in valid_token_indices[0]:
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
outputs = []
for init_token_indices in init_words:
sub_outputs = []
for init_token_idx in init_token_indices:
cur_token_indices = [init_token_idx]
for _ in range(max_connections):
if cur_token_indices[-1] in next_token_idx_dict:
if (
next_token_idx_dict[cur_token_indices[-1]]
not in init_token_indices
):
cur_token_indices.append(
next_token_idx_dict[cur_token_indices[-1]]
)
else:
break
else:
break
sub_outputs.append(tuple(cur_token_indices))
outputs.append(sub_outputs)
return outputs
def eval_el_spade_batch(
pr_el_labels,
gt_el_labels,
are_box_first_tokens,
dummy_idx,
):
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
bsz = pr_el_labels.shape[0]
for example_idx in range(bsz):
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
pr_el_labels[example_idx],
gt_el_labels[example_idx],
are_box_first_tokens[example_idx],
dummy_idx,
)
n_batch_gt_rel += n_gt_rel
n_batch_pr_rel += n_pr_rel
n_batch_correct_rel += n_correct_rel
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
gt_relations = set(gt_relations)
pr_relations = set(pr_relations)
n_gt_rel = len(gt_relations)
n_pr_rel = len(pr_relations)
n_correct_rel = len(gt_relations & pr_relations)
return n_gt_rel, n_pr_rel, n_correct_rel
def parse_relations(el_label, box_first_token_mask, dummy_idx):
valid_el_labels = el_label * box_first_token_mask
valid_el_labels = valid_el_labels.cpu().numpy()
el_label_np = el_label.cpu().numpy()
max_token = box_first_token_mask.shape[0] - 1
valid_token_indices = np.where(
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
)
link_map_tuples = []
for token_idx in valid_token_indices[0]:
link_map_tuples.append((el_label_np[token_idx], token_idx))
return set(link_map_tuples)
def do_eval_epoch_end(step_outputs):
scores = {}
for task in ['ee', 'el', 'el_from_key']:
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
for step_out in step_outputs:
n_total_gt_classes += step_out[task]["n_batch_gt"]
n_total_pr_classes += step_out[task]["n_batch_pr"]
n_total_correct_classes += step_out[task]["n_batch_correct"]
precision = (
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
)
recall = (
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
)
f1 = (
0.0
if recall * precision == 0
else 2.0 * recall * precision / (recall + precision)
)
scores[task] = {
"precision": precision,
"recall": recall,
"f1": f1,
}
return scores
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
class_names = get_class_names(dataset_root_path)
dummy_idx = max_seq_length
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
return eval_kwargs
def get_eval_kwargs_spade_rel(max_seq_length):
dummy_idx = max_seq_length
eval_kwargs = {"dummy_idx": dummy_idx}
return eval_kwargs