49 lines
1.4 KiB
Python
Executable File
49 lines
1.4 KiB
Python
Executable File
import torch
|
|
from torch import nn
|
|
|
|
|
|
class RelationExtractor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_relations,
|
|
backbone_hidden_size,
|
|
head_hidden_size,
|
|
head_p_dropout=0.1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.n_relations = n_relations
|
|
self.backbone_hidden_size = backbone_hidden_size
|
|
self.head_hidden_size = head_hidden_size
|
|
self.head_p_dropout = head_p_dropout
|
|
|
|
self.drop = nn.Dropout(head_p_dropout)
|
|
self.q_net = nn.Linear(
|
|
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
|
)
|
|
|
|
self.k_net = nn.Linear(
|
|
self.backbone_hidden_size, self.n_relations * self.head_hidden_size
|
|
)
|
|
|
|
self.dummy_node = nn.Parameter(torch.Tensor(1, self.backbone_hidden_size))
|
|
nn.init.normal_(self.dummy_node)
|
|
|
|
def forward(self, h_q, h_k):
|
|
h_q = self.q_net(self.drop(h_q))
|
|
|
|
dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, h_k.size(1), 1)
|
|
h_k = torch.cat([h_k, dummy_vec], axis=0)
|
|
h_k = self.k_net(self.drop(h_k))
|
|
|
|
head_q = h_q.view(
|
|
h_q.size(0), h_q.size(1), self.n_relations, self.head_hidden_size
|
|
)
|
|
head_k = h_k.view(
|
|
h_k.size(0), h_k.size(1), self.n_relations, self.head_hidden_size
|
|
)
|
|
|
|
relation_score = torch.einsum("ibnd,jbnd->nbij", (head_q, head_k))
|
|
|
|
return relation_score
|