forked from haihabi/MD-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simplex_generator.py
74 lines (61 loc) · 2.4 KB
/
simplex_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
from dataclasses import dataclass
def simplex_coordinates(m):
# This function is adopted from the Simplex Coordinates library
# https://people.sc.fsu.edu/~jburkardt/py_src/simplex_coordinates/simplex_coordinates.html
x = np.zeros([m, m + 1])
for j in range(0, m):
x[j, j] = 1.0
a = (1.0 - np.sqrt(float(1 + m))) / float(m)
for i in range(0, m):
x[i, m] = a
c = np.zeros(m)
for i in range(0, m):
s = 0.0
for j in range(0, m + 1):
s = s + x[i, j]
c[i] = s / float(m + 1)
for j in range(0, m + 1):
for i in range(0, m):
x[i, j] = x[i, j] - c[i]
s = 0.0
for i in range(0, m):
s = s + x[i, 0] ** 2
s = np.sqrt(s)
for j in range(0, m + 1):
for i in range(0, m):
x[i, j] = x[i, j] / s
return x
def var2cov(bot_dim, ngmm):
cov = np.zeros((bot_dim, bot_dim))
for k_ in range(bot_dim):
cov[k_, k_] = 1.
sigma_real_batch = []
for c in range(ngmm):
sigma_real_batch.append(cov)
return np.array(sigma_real_batch, dtype=np.float32).squeeze().astype('float32') * .25
@dataclass
class Simplex:
mu: torch.Tensor
sigma: torch.Tensor
w: torch.Tensor
sigma_det_rsqrt: torch.Tensor
sigma_inv: torch.Tensor
def simplex_params(bot_dim: int, input_working_device: torch.device) -> Simplex:
ngmm = bot_dim + 1
mu_real_batch = simplex_coordinates(bot_dim)
sigma_real = var2cov(bot_dim, ngmm).astype('float32')
mu_real = np.array(mu_real_batch.T, dtype=np.float32)
w_real = (np.ones((ngmm,)) / ngmm).astype('float32')
sigma_det_rsqrt = np.power(2 * np.pi * np.linalg.det(sigma_real), -0.5)
sigma_inv = np.linalg.inv(sigma_real)
##########################################
# Change to torch tensor
##########################################
mu_simplex = torch.tensor(mu_real, device=input_working_device, dtype=torch.float32)
sigma_simplex = torch.tensor(sigma_real, device=input_working_device, dtype=torch.float32)
w_simplex = torch.tensor(w_real, device=input_working_device, dtype=torch.float32)
sigma_det_rsqrt = torch.tensor(sigma_det_rsqrt, device=input_working_device, dtype=torch.float32)
sigma_inv = torch.tensor(sigma_inv, device=input_working_device, dtype=torch.float32)
return Simplex(mu_simplex, sigma_simplex, w_simplex, sigma_det_rsqrt, sigma_inv)