Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch support for Transform? #157

Closed
arunmallya opened this issue Apr 27, 2017 · 14 comments
Closed

Batch support for Transform? #157

arunmallya opened this issue Apr 27, 2017 · 14 comments

Comments

@arunmallya
Copy link

Any plans for updating Transform to support batch inputs instead of just single images?
This is useful for applying transforms outside of a DataLoader (which does it on one image at a time).

@fmassa
Copy link
Member

fmassa commented Apr 30, 2017

I don't think there are any plans on extending transforms to work on batched images.
Indeed, I think transforms are supposed to be applied only in Datasets, so only single instances are required.
Another point is that implementing batched transforms efficiently would require dedicated implementations, and would also raise the question of wether or not it would be interesting to have them on GPUs as well.

@alykhantejani
Copy link
Contributor

Closing this for now as there currently are no plans to extend transforms to work on batched images.

@Coolnesss
Copy link
Contributor

Coolnesss commented Dec 22, 2017

Just to follow up on this, right now to apply a transformation after getting a batch from DataLoader, I have to iterate over the batch and transform each tensor back to a PIL image, after which I do any additional transformations, and convert it back to tensor again. It's doable but it's fairly slow (unless I'm doing something wrong).

If you're open to a PR on this, I'd be happy to help if you can give me some pointers.

@alykhantejani
Copy link
Contributor

@Coolnesss usually you do the transformations at the Dataset level. The DataLoader has many processes that read from the Dataset which effectively does your transformations in parallel.

Perhaps you can share some details of what your goal is and we can see if it falls outside of the current paradigm

@Coolnesss
Copy link
Contributor

Thank you for your reply @alykhantejani !

I'm trying to create derivative datasets of e.g MNIST, by applying some category of random transformations on each set. Currently, I'm doing something like

d_transforms = [
    transforms.RandomHorizontalFlip(),
    # Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
    dataset = datasets.MNIST('./data', 
            train=train, 
            download=True, 
            transform=d_transforms[i]
    loaders.append(
        DataLoader(dataset, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=1)
        )

Here, I get the desired outcome of having multiple DataLoaders that each provide samples from the transformed datasets. However, this is really slow, presumably because each worker tries to access the same files stored in ./data, and they can be accessed by one worker at a time (?). After profiling my program, nearly all of the time is spent on calls like

x, y = next(iter(train_loaders[i]))

I can think of two ways to solve this

  1. Apply transformations after getting the batch from the loader - but this requires batched transformations, otherwise it's slow
  2. Make n copies of MNIST on disk and let the workers each have their own copy, e.g dataset = datasets.MNIST('./data1', ...) etc.

Sorry for the lengthy post, and thanks for your help.

@alykhantejani
Copy link
Contributor

@Coolnesss would this work for you:

class MultiTransformDataset(Dataset):
    def __init__(self, dataset, transforms):
        self.dataset = datset
        self.transforms = transforms

    def __get_item__(self, idx):
         input, target = self.dataset[idx]
         return tuple(t(input) for t in self.transforms) + (target, )

@Coolnesss
Copy link
Contributor

Thanks for the workaround @alykhantejani

It's a much nicer solution, and somewhat faster too. Unfortunately it's still not as fast as I had hoped, perhaps the transforms themselves just take too much time. In any case, thanks for your help.

@alykhantejani
Copy link
Contributor

@Coolnesss np. Let me know if you have any other questions

@bermanmaxim
Copy link

bermanmaxim commented Oct 19, 2018

Note that you can also design a custom collate function that does the necessary transformations on your batch after collating it, e.g.

def get_collate(batch_transform=None):
    def mycollate(batch):
        collated = torch.utils.data.dataloader.default_collate(batch)
        if batch_transform is not None:
            collated = batch_transform(collated)
        return collated
    return mycollate

I find this strategy useful to add information in the batch (such as batch statistics, or complementary images in the dataset), and making the workers do the necessary computation.

@hukkai
Copy link

hukkai commented Dec 28, 2019

Hello, I am doing video tasks where each video is 32 frames of images. Then I need to resize and crop the 32 images by loops. A batch operation may be helpful or (faster?).

@GabPrato
Copy link

If this can help anyone, I implemented a few batch Transforms:
https://github.com/pratogab/batch-transforms

@shivam13juna
Copy link

Common, let's have an official implementation of batch transforms, it's 2023 already!!

@AnthonyArmour
Copy link

This would be great in the case of online batch inference. Currently looking for a solution to my current use case.

@chenzhike110
Copy link

torchvision.transforms.Lambda may help

rajveerb pushed a commit to rajveerb/vision that referenced this issue Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants