Skip to content

Commit

Permalink
WIP: DualE with NegSampling implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Mozart committed Mar 26, 2024
1 parent 2fa2dcc commit cd8935b
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 2 deletions.
1 change: 1 addition & 0 deletions dicee/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .clifford import Keci, KeciBase, CMult, DeCaL # noqa
from .pykeen_models import * # noqa
from .function_space import * # noqa
from .dualE import DualE
365 changes: 365 additions & 0 deletions dicee/models/dualE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from .base_model import BaseKGE
import numpy as np
from numpy.random import RandomState



# class OMult(BaseKGE):
# def __init__(self, args):
# super().__init__(args)
# self.name = 'OMult'

# @staticmethod
# def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6,
# emb_rel_e7):
# denominator = torch.sqrt(
# emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2
# + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2)
# y0 = emb_rel_e0 / denominator
# y1 = emb_rel_e1 / denominator
# y2 = emb_rel_e2 / denominator
# y3 = emb_rel_e3 / denominator
# y4 = emb_rel_e4 / denominator
# y5 = emb_rel_e5 / denominator
# y6 = emb_rel_e6 / denominator
# y7 = emb_rel_e7 / denominator
# return y0, y1, y2, y3, y4, y5, y6, y7

# def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor):
# # (2) Split (1) into real and imaginary parts.
# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit(
# head_ent_emb, 8)
# emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit(
# rel_ent_emb,
# 8)
# if isinstance(self.normalize_relation_embeddings, IdentityClass):
# (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4,
# emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0,
# emb_rel_e1, emb_rel_e2, emb_rel_e3,
# emb_rel_e4, emb_rel_e5, emb_rel_e6,
# emb_rel_e7)

# emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit(
# tail_ent_emb, 8)
# # (3) Octonion Multiplication
# e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul(
# O_1=(
# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7),
# O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7))
# # (4)
# # (4.3) Inner product
# e0_score = (e0 * emb_tail_e0).sum(dim=1)
# e1_score = (e1 * emb_tail_e1).sum(dim=1)
# e2_score = (e2 * emb_tail_e2).sum(dim=1)
# e3_score = (e3 * emb_tail_e3).sum(dim=1)
# e4_score = (e4 * emb_tail_e4).sum(dim=1)
# e5_score = (e5 * emb_tail_e5).sum(dim=1)
# e6_score = (e6 * emb_tail_e6).sum(dim=1)
# e7_score = (e7 * emb_tail_e7).sum(dim=1)

# return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score



class DualE(BaseKGE):
def __init__(self, args):
super().__init__(args)
self.name = 'DualE'
self.lmbda = 0.0

self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)

# self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
# self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)

# self.emb_1 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_2 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_3 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_4 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_5 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_6 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_7 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.emb_8 = nn.Embedding(self.config.entTotal, self.config.hidden_size)
# self.rel_1 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_2 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_3 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_4 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_5 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_6 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_7 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_8 = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.rel_w = nn.Embedding(self.config.relTotal, self.config.hidden_size)
# self.criterion = nn.Softplus()
# self.fc = nn.Linear(100, 50, bias=False)
# self.ent_dropout = torch.nn.Dropout(self.config.ent_dropout)
# self.rel_dropout = torch.nn.Dropout(self.config.rel_dropout)
# self.bn = torch.nn.BatchNorm1d(self.config.hidden_size)

# self.init_weights()

def init_weights(self):
if True:
r, i, j, k,r_1,i_1,j_1,k_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size)
r, i, j, k,r_1,i_1,j_1,k_1 = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k),\
torch.from_numpy(r_1), torch.from_numpy(i_1), torch.from_numpy(j_1), torch.from_numpy(k_1)
self.emb_1.weight.data = r.type_as(self.emb_1.weight.data)
self.emb_2.weight.data = i.type_as(self.emb_2.weight.data)
self.emb_3.weight.data = j.type_as(self.emb_3.weight.data)
self.emb_4.weight.data = k.type_as(self.emb_4.weight.data)
self.emb_5.weight.data = r_1.type_as(self.emb_5.weight.data)
self.emb_6.weight.data = i_1.type_as(self.emb_6.weight.data)
self.emb_7.weight.data = j_1.type_as(self.emb_7.weight.data)
self.emb_8.weight.data = k_1.type_as(self.emb_8.weight.data)

s, x, y, z,s_1,x_1,y_1,z_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size)
s, x, y, z,s_1,x_1,y_1,z_1 = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z), \
torch.from_numpy(s_1), torch.from_numpy(x_1), torch.from_numpy(y_1), torch.from_numpy(z_1)
self.rel_1.weight.data = s.type_as(self.rel_1.weight.data)
self.rel_2.weight.data = x.type_as(self.rel_2.weight.data)
self.rel_3.weight.data = y.type_as(self.rel_3.weight.data)
self.rel_4.weight.data = z.type_as(self.rel_4.weight.data)
self.rel_5.weight.data = s_1.type_as(self.rel_5.weight.data)
self.rel_6.weight.data = x_1.type_as(self.rel_6.weight.data)
self.rel_7.weight.data = y_1.type_as(self.rel_7.weight.data)
self.rel_8.weight.data = z_1.type_as(self.rel_8.weight.data)
nn.init.xavier_uniform_(self.rel_w.weight.data)
else:
nn.init.xavier_uniform_(self.emb_1.weight.data)
nn.init.xavier_uniform_(self.emb_2.weight.data)
nn.init.xavier_uniform_(self.emb_3.weight.data)
nn.init.xavier_uniform_(self.emb_4.weight.data)
nn.init.xavier_uniform_(self.emb_5.weight.data)
nn.init.xavier_uniform_(self.emb_6.weight.data)
nn.init.xavier_uniform_(self.emb_7.weight.data)
nn.init.xavier_uniform_(self.emb_8.weight.data)
nn.init.xavier_uniform_(self.rel_1.weight.data)
nn.init.xavier_uniform_(self.rel_2.weight.data)
nn.init.xavier_uniform_(self.rel_3.weight.data)
nn.init.xavier_uniform_(self.rel_4.weight.data)
nn.init.xavier_uniform_(self.rel_5.weight.data)
nn.init.xavier_uniform_(self.rel_6.weight.data)
nn.init.xavier_uniform_(self.rel_7.weight.data)
nn.init.xavier_uniform_(self.rel_8.weight.data)



#Calculate the Dual Hamiltonian product
def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3):

h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3
h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3
h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2
h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2
h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1
h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1
h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0
h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0

return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3)

#Normalization of relationship embedding
def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8):
denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2
denominator_1 = torch.sqrt(denominator_0)
#denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2)
deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4

r_5 = r_5 - deno_cross / denominator_0 * r_1
r_6 = r_6 - deno_cross / denominator_0 * r_2
r_7 = r_7 - deno_cross / denominator_0 * r_3
r_8 = r_8 - deno_cross / denominator_0 * r_4

r_1 = r_1 / denominator_1
r_2 = r_2 / denominator_1
r_3 = r_3 / denominator_1
r_4 = r_4 / denominator_1
#r_5 = r_5 / denominator_2
#r_6 = r_6 / denominator_2
#r_7 = r_7 / denominator_2
#r_8 = r_8 / denominator_2
return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8

#Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity
def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ):

r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8)


score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t
+ o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t)

return -torch.sum(score_r, -1)



# def loss(self, score, regul, regul2):
# return (
# torch.mean(self.criterion(score * self.batch_y)) + self.lmbda * regul + self.lmbda * regul2
# )

def forward_triples(self, idx_triple):

head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple)

e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8)
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8)
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8)



score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )

regul = (torch.mean(torch.abs(e_1_h) ** 2)
+ torch.mean(torch.abs(e_2_h) ** 2)
+ torch.mean(torch.abs(e_3_h) ** 2)
+ torch.mean(torch.abs(e_4_h) ** 2)
+ torch.mean(torch.abs(e_5_h) ** 2)
+ torch.mean(torch.abs(e_6_h) ** 2)
+ torch.mean(torch.abs(e_7_h) ** 2)
+ torch.mean(torch.abs(e_8_h) ** 2)
+ torch.mean(torch.abs(e_1_t) ** 2)
+ torch.mean(torch.abs(e_2_t) ** 2)
+ torch.mean(torch.abs(e_3_t) ** 2)
+ torch.mean(torch.abs(e_4_t) ** 2)
+ torch.mean(torch.abs(e_5_t) ** 2)
+ torch.mean(torch.abs(e_6_t) ** 2)
+ torch.mean(torch.abs(e_7_t) ** 2)
+ torch.mean(torch.abs(e_8_t) ** 2)
)
regul2 = (torch.mean(torch.abs(r_1) ** 2)
+ torch.mean(torch.abs(r_2) ** 2)
+ torch.mean(torch.abs(r_3) ** 2)
+ torch.mean(torch.abs(r_4) ** 2)
+ torch.mean(torch.abs(r_5) ** 2)
+ torch.mean(torch.abs(r_6) ** 2)
+ torch.mean(torch.abs(r_7) ** 2)
+ torch.mean(torch.abs(r_8) ** 2))

return score #self.loss(score, regul, regul2)

def predict(self):
e_1_h = self.emb_1(self.batch_h)
e_2_h = self.emb_2(self.batch_h)
e_3_h = self.emb_3(self.batch_h)
e_4_h = self.emb_4(self.batch_h)
e_5_h = self.emb_5(self.batch_h)
e_6_h = self.emb_6(self.batch_h)
e_7_h = self.emb_7(self.batch_h)
e_8_h = self.emb_8(self.batch_h)

e_1_t = self.emb_1(self.batch_t)
e_2_t = self.emb_2(self.batch_t)
e_3_t = self.emb_3(self.batch_t)
e_4_t = self.emb_4(self.batch_t)
e_5_t = self.emb_5(self.batch_t)
e_6_t = self.emb_6(self.batch_t)
e_7_t = self.emb_7(self.batch_t)
e_8_t = self.emb_8(self.batch_t)

r_1 = self.rel_1(self.batch_r)
r_2 = self.rel_2(self.batch_r)
r_3 = self.rel_3(self.batch_r)
r_4 = self.rel_4(self.batch_r)
r_5 = self.rel_5(self.batch_r)
r_6 = self.rel_6(self.batch_r)
r_7 = self.rel_7(self.batch_r)
r_8 = self.rel_8(self.batch_r)

score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h,
e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t,
r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )
return score.cpu().data.numpy()




def quaternion_init(self, in_features, out_features, criterion='he'):

fan_in = in_features
fan_out = out_features

if criterion == 'glorot':
s = 1. / np.sqrt(2 * (fan_in + fan_out))
elif criterion == 'he':
s = 1. / np.sqrt(2 * fan_in)
else:
raise ValueError('Invalid criterion: ', criterion)
rng = RandomState(2020)

# Generating randoms and purely imaginary quaternions :
kernel_shape = (in_features, out_features)

number_of_weights = np.prod(kernel_shape)
v_i = np.random.uniform(0.0, 1.0, number_of_weights)
v_j = np.random.uniform(0.0, 1.0, number_of_weights)
v_k = np.random.uniform(0.0, 1.0, number_of_weights)

# Purely imaginary quaternions unitary
for i in range(0, number_of_weights):
norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)

modulus = rng.uniform(low=-s, high=s, size=kernel_shape)


# Calculate the three parts about t
kernel_shape1 = (in_features, out_features)
number_of_weights1 = np.prod(kernel_shape1)
t_i = np.random.uniform(0.0, 1.0, number_of_weights1)
t_j = np.random.uniform(0.0, 1.0, number_of_weights1)
t_k = np.random.uniform(0.0, 1.0, number_of_weights1)

# Purely imaginary quaternions unitary
for i in range(0, number_of_weights1):
norm1 = np.sqrt(t_i[i] ** 2 + t_j[i] ** 2 + t_k[i] ** 2) + 0.0001
t_i[i] /= norm1
t_j[i] /= norm1
t_k[i] /= norm1
t_i = t_i.reshape(kernel_shape1)
t_j = t_j.reshape(kernel_shape1)
t_k = t_k.reshape(kernel_shape1)
tmp_t = rng.uniform(low=-s, high=s, size=kernel_shape1)


phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
phase1 = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape1)

weight_r = modulus * np.cos(phase)
weight_i = modulus * v_i * np.sin(phase)
weight_j = modulus * v_j * np.sin(phase)
weight_k = modulus * v_k * np.sin(phase)

wt_i = tmp_t * t_i * np.sin(phase1)
wt_j = tmp_t * t_j * np.sin(phase1)
wt_k = tmp_t * t_k * np.sin(phase1)

i_0=weight_r
i_1=weight_i
i_2=weight_j
i_3=weight_k
i_4=(-wt_i*weight_i-wt_j*weight_j-wt_k*weight_k)/2
i_5=(wt_i*weight_r+wt_j*weight_k-wt_k*weight_j)/2
i_6=(-wt_i*weight_k+wt_j*weight_r+wt_k*weight_i)/2
i_7=(wt_i*weight_j-wt_j*weight_i+wt_k*weight_r)/2


return (i_0,i_1,i_2,i_3,i_4,i_5,i_6,i_7)
2 changes: 1 addition & 1 deletion dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_default_arguments(description=None):
parser.add_argument("--model", type=str,
default="Keci",
choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult",
"OMult", "Shallom", "DistMult", "TransE", "DeCaL",
"OMult", "Shallom", "DistMult", "TransE", "DualE",
"BytE",
"Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP",
"Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE",
Expand Down
5 changes: 4 additions & 1 deletion dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import datetime
from typing import Tuple, List
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL,\
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL, DualE,\
ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult
from .models.pykeen_models import PykeenKGE
from .models.transformers import BytE
Expand Down Expand Up @@ -421,6 +421,9 @@ def intialize_model(args: dict,verbose=0) -> Tuple[object, str]:
elif model_name == 'DeCaL':
model =DeCaL(args=args)
form_of_labelling = 'EntityPrediction'
elif model_name == 'DualE':
model =DualE(args=args)
form_of_labelling = 'EntityPrediction'
else:
raise ValueError(f"--model_name: {model_name} is not found.")
return model, form_of_labelling
Expand Down

0 comments on commit cd8935b

Please sign in to comment.