Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move dataloader initialize_object to factory methods #1510

Merged
merged 37 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9a202ba
update ade20k
hanlint Sep 9, 2022
93f9b54
skip shuffle since set in sampler
hanlint Sep 9, 2022
6c8d4ea
imagenet
hanlint Sep 9, 2022
b8af28e
add docs
hanlint Sep 9, 2022
6ba07e2
add mnist
hanlint Sep 9, 2022
e6d4838
move cifar10 initialize_object logic to factory funcs
dblalock Sep 13, 2022
e6ca5cb
add simple test for new cifar10 factory funcs
dblalock Sep 13, 2022
88d4258
fix pyright errors
dblalock Sep 13, 2022
5823ca9
have docstring type match fixed type
dblalock Sep 13, 2022
f0aac56
un-delete cifar_hparams.py and add mnist tests as one commit to appea…
dblalock Sep 13, 2022
a869cfb
de-yahpify coco for ssd
A-Jacobson Sep 13, 2022
fe67dfd
Update composer/datasets/coco.py
A-Jacobson Sep 13, 2022
0150dda
fix docstring
A-Jacobson Sep 13, 2022
5c1ee3e
add input_size to docstring
A-Jacobson Sep 13, 2022
cd4ff28
fix circular import
A-Jacobson Sep 14, 2022
80c0648
Deyahpify BRaTS
coryMosaicML Sep 14, 2022
a123e39
Fix types and creation of dataset
coryMosaicML Sep 14, 2022
7f1c1df
deyahpify, first attempt
growlix Sep 14, 2022
4dd3c37
lint and pyright ignore
growlix Sep 15, 2022
44cfe7a
fixed datasets and transformers conditional import scope error
growlix Sep 15, 2022
2fd4fb4
removed some type: ignores
growlix Sep 15, 2022
11600a3
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 15, 2022
40e55e6
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 15, 2022
eaddcf9
isort
growlix Sep 15, 2022
1bd8f49
move functions and cleanup unused args
hanlint Sep 15, 2022
73acd3a
fix isort
hanlint Sep 16, 2022
a5f21e0
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 16, 2022
b053162
fix isort again
hanlint Sep 16, 2022
95c3b7f
fix doctest
hanlint Sep 16, 2022
e783e51
fix
hanlint Sep 16, 2022
bcb3b21
lint
growlix Sep 16, 2022
8767915
Merge branch 'hanlin/dl_factories' of github.com:hanlint/composer int…
growlix Sep 16, 2022
d489c72
fix tests
hanlint Sep 17, 2022
7a6b9c3
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 19, 2022
3b8a022
lm_dataset_hparams fix, import fixes
growlix Sep 19, 2022
43c9cbc
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 26, 2022
ac3b32e
Merge branch 'dev' into hanlin/dl_factories
hanlint Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,42 @@

"""Natively supported datasets."""

from composer.datasets.ade20k import ADE20k, StreamingADE20k
from composer.datasets.ade20k import ADE20k, StreamingADE20k, build_ade20k_dataloader, build_synthetic_ade20k_dataloader
from composer.datasets.brats import PytTrain, PytVal
from composer.datasets.c4 import C4Dataset, StreamingC4
from composer.datasets.cifar import StreamingCIFAR10
from composer.datasets.cifar import StreamingCIFAR10, build_cifar10_dataloader, build_synthetic_cifar10_dataloader
from composer.datasets.coco import COCODetection, StreamingCOCO
from composer.datasets.imagenet import StreamingImageNet1k
from composer.datasets.imagenet import (StreamingImageNet1k, build_ffcv_imagenet_dataloader, build_imagenet_dataloader,
build_synthetic_imagenet_dataloader)
from composer.datasets.lm_dataset import build_lm_dataloader, build_synthetic_lm_dataloader
from composer.datasets.mnist import build_mnist_dataloader, build_synthetic_mnist_dataloader
from composer.datasets.synthetic import (SyntheticBatchPairDataset, SyntheticDataLabelType, SyntheticDataType,
SyntheticPILDataset)

__all__ = [
'ADE20k', 'StreamingADE20k', 'PytTrain', 'PytVal', 'C4Dataset', 'StreamingC4', 'StreamingCIFAR10', 'COCODetection',
'StreamingCOCO', 'StreamingImageNet1k', 'SyntheticBatchPairDataset', 'SyntheticDataLabelType', 'SyntheticDataType',
'SyntheticPILDataset'
'ADE20k',
'StreamingADE20k',
'PytTrain',
'PytVal',
'C4Dataset',
'StreamingC4',
'StreamingCIFAR10',
'COCODetection',
'StreamingCOCO',
'StreamingImageNet1k',
'SyntheticBatchPairDataset',
'SyntheticDataLabelType',
'SyntheticDataType',
'SyntheticPILDataset',
'build_ade20k_dataloader',
'build_cifar10_dataloader',
'build_synthetic_ade20k_dataloader',
'build_synthetic_cifar10_dataloader',
'build_ffcv_imagenet_dataloader',
'build_imagenet_dataloader',
'build_synthetic_imagenet_dataloader',
'build_mnist_dataloader',
'build_synthetic_mnist_dataloader',
'build_lm_dataloader',
'build_synthetic_lm_dataloader',
]
140 changes: 139 additions & 1 deletion composer/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,155 @@
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from composer.core.data_spec import DataSpec
from composer.core.types import MemoryFormat
from composer.datasets.streaming import StreamingDataset
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.datasets.utils import NormalizationFn, pil_image_collate
from composer.utils import dist

__all__ = ['ADE20k', 'StreamingADE20k']

IMAGENET_CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
IMAGENET_CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255)


def build_ade20k_dataloader(
batch_size: int,
datadir: str,
*,
split: str = 'train',
drop_last: bool = True,
shuffle: bool = True,
base_size: int = 512,
min_resize_scale: float = 0.5,
max_resize_scale: float = 2.0,
final_size: int = 512,
ignore_background: bool = True,
**dataloader_kwargs,
):
"""Builds an ADE20k dataloader.

Args:
datadir (str): path to location of dataset.
batch_size (int): Batch size per device.
split (str): the dataset split to use either 'train', 'val', or 'test'. Default: ``'train```.
drop_last (bool): whether to drop last samples. Default: ``True``.
shuffle (bool): whether to shuffle the dataset. Default: ``True``.
base_size (int): initial size of the image and target before other augmentations. Default: ``512``.
min_resize_scale (float): the minimum value the samples can be rescaled. Default: ``0.5``.
max_resize_scale (float): the maximum value the samples can be rescaled. Default: ``2.0``.
final_size (int): the final size of the image and target. Default: ``512``.
ignore_background (bool): if true, ignore the background class when calculating the training loss.
Default: ``true``.
**dataloader_kwargs (Dict[str, Any]): Additional settings for the dataloader (e.g. num_workers, etc.)
"""
if split == 'train':
both_transforms = torch.nn.Sequential(
RandomResizePair(
min_scale=min_resize_scale,
max_scale=max_resize_scale,
base_size=(base_size, base_size),
),
RandomCropPair(
crop_size=(final_size, final_size),
class_max_percent=0.75,
num_retry=10,
),
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/aa50358c71fe9c4cccdd2abe42433bdf702e757b/mmseg/datasets/pipelines/transforms.py#L861
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
PadToSize(size=(final_size, final_size), fill=(int(r_mean), int(g_mean), int(b_mean))))

target_transforms = PadToSize(size=(final_size, final_size), fill=0)
else:
both_transforms = None
image_transforms = transforms.Resize(size=(final_size, final_size), interpolation=TF.InterpolationMode.BILINEAR)
target_transforms = transforms.Resize(size=(final_size, final_size), interpolation=TF.InterpolationMode.NEAREST)

dataset = ADE20k(datadir=datadir,
split=split,
both_transforms=both_transforms,
image_transforms=image_transforms,
target_transforms=target_transforms)

sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle)
device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
std=IMAGENET_CHANNEL_STD,
ignore_background=ignore_background)

return DataSpec(
dataloader=DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=drop_last,
collate_fn=pil_image_collate,
**dataloader_kwargs),
device_transforms=device_transform_fn,
)


def build_synthetic_ade20k_dataloader(
batch_size: int,
*,
split: str = 'train',
drop_last: bool = True,
shuffle: bool = True,
final_size: int = 512,
num_unique_samples: int = 100,
device: str = 'cpu',
memory_format: MemoryFormat = MemoryFormat.CONTIGUOUS_FORMAT,
**dataloader_kwargs,
):
"""Builds a synthetic ADE20k dataloader.

Args:
batch_size (int): Batch size per device.
split (str): the dataset split to use either 'train', 'val', or 'test'. Default: ``'train```.
drop_last (bool): whether to drop last samples. Default: ``True``.
shuffle (bool): whether to shuffle the dataset. Default: ``True``.
final_size (int): the final size of the image and target. Default: ``512``.
num_unique_samples (int): number of unique samples in synthetic dataset. Default: ``100``.
device (str): device with which to load the dataset. Default: ``cpu``.
memory_format (MemoryFormat): memory format of the tensors. Default: ``CONTIGUOUS_FORMAT``.
**dataloader_kwargs (Dict[str, Any]): Additional settings for the dataloader (e.g. num_workers, etc.)
"""
if split == 'train':
total_dataset_size = 20_206
elif split == 'val':
total_dataset_size = 2_000
else:
total_dataset_size = 3_352

dataset = SyntheticBatchPairDataset(
total_dataset_size=total_dataset_size,
data_shape=[3, final_size, final_size],
label_shape=[final_size, final_size],
num_classes=150,
num_unique_samples_to_create=num_unique_samples,
device=device,
memory_format=memory_format,
)
sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle)

return DataSpec(
DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=batch_size,
drop_last=drop_last,
**dataloader_kwargs,
))


class RandomResizePair(torch.nn.Module):
"""Resize the image and target to ``base_size`` scaled by a randomly sampled value.

Expand Down
96 changes: 25 additions & 71 deletions composer/datasets/ade20k_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
dataset.
"""

from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Optional

import torch
Expand All @@ -16,14 +16,13 @@
from torchvision import transforms

from composer.core import DataSpec
from composer.datasets.ade20k import (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD, ADE20k, PadToSize,
PhotometricDistoration, RandomCropPair, RandomHFlipPair, RandomResizePair,
StreamingADE20k)
from composer.datasets.ade20k import (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD, PadToSize, PhotometricDistoration,
RandomCropPair, RandomHFlipPair, RandomResizePair, StreamingADE20k,
build_ade20k_dataloader, build_synthetic_ade20k_dataloader)
from composer.datasets.dataset_hparams import DataLoaderHparams, DatasetHparams
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.datasets.synthetic_hparams import SyntheticHparamsMixin
from composer.datasets.utils import NormalizationFn, pil_image_collate
from composer.utils import dist, warn_streaming_dataset_deprecation
from composer.utils import warn_streaming_dataset_deprecation
from composer.utils.import_helpers import MissingConditionalImportError

__all__ = ['ADE20kDatasetHparams', 'StreamingADE20kHparams']
Expand Down Expand Up @@ -76,75 +75,30 @@ def initialize_object(self, batch_size, dataloader_hparams) -> DataSpec:
self.validate()

if self.use_synthetic:
if self.split == 'train':
total_dataset_size = 20_206
elif self.split == 'val':
total_dataset_size = 2_000
else:
total_dataset_size = 3_352

dataset = SyntheticBatchPairDataset(
total_dataset_size=total_dataset_size,
data_shape=[3, self.final_size, self.final_size],
label_shape=[self.final_size, self.final_size],
num_classes=150,
num_unique_samples_to_create=self.synthetic_num_unique_samples,
return build_synthetic_ade20k_dataloader(
batch_size=batch_size,
split=self.split,
drop_last=self.drop_last,
shuffle=self.shuffle,
final_size=self.final_size,
num_unique_samples=self.synthetic_num_unique_samples,
device=self.synthetic_device,
memory_format=self.synthetic_memory_format,
**asdict(dataloader_hparams),
hanlint marked this conversation as resolved.
Show resolved Hide resolved
)
collate_fn = None
device_transform_fn = None

else:
# Define data transformations based on data split
if self.split == 'train':
both_transforms = torch.nn.Sequential(
RandomResizePair(min_scale=self.min_resize_scale,
max_scale=self.max_resize_scale,
base_size=(self.base_size, self.base_size)),
RandomCropPair(
crop_size=(self.final_size, self.final_size),
class_max_percent=0.75,
num_retry=10,
),
RandomHFlipPair(),
)

# Photometric distoration values come from mmsegmentation:
# https://github.com/open-mmlab/mmsegmentation/blob/aa50358c71fe9c4cccdd2abe42433bdf702e757b/mmseg/datasets/pipelines/transforms.py#L861
r_mean, g_mean, b_mean = IMAGENET_CHANNEL_MEAN
image_transforms = torch.nn.Sequential(
PhotometricDistoration(brightness=32. / 255, contrast=0.5, saturation=0.5, hue=18. / 255),
PadToSize(size=(self.final_size, self.final_size), fill=(int(r_mean), int(g_mean), int(b_mean))))

target_transforms = PadToSize(size=(self.final_size, self.final_size), fill=0)
else:
both_transforms = None
image_transforms = transforms.Resize(size=(self.final_size, self.final_size),
interpolation=TF.InterpolationMode.BILINEAR)
target_transforms = transforms.Resize(size=(self.final_size, self.final_size),
interpolation=TF.InterpolationMode.NEAREST)
collate_fn = pil_image_collate
device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
std=IMAGENET_CHANNEL_STD,
ignore_background=self.ignore_background)

# Add check to avoid type ignore below
if self.datadir is None:
raise ValueError('datadir must specify the path to the ADE20k dataset.')

dataset = ADE20k(datadir=self.datadir,
split=self.split,
both_transforms=both_transforms,
image_transforms=image_transforms,
target_transforms=target_transforms)
sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle)
return DataSpec(dataloader=dataloader_hparams.initialize_object(dataset=dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn,
drop_last=self.drop_last),
device_transforms=device_transform_fn)
return build_ade20k_dataloader(
batch_size=batch_size,
split=self.split,
drop_last=self.drop_last,
shuffle=self.shuffle,
base_size=self.base_size,
min_resize_scale=self.min_resize_scale,
max_resize_scale=self.max_resize_scale,
final_size=self.final_size,
ignore_background=self.ignore_background,
**asdict(dataloader_hparams),
)


@dataclass
Expand Down
34 changes: 34 additions & 0 deletions composer/datasets/brats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,47 @@
import torch.utils.data
import torchvision

from composer.utils import dist
from composer.utils.import_helpers import MissingConditionalImportError

PATCH_SIZE = [1, 192, 160]

__all__ = ['PytTrain', 'PytVal']


def build_brats_dataloader(datadir: str,
batch_size: int,
oversampling: float = 0.33,
is_train: bool = True,
drop_last: bool = True,
shuffle: bool = True,
**dataloader_kwargs):
"""Builds a BRaTS dataloader

Args:
**dataloader_kwargs (Dict[str, Any]): Additional settings for the dataloader (e.g. num_workers, etc.)
"""
x_train, y_train, x_val, y_val = get_data_split(datadir)
dataset = PytTrain(x_train, y_train, oversampling) if is_train else PytVal(x_val, y_val)
collate_fn = None if is_train else _my_collate
sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle)

return torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=drop_last,
collate_fn=collate_fn,
**dataloader_kwargs)


def _my_collate(batch):
"""Custom collate function to handle images with different depths."""
data = [item[0] for item in batch]
target = [item[1] for item in batch]

return [torch.Tensor(data), torch.Tensor(target)]


def _coin_flip(prob):
return random.random() < prob

Expand Down
Loading