Skip to content

Commit

Permalink
Allow DataLoaderAdapter subclasses to be pickled by implementing `__r…
Browse files Browse the repository at this point in the history
…educe__` (#3074)

* initial fix for breaking accelerator pickling

* cleanup

* skip_first_batches should be used on raw dls

* multigpu sanity test

* bugs

* does this work with iterable dsets?

* fix typo

* ignore these commits, i'm just syncing the origin so i can test on my cloud workstation

* comment out failing tests, unsure if those are existing bugs or a recent regression

* torch 2.4.0?

* pickling generator issues

* test_pickle_accelerator

* test_pickle_accelerator should work now)

* base.__len__() -> len(base)

* undo reduce

* undo super().__reduce__() again

* pass args through superclass

* remove prints

* doc changes + make style && make quality
  • Loading branch information
byi8220 authored and muellerzr committed Sep 5, 2024
1 parent 73a1531 commit e13bef2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 20 deletions.
63 changes: 43 additions & 20 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,25 +416,6 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, *
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})

if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

Expand All @@ -451,6 +432,18 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)

@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
"""
return self.base_dataloader.__class__

def __len__(self):
return len(self.base_dataloader)

def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
Expand Down Expand Up @@ -580,6 +573,15 @@ def __iter__(self):
self.iteration += 1
self.end()

def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (DataLoaderShard, *args[1:])

def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
Expand Down Expand Up @@ -865,14 +867,23 @@ def set_epoch(self, epoch: int):
self.dataset.set_epoch(epoch)

def __len__(self):
whole_length = super().__len__()
whole_length = len(self.base_dataloader)
if self.split_batches:
return whole_length
elif self._drop_last:
return whole_length // self.state.num_processes
else:
return math.ceil(whole_length / self.state.num_processes)

def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (DataLoaderDispatcher, *args[1:])

@property
def total_batch_size(self):
return (
Expand Down Expand Up @@ -1211,6 +1222,18 @@ def __iter__(self):
yield batch
self.end()

def __len__(self):
return len(self.base_dataloader) - self.skip_batches

def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (SkipDataLoader, *args[1:])


def skip_first_batches(dataloader, num_batches=0):
"""
Expand Down
14 changes: 14 additions & 0 deletions src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import tempfile
import warnings
from typing import List
Expand Down Expand Up @@ -247,6 +248,16 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
assert "only supported for map-style datasets" in str(w[-1].message)


def test_pickle_accelerator():
accelerator = create_accelerator()
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
_ = accelerator.prepare(data_loader)
pickled_accelerator = pickle.dumps(accelerator)
unpickled_accelerator = pickle.loads(pickled_accelerator)
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__


def test_data_loader(data_loader, accelerator):
# Prepare the DataLoader
data_loader = accelerator.prepare(data_loader)
Expand Down Expand Up @@ -368,6 +379,9 @@ def main():
test_join_raises_warning_for_non_ddp_distributed(accelerator)
accelerator.state.distributed_type = original_state

accelerator.print("Test pickling an accelerator")
test_pickle_accelerator()

dataset = DummyDataset()
# Conventional Dataloader with shuffle=False
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
Expand Down Expand Up @@ -647,6 +648,52 @@ def test_can_unwrap_model(self):
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)

@parameterized.expand([True, False])
def test_can_pickle_dataloader(self, dispatch_batches):
"""
Test that pickling a prepared dataloader works.
"""
data = torch.arange(10).to(torch_device)
ds = torch.utils.data.TensorDataset(data)
dl = torch.utils.data.DataLoader(ds)
skip_dl = skip_first_batches(dl, 2)

# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
# TODO: Add support for pickling StatefulDataLoader
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
accelerator = Accelerator(dataloader_config=dataloader_config)

original_dl, _ = accelerator.prepare(dl, skip_dl)
if dispatch_batches:
assert isinstance(original_dl, DataLoaderDispatcher)
else:
assert isinstance(original_dl, DataLoaderShard)

prepared_model_dumps = pickle.dumps(accelerator)

model_loaded = pickle.loads(prepared_model_dumps)
assert len(model_loaded._dataloaders) == 2

# Assert equality of recovered and original dataloader
loaded_dl = model_loaded._dataloaders[0]
assert isinstance(loaded_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_dl) == len(original_dl)
assert [i for i in loaded_dl] == [i for i in original_dl]

# Test skip dataloader works as expected as well
loaded_skip_dl = model_loaded._dataloaders[1]
assert isinstance(loaded_skip_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_skip_dl) == len(original_dl) - 2
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]

# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
@require_torchdata_stateful_dataloader
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,14 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

# Test dataloaders are instances of base classes
assert isinstance(skip_dl, DataLoader)
assert isinstance(dl_shard, DataLoader)
assert isinstance(dl_dispatcher, DataLoader)
Expand Down Expand Up @@ -556,6 +564,13 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

assert isinstance(skip_dl, StatefulDataLoader)
assert isinstance(dl_shard, StatefulDataLoader)
assert isinstance(dl_dispatcher, StatefulDataLoader)
Expand Down

0 comments on commit e13bef2

Please sign in to comment.