Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicitly broadcast sample sites using iarange dim and size information #1125

Merged
merged 8 commits into from
May 9, 2018

Conversation

neerajprad
Copy link
Member

@neerajprad neerajprad commented May 4, 2018

This makes a small change to use the expand logic from #1119 to implement implicit broadcasting via a broadcast poutine.

This should be safe to merge as this logic will not be exercised until the user wraps their model/guides inside poutine.broadcast. We could later decide to reuse this to do the broadcasting behind the scenes for parallelizing over num_particles, for instance. Or use our learnings to build more composable broadcasting effects as @eb8680 mentioned in #1115.

e.g. Using @fritzo's example from #1119

X, Y = 320, 200
x_axis = pyro.iarange("x_axis", X, dim=-2)
y_axis = pyro.iarange("y_axis", Y, dim=-1)
with x_axis:
    x_noise = pyro.sample("x", dist.Normal(0, 1).expand_by([X, 1])) # .expand_by() suffices
with y_axis:
    y_noise = pyro.sample("y", dist.Normal(0, 1).expand_by([Y]))    # .expand_by() suffices
with x_axis, y_axis:
    yx = pyro.sample("yx", dist.Normal(y_noise, 1).expand_by([X])) # .expand_by() suffices
    xy = pyro.sample("xy", dist.Normal(x_noise, 1).expand([X, Y])) # .expand() is needed
    ...

we do not need to expand or expand_by if the model is wrapped inside poutine.broadcast:

x_axis = pyro.iarange("x_axis", X, dim=-2)
y_axis = pyro.iarange("y_axis", Y, dim=-1)
with x_axis:
    x_noise = pyro.sample("x", dist.Normal(0, 1))
    assert x_noise.shape == torch.Size((320, 1))
with y_axis:
    y_noise = pyro.sample("y", dist.Normal(0, 1))
    assert y_noise.shape == torch.Size((200,))
with x_axis, y_axis:
    yx = pyro.sample("yx", dist.Normal(y_noise, 1))
    assert yx.shape == torch.Size((320, 200))
    xy = pyro.sample("xy", dist.Normal(x_noise, 1))
    assert xy.shape == torch.Size((320, 200))
  • Note that this requires the user to line up the different iaranges correctly via the dim arg.
  • Tests: Added a couple of tests to test_valid_models. Also modified one of the gradient tests to use the broadcasting logic.

Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this idempotent? in the sense that if i wrap my models in poutine.broadcast but then dont nest/use broadcasting in my iaranges should everything still work?

pyro/util.py Outdated
@@ -218,6 +218,7 @@ def check_site_shape(site, max_iarange_nesting):
if max_iarange_nesting < len(actual_shape):
actual_shape = actual_shape[len(actual_shape) - max_iarange_nesting:]

expected_shape = broadcast_shape(expected_shape, actual_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do strict checking here iiuc that the flag toggles allowing reshaping

@neerajprad
Copy link
Member Author

is this idempotent? in the sense that if i wrap my models in poutine.broadcast but then dont nest/use broadcasting in my iaranges should everything still work?

It should work; but this is worth adding as a test.

@neerajprad
Copy link
Member Author

Removed generic broadcasting, and added test for idempotence.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @neerajprad, this is surprisingly simple!

`BroadcastMessenger` automatically broadcasts the batch shape of
the stochastic function at a sample site when inside a single
or nested iarange context. The existing `batch_shape` must be
broadcastable with the size of the :class::`pyro.iarange`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

- :class::`pyro.iarange`
+ :class:`pyro.iarange`

dist = msg["fn"]
actual_batch_shape = getattr(dist, "batch_shape", None)
if actual_batch_shape is not None:
target_batch_shape = []
Copy link
Member

@fritzo fritzo May 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think this could be a little stricter and a little simpler if you used -1 sizes, something like

target_batch_shape = [-1 if size == 1 else size for size in actual_batch_shape]
for f in msg["cond_indep_stack"]:
    if f.dim is None:
        continue
    assert f.dim < 0
    if -f.dim > len(target_batch_shape):
        target_batch_shape = [-1] * (-f.dim - len(target_batch_shape)) + target_batch_shape
    elif target_batch_shape[f.dim] not in (-1, f.size):
       raise ValueError("... dim collision ...")
    target_batch_shape[f.dim] = f.size
msg["fn"] = msg["fn"].expand(target_batch_shape)

Copy link
Member Author

@neerajprad neerajprad May 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem that I was facing was with expanding smaller-sized latent sites when the iarange dims appear in a staggered fashion, which would cut off the broadcasting starting from the -1 index :

with pyro.iarange("num_particles", 10, dim=-3):
    with pyro.iarange("components", 2, dim=-1):
         # with .expand([10, -1, 2]), we get s.shape == (10,)
         # with .expand([10, 1, 2]), we get s.shape == (10, 1, 2)
         s = pyro.sample("sample", dist.Bernoulli(0.5))    
    with pyro.iarange("data", 100, dim=-2):
         # Note that we need s to have shape (10, 1, 2) here to correctly 
         # expand to (10, 100, 2)
        ...

Copy link
Member

@fritzo fritzo May 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the snippet I suggested should be order invariant. Note the final .expand() statement is outside of the loop. Am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated the snippet to clarify. The problem comes in the sample site in the second iarange where the outermost iarange is at -3 and the inner components one is at -1 (the data iarange is yet to come). If we expand s as dist.Bernoulli(0.5).expand([10, -1, 2]) (note the default -1) it will give us s.shape == torch.Size((10,)) whereas we want it to be of torch.Size((10, 1, 2)) so that it can be correctly broadcasted by the data iarange.

Copy link
Member

@fritzo fritzo May 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, I didn't know that -1 cannot be used for new dimensions in torch.expand(). Could we still try to catch expand errors early, while the iarange.name is still around? Maybe

target_batch_shape = [None if size == 1 else size for size in actual_batch_shape]
for f in msg["cond_indep_stack"]:
    if f.dim is None:
        continue
    assert f.dim < 0
    target_batch_shape = [None] * (-f.dim - len(target_batch_shape)) + target_batch_shape
    if target_batch_shape[f.dim] not in (None, f.size):
        raise ValueError("Shape mismatch inside iarange('{}') at site {} dim {}, {} vs {}".format(
            f.name, msg['name'], f.dim, f.size, target_batch_shape[dim]))
    target_batch_shape[f.dim] = f.size
# ... remainder of your code ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point; will update.

pyro.sample("nuisance_b", Normal(2, 3))
pyro.sample("nuisance_a", Normal(0, 1))

optim = Adam({"lr": 0.1})
model, guide = poutine.broadcast(model), poutine.broadcast(guide)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it's a little clearer to decorate

@poutine.broadcast
def model():
    ...

@poutine.broadcast
def guide():
    ...

That way readers know as soon as they start reading the model that it should be read with broadcast semantics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty neat!

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turned out nice - you should advertise it a little! Maybe add a usage example to the broadcast docstring and a note in the last section of our tensor shape tutorial?

def guide():
with pyro.iarange("num_particles", 10, dim=-3):
with pyro.iarange("components", 2, dim=-1):
pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put this or something similar in the handlers.broadcast docstring as a usage example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update.

@neerajprad
Copy link
Member Author

Maybe add a usage example to the broadcast docstring and a note in the last section of our tensor shape tutorial?

Will add to the docstring. I wasn't sure if we should add this to the tutorial yet, because things may change or we may uncover some edge cases, as we discuss and expand our broadcasting semantics to handle different use cases. Let me create a task to update our tutorial once things are stabilized.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@neerajprad neerajprad removed the WIP label May 8, 2018
@neerajprad
Copy link
Member Author

This should be good to merge, unless there are further comments.

@fritzo fritzo merged commit 413e2f2 into pyro-ppl:dev May 9, 2018
neerajprad added a commit to neerajprad/pyro that referenced this pull request May 17, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants