Skip to content

Commit

Permalink
CMult removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Oct 14, 2024
1 parent adf77d4 commit 9c4ccd4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 224 deletions.
226 changes: 3 additions & 223 deletions dicee/models/clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,227 +2,6 @@
import torch


class CMult(BaseKGE):
"""
Cl_(0,0) => Real Numbers
Cl_(0,1) =>
A multivector \mathbf{a} = a_0 + a_1 e_1
A multivector \mathbf{b} = b_0 + b_1 e_1
multiplication is isomorphic to the product of two complex numbers
\mathbf{a} \times \mathbf{b} = a_0 b_0 + a_0b_1 e1 + a_1 b_1 e_1 e_1
= (a_0 b_0 - a_1 b_1) + (a_0 b_1 + a_1 b_0) e_1
Cl_(2,0) =>
A multivector \mathbf{a} = a_0 + a_1 e_1 + a_2 e_2 + a_{12} e_1 e_2
A multivector \mathbf{b} = b_0 + b_1 e_1 + b_2 e_2 + b_{12} e_1 e_2
\mathbf{a} \times \mathbf{b} = a_0b_0 + a_0b_1 e_1 + a_0b_2e_2 + a_0 b_12 e_1 e_2
+ a_1 b_0 e_1 + a_1b_1 e_1_e1 ..
Cl_(0,2) => Quaternions
"""

def __init__(self, args):
super().__init__(args)
self.name = 'CMult'
self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim)
self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim)
self.param_init(self.entity_embeddings.weight.data), self.param_init(self.relation_embeddings.weight.data)
self.p = self.args['p']
self.q = self.args['q']
if self.p is None:
self.p = 0
if self.q is None:
self.q = 0
print(f'\tp:{self.p}\tq:{self.q}')

def clifford_mul(self, x: torch.FloatTensor, y: torch.FloatTensor, p: int, q: int) -> tuple:
"""
Clifford multiplication Cl_{p,q} (\mathbb{R})
ei ^2 = +1 for i =< i =< p
ej ^2 = -1 for p < j =< p+q
ei ej = -eje1 for i \neq j
Parameter
---------
x: torch.FloatTensor with (n,d) shape
y: torch.FloatTensor with (n,d) shape
p: a non-negative integer p>= 0
q: a non-negative integer q>= 0
Returns
-------
"""

if p == q == 0:
return x * y
elif (p == 1 and q == 0) or (p == 0 and q == 1):
# {1,e1} e_i^2 = +1 for i
a0, a1 = torch.hsplit(x, 2)
b0, b1 = torch.hsplit(y, 2)
if p == 1 and q == 0:
ab0 = a0 * b0 + a1 * b1
ab1 = a0 * b1 + a1 * b0
else:
ab0 = a0 * b0 - a1 * b1
ab1 = a0 * b1 + a1 * b0
return ab0, ab1
elif (p == 2 and q == 0) or (p == 0 and q == 2):
a0, a1, a2, a12 = torch.hsplit(x, 4)
b0, b1, b2, b12 = torch.hsplit(y, 4)
if p == 2 and q == 0:
ab0 = a0 * b0 + a1 * b1 + a2 * b2 - a12 * b12
ab1 = a0 * b1 + a1 * b0 - a2 * b12 + a12 * b2
ab2 = a0 * b2 + a1 * b12 + a2 * b0 - a12 * b1
ab12 = a0 * b12 + a1 * b2 - a2 * b1 + a12 * b0
else:
ab0 = a0 * b0 - a1 * b1 - a2 * b2 - a12 * b12
ab1 = a0 * b1 + a1 * b0 + a2 * b12 - a12 * b2
ab2 = a0 * b2 - a1 * b12 + a2 * b0 + a12 * b1
ab12 = a0 * b12 + a1 * b2 - a2 * b1 + a12 * b0
return ab0, ab1, ab2, ab12
elif p == 1 and q == 1:
a0, a1, a2, a12 = torch.hsplit(x, 4)
b0, b1, b2, b12 = torch.hsplit(y, 4)

ab0 = a0 * b0 + a1 * b1 - a2 * b2 + a12 * b12
ab1 = a0 * b1 + a1 * b0 + a2 * b12 - a12 * b2
ab2 = a0 * b2 + a1 * b12 + a2 * b0 - a12 * b1
ab12 = a0 * b12 + a1 * b2 - a2 * b1 + a12 * b0
return ab0, ab1, ab2, ab12
elif p == 3 and q == 0:
# cl3,0 no 0,3
a0, a1, a2, a3, a12, a13, a23, a123 = torch.hsplit(x, 8)
b0, b1, b2, b3, b12, b13, b23, b123 = torch.hsplit(y, 8)

ab0 = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 - a12 * b12 - a13 * b13 - a23 * b23 - a123 * b123
ab1 = a0 * b1 + a1 * b0 - a2 * b12 - a3 * b13 + a12 * b2 + a13 * b3 - a23 * b123 - a123 * b23
ab2 = a0 * b2 + a1 * b12 + a2 * b0 - a3 * b23 - a12 * b1 + a13 * b123 + a23 * b3 + a123 * b13
ab3 = a0 * b3 + a1 * b13 + a2 * b23 + a3 * b0 - a12 * b123 - a13 * b1 - a23 * b2 - a123 * b12
ab12 = a0 * b12 + a1 * b2 - a2 * b1 + a3 * b123 + a12 * b0 - a13 * b23 + a23 * b13 + a123 * b3
ab13 = a0 * b13 + a1 * b3 - a2 * b123 - a3 * b1 + a12 * b23 + a13 * b0 - a23 * b12 - a123 * b2
ab23 = a0 * b23 + a1 * b123 + a2 * b3 - a3 * b2 - a12 * b13 - a13 * b12 + a23 * b0 + a123 * b1
ab123 = a0 * b123 + a1 * b23 - a2 * b13 + a3 * b12 + a12 * b3 - a13 * b2 + a23 * b1 + a123 * b0
return ab0, ab1, ab2, ab3, ab12, ab13, ab23, ab123
else:
raise NotImplementedError

def score(self, head_ent_emb, rel_ent_emb, tail_ent_emb):
ab = self.clifford_mul(x=head_ent_emb, y=rel_ent_emb, p=self.p, q=self.q)

if self.p == self.q == 0:
return torch.einsum('bd,bd->b', ab, tail_ent_emb)
elif (self.p == 1 and self.q == 0) or (self.p == 0 and self.q == 1):
ab0, ab1 = ab
c0, c1 = torch.hsplit(tail_ent_emb, 2)
return torch.einsum('bd,bd->b', ab0, c0) + torch.einsum('bd,bd->b', ab1, c1)
elif (self.p == 2 and self.q == 0) or (self.p == 0 and self.q == 2):
ab0, ab1, ab2, ab12 = ab
c0, c1, c2, c12 = torch.hsplit(tail_ent_emb, 4)
return torch.einsum('bd,bd->b', ab0, c0) \
+ torch.einsum('bd,bd->b', ab1, c1) \
+ torch.einsum('bd,bd->b', ab2, c2) \
+ torch.einsum('bd,bd->b', ab12, c12)
else:
raise NotImplementedError

def forward_triples(self, x: torch.LongTensor) -> torch.FloatTensor:
"""
Compute batch triple scores
Parameter
---------
x: torch.LongTensor with shape n by 3
Returns
-------
torch.LongTensor with shape n
"""

# (1) Retrieve real-valued embedding vectors.
head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x)
ab = self.clifford_mul(x=head_ent_emb, y=rel_ent_emb, p=self.p, q=self.q)

if self.p == self.q == 0:
return torch.einsum('bd,bd->b', ab, tail_ent_emb)
elif (self.p == 1 and self.q == 0) or (self.p == 0 and self.q == 1):
ab0, ab1 = ab
c0, c1 = torch.hsplit(tail_ent_emb, 2)
return torch.einsum('bd,bd->b', ab0, c0) + torch.einsum('bd,bd->b', ab1, c1)
elif (self.p == 2 and self.q == 0) or (self.p == 0 and self.q == 2):
ab0, ab1, ab2, ab12 = ab
c0, c1, c2, c12 = torch.hsplit(tail_ent_emb, 4)
return torch.einsum('bd,bd->b', ab0, c0) \
+ torch.einsum('bd,bd->b', ab1, c1) \
+ torch.einsum('bd,bd->b', ab2, c2) \
+ torch.einsum('bd,bd->b', ab12, c12)
else:
raise NotImplementedError

def forward_k_vs_all(self, x: torch.Tensor) -> torch.FloatTensor:
"""
Compute batch KvsAll triple scores
Parameter
---------
x: torch.LongTensor with shape n by 3
Returns
-------
torch.LongTensor with shape n
"""
# (1) Retrieve embedding vectors of heads and relations.
head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x)
# (2) CL multiply (1).
ab = self.clifford_mul(x=head_ent_emb, y=rel_ent_emb, p=self.p, q=self.q)
# (3) Inner product of (2) and all entity embeddings.
if self.p == self.q == 0:
return torch.mm(ab, self.entity_embeddings.weight.transpose(1, 0))
elif (self.p == 1 and self.q == 0) or (self.p == 0 and self.q == 1):
ab0, ab1 = ab
c0, c1 = torch.hsplit(self.entity_embeddings.weight, 2)
return torch.mm(ab0, c0.transpose(1, 0)) + torch.mm(ab1, c1.transpose(1, 0))
elif (self.p == 2 and self.q == 0) or (self.p == 0 and self.q == 2):
ab0, ab1, ab2, ab12 = ab
c0, c1, c2, c12 = torch.hsplit(self.entity_embeddings.weight, 4)
return torch.mm(ab0, c0.transpose(1, 0)) + \
torch.mm(ab1, c1.transpose(1, 0)) + torch.mm(ab2, c2.transpose(1, 0)) + torch.mm(
ab12, c12.transpose(1, 0))
elif self.p == 3 and self.q == 0:

ab0, ab1, ab2, ab3, ab12, ab13, ab23, ab123 = ab
c0, c1, c2, c3, c12, c13, c23, c123 = torch.hsplit(self.entity_embeddings.weight, 8)

return torch.mm(ab0, c0.transpose(1, 0)) \
+ torch.mm(ab1, c1.transpose(1, 0)) \
+ torch.mm(ab2, c2.transpose(1, 0)) \
+ torch.mm(ab3, c3.transpose(1, 0)) + \
torch.mm(ab12, c3.transpose(1, 0)) + torch.mm(ab13, c13.transpose(1, 0)) \
+ torch.mm(ab23, c23.transpose(1, 0)) + torch.mm(ab123, c123.transpose(1, 0))
elif self.p == 1 and self.q == 1:
ab0, ab1, ab2, ab12 = ab
c0, c1, c2, c12 = torch.hsplit(self.entity_embeddings.weight, 4)
return torch.mm(ab0, c0.transpose(1, 0)) + torch.mm(ab1, c1.transpose(1, 0)) + \
torch.mm(ab2, c2.transpose(1, 0)) + torch.mm(
ab12, c12.transpose(1, 0))

else:
raise NotImplementedError


class Keci(BaseKGE):
def __init__(self, args):
super().__init__(args)
Expand All @@ -242,6 +21,7 @@ def __init__(self, args):
self.r = int(self.r)
self.requires_grad_for_interactions = True
# Initialize parameters for dimension scaling
# TODO:Do we need coefficients for the real part ?
if self.p > 0:
self.p_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.p)
torch.nn.init.zeros_(self.p_coefficients.weight)
Expand Down Expand Up @@ -747,7 +527,7 @@ def score(self, h, r, t):
if self.q > 0:
self.q_coefficients = self.q_coefficients.to(h0.device, non_blocking=True)

h0, hp, hq, h0, rp, rq = self.apply_coefficients(h0, hp, hq, h0, rp, rq)
hp, hq, rp, rq = self.apply_coefficients(hp, hq, rp, rq)
# (4) Compute a triple score based on interactions described by the basis 1. Eq. 20
h0r0t0 = torch.einsum('br, br -> b', h0 * r0, t0)

Expand Down Expand Up @@ -808,7 +588,7 @@ def forward_triples(self, x: torch.Tensor) -> torch.FloatTensor:
h0, hp, hq = self.construct_cl_multivector(head_ent_emb, r=self.r, p=self.p, q=self.q)
r0, rp, rq = self.construct_cl_multivector(rel_ent_emb, r=self.r, p=self.p, q=self.q)
t0, tp, tq = self.construct_cl_multivector(tail_ent_emb, r=self.r, p=self.p, q=self.q)
h0, hp, hq, h0, rp, rq = self.apply_coefficients(h0, hp, hq, h0, rp, rq)
hp, hq, rp, rq = self.apply_coefficients( hp, hq, rp, rq)
# (4) Compute a triple score based on interactions described by the basis 1. Eq. 20
h0r0t0 = torch.einsum('br, br -> b', h0 * r0, t0)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_online_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_umls(self):
args.read_only_few = None
args.sample_triples_ratio = None
args.num_folds_for_cv = None
args.backend = 'polars'
args.backend = 'pandas'
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
assert os.path.isdir(result['path_experiment_folder'])
Expand Down

0 comments on commit 9c4ccd4

Please sign in to comment.