Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add primitive list support #1906

Merged
merged 4 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def _num_microbatches_split_mapping(m, num_microbatches: int):
for k, v in m.items():
if isinstance(v, torch.Tensor):
chunked[k] = _num_microbatches_split_tensor(v, num_microbatches)
if isinstance(v, (List, Tuple)):
elif isinstance(v, (List, Tuple)):
chunked[k] = _num_microbatches_split_list(v, num_microbatches)
elif isinstance(v, (int, float, str, bool)):
# Broadcast primitives to all chunks
chunked[k] = [v] * num_microbatches
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f'Unsupported batch type: {type(v)}.')
num_chunks = len(list(chunked.values())[0])
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]

Expand All @@ -74,6 +79,9 @@ def _num_microbatches_split_batch(batch: Any, num_microbatches: int) -> Sequence
if isinstance(batch, Mapping): # check for dictionary (hf style)
return _num_microbatches_split_mapping(batch, num_microbatches)

if isinstance(batch, (Tuple, list)) and _check_list_is_primitives(batch): # check for list of primitives
return _num_microbatches_split_list(batch, num_microbatches)

mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(batch, (Tuple, List)): # check for batch on 2nd dimension
result = []
for item in batch:
Expand Down Expand Up @@ -117,12 +125,36 @@ def _split_mapping(m, microbatch_size: int):
for k, v in m.items():
if isinstance(v, torch.Tensor):
chunked[k] = _split_tensor(v, microbatch_size)
if isinstance(v, (List, Tuple)):
elif isinstance(v, (List, Tuple)):
chunked[k] = _split_list(v, microbatch_size)
num_chunks = len(list(chunked.values())[0])
elif isinstance(v, (int, float, str, bool)):
# Defer broadcasting primitives until we know num_chunks
pass
else:
raise ValueError(f'Unsupported batch type: {type(v)}.')
num_chunks = 1 # Default to 1 chunks if there are no tensors or everything is primitive
if len(chunked.keys()) != 0:
num_chunks = len(list(chunked.values())[0])
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
# Broadcast primitives to all chunks
for k, v in m.items():
if isinstance(v, (int, float, str, bool)):
chunked[k] = [v] * num_chunks
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]


def _check_list_is_primitives(l):
"""Checks if all elements in a list are the same primitive type."""
if len(l) == 0:
return True
first_type = type(l[0])
if not isinstance(l[0], (int, float, str, bool)):
return False
for item in l:
if type(item) != first_type:
return False
return True


def _default_split_batch(batch: Any, microbatch_size: int) -> Sequence:
"""Splits batch into chunks of size `microbatch_size` for gradient accumulation.

Expand All @@ -136,6 +168,8 @@ def _default_split_batch(batch: Any, microbatch_size: int) -> Sequence:
return _split_tensor(batch, microbatch_size)
elif isinstance(batch, Mapping): # check for dictionary (hf style)
return _split_mapping(batch, microbatch_size)
elif isinstance(batch, (Tuple, list)) and _check_list_is_primitives(batch): # check for list of primitives
return _split_list(batch, microbatch_size)
elif isinstance(batch, (Tuple, List)): # check for batch on 2nd dimension
result = []
for item in batch:
Expand Down
42 changes: 40 additions & 2 deletions tests/test_split_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def dummy_tensor_batch(batch_size=12) -> torch.Tensor:
return torch.randn(size=(batch_size, 3, 32, 32))


def dummy_list_str(batch_size=12) -> List[str]:
return [str(x) for x in range(batch_size)]


def dummy_tuple_batch(batch_size=12) -> List[torch.Tensor]:
# pytorch default collate converts tuples to lists
# https://github.com/pytorch/pytorch/blob/e451259a609acdcd83105177ddba73fc41cfa9b4/torch/utils/data/_utils/collate.py#L67
Expand All @@ -37,14 +41,23 @@ def dummy_dict_batch(batch_size=12) -> Dict[str, torch.Tensor]:


def dummy_dict_batch_with_metadata(batch_size=12) -> Dict[str, Union[List, torch.Tensor, str]]:
# sometimes metadata is included with a batch that isnt taken by the model.
# sometimes metadata is included with a batch that isn't taken by the model.
image = torch.randn(size=(batch_size, 3, 32, 32))
target = torch.randint(size=(batch_size,), high=10)
meta = ['hi im a tag' for _ in range(batch_size)]
index = [[1, 2, 3] for _ in range(batch_size)]
return {'image': image, 'target': target, 'meta': meta, 'index': index}


def dummy_dict_batch_with_common_metadata(batch_size=12) -> Dict[str, Union[List, torch.Tensor, str]]:
# sometimes metadata is included with a batch that isn't taken by the model.
image = torch.randn(size=(batch_size, 3, 32, 32))
target = torch.randint(size=(batch_size,), high=10)
meta = 'this is a string'
index = [[1, 2, 3] for _ in range(batch_size)]
return {'image': image, 'target': target, 'meta': meta, 'index': index}


def dummy_maskrcnn_batch(batch_size=12,
image_height=12,
image_width=12,
Expand Down Expand Up @@ -76,10 +89,12 @@ def generate_maskrcnn_sample(num_detections,
def dummy_batches(batch_size=12):
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
return [
dummy_tensor_batch(batch_size=batch_size),
dummy_list_str(batch_size=batch_size),
dummy_tuple_batch(batch_size=batch_size),
dummy_tuple_batch_long(batch_size=batch_size),
dummy_dict_batch(batch_size=batch_size),
dummy_dict_batch_with_metadata(batch_size=batch_size)
dummy_dict_batch_with_metadata(batch_size=batch_size),
dummy_dict_batch_with_common_metadata(batch_size=batch_size),
]


Expand Down Expand Up @@ -112,6 +127,21 @@ def test_split_tuple_long(batch):
assert len(microbatches[0]) == 4


@pytest.mark.parametrize('batch', dummy_batches(6))
def test_batch_sizes(batch):
microbatches = _default_split_batch(batch, microbatch_size=2)
# should split into [len(2), len(2), len(1)]
assert len(microbatches) == 3
for microbatch in microbatches:
if isinstance(microbatch, Mapping):
assert len(microbatch['image']) == 2
assert len(microbatch['target']) == 2
if isinstance(microbatch, tuple):
assert len(microbatch[0]) == 2
if isinstance(microbatch, list):
assert len(microbatch) == 2


@pytest.mark.parametrize('batch', dummy_batches(5))
def test_odd_batch_sizes(batch):
microbatches = _default_split_batch(batch, microbatch_size=2)
Expand Down Expand Up @@ -140,6 +170,14 @@ def test_microbatch_size_split_maskrcnn(batch):
assert len(microbatches) == 3


@pytest.mark.parametrize('batch', [dummy_dict_batch_with_common_metadata(12)])
def test_primitive_broadcast(batch):
microbatches = _default_split_batch(batch, microbatch_size=3)
assert len(microbatches) == 4
for mb in microbatches:
assert mb['meta'] == 'this is a string'


## Older tests for deprecated codepath. To be removed in 0.13


Expand Down