Skip to content

Commit

Permalink
Batching in linearize and influence (#465)
Browse files Browse the repository at this point in the history
* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error
  • Loading branch information
agrawalraj authored Dec 22, 2023
1 parent 2e01b7b commit 538cef8
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 19 deletions.
35 changes: 27 additions & 8 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def linearize(
max_plate_nesting: Optional[int] = None,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-10,
pointwise_influence: bool = True,
) -> Callable[Concatenate[Point[T], P], ParamDict]:
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
Expand All @@ -140,8 +141,6 @@ def linearize(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
)
log_prob_params, func_log_prob = make_functional_call(log_prob)
score_fn = torch.func.grad(func_log_prob)

log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
if cg_iters is None:
cg_iters = log_prob_params_numel
Expand All @@ -151,17 +150,37 @@ def linearize(
conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol
)

@functools.wraps(score_fn)
def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict:
def _fn(
points: Point[T],
*args: P.args,
**kwargs: P.kwargs,
) -> ParamDict:
with torch.no_grad():
data: Point[T] = func_predictive(predictive_params, *args, **kwargs)
data = {k: data[k] for k in point.keys()}
data = {k: data[k] for k in points.keys()}
fvp = make_empirical_fisher_vp(
func_log_prob, log_prob_params, data, *args, **kwargs
)

pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp)
point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs)
return cg_solver(pinned_fvp, point_score)
batched_func_log_prob = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)

def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor:
return batched_func_log_prob(p, points)

if pointwise_influence:
score_fn = torch.func.jacrev(bound_batched_func_log_prob)
point_scores = score_fn(log_prob_params)
else:
score_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1]
N_pts = points[next(iter(points))].shape[0] # type: ignore
point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0]
point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()}
return torch.func.vmap(
lambda v: cg_solver(pinned_fvp, v), randomness="different"
)(point_scores)

return _fn
14 changes: 9 additions & 5 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def influence_fn(
target_params, func_target = make_functional_call(target)

@functools.wraps(target)
def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
param_eif = linearized(point, *args, **kwargs)
return torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,)
)[1]
def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
param_eif = linearized(points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn
24 changes: 20 additions & 4 deletions tests/robust/test_internals_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def test_nmc_param_influence_vmap_smoke(
model, num_samples=4, return_sites=obs_names, parallel=True
)()

batch_param_eif = torch.vmap(param_eif, randomness="different")
test_data_eif: Mapping[str, torch.Tensor] = batch_param_eif(test_data)
test_data_eif: Mapping[str, torch.Tensor] = param_eif(test_data)
assert len(test_data_eif) > 0
for k, v in test_data_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand Down Expand Up @@ -349,10 +348,10 @@ def link(mu):
num_samples_outer=10000,
num_samples_inner=1,
cg_iters=4, # dimension of params = 4
pointwise_influence=True,
)

batch_param_eif = torch.vmap(param_eif, randomness="different")
test_data_eif = batch_param_eif(D_test)
test_data_eif = param_eif(D_test)
median_abs_error = torch.abs(
test_data_eif["guide.treatment_weight_param"] - analytic_eif_at_test_pts
).median()
Expand All @@ -361,3 +360,20 @@ def link(mu):
assert median_abs_error / median_scale < 0.5
else:
assert median_abs_error < 0.5

# Test w/ pointwise_influence=False
param_eif = linearize(
model,
mle_guide,
num_samples_outer=10000,
num_samples_inner=1,
cg_iters=4, # dimension of params = 4
pointwise_influence=False,
)

test_data_eif = param_eif(D_test)
assert torch.allclose(
test_data_eif["guide.treatment_weight_param"][0],
analytic_eif_at_test_pts.mean(),
atol=1e-1,
)
3 changes: 1 addition & 2 deletions tests/robust/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def test_nmc_predictive_influence_vmap_smoke(
model, num_samples=4, return_sites=obs_names, parallel=True
)()

batch_predictive_eif = torch.vmap(predictive_eif, randomness="different")
test_data_eif: Mapping[str, torch.Tensor] = batch_predictive_eif(test_data)
test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data)
assert len(test_data_eif) > 0
for k, v in test_data_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand Down

0 comments on commit 538cef8

Please sign in to comment.