Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Bump torch version to 1.9.0 #226

Merged
merged 12 commits into from
Oct 8, 2021
7 changes: 1 addition & 6 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ jobs:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
pytorch-version: [1.6.0, 1.7.1, 1.8.0, 1.9.0]
exclude:
- pytorch-version: 1.6.0
python-version: 3.9
- pytorch-version: 1.7.1
python-version: 3.9
pytorch-version: [1.9.0, 1.9.1]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
Expand Down
22 changes: 3 additions & 19 deletions backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
from typing import Callable, Optional, Tuple, Type, Union

from torch import Tensor, is_grad_enabled
from torch.fx import GraphModule
from torch.nn import Module

from backpack import extensions
from backpack.context import CTX
from backpack.custom_module.graph_utils import convert_module_to_backpack
from backpack.extensions.backprop_extension import BackpropExtension
from backpack.utils import CONVERTER_AVAILABLE, FULL_BACKWARD_HOOK
from backpack.utils.hooks import no_op
from backpack.utils.module_classification import is_no_op

if CONVERTER_AVAILABLE:
from torch.fx import GraphModule

from backpack.custom_module.graph_utils import convert_module_to_backpack


class backpack:
"""Context manager to activate BackPACK extensions."""
Expand Down Expand Up @@ -244,17 +240,11 @@ def extend(module: Module, debug: bool = False, use_converter: bool = False) ->

Returns:
Extended module.

Raises:
RuntimeError: if trying to use converter without torch>=1.9.0
"""
if debug:
print("[DEBUG] Extending", module)

if use_converter:
if not CONVERTER_AVAILABLE:
raise RuntimeError("use_converter=True is only available for torch>=1.9.0.")

module: GraphModule = convert_module_to_backpack(module, debug)
return extend(module)

Expand All @@ -277,10 +267,4 @@ def _register_hooks(module: Module) -> None:
module: module that is going to be extended
"""
CTX.add_hook_handle(module.register_forward_hook(hook_store_io))

if FULL_BACKWARD_HOOK:
register_backward_hook_fn = module.register_full_backward_hook
else:
register_backward_hook_fn = module.register_backward_hook

CTX.add_hook_handle(register_backward_hook_fn(hook_run_extensions))
CTX.add_hook_handle(module.register_full_backward_hook(hook_run_extensions))
12 changes: 1 addition & 11 deletions backpack/core/derivatives/batchnorm_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0
from backpack.utils.subsampling import subsample


Expand Down Expand Up @@ -148,16 +147,7 @@ def _weight_jac_t_mat_prod(
x_hat, _ = self._get_normalized_input_and_var(module)
x_hat = subsample(x_hat, subsampling=subsampling)

if TORCH_VERSION_AT_LEAST_1_9_0:
equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c"
# TODO Remove else-branch after deprecating torch<1.9.0
else:
N: int = self._get_n_axis(module)
spatial_dims = "xyz"[:N]
equation = (
f"vnc{spatial_dims},nc{spatial_dims}->v{'' if sum_batch else 'n'}c"
)

equation = f"vnc...,nc...->v{'' if sum_batch else 'n'}c"
return einsum(equation, mat, x_hat)

def _bias_jac_mat_prod(
Expand Down
31 changes: 0 additions & 31 deletions backpack/core/derivatives/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def hessian_diagonal(self, module, g_inp, g_out):
g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
"""
self._no_inplace(module)

return self.d2f(module, g_inp, g_out) * g_out[0]

def hessian_is_diagonal(self, module):
Expand All @@ -89,46 +87,17 @@ def _jac_t_mat_prod(
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
self._no_inplace(module)

df_elementwise = self.df(module, g_inp, g_out, subsampling=subsampling)
return einsum("...,v...->v...", df_elementwise, mat)

def _jac_mat_prod(self, module, g_inp, g_out, mat):
self._no_inplace(module)

return self.jac_t_mat_prod(module, g_inp, g_out, mat)

def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
self._no_inplace(module)

N = module.input0.size(0)
df_flat = self.df(module, g_inp, g_out).reshape(N, -1)
return einsum("ni,nj,ij->ij", df_flat, df_flat, mat) / N

def _residual_mat_prod(self, module, g_inp, g_out, mat):
residual = self.d2f(module, g_inp, g_out) * g_out[0]
return einsum("...,v...->v...", residual, mat)

# TODO Deprecate after supporting torch >= 1.8.0 and full_backward_hook
@staticmethod
def _no_inplace(module: Module):
"""Do not support inplace modification.

Jacobians/Hessians might be computed using the modified input instead
of the original.

Args:
module: Elementwise activation module.

Raises:
NotImplementedError: If `module` has inplace option enabled.

Todo:
- Write tests to investigate what happens with `inplace=True`.
"""
has_inplace_option = hasattr(module, "inplace")

if has_inplace_option:
if module.inplace is True:
raise NotImplementedError("Inplace not supported in {}.".format(module))
10 changes: 1 addition & 9 deletions backpack/core/derivatives/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import Embedding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0
from backpack.utils.subsampling import subsample


Expand Down Expand Up @@ -50,14 +49,7 @@ def _weight_jac_t_mat_prod(
delta = zeros(module.num_embeddings, *input0.shape, device=mat.device)
for s in range(module.num_embeddings):
delta[s] = input0 == s
if TORCH_VERSION_AT_LEAST_1_9_0:
equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh"
elif delta.dim() == 2:
equation = f"sn,vnh->v{'' if sum_batch else 'n'}sh"
else:
equation = f"snx,vnxh->v{'' if sum_batch else 'n'}sh"
delta = delta.flatten(start_dim=2, end_dim=-1)
mat = mat.flatten(start_dim=2, end_dim=-2)
equation = f"sn...,vn...h->v{'' if sum_batch else 'n'}sh"
return einsum(equation, delta, mat)

def _check_parameters(self, module: Embedding) -> None:
Expand Down
29 changes: 1 addition & 28 deletions backpack/core/derivatives/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import Linear

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0
from backpack.utils.subsampling import subsample


Expand Down Expand Up @@ -152,17 +151,7 @@ def _weight_jac_t_mat_prod(
"""
d_weight = subsample(module.input0, subsampling=subsampling)

if TORCH_VERSION_AT_LEAST_1_9_0:
equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi"
# TODO Remove else-branch after deprecating torch<1.9.0
else:
if self._has_additional_dims(module):
d_weight = d_weight.flatten(start_dim=1, end_dim=-2)
mat = mat.flatten(start_dim=2, end_dim=-2)
equation = f"vnao,nai->v{'' if sum_batch else 'n'}oi"
else:
equation = f"vno,ni->v{'' if sum_batch else 'n'}oi"

equation = f"vn...o,n...i->v{'' if sum_batch else 'n'}oi"
return einsum(equation, mat, d_weight)

def _bias_jac_mat_prod(
Expand Down Expand Up @@ -223,22 +212,6 @@ def _bias_jac_t_mat_prod(
equation = f"vn...o->v{'' if sum_batch else 'n'}o"
return einsum(equation, mat)

# TODO Remove after deprecating torch<1.9.0
@classmethod
def _has_additional_dims(cls, module: Linear) -> bool:
"""Return whether the input to a linear layer has additional (>1) dimensions.

The input to a linear layer may have shape ``[N, *, out_features]``.
It has additional dimensions if ``*`` is non-empty.

Args:
module: Linear layer.

Returns:
Whether the input has hidden dimensions.
"""
return len(cls._get_additional_dims(module)) != 0

@staticmethod
def _get_additional_dims(module: Linear) -> Size:
"""Return the shape of additional dimensions in the input to a linear layer.
Expand Down
6 changes: 2 additions & 4 deletions backpack/core/derivatives/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import LSTM

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION_AT_LEAST_1_8_0
from backpack.utils.subsampling import subsample


Expand Down Expand Up @@ -51,9 +50,8 @@ def _check_parameters(module: LSTM) -> None:
raise NotImplementedError("only dropout = 0 is supported")
if module.bidirectional is not False:
raise NotImplementedError("only bidirectional = False is supported")
if TORCH_VERSION_AT_LEAST_1_8_0:
if module.proj_size != 0:
raise NotImplementedError("only proj_size = 0 is supported")
if module.proj_size != 0:
raise NotImplementedError("only proj_size = 0 is supported")

@staticmethod
def _forward_pass(
Expand Down
4 changes: 2 additions & 2 deletions backpack/core/derivatives/scale_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Derivatives of ScaleModule (implies ActiveIdentity and Identity)."""
"""Derivatives of ScaleModule (implies Identity)."""
from typing import List, Tuple, Union

from torch import Tensor
Expand All @@ -9,7 +9,7 @@


class ScaleModuleDerivatives(BaseDerivatives):
"""Derivatives of ScaleModule (implies ActiveIdentity and Identity)."""
"""Derivatives of ScaleModule (implies Identity)."""

def _jac_t_mat_prod(
self,
Expand Down
10 changes: 0 additions & 10 deletions backpack/custom_module/branching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@
from torch import Tensor
from torch.nn import Module

from backpack.custom_module.scale_module import ScaleModule


class ActiveIdentity(ScaleModule):
"""Like ``torch.nn.Identity``, but creates a new node in the computation graph."""

def __init__(self):
"""Initialization with weight=1.0."""
super().__init__(weight=1.0)


class _Branch(Module):
"""Module used by BackPACK to handle branching in the computation graph.
Expand Down
15 changes: 2 additions & 13 deletions backpack/custom_module/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from torch.fx import Graph, GraphModule, Node, Tracer
from torch.nn import LSTM, RNN, Dropout, Flatten, Module, Sequential

from backpack.custom_module.branching import ActiveIdentity, SumModule, _Branch
from backpack.custom_module.branching import SumModule, _Branch
from backpack.custom_module.permute import Permute
from backpack.custom_module.reduce_tuple import ReduceTuple
from backpack.custom_module.scale_module import ScaleModule
from backpack.utils import TORCH_VERSION_AT_LEAST_1_9_0


class BackpackTracer(Tracer):
Expand All @@ -19,9 +18,7 @@ class BackpackTracer(Tracer):
def is_leaf_module(
self, m: Module, module_qualified_name: str
) -> bool: # noqa: D102
if isinstance(
m, (ScaleModule, SumModule, _Branch, ActiveIdentity, ReduceTuple, Permute)
):
if isinstance(m, (ScaleModule, SumModule, _Branch, ReduceTuple, Permute)):
return True
else:
return super().is_leaf_module(m, module_qualified_name)
Expand Down Expand Up @@ -49,15 +46,7 @@ def convert_module_to_backpack(module: Module, debug: bool) -> GraphModule:

Returns:
BackPACK-compatible module

Raises:
NotImplementedError: if not torch >= 1.9.0
"""
if TORCH_VERSION_AT_LEAST_1_9_0 is False:
raise NotImplementedError(
"Conversion is only possible for torch >= 1.9.0. This is because these "
"functions use functionality such as torch.nn.Module.get_submodule"
)
if debug:
print("\nMake module BackPACK-compatible...")
module_new = _transform_mul_to_scale_module(module, debug)
Expand Down
24 changes: 5 additions & 19 deletions backpack/extensions/module_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from warnings import warn

from torch import Tensor
from torch.nn import Flatten, Module
from torch.nn import Module

from backpack.utils import FULL_BACKWARD_HOOK
from backpack.utils.module_classification import is_loss

if TYPE_CHECKING:
Expand Down Expand Up @@ -104,23 +103,10 @@ def __call__(
and bp_quantity is None
and not is_loss(module)
):
if not FULL_BACKWARD_HOOK and isinstance(module, Flatten):
# Flatten layers whose input is already flat do not add a node to the
# graph. This leads to unintuitive order of backward hook execution:
# https://discuss.pytorch.org/t/backward-hooks-changing-order-of-execution-in-nn-sequential/12447/4. # noqa: B950
# Skip everything below if this scenario is encountered.
no_op = module.input0.shape == module.output.shape
if not no_op:
raise AssertionError(
"Expected no op Flatten module. Got "
+ f"{module.input0.shape} -> {module.output.shape}"
)
return
else:
raise AssertionError(
"BackPACK extension expects a backpropagation quantity but it is None. "
f"Module: {module}, Extension: {extension}."
)
raise AssertionError(
"BackPACK extension expects a backpropagation quantity but it is None. "
f"Module: {module}, Extension: {extension}."
)

for param in self.__params:
if self.__param_exists_and_requires_grad(module, param):
Expand Down
4 changes: 1 addition & 3 deletions backpack/extensions/saved_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def retrieve_quantity(self, key: int, delete_old: bool) -> Union[Tensor, None]:
"""Returns the saved quantity.

Args:
key: data_ptr() of reference tensor.
For torch>=1.9.0 the reference tensor is grad_output[0].
For older versions the reference tensor is module.output.
key: data_ptr() of module.output.
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
delete_old: whether to delete the old quantity

Returns:
Expand Down
4 changes: 1 addition & 3 deletions backpack/extensions/secondorder/diag_ggn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
ZeroPad2d,
)

from backpack.custom_module.branching import ActiveIdentity, SumModule
from backpack.custom_module.branching import SumModule
from backpack.custom_module.permute import Permute
from backpack.custom_module.scale_module import ScaleModule
from backpack.extensions.secondorder.base import SecondOrderBackpropExtension
Expand Down Expand Up @@ -131,7 +131,6 @@ def __init__(self, loss_hessian_strategy: str, savefield: str):
ELU: activations.DiagGGNELU(),
SELU: activations.DiagGGNSELU(),
Identity: custom_module.DiagGGNScaleModule(),
ActiveIdentity: custom_module.DiagGGNScaleModule(),
ScaleModule: custom_module.DiagGGNScaleModule(),
SumModule: custom_module.DiagGGNSumModule(),
RNN: rnn.DiagGGNRNN(),
Expand Down Expand Up @@ -255,7 +254,6 @@ def __init__(self, loss_hessian_strategy: str, savefield: str):
ELU: activations.DiagGGNELU(),
SELU: activations.DiagGGNSELU(),
Identity: custom_module.DiagGGNScaleModule(),
ActiveIdentity: custom_module.DiagGGNScaleModule(),
ScaleModule: custom_module.DiagGGNScaleModule(),
SumModule: custom_module.DiagGGNSumModule(),
RNN: rnn.BatchDiagGGNRNN(),
Expand Down
Loading