-
Notifications
You must be signed in to change notification settings - Fork 1
/
pairing.py
74 lines (58 loc) · 1.66 KB
/
pairing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
def szudzik_encode(x, y):
out1 = x * x + x + y
out2 = x + y * y
# mask = x == torch.max([x, y], dim=0)
mask = x == torch.maximum(x, y)
out = out1 * mask + out2 * (~mask)
return out
def szudzik_decode(z):
sqrtz = torch.floor(torch.sqrt(z))
sqz = sqrtz * sqrtz
diff = z - sqz
mask = diff < sqrtz
x = diff * mask + sqrtz * ~mask
y = sqrtz * mask + (diff - sqrtz) * ~mask
return torch.stack([x, y]).int()
class SzudzikPair:
"""
Szudzik's pairing function
Allows to map a pair of integers to a unique integer, in a reversible way.
"""
def __init__(self) -> None:
pass
@staticmethod
def encode(x, y):
out1 = x * x + x + y
out2 = x + y * y
mask = x == np.max([x, y], axis=0)
out = out1 * mask + out2 * (1 - mask)
return out
@staticmethod
def decode(z):
sqrtz = np.floor(np.sqrt(z))
sqz = sqrtz * sqrtz
diff = z - sqz
mask = diff < sqrtz
x = diff * mask + sqrtz * ~mask
y = sqrtz * mask + (diff - sqrtz) * ~mask
return np.stack([x, y]).astype(int)
class OrdinalPair:
def __init__(self) -> None:
pass
@staticmethod
def encode(x, y):
_, idx = np.unique(np.stack([x, y]), axis=0, return_inverse=True)
return idx
class LinearIndexing:
def __init__(self, n, m=None):
self.n = n
if m is None:
self.m = n
else:
self.m = m
def encode(self, x, y):
return np.ravel_multi_index((x, y), (self.n, self.m))
def decode(self, z):
return np.unravel_index(z, (self.n, self.m))