Skip to content

Commit

Permalink
Update SAC to work with latest PyTorch (facebookresearch#1138)
Browse files Browse the repository at this point in the history
* Update SAC to work with latest PyTorch

This will break for older PyTorch though

* Lint
  • Loading branch information
fmassa authored Jun 24, 2024
1 parent 5d7c0de commit 44b8dd9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
10 changes: 5 additions & 5 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _all_policy(func, *args, **kwargs):
return True


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Only new PyTorch supported")
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("device", _devices)
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_checkpoint(policy_fn, input_requires_grad, device, autocast):
assert torch.allclose(p.grad, p_copy.grad)


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Only new PyTorch supported")
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("grad_mode", [True, False])
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode):
assert torch.allclose(out, out_copy)


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Only new PyTorch supported")
@cuda_only
@pytest.mark.parametrize("policy_fn", [None, [], _relu_policy, _all_policy])
@pytest.mark.parametrize("input_requires_grad", [True, False])
Expand Down Expand Up @@ -287,7 +287,7 @@ def forward(self, x):
return x


@pytest.mark.skipif(torch.__version__ < "2.2", reason="Only new PyTorch supported")
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Only new PyTorch supported")
@cuda_only
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("memory_budget", [0, 0.03, 0.05, 0.1, 0.3, 0.5, 0.8, 1.0])
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_optimal_checkpoint_policy(
torch.testing.assert_close(p.grad, p_ref.grad)


@pytest.mark.skipif(torch.__version__ < "2.3", reason="Only new PyTorch supported")
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Only new PyTorch supported")
@pytest.mark.skipif(True, reason="TODO[fmassa]: Broken on nightly")
@cuda_only
@pytest.mark.parametrize("no_grad", [False, True])
Expand Down
24 changes: 16 additions & 8 deletions xformers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,23 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
ActivationWrapper,
)
from torch.utils.checkpoint import SAC_IGNORED_OPS as _ignored_ops # type: ignore
from torch.utils.checkpoint import ( # type: ignore
CheckpointPolicy,
_CachedTorchDispatchMode,
_CachingTorchDispatchMode,
_ignored_ops,
)
except ImportError:
ActivationWrapper = torch.nn.Module # type: ignore

class _NotAvailable:
def __init__(self, *args, **kwargs):
raise RuntimeError("Need PyTorch >= 2.2")
raise RuntimeError("Need PyTorch > 2.4")

_CachedTorchDispatchMode = _NotAvailable # type: ignore
_CachingTorchDispatchMode = _NotAvailable # type: ignore
_ignored_ops = set() # type: ignore
CheckpointPolicy = None # type: ignore


_additional_ignored_ops = {
Expand Down Expand Up @@ -86,8 +88,11 @@ def _get_default_policy(allow_list=None):
if allow_list is None:
allow_list = _default_allow_list

def _default_policy(mode, func, *args, **kwargs):
return str(func) in allow_list
def _default_policy(ctx, func, *args, **kwargs):
store = str(func) in allow_list
return (
CheckpointPolicy.MUST_SAVE if store else CheckpointPolicy.PREFER_RECOMPUTE
)

return _default_policy

Expand Down Expand Up @@ -156,7 +161,7 @@ def selective_checkpoint_context_fn(policy_fn=None):
caching_mode = _CachingTorchDispatchMode(deepcopy(policy_fn), temp_storage)
else:
caching_mode = NullTorchDispatchMode()
cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage)
cached_mode = CachedTorchDispatchMode(deepcopy(policy_fn), temp_storage, True)

return caching_mode, cached_mode

Expand Down Expand Up @@ -452,13 +457,16 @@ def __init__(self, optim_output: torch.Tensor):
self.counter = 0
self.optim_output = optim_output.tolist()

def __call__(self, mode, func, *args, **kwargs) -> bool:
def __call__(self, ctx, func, *args, **kwargs) -> bool:
# returning False means recompute, True means store in memory
if func in OPS_TO_ALWAYS_SKIP:
return False
return CheckpointPolicy.PREFER_RECOMPUTE
count = self.counter
self.counter += 1
return self.optim_output[count] == 1
store = self.optim_output[count] == 1
return (
CheckpointPolicy.MUST_SAVE if store else CheckpointPolicy.PREFER_RECOMPUTE
)


class SelectiveCheckpointWrapper(ActivationWrapper):
Expand Down

0 comments on commit 44b8dd9

Please sign in to comment.