-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize ELBO computation over num particles #1176
Conversation
5c3686d
to
cf65af6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good so far!
pyro/infer/elbo.py
Outdated
@@ -32,7 +45,63 @@ class ELBO(object): | |||
def __init__(self, | |||
num_particles=1, | |||
max_iarange_nesting=float('inf'), | |||
auto_vectorize=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe call this vectorize_particles
in case we want to vectorize some other aspect?
with pyro.iarange("num_particles_vectorized", self.num_particles, dim=-self.max_iarange_nesting): | ||
return fn(*args, **kwargs) | ||
|
||
return poutine.broadcast(wrapped_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth checking whether fn
is already wrapped in poutine.broadcast
and avoid wrapping a second time? Alternatively, this avoidance check could be done inside poutine.broadcast
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I'll check it in poutine.broadcast
itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add fn.__broadcasted__
as an attribute to mark a function that has already been broadcasted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't broadcast
be idempotent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes broadcast
should be idempotent, I'm asking only whether it is worth avoiding double wrapping for efficiency purposes. If it costs more than 5 lines of code, then it's probably not worth avoiding the double wrap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I'll remove this.
pyro/infer/elbo.py
Outdated
strict_enumeration_warning=True): | ||
self.num_particles = num_particles | ||
self.max_iarange_nesting = max_iarange_nesting | ||
self.auto_vectorize = auto_vectorize | ||
if self.auto_vectorize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if self.vectorize_particles and self.num_particles > 1: ...
pyro/infer/elbo.py
Outdated
Runs the guide and runs the model against the guide with | ||
the result packaged as a trace generator. | ||
""" | ||
if self.auto_vectorize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may error if self.num_particles == 1
since max_iarange_nesting
will not have been incremented. Instead gate on if self.vectorize_particles and self.num_particles > 1
, as used elsewhere, or simplify the gate above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _vectorized_num_particles
method does a special check for num_particles=1
, so it should work fine. Note that vectorize_particles
also calls poutine.broadcast
(the two effect are hard to decouple), so even when the user has num_particles=1
with parallelize_particles=True
, we would like to wrap the model and guide inside a poutine.broadcast
so that the behavior is consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! do you have any basic speed comparisons, at least on the test examples?
pyro/infer/elbo.py
Outdated
@@ -15,6 +23,11 @@ class ELBO(object): | |||
:func:`pyro.iarange` contexts. This is only required to enumerate over | |||
sample sites in parallel, e.g. if a site sets | |||
``infer={"enumerate": "parallel"}``. | |||
:param bool auto_vectorize: Vectorize ELBO computation over `num_particles`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add what it defaults to
check_site_shape(site, self.max_iarange_nesting) | ||
|
||
yield model_trace, guide_trace | ||
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this method identical sans the for loop over num_particles? the diff shows the whole method has changed but from a cursory glance it looks the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are identical. have you tried the old "cross your eyes at the side-by-side diff" trick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i do the "cmd-f a randomly sampled substring and see if it appears on both sides". 90% of the time, it works every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's right! Mostly Trace_ELBO
, TraceGraph_ELBO
and TraceEnum_ELBO
will only differ in how they generate a single trace. This is reused when we want to generate traces vectorized over num_particles.
@jpchen - This will be many orders of magnitude faster than sequentially sampling traces num particles times. e.g. for I intend to do some comparisons w.r.t. hand parallelization which is what we should be comparing against. I did not observe any significant difference for most models, but there are cases when we end up adding up to 20% more time as compared to hand parallelization. It is still very fast though (all in milliseconds vs seconds or mins). |
pyro/poutine/handlers.py
Outdated
@@ -248,7 +248,14 @@ def broadcast(fn=None): | |||
... return sample | |||
""" | |||
msngr = BroadcastMessenger() | |||
return msngr(fn) if fn is not None else msngr | |||
if fn is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can omit this logic until we see a real performance cost of double wrapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is super awesome!
lgtm
question: i guess as the code stands we expect the user to know when auto-vectorization will break (e.g. in cases with stochastic control flow etc.)?
b9c04e6
to
6e1ac43
Compare
Yeah, we'd have to do a bit of static analysis for automatic validation |
That's true. I think if an iarange does not specify a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for simplifying the enumeration tests!
tests/infer/test_enum.py
Outdated
pyro.sample("x3", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate3}) | ||
pyro.sample("x1", dist.Bernoulli(q), infer={"enumerate": enumerate1}) | ||
pyro.sample("x2", dist.Bernoulli(q), infer={"enumerate": enumerate2}) | ||
pyro.sample("x3", dist.Bernoulli(q), infer={"enumerate": enumerate3}) | ||
|
||
kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3]) | ||
expected_loss = kl.item() | ||
expected_grad = grad(kl, [q])[0] | ||
|
||
elbo = TraceEnum_ELBO(max_iarange_nesting=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can now set max_iarange_nesting=0
since it is incremented internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/infer/test_enum.py
Outdated
pyro.sample("x1", dist.Bernoulli(0.1).expand_by([num_particles])) | ||
pyro.sample("x2", dist.Bernoulli(0.2).expand_by([num_particles])) | ||
pyro.sample("x3", dist.Bernoulli(0.3).expand_by([num_particles])) | ||
pyro.sample("x1", dist.Bernoulli(0.1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: outdent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, fixed.
tests/infer/test_enum.py
Outdated
pyro.sample("y", d.expand_by([inner_dim, 1, num_particles])) | ||
with context1, context2: | ||
pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles])) | ||
context1 = pyro.iarange("outer", outer_dim, dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we change these dims to -1,-2
now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/infer/test_gradient.py
Outdated
@@ -26,35 +26,35 @@ def test_subsample_gradient(Elbo, reparameterized, subsample): | |||
pyro.clear_param_store() | |||
data = torch.tensor([-0.5, 2.0]) | |||
subsample_size = 1 if subsample else len(data) | |||
num_particles = 50000 | |||
precision = 0.06 | |||
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal | |||
|
|||
def model(subsample): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should wrap these model and guide with @poutine.broadcast
and avoid relying on the implicit @poutine.broadcast
implied by num_particles>1
and vectorize_particles=True
. While it is correct without the explicit decorator, it is better style to be explicit, and users may look to our tests for idiomatic usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vectorize_particles
with num_particles=1
would still use @poutine.broadcast
, so as not to have some special surprising behavior for num_particles=1
. That was part of the reason why I thought of calling it auto_vectorize
to emphasize that its a bit more general than adding an outer iarange, and could be used as a replacement for @poutine.broadcast
with or without num_particles
. That said, vectorize_particles
is a clearer arg name, but it makes the broadcasting side-effect seem unexpected. Will change to using @poutine.broadcast
, unless there is a better name that we can use for this arg.
tests/infer/test_valid_models.py
Outdated
|
||
@pytest.mark.parametrize('enumerate_', [None, "sequential", "parallel"]) | ||
def test_enum_discrete_vectorized_num_particles(enumerate_): | ||
num_particles = 50 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you @pytest.mark.parametrize('num_particles', [1, 10])
to test the single-particle edge case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will do.
31371c5
to
76fa97c
Compare
elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized) | ||
|
||
for i in range(1, N + 1): | ||
for param_prefix in ["loc_q_%d", "log_sig_q_%d", "kappa_q_%d"]: | ||
if i == N and param_prefix == 'kappa_q_%d': | ||
continue | ||
actual_grad = pyro.param(param_prefix % i).grad | ||
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.20, msg="".join([ | ||
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.10, msg="".join([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
pyro/infer/elbo.py
Outdated
@@ -15,6 +23,11 @@ class ELBO(object): | |||
:func:`pyro.iarange` contexts. This is only required to enumerate over | |||
sample sites in parallel, e.g. if a site sets | |||
``infer={"enumerate": "parallel"}``. | |||
:param bool vectorize_particles: Vectorize ELBO computation over `num_particles`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider adding more info:
"Whether to vectorize the ELBO computation over num_particles
. Defaults to False. This requires static structure in model and guide."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will update.
The failure is due to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Resolves #791.
My initial plan was to write a separate SVI class for this; but it turned out that modifying the base
ELBO
class was the simplest way to percolate this change to all the other loss classes.I have modified only a few tests that were hand-parallelizing over num particles, but we can do a broader sweep later after this PR is merged. It is not strictly necessary though since those tests are fast anyways. Also modified the
conjugate_gradient
test that was not parallelized.