sbt-idp/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/utils.py
2023-11-30 18:22:16 +07:00

162 lines
6.8 KiB
Python
Executable File

import os
import copy
import numpy as np
import torch
import torch.nn as nn
import math
def sliding_windows(elements: list, window_size: int, slice_interval: int) -> list:
element_windows = []
if len(elements) > window_size:
max_step = math.ceil((len(elements) - window_size)/slice_interval)
for i in range(0, max_step + 1):
# element_windows.append(copy.deepcopy(elements[min(i, len(elements) - window_size): min(i+window_size, len(elements))]))
if (i*slice_interval+window_size) >= len(elements):
_window = copy.deepcopy(elements[i*slice_interval:])
else:
_window = copy.deepcopy(elements[i*slice_interval: i*slice_interval+window_size])
element_windows.append(_window)
return element_windows
else:
return [elements]
def sliding_windows_by_words(lwords: list, parse_class: dict, parse_relation: list, window_size: int, slice_interval: int) -> list:
word_windows = []
parse_class_windows = []
parse_relation_windows = []
if len(lwords) > window_size:
max_step = math.ceil((len(lwords) - window_size)/slice_interval)
for i in range(0, max_step+1):
# _word_window = copy.deepcopy(lwords[min(i*slice_interval, len(lwords) - window_size): min(i*slice_interval+window_size, len(lwords))])
if (i*slice_interval+window_size) >= len(lwords):
_word_window = copy.deepcopy(lwords[i*slice_interval:])
else:
_word_window = copy.deepcopy(lwords[i*slice_interval: i*slice_interval+window_size])
if len(_word_window) < 2:
continue
first_word_id = _word_window[0]['word_id']
last_word_id = _word_window[-1]['word_id']
# assert (last_word_id - first_word_id == window_size - 1) or (first_word_id == 0 and last_word_id == len(lwords) - 1), [v['word_id'] for v in _word_window] #(last_word_id,first_word_id,len(lwords))
# word list
for _word in _word_window:
_word['word_id'] -= first_word_id
# Entity extraction
_class_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
# Entity Linking
_relation_window = entity_extraction_by_words(parse_class, first_word_id, last_word_id)
word_windows.append(_word_window)
parse_class_windows.append(_class_window)
parse_relation_windows.append(_relation_window)
return word_windows, parse_class_windows, parse_relation_windows
else:
return [lwords], [parse_class], [parse_relation]
def entity_extraction_by_words(parse_class, first_word_id, last_word_id):
_class_window = {k: [] for k in list(parse_class.keys())}
for class_name, _parse_class in parse_class.items():
for group in _parse_class:
tmp = []
for idw in group:
idw -= first_word_id
if 0 <= idw <= (last_word_id - first_word_id):
tmp.append(idw)
_class_window[class_name].append(tmp)
return _class_window
def entity_linking_by_words(parse_relation, first_word_id, last_word_id):
_relation_window = []
for pair in parse_relation:
if all([0 <= idw - first_word_id <= (last_word_id - first_word_id) for idw in pair]):
_relation_window.append([idw - first_word_id for idw in pair])
return _relation_window
def merged_token_embeddings(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = copy.deepcopy(lpatches[0][:, start_pos:end_pos, ...])
cls_token = copy.deepcopy(lpatches[0][:, :1, ...])
sep_token = copy.deepcopy(lpatches[0][:, -1:, ...])
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = copy.deepcopy(loverlaps[i-1])
window = copy.deepcopy(lpatches[i][:, start_pos:end_pos, ...])
if overlap_gap != 0:
prev_overlap = copy.deepcopy(embedding_tokens[:, -overlap_gap:, ...])
curr_overlap = copy.deepcopy(window[:, :overlap_gap, ...])
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], curr_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)
def merged_token_embeddings2(lpatches: list, loverlaps:list, lvalids: list, average: bool) -> torch.tensor:
start_pos = 1
end_pos = start_pos + lvalids[0]
embedding_tokens = lpatches[0][:, start_pos:end_pos, ...]
cls_token = lpatches[0][:, :1, ...]
sep_token = lpatches[0][:, -1:, ...]
for i in range(1, len(lpatches)):
start_pos = 1
end_pos = start_pos + lvalids[i]
overlap_gap = loverlaps[i-1]
window = lpatches[i][:, start_pos:end_pos, ...]
if overlap_gap != 0:
prev_overlap = embedding_tokens[:, -overlap_gap:, ...]
curr_overlap = window[:, :overlap_gap, ...]
assert prev_overlap.shape == curr_overlap.shape, f"{prev_overlap.shape} # {curr_overlap.shape} with overlap: {overlap_gap}"
if average:
avg_overlap = (
prev_overlap + curr_overlap
) / 2.
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], avg_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens[:, :-overlap_gap, ...], prev_overlap, window[:, overlap_gap:, ...]], dim=1
)
else:
embedding_tokens = torch.cat(
[embedding_tokens, window], dim=1
)
return torch.cat([cls_token, embedding_tokens, sep_token], dim=1)