Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RenyiELBO no longer compatible with sequential plates? [question] [documentation] #2361

Closed
sharrison5 opened this issue Mar 10, 2020 · 6 comments
Labels

Comments

@sharrison5
Copy link
Contributor

sharrison5 commented Mar 10, 2020

Issue Description

In Pyro 1.3.0, the behaviour of RenyiELBO changed. In particular, it now seems to be incompatible with sequential plates. See MWE below.

Questions:

  • Is there a way to keep using sequential plates? The plate I'm using can't naturally be replaced by a vectorised version (unlike in the example).
  • If a model/guide contains plates, is it safe to use RenyiELBO in Pyro < 1.3.0 (i.e. what was supported before Support plates in RenyiELBO #2321)?

Many thanks! 😄

Code Snippet

import sys, torch, pyro
print(sys.version, torch.__version__, pyro.__version__)

def model(data):
    x = pyro.sample('x', pyro.distributions.Bernoulli(torch.tensor(0.5)))
    for i in pyro.plate('data_plate', len(data)):
        pyro.sample('data_{:d}'.format(i), pyro.distributions.Normal(x, scale=torch.tensor(0.1)), obs=data[i])

# Works as expected
# def model(data):
#     x = pyro.sample('x', pyro.distributions.Bernoulli(torch.tensor(0.5)))
#     with pyro.plate('data_plate', len(data)):
#         pyro.sample('data', pyro.distributions.Normal(x, scale=torch.tensor(0.1)), obs=data)

def guide(data):
    p = pyro.param('p', torch.tensor(0.5))
    x = pyro.sample('x', pyro.distributions.Bernoulli(p))

elbo = pyro.infer.TraceGraph_ELBO()
print(elbo.loss(model, guide, torch.randn([3])))

elbo = pyro.infer.RenyiELBO()
print(elbo.loss(model, guide, torch.randn([3])))

Output:

3.8.1 (default, Jan  8 2020, 16:15:59) 
[Clang 4.0.1 (tags/RELEASE_401/final)] 1.4.0 1.3.0
2.4235687255859375
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-6865281f40fd> in <module>
     15 
     16 elbo = pyro.infer.RenyiELBO()
---> 17 print(elbo.loss(model, guide, torch.randn([3])))

~/miniconda3/envs/pDCM-dev/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_no_grad(*args, **kwargs)
     47         def decorate_no_grad(*args, **kwargs):
     48             with self:
---> 49                 return func(*args, **kwargs)
     50         return decorate_no_grad
     51 

~/miniconda3/envs/pDCM-dev/lib/python3.8/site-packages/pyro/infer/renyi_elbo.py in loss(self, model, guide, *args, **kwargs)
     96         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
     97             elbo_particle = 0.
---> 98             sum_dims = get_dependent_plate_dims(model_trace.nodes.values())
     99 
    100             # compute elbo

~/miniconda3/envs/pDCM-dev/lib/python3.8/site-packages/pyro/infer/util.py in get_dependent_plate_dims(sites)
    114     common_plates = all_plates.intersection(*plate_sets)
    115     sum_plates = all_plates - common_plates
--> 116     sum_dims = list(sorted(f.dim for f in sum_plates))
    117     return sum_dims
    118 

TypeError: '<' not supported between instances of 'NoneType' and 'NoneType'
@fritzo fritzo added the bug label Mar 10, 2020
@iffsid
Copy link
Contributor

iffsid commented Jun 26, 2020

@fritzo I believe this is a valid fix?

    sum_dims = list(sorted(f.dim for f in sum_plates if f.dim is not None))

https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/util.py#L116
The issue appears to be that the cond_indep_stack does not have dims when the plate is unrolled manually.

I've just got around to putting-in the plating fix (#2321) into rws, and it looks like this fix works for a simple sequential plate situation as above.

Any reason why this might be incorrect?

@fehiepsi
Copy link
Member

@iffsid Good catch! Thanks for debugging this issue! Could you submit a PR so that we can see if all other tests pass with the change?

iffsid added a commit to iffsid/pyro that referenced this issue Jun 27, 2020
fritzo pushed a commit that referenced this issue Jun 28, 2020
* fix summation with plates when manually unrolling

issue: #2361

* add test for manual unrolling of plates

* drop unnecessary assertions and interations

 - only need to ensure that the model runs without crashing
@fritzo
Copy link
Member

fritzo commented Jun 28, 2020

Fixed by #2541

@fritzo fritzo closed this as completed Jun 28, 2020
@sharrison5
Copy link
Contributor Author

Thanks for addressing this in 1.4.0! However, I'm seeing another issue with RenyiELBO and non-nested plates (see MWE below). My hunch is that the issue arises because x_plate and data_plate_1 both have the same dim, so sum_dims can have repeated elements. However, the error could well be coming from elsewhere.

Thanks again! 😃

Code snippet:

import sys, torch, pyro

print(sys.version, torch.__version__, pyro.__version__)

# Generative model: data = x @ weights + eps
def model(data, weights):
    loc = torch.tensor(1.0)
    scale = torch.tensor(0.1)

    # Sample latents (shares no dimensions with data)
    with pyro.plate('x_plate', weights.shape[0]):
        x = pyro.sample('x', pyro.distributions.Normal(loc, scale))

    # Combine with weights and sample
    with pyro.plate('data_plate_1', data.shape[-1]):
        with pyro.plate('data_plate_2', data.shape[-2]):
            pyro.sample('data', pyro.distributions.Normal(x @ weights, scale), obs=data)

    return

def guide(data, weights):
    loc = pyro.param('x_loc', torch.tensor(0.5))
    scale = torch.tensor(0.1)

    with pyro.plate('x_plate', weights.shape[0]):
        x = pyro.sample('x', pyro.distributions.Normal(loc, scale))

    return

# Works with other ELBO
loss = pyro.infer.TraceGraph_ELBO(num_particles=3, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), loss)
svi.step(torch.randn([5, 3]), torch.randn([2, 3]))

# Dim colision with RenyiELBO                                                                 
loss = pyro.infer.RenyiELBO(num_particles=7, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), loss)
svi.step(torch.randn([5, 3]), torch.randn([2, 3]))

Output:

3.8.5 | packaged by conda-forge | (default, Jul 24 2020, 01:06:20) 
[Clang 10.0.1 ] 1.5.1 1.4.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/Code/dPNM/pyro_test.py in <module>
     40 loss = pyro.infer.RenyiELBO(num_particles=7, vectorize_particles=True)
     41 svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), loss)
---> 42 svi.step(torch.randn([5, 3]), torch.randn([2, 3]))

~/miniconda3/envs/dPNM-dev/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    126         # get loss and compute gradients
    127         with poutine.trace(param_only=True) as param_capture:
--> 128             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    129 
    130         params = set(site["value"].unconstrained()

~/miniconda3/envs/dPNM-dev/lib/python3.8/site-packages/pyro/infer/renyi_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    207             surrogate_elbo = (normalized_weights * surrogate_elbo_particles).sum() / self.num_particles
    208             surrogate_loss = -surrogate_elbo
--> 209             surrogate_loss.backward()
    210         loss = -elbo
    211         warn_if_nan(loss, "loss")

~/miniconda3/envs/dPNM-dev/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/miniconda3/envs/dPNM-dev/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     96         retain_graph = create_graph
     97 
---> 98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
    100         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: dim 2 appears multiple times in the list of dims (dim_list_to_bitset at /Users/distiller/project/conda/conda-bld/pytorch_1591914893314/work/aten/src/ATen/WrapDimUtilsMulti.h:20)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x11996cab7 in libc10.dylib)
frame #1: at::dim_list_to_bitset(c10::ArrayRef<long long>, long long) + 492 (0x11e897bdc in libtorch_cpu.dylib)
frame #2: torch::autograd::generated::(anonymous namespace)::unsqueeze_multiple(at::Tensor const&, c10::ArrayRef<long long>, unsigned long) + 40 (0x11e8977f8 in libtorch_cpu.dylib)
frame #3: torch::autograd::generated::(anonymous namespace)::sum_backward(at::Tensor const&, c10::ArrayRef<long long>, c10::ArrayRef<long long>, bool) + 442 (0x11e7d236a in libtorch_cpu.dylib)
frame #4: torch::autograd::generated::SumBackward1::apply(std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> >&&) + 121 (0x11e7ff389 in libtorch_cpu.dylib)
frame #5: torch::autograd::Node::operator()(std::__1::vector<at::Tensor, std::__1::allocator<at::Tensor> >&&) + 658 (0x11efe37c2 in libtorch_cpu.dylib)
frame #6: torch::autograd::Engine::evaluate_function(std::__1::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 1408 (0x11efd9f90 in libtorch_cpu.dylib)
frame #7: torch::autograd::Engine::thread_main(std::__1::shared_ptr<torch::autograd::GraphTask> const&, bool) + 497 (0x11efd91a1 in libtorch_cpu.dylib)
frame #8: torch::autograd::Engine::thread_init(int) + 152 (0x11efd8f38 in libtorch_cpu.dylib)
frame #9: torch::autograd::python::PythonEngine::thread_init(int) + 52 (0x11bb26ee4 in libtorch_python.dylib)
frame #10: void* std::__1::__thread_proxy<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct> >, void (torch::autograd::Engine::*)(int), torch::autograd::Engine*, int> >(void*) + 66 (0x11efe8522 in libtorch_cpu.dylib)
frame #11: _pthread_body + 126 (0x7fff769bc2eb in libsystem_pthread.dylib)
frame #12: _pthread_start + 66 (0x7fff769bf249 in libsystem_pthread.dylib)
frame #13: thread_start + 13 (0x7fff769bb40d in libsystem_pthread.dylib)

@fritzo fritzo reopened this Jul 31, 2020
sharrison5 added a commit to sharrison5/pyro that referenced this issue Aug 5, 2020
From issue pyro-ppl#2361. Remove repeated dimensions in
`get_dependent_plate_dims()`, which makes sure the output can be passed
to e.g. `torch.sum()`.
  Minor change: `list(sorted())` simplified as `sorted` returns a `list`
directly.
sharrison5 added a commit to sharrison5/pyro that referenced this issue Aug 5, 2020
Based on example in pyro-ppl#2361, and follows format of similar test introduced
by pyro-ppl#2541.
fritzo pushed a commit that referenced this issue Aug 6, 2020
* Fix summation over non-nested plates sharing dimensions

From issue #2361. Remove repeated dimensions in
`get_dependent_plate_dims()`, which makes sure the output can be passed
to e.g. `torch.sum()`.
  Minor change: `list(sorted())` simplified as `sorted` returns a `list`
directly.

* Add regression test

Based on example in #2361, and follows format of similar test introduced
by #2541.

Co-authored-by: Sam Harrison <harrison@biomed.ee.ethz.ch>
@fritzo
Copy link
Member

fritzo commented Aug 6, 2020

Is this now resolved?

@sharrison5
Copy link
Contributor Author

Yes, I think so

@fritzo fritzo closed this as completed Aug 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants