diff --git a/torchcfm/models/unet/unet.py b/torchcfm/models/unet/unet.py index e29df92..205ecab 100644 --- a/torchcfm/models/unet/unet.py +++ b/torchcfm/models/unet/unet.py @@ -270,7 +270,7 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) def _forward(self, x): b, c, *spatial = x.shape