diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index c2d24e9c11..a1e1054249 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -48,8 +48,13 @@ def _num_microbatches_split_mapping(m, num_microbatches: int): for k, v in m.items(): if isinstance(v, torch.Tensor): chunked[k] = _num_microbatches_split_tensor(v, num_microbatches) - if isinstance(v, (List, Tuple)): + elif isinstance(v, (List, Tuple)): chunked[k] = _num_microbatches_split_list(v, num_microbatches) + elif isinstance(v, (int, float, str, bool)): + # Broadcast primitives to all chunks + chunked[k] = [v] * num_microbatches + else: + raise ValueError(f'Unsupported batch type: {type(v)}.') num_chunks = len(list(chunked.values())[0]) return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] @@ -74,6 +79,9 @@ def _num_microbatches_split_batch(batch: Any, num_microbatches: int) -> Sequence if isinstance(batch, Mapping): # check for dictionary (hf style) return _num_microbatches_split_mapping(batch, num_microbatches) + if isinstance(batch, (Tuple, list)) and _check_list_is_primitives(batch): # check for list of primitives + return _num_microbatches_split_list(batch, num_microbatches) + if isinstance(batch, (Tuple, List)): # check for batch on 2nd dimension result = [] for item in batch: @@ -117,12 +125,36 @@ def _split_mapping(m, microbatch_size: int): for k, v in m.items(): if isinstance(v, torch.Tensor): chunked[k] = _split_tensor(v, microbatch_size) - if isinstance(v, (List, Tuple)): + elif isinstance(v, (List, Tuple)): chunked[k] = _split_list(v, microbatch_size) - num_chunks = len(list(chunked.values())[0]) + elif isinstance(v, (int, float, str, bool)): + # Defer broadcasting primitives until we know num_chunks + pass + else: + raise ValueError(f'Unsupported batch type: {type(v)}.') + num_chunks = 1 # Default to 1 chunks if there are no tensors or everything is primitive + if len(chunked.keys()) != 0: + num_chunks = len(list(chunked.values())[0]) + # Broadcast primitives to all chunks + for k, v in m.items(): + if isinstance(v, (int, float, str, bool)): + chunked[k] = [v] * num_chunks return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)] +def _check_list_is_primitives(l): + """Checks if all elements in a list are the same primitive type.""" + if len(l) == 0: + return True + first_type = type(l[0]) + if not isinstance(l[0], (int, float, str, bool)): + return False + for item in l: + if type(item) != first_type: + return False + return True + + def _default_split_batch(batch: Any, microbatch_size: int) -> Sequence: """Splits batch into chunks of size `microbatch_size` for gradient accumulation. @@ -136,6 +168,8 @@ def _default_split_batch(batch: Any, microbatch_size: int) -> Sequence: return _split_tensor(batch, microbatch_size) elif isinstance(batch, Mapping): # check for dictionary (hf style) return _split_mapping(batch, microbatch_size) + elif isinstance(batch, (Tuple, list)) and _check_list_is_primitives(batch): # check for list of primitives + return _split_list(batch, microbatch_size) elif isinstance(batch, (Tuple, List)): # check for batch on 2nd dimension result = [] for item in batch: diff --git a/tests/test_split_batch.py b/tests/test_split_batch.py index 904a491281..8776fec047 100644 --- a/tests/test_split_batch.py +++ b/tests/test_split_batch.py @@ -14,6 +14,10 @@ def dummy_tensor_batch(batch_size=12) -> torch.Tensor: return torch.randn(size=(batch_size, 3, 32, 32)) +def dummy_list_str(batch_size=12) -> List[str]: + return [str(x) for x in range(batch_size)] + + def dummy_tuple_batch(batch_size=12) -> List[torch.Tensor]: # pytorch default collate converts tuples to lists # https://github.com/pytorch/pytorch/blob/e451259a609acdcd83105177ddba73fc41cfa9b4/torch/utils/data/_utils/collate.py#L67 @@ -37,7 +41,7 @@ def dummy_dict_batch(batch_size=12) -> Dict[str, torch.Tensor]: def dummy_dict_batch_with_metadata(batch_size=12) -> Dict[str, Union[List, torch.Tensor, str]]: - # sometimes metadata is included with a batch that isnt taken by the model. + # sometimes metadata is included with a batch that isn't taken by the model. image = torch.randn(size=(batch_size, 3, 32, 32)) target = torch.randint(size=(batch_size,), high=10) meta = ['hi im a tag' for _ in range(batch_size)] @@ -45,6 +49,15 @@ def dummy_dict_batch_with_metadata(batch_size=12) -> Dict[str, Union[List, torch return {'image': image, 'target': target, 'meta': meta, 'index': index} +def dummy_dict_batch_with_common_metadata(batch_size=12) -> Dict[str, Union[List, torch.Tensor, str]]: + # sometimes metadata is included with a batch that isn't taken by the model. + image = torch.randn(size=(batch_size, 3, 32, 32)) + target = torch.randint(size=(batch_size,), high=10) + meta = 'this is a string' + index = [[1, 2, 3] for _ in range(batch_size)] + return {'image': image, 'target': target, 'meta': meta, 'index': index} + + def dummy_maskrcnn_batch(batch_size=12, image_height=12, image_width=12, @@ -76,10 +89,12 @@ def generate_maskrcnn_sample(num_detections, def dummy_batches(batch_size=12): return [ dummy_tensor_batch(batch_size=batch_size), + dummy_list_str(batch_size=batch_size), dummy_tuple_batch(batch_size=batch_size), dummy_tuple_batch_long(batch_size=batch_size), dummy_dict_batch(batch_size=batch_size), - dummy_dict_batch_with_metadata(batch_size=batch_size) + dummy_dict_batch_with_metadata(batch_size=batch_size), + dummy_dict_batch_with_common_metadata(batch_size=batch_size), ] @@ -112,6 +127,21 @@ def test_split_tuple_long(batch): assert len(microbatches[0]) == 4 +@pytest.mark.parametrize('batch', dummy_batches(6)) +def test_batch_sizes(batch): + microbatches = _default_split_batch(batch, microbatch_size=2) + # should split into [len(2), len(2), len(1)] + assert len(microbatches) == 3 + for microbatch in microbatches: + if isinstance(microbatch, Mapping): + assert len(microbatch['image']) == 2 + assert len(microbatch['target']) == 2 + if isinstance(microbatch, tuple): + assert len(microbatch[0]) == 2 + if isinstance(microbatch, list): + assert len(microbatch) == 2 + + @pytest.mark.parametrize('batch', dummy_batches(5)) def test_odd_batch_sizes(batch): microbatches = _default_split_batch(batch, microbatch_size=2) @@ -140,6 +170,14 @@ def test_microbatch_size_split_maskrcnn(batch): assert len(microbatches) == 3 +@pytest.mark.parametrize('batch', [dummy_dict_batch_with_common_metadata(12)]) +def test_primitive_broadcast(batch): + microbatches = _default_split_batch(batch, microbatch_size=3) + assert len(microbatches) == 4 + for mb in microbatches: + assert mb['meta'] == 'this is a string' + + ## Older tests for deprecated codepath. To be removed in 0.13