-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch support for Transform? #157
Comments
I don't think there are any plans on extending |
Closing this for now as there currently are no plans to extend transforms to work on batched images. |
Just to follow up on this, right now to apply a transformation after getting a batch from If you're open to a PR on this, I'd be happy to help if you can give me some pointers. |
@Coolnesss usually you do the transformations at the Perhaps you can share some details of what your goal is and we can see if it falls outside of the current paradigm |
Thank you for your reply @alykhantejani ! I'm trying to create derivative datasets of e.g MNIST, by applying some category of random transformations on each set. Currently, I'm doing something like d_transforms = [
transforms.RandomHorizontalFlip(),
# Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
dataset = datasets.MNIST('./data',
train=train,
download=True,
transform=d_transforms[i]
loaders.append(
DataLoader(dataset,
shuffle=True,
pin_memory=True,
num_workers=1)
) Here, I get the desired outcome of having multiple DataLoaders that each provide samples from the transformed datasets. However, this is really slow, presumably because each worker tries to access the same files stored in x, y = next(iter(train_loaders[i])) I can think of two ways to solve this
Sorry for the lengthy post, and thanks for your help. |
@Coolnesss would this work for you: class MultiTransformDataset(Dataset):
def __init__(self, dataset, transforms):
self.dataset = datset
self.transforms = transforms
def __get_item__(self, idx):
input, target = self.dataset[idx]
return tuple(t(input) for t in self.transforms) + (target, ) |
Thanks for the workaround @alykhantejani It's a much nicer solution, and somewhat faster too. Unfortunately it's still not as fast as I had hoped, perhaps the transforms themselves just take too much time. In any case, thanks for your help. |
@Coolnesss np. Let me know if you have any other questions |
Note that you can also design a custom collate function that does the necessary transformations on your batch after collating it, e.g. def get_collate(batch_transform=None):
def mycollate(batch):
collated = torch.utils.data.dataloader.default_collate(batch)
if batch_transform is not None:
collated = batch_transform(collated)
return collated
return mycollate I find this strategy useful to add information in the batch (such as batch statistics, or complementary images in the dataset), and making the workers do the necessary computation. |
Hello, I am doing video tasks where each video is 32 frames of images. Then I need to resize and crop the 32 images by loops. A batch operation may be helpful or (faster?). |
If this can help anyone, I implemented a few batch Transforms: |
Common, let's have an official implementation of batch transforms, it's 2023 already!! |
This would be great in the case of online batch inference. Currently looking for a solution to my current use case. |
torchvision.transforms.Lambda may help |
Any plans for updating Transform to support batch inputs instead of just single images?
This is useful for applying transforms outside of a DataLoader (which does it on one image at a time).
The text was updated successfully, but these errors were encountered: