162 lines
6.8 KiB
Python
Executable File
162 lines
6.8 KiB
Python
Executable File
import os
|
|
import copy
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
|
|
|
|
|
|
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
|
|
element_windows = []
|
|
|
|
if len(elements) > window_size:
|
|
max_step = math.ceil((len(elements) - window_size)/slice_interval)
|
|
|
|
for i in range(0, max_step + 1):
|
|
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
|
|
if (i*slice_interval+window_size) >= len(elements):
|
|
_window = copy.deepcopy(elements[i*slice_interval:])
|
|
else:
|
|
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
|
|
element_windows.append(_window)
|
|
return element_windows
|
|
else:
|
|
return [elements]
|
|
|
|
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
|
|
word_windows = []
|
|
parse_class_windows = []
|
|
parse_relation_windows = []
|
|
|
|
if len(lwords) > window_size:
|
|
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
|
|
for i in range(0, max_step+1):
|
|
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
|
|
if (i*slice_interval+window_size) >= len(lwords):
|
|
_word_window = copy.deepcopy(lwords[i*slice_interval:])
|
|
else:
|
|
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
|
|
|
|
if len(_word_window) < 2:
|
|
continue
|
|
|
|
first_word_id = _word_window[0]['word_id']
|
|
last_word_id = _word_window[-1]['word_id']
|
|
|
|
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
|
|
# word list
|
|
for _word in _word_window:
|
|
_word['word_id'] -= first_word_id
|
|
|
|
|
|
# Entity extraction
|
|
_class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
|
|
|
|
# Entity Linking
|
|
_relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
|
|
|
|
word_windows.append(_word_window)
|
|
parse_class_windows.append(_class_window)
|
|
parse_relation_windows.append(_relation_window)
|
|
|
|
return word_windows, parse_class_windows, parse_relation_windows
|
|
else:
|
|
return [lwords], [parse_class], [parse_relation]
|
|
|
|
|
|
def entity_extraction_by_words(parse_class, first_word_id, last_word_id):
|
|
_class_window = {k: [] for k in list(parse_class.keys())}
|
|
for class_name, _parse_class in parse_class.items():
|
|
for group in _parse_class:
|
|
tmp = []
|
|
for idw in group:
|
|
idw -= first_word_id
|
|
if 0 <= idw <= (last_word_id - first_word_id):
|
|
tmp.append(idw)
|
|
_class_window[class_name].append(tmp)
|
|
return _class_window
|
|
|
|
def entity_linking_by_words(parse_relation, first_word_id, last_word_id):
|
|
_relation_window = []
|
|
for pair in parse_relation:
|
|
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
|
|
_relation_window.append([idw - first_word_id for idw in pair])
|
|
return _relation_window
|
|
|
|
|
|
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
|
start_pos = 1
|
|
end_pos = start_pos + lvalids[0]
|
|
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
|
|
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
|
|
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
|
|
|
|
for i in range(1, len(lpatches)):
|
|
start_pos = 1
|
|
end_pos = start_pos + lvalids[i]
|
|
|
|
overlap_gap = copy.deepcopy(loverlaps[i-1])
|
|
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
|
|
|
|
if overlap_gap != 0:
|
|
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
|
|
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
|
|
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
|
|
|
if average:
|
|
avg_overlap = (
|
|
prev_overlap + curr_overlap
|
|
) / 2.
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
|
)
|
|
else:
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
|
|
)
|
|
else:
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens, window], dim=1
|
|
)
|
|
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
|
|
|
|
|
|
|
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
|
|
start_pos = 1
|
|
end_pos = start_pos + lvalids[0]
|
|
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
|
|
cls_token = lpatches[0][:, :1, ...]
|
|
sep_token = lpatches[0][:, -1:, ...]
|
|
|
|
for i in range(1, len(lpatches)):
|
|
start_pos = 1
|
|
end_pos = start_pos + lvalids[i]
|
|
|
|
overlap_gap = loverlaps[i-1]
|
|
window = lpatches[i][:, start_pos:end_pos, ...]
|
|
|
|
if overlap_gap != 0:
|
|
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
|
|
curr_overlap = window[:, :overlap_gap, ...]
|
|
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
|
|
|
|
if average:
|
|
avg_overlap = (
|
|
prev_overlap + curr_overlap
|
|
) / 2.
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
|
|
)
|
|
else:
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
|
|
)
|
|
else:
|
|
embedding_tokens = torch.cat(
|
|
[embedding_tokens, window], dim=1
|
|
)
|
|
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
|
|
|