import os import copy import numpy as np import torch import torch.nn as nn import math def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list: element_windows = [] if len(elements) > window_size: max_step = math.ceil((len(elements) - window_size)/slice_interval) for i in range(0, max_step + 1): # element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))])) if (i*slice_interval+window_size) >= len(elements): _window = copy.deepcopy(elements[i*slice_interval:]) else: _window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size]) element_windows.append(_window) return element_windows else: return [elements] def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list: word_windows = [] parse_class_windows = [] parse_relation_windows = [] if len(lwords) > window_size: max_step = math.ceil((len(lwords) - window_size)/slice_interval) for i in range(0, max_step+1): # _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))]) if (i*slice_interval+window_size) >= len(lwords): _word_window = copy.deepcopy(lwords[i*slice_interval:]) else: _word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size]) if len(_word_window) < 2: continue first_word_id = _word_window[0]['word_id'] last_word_id = _word_window[-1]['word_id'] # assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords)) # word list for _word in _word_window: _word['word_id'] -= first_word_id # Entity extraction _class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id) # Entity Linking _relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id) word_windows.append(_word_window) parse_class_windows.append(_class_window) parse_relation_windows.append(_relation_window) return word_windows, parse_class_windows, parse_relation_windows else: return [lwords], [parse_class], [parse_relation] def entity_extraction_by_words(parse_class, first_word_id, last_word_id): _class_window = {k: [] for k in list(parse_class.keys())} for class_name, _parse_class in parse_class.items(): for group in _parse_class: tmp = [] for idw in group: idw -= first_word_id if 0 <= idw <= (last_word_id - first_word_id): tmp.append(idw) _class_window[class_name].append(tmp) return _class_window def entity_linking_by_words(parse_relation, first_word_id, last_word_id): _relation_window = [] for pair in parse_relation: if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]): _relation_window.append([idw - first_word_id for idw in pair]) return _relation_window def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: start_pos = 1 end_pos = start_pos + lvalids[0] embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...]) cls_token = copy.deepcopy(lpatches[0][:, :1, ...]) sep_token = copy.deepcopy(lpatches[0][:, -1:, ...]) for i in range(1, len(lpatches)): start_pos = 1 end_pos = start_pos + lvalids[i] overlap_gap = copy.deepcopy(loverlaps[i-1]) window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...]) if overlap_gap != 0: prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...]) curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...]) assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" if average: avg_overlap = ( prev_overlap + curr_overlap ) / 2. embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens, window], dim=1 ) return torch.cat([cls_token, embedding_tokens, sep_token], dim=1) def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor: start_pos = 1 end_pos = start_pos + lvalids[0] embedding_tokens = lpatches[0][:, start_pos:end_pos, ...] cls_token = lpatches[0][:, :1, ...] sep_token = lpatches[0][:, -1:, ...] for i in range(1, len(lpatches)): start_pos = 1 end_pos = start_pos + lvalids[i] overlap_gap = loverlaps[i-1] window = lpatches[i][:, start_pos:end_pos, ...] if overlap_gap != 0: prev_overlap = embedding_tokens[:, -overlap_gap:, ...] curr_overlap = window[:, :overlap_gap, ...] assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}" if average: avg_overlap = ( prev_overlap + curr_overlap ) / 2. embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1 ) else: embedding_tokens = torch.cat( [embedding_tokens, window], dim=1 ) return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)