Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing PyroModuleList, because torch.nn.ModueList reinitializies modules when slice-indexing #3339

Merged
merged 8 commits into from
Mar 17, 2024
9 changes: 8 additions & 1 deletion pyro/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
MaskedLinear,
)
from pyro.nn.dense_nn import ConditionalDenseNN, DenseNN
from pyro.nn.module import PyroModule, PyroParam, PyroSample, pyro_method
from pyro.nn.module import (
PyroModule,
PyroModuleList,
PyroParam,
PyroSample,
pyro_method,
)

__all__ = [
"AutoRegressiveNN",
Expand All @@ -21,4 +27,5 @@
"PyroParam",
"PyroSample",
"pyro_method",
"PyroModuleList",
]
43 changes: 43 additions & 0 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@
"""
import functools
import inspect
import warnings
import weakref

try:
from torch._jit_internal import _copy_to_script_wrapper
except ImportError:
warnings.warn(
"Cannot find torch._jit_internal._copy_to_script_wrapper", ImportWarning
)

# Fall back to trivial decorator.
def _copy_to_script_wrapper(fn):
return fn


from collections import OrderedDict
from dataclasses import dataclass
from types import TracebackType
Expand Down Expand Up @@ -902,3 +916,32 @@ def __set__(self, obj: object, value: Any) -> None:


PyroModule[torch.nn.RNNBase]._flat_weights = _FlatWeightsDescriptor() # type: ignore[attr-defined]


# pyro module list
# using pyro.nn.PyroModule[torch.nn.ModuleList] can cause issues when
# slice-indexing nested PyroModuleLists, so we define a separate PyroModuleList
# class that overwrites the __getitem__ method to return a torch.nn.ModuleList
# to not use self.__class__ in __getitem__, as that would call the
# PyroModule.__init__ without the parent module context, leading to a loss
# of the parent module's _pyro_name, and eventually, errors during sampling
# as parameter names may not be unique anymore
# The scenario is rare but happend.
# The fix could not be applied in torch directly, which is why we have to deal
# with it here, see https://github.com/pytorch/pytorch/issues/121008
class PyroModuleList(torch.nn.ModuleList, PyroModule):
def __init__(self, modules):
super().__init__(modules)

@_copy_to_script_wrapper
def __getitem__(
self, idx: Union[int, slice]
) -> Union[torch.nn.Module, "PyroModuleList"]:
if isinstance(idx, slice):
# return self.__class__(list(self._modules.values())[idx])
return torch.nn.ModuleList(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]


_PyroModuleMeta._pyro_mixin_cache[torch.nn.ModuleList] = PyroModuleList
201 changes: 201 additions & 0 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import io
import math
import warnings
from typing import Callable, Iterable

import pytest
import torch
Expand All @@ -13,6 +15,7 @@
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide.guides import AutoDiagonalNormal
from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_
from pyro.optim import Adam
from tests.common import assert_equal, xfail_param
Expand Down Expand Up @@ -844,3 +847,201 @@ def forward(self, x, y):
grad_params_func[k], torch.zeros_like(grad_params_func[k])
), k
assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k


class BNN(PyroModule):
# this is a vanilla Bayesian neural network implementation, nothing new or exiting here
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__()

layer_sizes = (
[(input_size, hidden_layer_sizes[0])]
+ list(zip(hidden_layer_sizes[:-1], hidden_layer_sizes[1:]))
+ [(hidden_layer_sizes[-1], output_size)]
)

layers = [
pyro.nn.module.PyroModule[torch.nn.Linear](in_size, out_size)
for in_size, out_size in layer_sizes
]
if use_new_module_list_type:
self.layers = pyro.nn.module.PyroModuleList(layers)
else:
self.layers = pyro.nn.module.PyroModule[torch.nn.ModuleList](layers)

# make the layers Bayesian
for layer_idx, layer in enumerate(self.layers):
layer.weight = pyro.nn.module.PyroSample(
dist.Normal(0.0, 5.0 * math.sqrt(2 / layer_sizes[layer_idx][0]))
.expand(
[
layer_sizes[layer_idx][1],
layer_sizes[layer_idx][0],
]
)
.to_event(2)
)
layer.bias = pyro.nn.module.PyroSample(
dist.Normal(0.0, 5.0).expand([layer_sizes[layer_idx][1]]).to_event(1)
)

self.activation = torch.nn.Tanh()
self.output_size = output_size

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
mean = self.layers[-1](x)

if obs is not None:
with pyro.plate("data", x.shape[0]):
pyro.sample(
"obs", dist.Normal(mean, 0.1).to_event(self.output_size), obs=obs
)

return mean


class SliceIndexingModuleListBNN(BNN):
# I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used
# when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList]
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__(
input_size, hidden_layer_sizes, output_size, use_new_module_list_type
)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
for layer in self.layers[:-1]:
x = layer(x)
x = self.activation(x)

return super().forward(x, obs=obs)


class PositionIndexingModuleListBNN(BNN):
# I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used
# when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList]
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__(
input_size, hidden_layer_sizes, output_size, use_new_module_list_type
)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
for i in range(len(self.layers) - 1):
x = self.layers[i](x)
x = self.activation(x)

return super().forward(x, obs=obs)


class NestedBNN(pyro.nn.module.PyroModule):
# finally, the issue I want to describe occurs after the second "layer of nesting",
# i.e. when a PyroModule[ModuleList] is wrapped in a PyroModule[ModuleList]
def __init__(self, bnns: Iterable[BNN], use_new_module_list_type: bool) -> None:
super().__init__()
if use_new_module_list_type:
self.bnns = pyro.nn.module.PyroModuleList(bnns)
else:
self.bnns = pyro.nn.module.PyroModule[torch.nn.ModuleList](bnns)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
mean = sum([bnn(x) for bnn in self.bnns]) / len(self.bnns)

with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Normal(mean, 0.1).to_event(1), obs=obs)

return mean


def train_bnn(model: BNN, input_size: int) -> None:
pyro.clear_param_store()

# small numbers for demo purposes
num_points = 20
num_svi_iterations = 100

x = torch.linspace(0, 1, num_points).reshape((-1, input_size))
y = torch.sin(2 * math.pi * x) + torch.randn(x.size()) * 0.1

guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

for _ in range(num_svi_iterations):
svi.step(x, y)


class ModuleListTester:
def setup(self, use_new_module_list_type: bool) -> None:
self.input_size = 1
self.output_size = 1
self.hidden_size = 3
self.num_hidden_layers = 3
self.use_new_module_list_type = use_new_module_list_type

def get_position_indexing_modulelist_bnn(self) -> PositionIndexingModuleListBNN:
return PositionIndexingModuleListBNN(
self.input_size,
[self.hidden_size] * self.num_hidden_layers,
self.output_size,
self.use_new_module_list_type,
)

def get_slice_indexing_modulelist_bnn(self) -> SliceIndexingModuleListBNN:
return SliceIndexingModuleListBNN(
self.input_size,
[self.hidden_size] * self.num_hidden_layers,
self.output_size,
self.use_new_module_list_type,
)

def train_nested_bnn(self, module_getter: Callable[[], BNN]) -> None:
train_bnn(
NestedBNN(
[module_getter() for _ in range(2)],
use_new_module_list_type=self.use_new_module_list_type,
),
self.input_size,
)


class TestTorchModuleList(ModuleListTester):
def test_with_position_indexing(self) -> None:
self.setup(False)
self.train_nested_bnn(self.get_position_indexing_modulelist_bnn)

def test_with_slice_indexing(self) -> None:
self.setup(False)
# with pytest.raises(RuntimeError):
# error no longer gets raised
self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn)


class TestPyroModuleList(ModuleListTester):
def test_with_position_indexing(self) -> None:
self.setup(True)
self.train_nested_bnn(self.get_position_indexing_modulelist_bnn)

def test_with_slice_indexing(self) -> None:
self.setup(True)
self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn)


def test_module_list() -> None:
assert PyroModule[torch.nn.ModuleList] is pyro.nn.PyroModuleList
Loading