sbt-idp/cope2n-ai-fi/common/AnyKey_Value/model/relation_extractor.py
2023-11-30 18:22:16 +07:00

49 lines
1.4 KiB
Python
Executable File

import torch
from torch import nn
class RelationExtractor(nn.Module):
def __init__(
self,
n_relations,
backbone_hidden_size,
head_hidden_size,
head_p_dropout=0.1,
):
super().__init__()
self.n_relations = n_relations
self.backbone_hidden_size = backbone_hidden_size
self.head_hidden_size = head_hidden_size
self.head_p_dropout = head_p_dropout
self.drop = nn.Dropout(head_p_dropout)
self.q_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.k_net = nn.Linear(
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
)
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
nn.init.normal_(self.dummy_node)
def forward(self, h_q, h_k):
h_q = self.q_net(self.drop(h_q))
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
h_k = torch.cat([h_k, dummy_vec], axis=0)
h_k = self.k_net(self.drop(h_k))
head_q = h_q.view(
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
)
head_k = h_k.view(
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
)
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
return relation_score