diff --git a/dicee/models/dualE.py b/dicee/models/dualE.py index 0d654b6c..f40a8a0e 100644 --- a/dicee/models/dualE.py +++ b/dicee/models/dualE.py @@ -1,69 +1,5 @@ import torch -import torch.autograd as autograd -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.autograd import Variable from .base_model import BaseKGE -import numpy as np -from numpy.random import RandomState - - - -# class OMult(BaseKGE): -# def __init__(self, args): -# super().__init__(args) -# self.name = 'OMult' - -# @staticmethod -# def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, -# emb_rel_e7): -# denominator = torch.sqrt( -# emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 -# + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) -# y0 = emb_rel_e0 / denominator -# y1 = emb_rel_e1 / denominator -# y2 = emb_rel_e2 / denominator -# y3 = emb_rel_e3 / denominator -# y4 = emb_rel_e4 / denominator -# y5 = emb_rel_e5 / denominator -# y6 = emb_rel_e6 / denominator -# y7 = emb_rel_e7 / denominator -# return y0, y1, y2, y3, y4, y5, y6, y7 - -# def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor): -# # (2) Split (1) into real and imaginary parts. -# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( -# head_ent_emb, 8) -# emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( -# rel_ent_emb, -# 8) -# if isinstance(self.normalize_relation_embeddings, IdentityClass): -# (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, -# emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, -# emb_rel_e1, emb_rel_e2, emb_rel_e3, -# emb_rel_e4, emb_rel_e5, emb_rel_e6, -# emb_rel_e7) - -# emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( -# tail_ent_emb, 8) -# # (3) Octonion Multiplication -# e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( -# O_1=( -# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), -# O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) -# # (4) -# # (4.3) Inner product -# e0_score = (e0 * emb_tail_e0).sum(dim=1) -# e1_score = (e1 * emb_tail_e1).sum(dim=1) -# e2_score = (e2 * emb_tail_e2).sum(dim=1) -# e3_score = (e3 * emb_tail_e3).sum(dim=1) -# e4_score = (e4 * emb_tail_e4).sum(dim=1) -# e5_score = (e5 * emb_tail_e5).sum(dim=1) -# e6_score = (e6 * emb_tail_e6).sum(dim=1) -# e7_score = (e7 * emb_tail_e7).sum(dim=1) - -# return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score @@ -71,83 +7,9 @@ class DualE(BaseKGE): def __init__(self, args): super().__init__(args) self.name = 'DualE' - self.lmbda = 0.0 - self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) - - # self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) - # self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) - - # self.emb_1 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_2 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_3 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_4 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_5 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_6 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_7 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_8 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.rel_1 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_2 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_3 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_4 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_5 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_6 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_7 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_8 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_w = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.criterion = nn.Softplus() - # self.fc = nn.Linear(100, 50, bias=False) - # self.ent_dropout = torch.nn.Dropout(self.config.ent_dropout) - # self.rel_dropout = torch.nn.Dropout(self.config.rel_dropout) - # self.bn = torch.nn.BatchNorm1d(self.config.hidden_size) - - # self.init_weights() - - def init_weights(self): - if True: - r, i, j, k,r_1,i_1,j_1,k_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) - r, i, j, k,r_1,i_1,j_1,k_1 = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k),\ - torch.from_numpy(r_1), torch.from_numpy(i_1), torch.from_numpy(j_1), torch.from_numpy(k_1) - self.emb_1.weight.data = r.type_as(self.emb_1.weight.data) - self.emb_2.weight.data = i.type_as(self.emb_2.weight.data) - self.emb_3.weight.data = j.type_as(self.emb_3.weight.data) - self.emb_4.weight.data = k.type_as(self.emb_4.weight.data) - self.emb_5.weight.data = r_1.type_as(self.emb_5.weight.data) - self.emb_6.weight.data = i_1.type_as(self.emb_6.weight.data) - self.emb_7.weight.data = j_1.type_as(self.emb_7.weight.data) - self.emb_8.weight.data = k_1.type_as(self.emb_8.weight.data) - - s, x, y, z,s_1,x_1,y_1,z_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) - s, x, y, z,s_1,x_1,y_1,z_1 = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z), \ - torch.from_numpy(s_1), torch.from_numpy(x_1), torch.from_numpy(y_1), torch.from_numpy(z_1) - self.rel_1.weight.data = s.type_as(self.rel_1.weight.data) - self.rel_2.weight.data = x.type_as(self.rel_2.weight.data) - self.rel_3.weight.data = y.type_as(self.rel_3.weight.data) - self.rel_4.weight.data = z.type_as(self.rel_4.weight.data) - self.rel_5.weight.data = s_1.type_as(self.rel_5.weight.data) - self.rel_6.weight.data = x_1.type_as(self.rel_6.weight.data) - self.rel_7.weight.data = y_1.type_as(self.rel_7.weight.data) - self.rel_8.weight.data = z_1.type_as(self.rel_8.weight.data) - nn.init.xavier_uniform_(self.rel_w.weight.data) - else: - nn.init.xavier_uniform_(self.emb_1.weight.data) - nn.init.xavier_uniform_(self.emb_2.weight.data) - nn.init.xavier_uniform_(self.emb_3.weight.data) - nn.init.xavier_uniform_(self.emb_4.weight.data) - nn.init.xavier_uniform_(self.emb_5.weight.data) - nn.init.xavier_uniform_(self.emb_6.weight.data) - nn.init.xavier_uniform_(self.emb_7.weight.data) - nn.init.xavier_uniform_(self.emb_8.weight.data) - nn.init.xavier_uniform_(self.rel_1.weight.data) - nn.init.xavier_uniform_(self.rel_2.weight.data) - nn.init.xavier_uniform_(self.rel_3.weight.data) - nn.init.xavier_uniform_(self.rel_4.weight.data) - nn.init.xavier_uniform_(self.rel_5.weight.data) - nn.init.xavier_uniform_(self.rel_6.weight.data) - nn.init.xavier_uniform_(self.rel_7.weight.data) - nn.init.xavier_uniform_(self.rel_8.weight.data) - + self.num_ent = self.num_entities #Calculate the Dual Hamiltonian product @@ -201,165 +63,68 @@ def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t) return -torch.sum(score_r, -1) - + def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ): + + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) + + o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) + - # def loss(self, score, regul, regul2): - # return ( - # torch.mean(self.criterion(score * self.batch_y)) + self.lmbda * regul + self.lmbda * regul2 - # ) + score_r = torch.mm(o_1, e_1_t) + torch.mm(o_2 ,e_2_t) + torch.mm(o_3, e_3_t) + torch.mm(o_4, e_4_t)\ + + torch.mm(o_5, e_5_t) + torch.mm(o_6, e_6_t) + torch.mm(o_7, e_7_t) +torch.mm( o_8 , e_8_t) + + return -score_r + def forward_triples(self, idx_triple): head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple) + e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8) + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8) - - score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - regul = (torch.mean(torch.abs(e_1_h) ** 2) - + torch.mean(torch.abs(e_2_h) ** 2) - + torch.mean(torch.abs(e_3_h) ** 2) - + torch.mean(torch.abs(e_4_h) ** 2) - + torch.mean(torch.abs(e_5_h) ** 2) - + torch.mean(torch.abs(e_6_h) ** 2) - + torch.mean(torch.abs(e_7_h) ** 2) - + torch.mean(torch.abs(e_8_h) ** 2) - + torch.mean(torch.abs(e_1_t) ** 2) - + torch.mean(torch.abs(e_2_t) ** 2) - + torch.mean(torch.abs(e_3_t) ** 2) - + torch.mean(torch.abs(e_4_t) ** 2) - + torch.mean(torch.abs(e_5_t) ** 2) - + torch.mean(torch.abs(e_6_t) ** 2) - + torch.mean(torch.abs(e_7_t) ** 2) - + torch.mean(torch.abs(e_8_t) ** 2) - ) - regul2 = (torch.mean(torch.abs(r_1) ** 2) - + torch.mean(torch.abs(r_2) ** 2) - + torch.mean(torch.abs(r_3) ** 2) - + torch.mean(torch.abs(r_4) ** 2) - + torch.mean(torch.abs(r_5) ** 2) - + torch.mean(torch.abs(r_6) ** 2) - + torch.mean(torch.abs(r_7) ** 2) - + torch.mean(torch.abs(r_8) ** 2)) - - return score #self.loss(score, regul, regul2) - - def predict(self): - e_1_h = self.emb_1(self.batch_h) - e_2_h = self.emb_2(self.batch_h) - e_3_h = self.emb_3(self.batch_h) - e_4_h = self.emb_4(self.batch_h) - e_5_h = self.emb_5(self.batch_h) - e_6_h = self.emb_6(self.batch_h) - e_7_h = self.emb_7(self.batch_h) - e_8_h = self.emb_8(self.batch_h) - - e_1_t = self.emb_1(self.batch_t) - e_2_t = self.emb_2(self.batch_t) - e_3_t = self.emb_3(self.batch_t) - e_4_t = self.emb_4(self.batch_t) - e_5_t = self.emb_5(self.batch_t) - e_6_t = self.emb_6(self.batch_t) - e_7_t = self.emb_7(self.batch_t) - e_8_t = self.emb_8(self.batch_t) - - r_1 = self.rel_1(self.batch_r) - r_2 = self.rel_2(self.batch_r) - r_3 = self.rel_3(self.batch_r) - r_4 = self.rel_4(self.batch_r) - r_5 = self.rel_5(self.batch_r) - r_6 = self.rel_6(self.batch_r) - r_7 = self.rel_7(self.batch_r) - r_8 = self.rel_8(self.batch_r) - - score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, - e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, - r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - return score.cpu().data.numpy() - - - - - def quaternion_init(self, in_features, out_features, criterion='he'): - - fan_in = in_features - fan_out = out_features - - if criterion == 'glorot': - s = 1. / np.sqrt(2 * (fan_in + fan_out)) - elif criterion == 'he': - s = 1. / np.sqrt(2 * fan_in) - else: - raise ValueError('Invalid criterion: ', criterion) - rng = RandomState(2020) - - # Generating randoms and purely imaginary quaternions : - kernel_shape = (in_features, out_features) - - number_of_weights = np.prod(kernel_shape) - v_i = np.random.uniform(0.0, 1.0, number_of_weights) - v_j = np.random.uniform(0.0, 1.0, number_of_weights) - v_k = np.random.uniform(0.0, 1.0, number_of_weights) + + return score + - # Purely imaginary quaternions unitary - for i in range(0, number_of_weights): - norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 - v_i[i] /= norm - v_j[i] /= norm - v_k[i] /= norm - v_i = v_i.reshape(kernel_shape) - v_j = v_j.reshape(kernel_shape) - v_k = v_k.reshape(kernel_shape) - modulus = rng.uniform(low=-s, high=s, size=kernel_shape) + def forward_k_vs_all(self,x): + # (1) Retrieve embeddings & Apply Dropout & Normalization. + head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) + + e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) - # Calculate the three parts about t - kernel_shape1 = (in_features, out_features) - number_of_weights1 = np.prod(kernel_shape1) - t_i = np.random.uniform(0.0, 1.0, number_of_weights1) - t_j = np.random.uniform(0.0, 1.0, number_of_weights1) - t_k = np.random.uniform(0.0, 1.0, number_of_weights1) + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_ent_emb, 8) - # Purely imaginary quaternions unitary - for i in range(0, number_of_weights1): - norm1 = np.sqrt(t_i[i] ** 2 + t_j[i] ** 2 + t_k[i] ** 2) + 0.0001 - t_i[i] /= norm1 - t_j[i] /= norm1 - t_k[i] /= norm1 - t_i = t_i.reshape(kernel_shape1) - t_j = t_j.reshape(kernel_shape1) - t_k = t_k.reshape(kernel_shape1) - tmp_t = rng.uniform(low=-s, high=s, size=kernel_shape1) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(self.entity_embeddings.weight, 8) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = self.T(e_1_t), self.T(e_2_t), self.T(e_3_t),\ + self.T(e_4_t), self.T(e_5_t), self.T(e_6_t), self.T(e_7_t), self.T(e_8_t) - phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) - phase1 = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape1) + score = self.kvsall_score(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - weight_r = modulus * np.cos(phase) - weight_i = modulus * v_i * np.sin(phase) - weight_j = modulus * v_j * np.sin(phase) - weight_k = modulus * v_k * np.sin(phase) - wt_i = tmp_t * t_i * np.sin(phase1) - wt_j = tmp_t * t_j * np.sin(phase1) - wt_k = tmp_t * t_k * np.sin(phase1) + return score + + def T(self, x): + + return x.transpose(1, 0) - i_0=weight_r - i_1=weight_i - i_2=weight_j - i_3=weight_k - i_4=(-wt_i*weight_i-wt_j*weight_j-wt_k*weight_k)/2 - i_5=(wt_i*weight_r+wt_j*weight_k-wt_k*weight_j)/2 - i_6=(-wt_i*weight_k+wt_j*weight_r+wt_k*weight_i)/2 - i_7=(wt_i*weight_j-wt_j*weight_i+wt_k*weight_r)/2 + - return (i_0,i_1,i_2,i_3,i_4,i_5,i_6,i_7)