Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataModules: run all data augmentation on the GPU #992

Merged
merged 108 commits into from
Jan 23, 2023
Merged

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Dec 30, 2022

This PR overhauls all of our data modules to improve uniformity. This includes the following changes:

  • Add GeoDataModule and NonGeoDataModule base classes to reduce code duplication
  • Only instantiate the datasets that are needed for a particular stage
  • Replace torchvision with kornia (better support for MSI, GPU, inverse)
  • Replace dataset transforms with on_after_batch_transfer (CPU -> GPU, sample -> batch, faster)
  • Remove instance methods for preprocessing (fixes Trainers: num_workers > 0 results in pickling error on macOS/Windows #886)
  • Fix bug where train/val/test split would differ for each stage and images would leak between sets
  • Deprecate torchgeo.transforms.AugmentationSequential (use kornia.augmentation.AugmentationSequential instead)
  • Deprecate torchgeo.datamodules.utils.dataset_split (use torch.utils.data.random_split instead)

In a future PR, I'm planning on extending this to the rest of our transforms:

  • Rewrite all index transforms to be compatible with Kornia (Convert all index transforms to Kornia #999)
  • Update tutorials to use Kornia with our transforms
  • Upstream and remove our custom transforms and AugmentationSequential hacks

Fixes #619
Fixes #337
Fixes #336

@adamjstewart adamjstewart added this to the 0.4.0 milestone Dec 30, 2022
@github-actions github-actions bot added datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets transforms Data augmentation transforms testing Continuous integration testing documentation Improvements or additions to documentation labels Dec 30, 2022
@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Jan 2, 2023
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jan 2, 2023
torchgeo/datamodules/geo.py Show resolved Hide resolved
"""
dataset = self.val_dataset or self.dataset
if dataset is not None:
if hasattr(dataset, "plot"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we enforce that all datasets must have a plot method, we could remove this check. Currently the only ones lacking are VHR-10 (WIP) and our point datasets (GBIF, iNaturalist, EDDMapS).

self.train_aug: Optional[Transform] = None
self.val_aug: Optional[Transform] = None
self.test_aug: Optional[Transform] = None
self.predict_aug: Optional[Transform] = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea for all of these is that you can either define a single attribute (self.foo) or a different attribute for each stage (self.train_foo, self.val_foo, etc.)

stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit"]:
self.train_dataset = self.dataset_class( # type: ignore[call-arg]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to ignore type warnings because not all datasets accept a split argument

MisconfigurationException: If :meth:`setup` does not define a
'train_dataset'.
"""
dataset = self.train_dataset or self.dataset
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that datasets/samplers with length 0 also evaluate to False, so this may lead to red herrings

"""
# Non-Tensor values cannot be moved to a device
del batch["crs"]
del batch["bbox"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may decide to remove these completely but didn't do that for this PR since it requires a ton of unrelated changes. Will reconsider when working on #985.

@adamjstewart adamjstewart marked this pull request as ready for review January 4, 2023 16:40
@adamjstewart
Copy link
Collaborator Author

I'm having trouble understanding the failing BYOL tests. The same datamodules work fine with SemanticSegmentationTask. If anyone can figure out how to fix these, we can see how much coverage we lack and fix that.

torchgeo/datamodules/geo.py Outdated Show resolved Hide resolved
torchgeo/datamodules/geo.py Show resolved Hide resolved
torchgeo/datamodules/geo.py Outdated Show resolved Hide resolved
torchgeo/datamodules/geo.py Outdated Show resolved Hide resolved
@adamjstewart
Copy link
Collaborator Author

The last remaining test failure is due to discrepancies between Chesapeake and all other datasets. All of our datasets return a mask of shape "h w" but Chesapeake returns "c h w". This breaks everything. Can't wait until PEP 646 is supported in PyTorch...

How should we handle this? It doesn't seem like Chesapeake can be changed to "h w" since the prior labels are c=4. We could change every other dataset to be "c h w" like Kornia expects, but that sounds like a lot more work. We could also write a custom AugmentationSequential just for Chesapeake that handles things properly. Our AugmentationSequential only works for batches, not samples.

@adamjstewart adamjstewart force-pushed the datamodules/gpu branch 2 times, most recently from bf2bbb7 to 3d27e5d Compare January 15, 2023 18:49
@adamjstewart
Copy link
Collaborator Author

@ashnair1 @nilsleh how would you propose to solve the ExtractTensorPatches OOM issue?

Easiest fix is to use RandomNCrop instead (this is what we do during train). I'm happy to do this. However, it doesn't really make sense during val/test, and isn't useful at all during predict.

Alternative fix would be to introduce a GridNonGeoSampler that works similarly to GridGeoSampler and allows a single image to span multiple batches. However, this would require all NonGeoDatasets to change from __getitem__(idx: int) to __getitem__(idx: int, rows: slice, cols: slice) which would be a ton of work.

Any other ideas for how to handle this?

@ashnair1
Copy link
Collaborator

@ashnair1 @nilsleh how would you propose to solve the ExtractTensorPatches OOM issue?

Easiest fix is to use RandomNCrop instead (this is what we do during train). I'm happy to do this. However, it doesn't really make sense during val/test, and isn't useful at all during predict.

Alternative fix would be to introduce a GridNonGeoSampler that works similarly to GridGeoSampler and allows a single image to span multiple batches. However, this would require all NonGeoDatasets to change from __getitem__(idx: int) to __getitem__(idx: int, rows: slice, cols: slice) which would be a ton of work.

Any other ideas for how to handle this?

A trivial solution to OOM would be to run inference in an iterative manner whenever we use _ExtractTensorPatches.

For example, in the validation_step of torchgeo/trainers/segmentation.py we iterate over the batch instead of passing in the entire batch.

def validation_step(self, *args: Any, **kwargs: Any) -> None:
    """Compute validation loss and log example predictions.

    Args:
        batch: the output of your DataLoader
        batch_idx: the index of this batch
    """
    batch = args[0]
    batch_idx = args[1]
    x = batch["image"]
    y = batch["mask"]
-   y_hat = self(x)
+   from einops import rearrange
+   y_hat_list = []
+   for i in range(x.shape[0]):
+       out = rearrange(x[i], 'c h w -> 1 c h w') 
+       out = self(out)
+       y_hat_list.append(out)    
+   y_hat = rearrange(y_hat_list, 'b 1 c h w -> b c h w')
    y_hat_hard = y_hat.argmax(dim=1)

    loss = self.loss(y_hat, y)

@adamjstewart
Copy link
Collaborator Author

run inference in an iterative manner whenever we use _ExtractTensorPatches

I don't think there's an easy way to tell whether or not we are using _ExtractTensorPatches. We could say "if batch size > X run iteratively" but I'm not sure what X should be.

@ashnair1
Copy link
Collaborator

How about making this the default behaviour for val, test and predict steps?

X is dependent on the hardware RAM which can only determined by the user. So I don't think we can go down that route.

@adamjstewart
Copy link
Collaborator Author

I think replacing all val/test/predict steps from batch processing to a for-loop over the batch size would be extremely bad for speed.

@nilsleh
Copy link
Collaborator

nilsleh commented Jan 17, 2023

We could say "if batch size > X run iteratively" but I'm not sure what X should be.

Don't know if it is a desired approach, but could you wrap y_hat = self(x) in a try, except block checking for OOM, and in the except block do the looping?

@adamjstewart
Copy link
Collaborator Author

If I recall, it isn't possible to check for OOM, the program crashes without raising an error. But correct me if I'm wrong. Also, it would be difficult to get test coverage of that branch.

@nilsleh
Copy link
Collaborator

nilsleh commented Jan 17, 2023

On the pytorch docs, they seem to suggest a possible way as a RuntimeError but as you said difficult to get test coverage.

Copy link
Collaborator

@ashnair1 ashnair1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently training fails during the sanity checking (validation) step fig = datamodule.plot(sample) as fig will be of type None because self.val_dataset (of type torch.utils.data.Subset) does not have a plot method.

torchgeo/datamodules/inria.py Outdated Show resolved Hide resolved
torchgeo/datamodules/geo.py Outdated Show resolved Hide resolved
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@calebrob6 starting a thread here so we can discuss how to handle NASA Marine Debris.

I think the solution in this collate_fn is actually correct, we want a list of tensors, not a single tensor: kornia/kornia#1497

Note that the name is wrong, we should be used bbox_xyxy or bbox_xywh instead of bbox (boxes is translated to bbox internally in our AugmentationSequential wrapper) but we can fix that when working on #985.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, see if the last commit fixes this.

@adamjstewart
Copy link
Collaborator Author

@calebrob6 I believe I've addressed all review comments, let me know if you find anything else to fix.

@calebrob6
Copy link
Member

Yep that did it, nice job!

@adamjstewart adamjstewart merged commit 55f74da into main Jan 23, 2023
@adamjstewart adamjstewart deleted the datamodules/gpu branch January 23, 2023 22:08
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* DataModules: run all data augmentation on the GPU

* Passing tests

* Update BigEarthNet

* Break ChesapeakeCVPR

* Update COWC

* Update Cyclone

* Update ETCI2021

* mypy fixes

* Update FAIR1M

* Update Inria

* Update LandCoverAI

* Update LoveDA

* Update NAIP

* Update NASA

* Update OSCD

* Update RESISC45

* Update SEN12MS

* Update So2Sat

* Update SpaceNet

* Update UCMerced

* Update USAVars

* Update xview

* Remove seed

* mypy fixes

* OSCD hacks

* Add NonGeoDataModule base class

* Fixes

* Add base class to docs

* mypy fixes

* Fix several tests

* Fix Normalize

* Syntax error

* Fix bigearthnet

* Fix dtype

* Consistent kornia import

* Get regression datasets working

* Fix detection tests

* Fix some chesapeake bugs

* Fix several segmentation issues

* isort fixes

* Undo breaking change

* Remove more code duplication, standardize docstrings

* mypy fixes

* Add default augmentation

* Augmentations can be any callable

* Fix datasets tests

* Fix datamodule tests

* Fix more datamodules

* Typo fix

* Set up val_dataset even when fit

* Fix classification tests

* Fix ETCI2021

* Fix SEN12MS

* Add GeoDataModule base class

* Fix several chesapeake bugs

* Fix dtype and shape

* Fix crs/bbox issue

* Fix test dtype

* Fix unequal size stacking error

* flake8 fix

* Better checks on sampler

* Fix bug introduced in NAIP dm

* Fix chesapeake dimensions

* Add one to mask

* Fix missing imports

* Fix batch size

* Simplify augmentations

* Don't run test or predict without datasets

* Fix tests

* Allow shared dataset

* One more try

* Fix typo

* Fix another typo

* Fix Chesapeake dimensions

* Apply augmentations during sanity check too

* Don't reuse fixtures

* Increase coverage

* Fix ETCI tests

* Test predict_step

* Test all loss methods

* Simplify validation plotting

* Document new classes

* Fix plotting

* Plotting should be robust in case dataset does not contain RGB bands

* Fix flake8

* 100% coverage of trainers

* Add lightning-lite dependency

* Revert "Add lightning-lite dependency"

This reverts commit 1df7291.

* Define our own MisconfigurationException

* Properly test new data module base classes

* Fix mistake in setup call

* ExtractTensorPatches runs into OOM errors

* Test both fast_dev_run True and False

* Fix plot methods

* Fix OSCD tests

* Fix bug with inconsistent train/val/test splits between stages

* Fix issues with images of different sizes

* Fix OSCD tests

* Fix OSCD tests

* Bad rebase

* No trainer for OSCD so no need for config

* Bad rebase

* plot: only works during validation

* Fix collation of NASA Marine Debris dataset

* flake8 fix

* Quick test

* Revert "Quick test"

This reverts commit f465efc.

* 56 workers is a bit excessive

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers transforms Data augmentation transforms
Projects
None yet
4 participants