-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implicitly broadcast sample sites using iarange dim and size information #1125
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this idempotent? in the sense that if i wrap my models in poutine.broadcast
but then dont nest/use broadcasting in my iaranges should everything still work?
pyro/util.py
Outdated
@@ -218,6 +218,7 @@ def check_site_shape(site, max_iarange_nesting): | |||
if max_iarange_nesting < len(actual_shape): | |||
actual_shape = actual_shape[len(actual_shape) - max_iarange_nesting:] | |||
|
|||
expected_shape = broadcast_shape(expected_shape, actual_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should do strict checking here iiuc that the flag toggles allowing reshaping
It should work; but this is worth adding as a test. |
Removed generic broadcasting, and added test for idempotence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @neerajprad, this is surprisingly simple!
pyro/poutine/broadcast_messenger.py
Outdated
`BroadcastMessenger` automatically broadcasts the batch shape of | ||
the stochastic function at a sample site when inside a single | ||
or nested iarange context. The existing `batch_shape` must be | ||
broadcastable with the size of the :class::`pyro.iarange` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
- :class::`pyro.iarange`
+ :class:`pyro.iarange`
pyro/poutine/broadcast_messenger.py
Outdated
dist = msg["fn"] | ||
actual_batch_shape = getattr(dist, "batch_shape", None) | ||
if actual_batch_shape is not None: | ||
target_batch_shape = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I think this could be a little stricter and a little simpler if you used -1
sizes, something like
target_batch_shape = [-1 if size == 1 else size for size in actual_batch_shape]
for f in msg["cond_indep_stack"]:
if f.dim is None:
continue
assert f.dim < 0
if -f.dim > len(target_batch_shape):
target_batch_shape = [-1] * (-f.dim - len(target_batch_shape)) + target_batch_shape
elif target_batch_shape[f.dim] not in (-1, f.size):
raise ValueError("... dim collision ...")
target_batch_shape[f.dim] = f.size
msg["fn"] = msg["fn"].expand(target_batch_shape)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem that I was facing was with expanding smaller-sized latent sites when the iarange dims appear in a staggered fashion, which would cut off the broadcasting starting from the -1 index :
with pyro.iarange("num_particles", 10, dim=-3):
with pyro.iarange("components", 2, dim=-1):
# with .expand([10, -1, 2]), we get s.shape == (10,)
# with .expand([10, 1, 2]), we get s.shape == (10, 1, 2)
s = pyro.sample("sample", dist.Bernoulli(0.5))
with pyro.iarange("data", 100, dim=-2):
# Note that we need s to have shape (10, 1, 2) here to correctly
# expand to (10, 100, 2)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the snippet I suggested should be order invariant. Note the final .expand()
statement is outside of the loop. Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just updated the snippet to clarify. The problem comes in the sample site in the second iarange
where the outermost iarange is at -3 and the inner components
one is at -1 (the data
iarange is yet to come). If we expand s
as dist.Bernoulli(0.5).expand([10, -1, 2])
(note the default -1
) it will give us s.shape == torch.Size((10,))
whereas we want it to be of torch.Size((10, 1, 2))
so that it can be correctly broadcasted by the data
iarange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining, I didn't know that -1 cannot be used for new dimensions in torch.expand()
. Could we still try to catch expand errors early, while the iarange.name
is still around? Maybe
target_batch_shape = [None if size == 1 else size for size in actual_batch_shape]
for f in msg["cond_indep_stack"]:
if f.dim is None:
continue
assert f.dim < 0
target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape
if target_batch_shape[f.dim] not in (None, f.size):
raise ValueError("Shape mismatch inside iarange('{}') at site {} dim {}, {} vs {}".format(
f.name, msg['name'], f.dim, f.size, target_batch_shape[dim]))
target_batch_shape[f.dim] = f.size
# ... remainder of your code ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point; will update.
tests/infer/test_gradient.py
Outdated
pyro.sample("nuisance_b", Normal(2, 3)) | ||
pyro.sample("nuisance_a", Normal(0, 1)) | ||
|
||
optim = Adam({"lr": 0.1}) | ||
model, guide = poutine.broadcast(model), poutine.broadcast(guide) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it's a little clearer to decorate
@poutine.broadcast
def model():
...
@poutine.broadcast
def guide():
...
That way readers know as soon as they start reading the model that it should be read with broadcast semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty neat!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This turned out nice - you should advertise it a little! Maybe add a usage example to the broadcast
docstring and a note in the last section of our tensor shape tutorial?
def guide(): | ||
with pyro.iarange("num_particles", 10, dim=-3): | ||
with pyro.iarange("components", 2, dim=-1): | ||
pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe put this or something similar in the handlers.broadcast
docstring as a usage example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will update.
Will add to the docstring. I wasn't sure if we should add this to the tutorial yet, because things may change or we may uncover some edge cases, as we discuss and expand our broadcasting semantics to handle different use cases. Let me create a task to update our tutorial once things are stabilized. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This should be good to merge, unless there are further comments. |
This makes a small change to use the
expand
logic from #1119 to implement implicit broadcasting via abroadcast
poutine.This should be safe to merge as this logic will not be exercised until the user wraps their model/guides inside
poutine.broadcast
. We could later decide to reuse this to do the broadcasting behind the scenes for parallelizing overnum_particles
, for instance. Or use our learnings to build more composable broadcasting effects as @eb8680 mentioned in #1115.e.g. Using @fritzo's example from #1119
we do not need to
expand
orexpand_by
if the model is wrapped insidepoutine.broadcast
:iarange
s correctly via thedim
arg.test_valid_models
. Also modified one of the gradient tests to use the broadcasting logic.