Skip to content

Commit

Permalink
batched cg (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawalraj authored Dec 22, 2023
1 parent 538cef8 commit 6bba70b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
43 changes: 27 additions & 16 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _flat_conjugate_gradient_solve(
b: torch.Tensor,
*,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-10,
residual_tol: float = 1e-3,
) -> torch.Tensor:
r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312.
Expand All @@ -42,31 +42,41 @@ def _flat_conjugate_gradient_solve(
Notes: This code is adapted from
https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py
"""
assert len(b.shape), "b must be a 2D matrix"

if cg_iters is None:
cg_iters = b.numel()
cg_iters = b.shape[1]
else:
cg_iters = min(cg_iters, b.shape[1])

def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return (x1 * x2).sum(axis=-1) # type: ignore

def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return a.unsqueeze(0).t() * B

p = b.clone()
r = b.clone()
x = torch.zeros_like(b)
z = f_Ax(p)
rdotr = torch.dot(r, r)
v = rdotr / torch.dot(p, z)
rdotr = _batched_dot(r, r)
v = rdotr / _batched_dot(p, z)
newrdotr = rdotr
mu = newrdotr / rdotr

zeros_xr = torch.zeros_like(x)

for _ in range(cg_iters):
not_converged = rdotr > residual_tol
z = torch.where(not_converged, f_Ax(p), z)
v = torch.where(not_converged, rdotr / torch.dot(p, z), v)
x += torch.where(not_converged, v * p, zeros_xr)
r -= torch.where(not_converged, v * z, zeros_xr)
newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr)
not_converged_broadcasted = not_converged.unsqueeze(0).t()
z = torch.where(not_converged_broadcasted, f_Ax(p), z)
v = torch.where(not_converged, rdotr / _batched_dot(p, z), v)
x += torch.where(not_converged_broadcasted, _batched_product(v, p), zeros_xr)
r -= torch.where(not_converged_broadcasted, _batched_product(v, z), zeros_xr)
newrdotr = torch.where(not_converged, _batched_dot(r, r), newrdotr)
mu = torch.where(not_converged, newrdotr / rdotr, mu)
p = torch.where(not_converged, r + mu * p, p)
p = torch.where(not_converged_broadcasted, r + _batched_product(mu, p), p)
rdotr = torch.where(not_converged, newrdotr, rdotr)

if torch.all(~not_converged):
return x
return x


Expand Down Expand Up @@ -162,6 +172,9 @@ def _fn(
func_log_prob, log_prob_params, data, *args, **kwargs
)
pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp)
pinned_fvp_batched = torch.func.vmap(
lambda v: pinned_fvp(v), randomness="different"
)
batched_func_log_prob = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
Expand All @@ -179,8 +192,6 @@ def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor:
N_pts = points[next(iter(points))].shape[0] # type: ignore
point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0]
point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()}
return torch.func.vmap(
lambda v: cg_solver(pinned_fvp, v), randomness="different"
)(point_scores)
return cg_solver(pinned_fvp_batched, point_scores)

return _fn
15 changes: 12 additions & 3 deletions chirho/robust/internals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def make_flatten_unflatten(

@make_flatten_unflatten.register(torch.Tensor)
def _make_flatten_unflatten_tensor(v: torch.Tensor):
batch_size = v.shape[0]

def flatten(v: torch.Tensor) -> torch.Tensor:
r"""
Flatten a tensor into a single vector.
"""
return v.flatten()
return v.reshape((batch_size, -1))

def unflatten(x: torch.Tensor) -> torch.Tensor:
r"""
Expand All @@ -40,11 +42,13 @@ def unflatten(x: torch.Tensor) -> torch.Tensor:

@make_flatten_unflatten.register(dict)
def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]):
batch_size = next(iter(d.values())).shape[0]

def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor:
r"""
Flatten a dictionary of tensors into a single vector.
"""
return torch.cat([v.flatten() for k, v in d.items()])
return torch.hstack([v.reshape((batch_size, -1)) for k, v in d.items()])

def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]:
r"""
Expand All @@ -56,7 +60,12 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]:
[
v_flat.reshape(v.shape)
for v, v_flat in zip(
d.values(), torch.split(x, [v.numel() for k, v in d.items()])
d.values(),
torch.split(
x,
[int(v.numel() / batch_size) for k, v in d.items()],
dim=1,
),
)
],
)
Expand Down
9 changes: 7 additions & 2 deletions tests/robust/test_internals_compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition():

with torch.no_grad():
data = func_predictive(predictive_params)
fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data)

fvp = torch.func.vmap(
make_empirical_fisher_vp(func_log_prob, log_prob_params, data)
)

v = {
k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v)
k: torch.ones_like(v).unsqueeze(0)
if k != "guide.loc_a"
else torch.zeros_like(v).unsqueeze(0)
for k, v in log_prob_params.items()
}

Expand Down
29 changes: 6 additions & 23 deletions tests/robust/test_internals_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@
T = TypeVar("T")


@pytest.mark.parametrize("ndim", [1, 2, 3, 10])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_cg_solve(ndim: int, dtype: torch.dtype):
cg_iters = None
residual_tol = 1e-10
U = torch.rand(ndim, ndim, dtype=dtype)
A = torch.eye(ndim, dtype=dtype) + 0.1 * U.mm(U.t())
expected_x = torch.randn(ndim, dtype=dtype)
b = A @ expected_x

actual_x = conjugate_gradient_solve(
lambda v: A @ v, b, cg_iters=cg_iters, residual_tol=residual_tol
)
assert torch.sum((actual_x - expected_x) ** 2) < 1e-4


@pytest.mark.parametrize("ndim", [1, 2, 3, 10])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("num_particles", [1, 4])
Expand All @@ -63,14 +47,13 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int):
b = torch.einsum("ij,nj->ni", A, expected_x)
assert b.shape == (num_particles, ndim)

batch_solve = torch.vmap(
functools.partial(
conjugate_gradient_solve,
lambda v: A @ v,
cg_iters=cg_iters,
residual_tol=residual_tol,
),
batch_solve = functools.partial(
conjugate_gradient_solve,
lambda v: torch.einsum("ij,nj->ni", A, v),
cg_iters=cg_iters,
residual_tol=residual_tol,
)

actual_x = batch_solve(b)

assert torch.all(torch.sum((actual_x - expected_x) ** 2, dim=1) < 1e-4)
Expand Down

0 comments on commit 6bba70b

Please sign in to comment.