Skip to content

Commit

Permalink
Change torch.potrf usage to torch.cholesky (#1529)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored and fritzo committed Nov 13, 2018
1 parent 86e3cf3 commit a599587
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def laplace_approximation(self, *args, **kwargs):
loc = pyro.param("{}_loc".format(self.prefix))
H = hessian(loss, loc.unconstrained())
cov = H.inverse()
scale_tril = cov.potrf(upper=False)
scale_tril = cov.cholesky()

# calculate scale_tril from self.guide()
scale_tril_name = "{}_scale_tril".format(self.prefix)
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def model(self):
N = self.X.shape[0]
Kff = self.kernel(self.X)
Kff.view(-1)[::N + 1] += noise # add noise to diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

zero_loc = self.X.new_zeros(self.X.shape[0])
f_loc = zero_loc + self.mean_function(self.X)
Expand Down Expand Up @@ -129,7 +129,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
N = self.X.shape[0]
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += noise # add noise to the diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

y_residual = self.y - self.mean_function(self.X)
loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff,
Expand Down Expand Up @@ -185,7 +185,7 @@ def sample_next(xnew, outside_vars):
X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"]

# Compute Cholesky decomposition of kernel matrix
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
y_residual = y - self.mean_function(X)

# Compute conditional mean and variance
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/sgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def model(self):
M = Xu.shape[0]
Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()
Kuf = self.kernel(Xu, self.X)
W = Kuf.trtrs(Luu, upper=False)[0].t()

Expand Down Expand Up @@ -210,7 +210,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):

Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()
Kus = self.kernel(Xu, Xnew)
Kuf = self.kernel(Xu, self.X)

Expand All @@ -225,7 +225,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
W_Dinv = W / D
K = W_Dinv.matmul(W.t()).contiguous()
K.view(-1)[::M + 1] += 1 # add identity matrix to K
L = K.potrf(upper=False)
L = K.cholesky()

# get y_residual and convert it into 2D tensor for packing
y_residual = self.y - self.mean_function(self.X)
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/models/vgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def model(self):
N = self.X.shape[0]
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()

zero_loc = self.X.new_zeros(f_loc.shape)
f_name = param_with_module_name(self.name, "f")
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/models/vsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def model(self):
M = Xu.shape[0]
Kuu = self.kernel(Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.potrf(upper=False)
Luu = Kuu.cholesky()

zero_loc = Xu.new_zeros(u_loc.shape)
u_name = param_with_module_name(self.name, "u")
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa
if Lff is None:
Kff = kernel(X).contiguous()
Kff.view(-1)[::N + 1] += jitter # add jitter to diagonal
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
Kfs = kernel(X, Xnew)

# convert f_loc_shape from latent_shape x N to N x latent_shape
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/gp/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
X = torch.tensor([[1., 5.], [2., 1.], [3., 2.]])
kernel = Matern52(input_dim=2)
Kff = kernel(X) + torch.eye(3) * 1e-6
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
pyro.set_rng_seed(123)
f_loc = torch.rand(3)
f_scale_tril = torch.rand(3, 3).tril(-1) + torch.rand(3).exp().diag()
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_conditional_whiten(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov):
loc0, cov0 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True,
whiten=False)
Kff = kernel(X) + torch.eye(3) * 1e-6
Lff = Kff.potrf(upper=False)
Lff = Kff.cholesky()
whiten_f_loc = Lff.inverse().matmul(f_loc)
whiten_f_scale_tril = Lff.inverse().matmul(f_scale_tril)
loc1, cov1 = conditional(Xnew, X, kernel, whiten_f_loc, whiten_f_scale_tril,
Expand Down
2 changes: 1 addition & 1 deletion tests/perf/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def svgp_multiclass(num_steps, whiten):
pyro.set_rng_seed(0)
X = torch.rand(100, 1)
K = (-0.5 * (X - X.t()).pow(2) / 0.01).exp() + torch.eye(100) * 1e-6
f = K.potrf(upper=False).matmul(torch.randn(100, 3))
f = K.cholesky().matmul(torch.randn(100, 3))
y = f.argmax(dim=-1)

kernel = gp.kernels.Matern32(1).add(
Expand Down

0 comments on commit a599587

Please sign in to comment.