390 lines
13 KiB
Python
Executable File
390 lines
13 KiB
Python
Executable File
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
from overrides import overrides
|
|
|
|
from lightning_modules.classifier import ClassifierModule
|
|
from utils import get_class_names
|
|
|
|
|
|
class KVUClassifierModule(ClassifierModule):
|
|
def __init__(self, cfg):
|
|
super().__init__(cfg)
|
|
|
|
class_names = get_class_names(self.cfg.dataset_root_path)
|
|
|
|
self.window_size = cfg.train.max_num_words
|
|
self.slice_interval = cfg.train.slice_interval
|
|
self.eval_kwargs = {
|
|
"class_names": class_names,
|
|
"dummy_idx": self.cfg.train.max_seq_length, # update dummy_idx in next step
|
|
}
|
|
self.stage = cfg.stage
|
|
|
|
@overrides
|
|
def training_step(self, batch, batch_idx, *args):
|
|
if self.stage == 1:
|
|
_, loss = self.net(batch['windows'])
|
|
elif self.stage == 2:
|
|
_, loss = self.net(batch)
|
|
else:
|
|
raise ValueError(
|
|
f"Not supported stage: {self.stage}"
|
|
)
|
|
|
|
log_dict_input = {"train_loss": loss}
|
|
self.log_dict(log_dict_input, sync_dist=True)
|
|
return loss
|
|
|
|
@torch.no_grad()
|
|
@overrides
|
|
def validation_step(self, batch, batch_idx, *args):
|
|
if self.stage == 1:
|
|
step_out_total = {
|
|
"loss": 0,
|
|
"ee":{
|
|
"n_batch_gt": 0,
|
|
"n_batch_pr": 0,
|
|
"n_batch_correct": 0,
|
|
},
|
|
"el":{
|
|
"n_batch_gt": 0,
|
|
"n_batch_pr": 0,
|
|
"n_batch_correct": 0,
|
|
},
|
|
"el_from_key":{
|
|
"n_batch_gt": 0,
|
|
"n_batch_pr": 0,
|
|
"n_batch_correct": 0,
|
|
}}
|
|
for window in batch['windows']:
|
|
head_outputs, loss = self.net(window)
|
|
step_out = do_eval_step(window, head_outputs, loss, self.eval_kwargs)
|
|
for key in step_out_total:
|
|
if key == 'loss':
|
|
step_out_total[key] += step_out[key]
|
|
else:
|
|
for subkey in step_out_total[key]:
|
|
step_out_total[key][subkey] += step_out[key][subkey]
|
|
return step_out_total
|
|
|
|
elif self.stage == 2:
|
|
head_outputs, loss = self.net(batch)
|
|
# self.eval_kwargs['dummy_idx'] = batch['itc_labels'].shape[1]
|
|
# step_out = do_eval_step(batch, head_outputs, loss, self.eval_kwargs)
|
|
self.eval_kwargs['dummy_idx'] = batch['documents']['itc_labels'].shape[1]
|
|
step_out = do_eval_step(batch['documents'], head_outputs, loss, self.eval_kwargs)
|
|
return step_out
|
|
|
|
@torch.no_grad()
|
|
@overrides
|
|
def validation_epoch_end(self, validation_step_outputs):
|
|
scores = do_eval_epoch_end(validation_step_outputs)
|
|
self.print(
|
|
f"[EE] Precision: {scores['ee']['precision']:.4f}, Recall: {scores['ee']['recall']:.4f}, F1-score: {scores['ee']['f1']:.4f}"
|
|
)
|
|
self.print(
|
|
f"[EL] Precision: {scores['el']['precision']:.4f}, Recall: {scores['el']['recall']:.4f}, F1-score: {scores['el']['f1']:.4f}"
|
|
)
|
|
self.print(
|
|
f"[ELK] Precision: {scores['el_from_key']['precision']:.4f}, Recall: {scores['el_from_key']['recall']:.4f}, F1-score: {scores['el_from_key']['f1']:.4f}"
|
|
)
|
|
self.log('val_f1', (scores['ee']['f1'] + scores['el']['f1'] + scores['el_from_key']['f1']) / 3.)
|
|
tensorboard_logs = {'val_precision_ee': scores['ee']['precision'], 'val_recall_ee': scores['ee']['recall'], 'val_f1_ee': scores['ee']['f1'],
|
|
'val_precision_el': scores['el']['precision'], 'val_recall_el': scores['el']['recall'], 'val_f1_el': scores['el']['f1'],
|
|
'val_precision_el_from_key': scores['el_from_key']['precision'], 'val_recall_el_from_key': scores['el_from_key']['recall'], \
|
|
'val_f1_el_from_key': scores['el_from_key']['f1'],}
|
|
return {'log': tensorboard_logs}
|
|
|
|
|
|
def do_eval_step(batch, head_outputs, loss, eval_kwargs):
|
|
class_names = eval_kwargs["class_names"]
|
|
dummy_idx = eval_kwargs["dummy_idx"]
|
|
|
|
itc_outputs = head_outputs["itc_outputs"]
|
|
stc_outputs = head_outputs["stc_outputs"]
|
|
el_outputs = head_outputs["el_outputs"]
|
|
el_outputs_from_key = head_outputs["el_outputs_from_key"]
|
|
|
|
pr_itc_labels = torch.argmax(itc_outputs, -1)
|
|
pr_stc_labels = torch.argmax(stc_outputs, -1)
|
|
pr_el_labels = torch.argmax(el_outputs, -1)
|
|
pr_el_labels_from_key = torch.argmax(el_outputs_from_key, -1)
|
|
|
|
(
|
|
n_batch_gt_classes,
|
|
n_batch_pr_classes,
|
|
n_batch_correct_classes,
|
|
) = eval_ee_spade_batch(
|
|
pr_itc_labels,
|
|
batch["itc_labels"],
|
|
batch["are_box_first_tokens"],
|
|
pr_stc_labels,
|
|
batch["stc_labels"],
|
|
batch["attention_mask_layoutxlm"],
|
|
class_names,
|
|
dummy_idx,
|
|
)
|
|
|
|
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
|
|
pr_el_labels,
|
|
batch["el_labels"],
|
|
batch["are_box_first_tokens"],
|
|
dummy_idx,
|
|
)
|
|
|
|
n_batch_gt_rel_from_key, n_batch_pr_rel_from_key, n_batch_correct_rel_from_key = eval_el_spade_batch(
|
|
pr_el_labels_from_key,
|
|
batch["el_labels_from_key"],
|
|
batch["are_box_first_tokens"],
|
|
dummy_idx,
|
|
)
|
|
|
|
step_out = {
|
|
"loss": loss,
|
|
"ee":{
|
|
"n_batch_gt": n_batch_gt_classes,
|
|
"n_batch_pr": n_batch_pr_classes,
|
|
"n_batch_correct": n_batch_correct_classes,
|
|
},
|
|
"el":{
|
|
"n_batch_gt": n_batch_gt_rel,
|
|
"n_batch_pr": n_batch_pr_rel,
|
|
"n_batch_correct": n_batch_correct_rel,
|
|
},
|
|
"el_from_key":{
|
|
"n_batch_gt": n_batch_gt_rel_from_key,
|
|
"n_batch_pr": n_batch_pr_rel_from_key,
|
|
"n_batch_correct": n_batch_correct_rel_from_key,
|
|
}
|
|
|
|
}
|
|
|
|
return step_out
|
|
|
|
|
|
def eval_ee_spade_batch(
|
|
pr_itc_labels,
|
|
gt_itc_labels,
|
|
are_box_first_tokens,
|
|
pr_stc_labels,
|
|
gt_stc_labels,
|
|
attention_mask,
|
|
class_names,
|
|
dummy_idx,
|
|
):
|
|
n_batch_gt_classes, n_batch_pr_classes, n_batch_correct_classes = 0, 0, 0
|
|
|
|
bsz = pr_itc_labels.shape[0]
|
|
for example_idx in range(bsz):
|
|
n_gt_classes, n_pr_classes, n_correct_classes = eval_ee_spade_example(
|
|
pr_itc_labels[example_idx],
|
|
gt_itc_labels[example_idx],
|
|
are_box_first_tokens[example_idx],
|
|
pr_stc_labels[example_idx],
|
|
gt_stc_labels[example_idx],
|
|
attention_mask[example_idx],
|
|
class_names,
|
|
dummy_idx,
|
|
)
|
|
|
|
n_batch_gt_classes += n_gt_classes
|
|
n_batch_pr_classes += n_pr_classes
|
|
n_batch_correct_classes += n_correct_classes
|
|
|
|
return (
|
|
n_batch_gt_classes,
|
|
n_batch_pr_classes,
|
|
n_batch_correct_classes,
|
|
)
|
|
|
|
|
|
def eval_ee_spade_example(
|
|
pr_itc_label,
|
|
gt_itc_label,
|
|
box_first_token_mask,
|
|
pr_stc_label,
|
|
gt_stc_label,
|
|
attention_mask,
|
|
class_names,
|
|
dummy_idx,
|
|
):
|
|
gt_first_words = parse_initial_words(
|
|
gt_itc_label, box_first_token_mask, class_names
|
|
)
|
|
gt_class_words = parse_subsequent_words(
|
|
gt_stc_label, attention_mask, gt_first_words, dummy_idx
|
|
)
|
|
|
|
pr_init_words = parse_initial_words(pr_itc_label, box_first_token_mask, class_names)
|
|
pr_class_words = parse_subsequent_words(
|
|
pr_stc_label, attention_mask, pr_init_words, dummy_idx
|
|
)
|
|
|
|
n_gt_classes, n_pr_classes, n_correct_classes = 0, 0, 0
|
|
for class_idx in range(len(class_names)):
|
|
# Evaluate by ID
|
|
gt_parse = set(gt_class_words[class_idx])
|
|
pr_parse = set(pr_class_words[class_idx])
|
|
|
|
n_gt_classes += len(gt_parse)
|
|
n_pr_classes += len(pr_parse)
|
|
n_correct_classes += len(gt_parse & pr_parse)
|
|
|
|
return n_gt_classes, n_pr_classes, n_correct_classes
|
|
|
|
|
|
def parse_initial_words(itc_label, box_first_token_mask, class_names):
|
|
itc_label_np = itc_label.cpu().numpy()
|
|
box_first_token_mask_np = box_first_token_mask.cpu().numpy()
|
|
|
|
outputs = [[] for _ in range(len(class_names))]
|
|
|
|
for token_idx, label in enumerate(itc_label_np):
|
|
if box_first_token_mask_np[token_idx] and label != 0:
|
|
outputs[label].append(token_idx)
|
|
|
|
return outputs
|
|
|
|
|
|
def parse_subsequent_words(stc_label, attention_mask, init_words, dummy_idx):
|
|
max_connections = 50
|
|
|
|
valid_stc_label = stc_label * attention_mask.bool()
|
|
valid_stc_label = valid_stc_label.cpu().numpy()
|
|
stc_label_np = stc_label.cpu().numpy()
|
|
|
|
valid_token_indices = np.where(
|
|
(valid_stc_label != dummy_idx) * (valid_stc_label != 0)
|
|
)
|
|
|
|
next_token_idx_dict = {}
|
|
for token_idx in valid_token_indices[0]:
|
|
next_token_idx_dict[stc_label_np[token_idx]] = token_idx
|
|
|
|
outputs = []
|
|
for init_token_indices in init_words:
|
|
sub_outputs = []
|
|
for init_token_idx in init_token_indices:
|
|
cur_token_indices = [init_token_idx]
|
|
for _ in range(max_connections):
|
|
if cur_token_indices[-1] in next_token_idx_dict:
|
|
if (
|
|
next_token_idx_dict[cur_token_indices[-1]]
|
|
not in init_token_indices
|
|
):
|
|
cur_token_indices.append(
|
|
next_token_idx_dict[cur_token_indices[-1]]
|
|
)
|
|
else:
|
|
break
|
|
else:
|
|
break
|
|
sub_outputs.append(tuple(cur_token_indices))
|
|
|
|
outputs.append(sub_outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
def eval_el_spade_batch(
|
|
pr_el_labels,
|
|
gt_el_labels,
|
|
are_box_first_tokens,
|
|
dummy_idx,
|
|
):
|
|
n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0
|
|
|
|
bsz = pr_el_labels.shape[0]
|
|
for example_idx in range(bsz):
|
|
n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
|
|
pr_el_labels[example_idx],
|
|
gt_el_labels[example_idx],
|
|
are_box_first_tokens[example_idx],
|
|
dummy_idx,
|
|
)
|
|
|
|
n_batch_gt_rel += n_gt_rel
|
|
n_batch_pr_rel += n_pr_rel
|
|
n_batch_correct_rel += n_correct_rel
|
|
|
|
return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel
|
|
|
|
|
|
def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
|
|
gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
|
|
pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)
|
|
|
|
gt_relations = set(gt_relations)
|
|
pr_relations = set(pr_relations)
|
|
|
|
n_gt_rel = len(gt_relations)
|
|
n_pr_rel = len(pr_relations)
|
|
n_correct_rel = len(gt_relations & pr_relations)
|
|
|
|
return n_gt_rel, n_pr_rel, n_correct_rel
|
|
|
|
|
|
def parse_relations(el_label, box_first_token_mask, dummy_idx):
|
|
valid_el_labels = el_label * box_first_token_mask
|
|
valid_el_labels = valid_el_labels.cpu().numpy()
|
|
el_label_np = el_label.cpu().numpy()
|
|
|
|
max_token = box_first_token_mask.shape[0] - 1
|
|
|
|
valid_token_indices = np.where(
|
|
((valid_el_labels != dummy_idx) * (valid_el_labels != 0)) ###
|
|
)
|
|
|
|
link_map_tuples = []
|
|
for token_idx in valid_token_indices[0]:
|
|
link_map_tuples.append((el_label_np[token_idx], token_idx))
|
|
|
|
return set(link_map_tuples)
|
|
|
|
def do_eval_epoch_end(step_outputs):
|
|
scores = {}
|
|
for task in ['ee', 'el', 'el_from_key']:
|
|
n_total_gt_classes, n_total_pr_classes, n_total_correct_classes = 0, 0, 0
|
|
|
|
for step_out in step_outputs:
|
|
n_total_gt_classes += step_out[task]["n_batch_gt"]
|
|
n_total_pr_classes += step_out[task]["n_batch_pr"]
|
|
n_total_correct_classes += step_out[task]["n_batch_correct"]
|
|
|
|
precision = (
|
|
0.0 if n_total_pr_classes == 0 else n_total_correct_classes / n_total_pr_classes
|
|
)
|
|
recall = (
|
|
0.0 if n_total_gt_classes == 0 else n_total_correct_classes / n_total_gt_classes
|
|
)
|
|
f1 = (
|
|
0.0
|
|
if recall * precision == 0
|
|
else 2.0 * recall * precision / (recall + precision)
|
|
)
|
|
|
|
scores[task] = {
|
|
"precision": precision,
|
|
"recall": recall,
|
|
"f1": f1,
|
|
}
|
|
|
|
return scores
|
|
|
|
|
|
def get_eval_kwargs_spade(dataset_root_path, max_seq_length):
|
|
class_names = get_class_names(dataset_root_path)
|
|
dummy_idx = max_seq_length
|
|
|
|
eval_kwargs = {"class_names": class_names, "dummy_idx": dummy_idx}
|
|
|
|
return eval_kwargs
|
|
|
|
|
|
def get_eval_kwargs_spade_rel(max_seq_length):
|
|
dummy_idx = max_seq_length
|
|
|
|
eval_kwargs = {"dummy_idx": dummy_idx}
|
|
|
|
return eval_kwargs |