Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datamodules: Switch to kornia AugmentationSequential #2147

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

ashnair1
Copy link
Collaborator

@ashnair1 ashnair1 commented Jul 1, 2024

Switch to kornia's AugmentationSequential in datamodules submodule

Closes #1432

TODO (once kornia 0.7.4 is released):

  • Remove manual setting of keepdim attribute
  • Bump min version of kornia

@github-actions github-actions bot added testing Continuous integration testing datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets labels Jul 1, 2024
torchgeo/datasets/cyclone.py Outdated Show resolved Hide resolved
@adamjstewart adamjstewart added this to the 0.6.0 milestone Jul 1, 2024
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jul 1, 2024
@ashnair1
Copy link
Collaborator Author

ashnair1 commented Jul 2, 2024

TODO: Tutorials need to be updated once we fully switch to kornia's AugmentationSequential

@ashnair1 ashnair1 marked this pull request as ready for review July 2, 2024 07:23
.pre-commit-config.yaml Outdated Show resolved Hide resolved
torchgeo/datamodules/inria.py Outdated Show resolved Hide resolved
@github-actions github-actions bot removed the trainers PyTorch Lightning trainers label Jul 14, 2024
@ashnair1 ashnair1 force-pushed the aug-remove-datamodules branch 4 times, most recently from 5fb90fb to c7ce4f7 Compare July 14, 2024 19:55
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still confused by some of the dataset changes, but all datamodule changes now look good. I'll see if I can figure out why the tests are failing.


if self.transforms is not None:
sample = self.transforms(sample)
sample.update({x: features[x] for x in features if x != 'label'})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit confused by this. Relative time and ocean are not returned if transforms is None? This PR only touches datamodule augmentations, not dataset transforms, right? Why is this needed? Same question for all other datasets.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for this mentioned here- #2147 (comment)

These type of fixes won't be required once kornia/kornia#2971 makes it to a release. It allows non-kornia compatible keys to be passed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this kind of seems like a hack. It's an implementation detail that TropicalCycloneDataModule doesn't use TropicalCyclone(..., transform=...), not by design. I definitely don't think we should return different things depending on whether transforms is used or not.

I have a better idea. In BaseDataModule in transfer_batch_to_device, drop all keys that Kornia doesn't recognize. This will allow us to drop most changes to datasets and the transfer_batch_to_device in GeoDataModule. You can loop over all keys and drop any that don't correspond to a valid Kornia key (there should be a way to get a list of valid keys from Kornia).

@@ -217,7 +217,10 @@ def plot(
show_predictions = 'prediction' in sample

if show_mask:
mask = sample['mask'].numpy()
mask = sample['mask']
if mask.ndim == 3 and mask.shape[0] == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the mask occasionally change dimensions? Maybe this can be made uniform in getitem?

@adamjstewart
Copy link
Collaborator

But tests are failing for most GeoDataModules.

I wish it were this simple... Note that some failing tests are for NonGeoDataModules (ssl4eo_l_benchmark_*).

Also note that some tests fail depending on batch size. For example, gid15 passes with batch_size=1, but fails if batch_size=2. But there are other datasets for which both 1 and 2+ pass. For the case of gid15, it seems like this is because the image is 1x1, so this may be a red herring.

Most of the failing tests seem to be because the ground truth mask has a channel dimension, which cross-entropy loss does not like. Let me see if I can fix a few.

@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Jul 18, 2024
@adamjstewart
Copy link
Collaborator

The ChesapeakeCVPR dataset is royally screwed up and I'm not touching it, so I'll let you or @calebrob6 solve that one. The problem is that it doesn't use RasterDataset and the mask is sometimes single channel, sometimes multichannel, so you have to be very careful with all dimensions. I'm not even sure our trainer supports multichannel masks.

@ashnair1 ashnair1 force-pushed the aug-remove-datamodules branch 2 times, most recently from 7e14613 to d2b5b01 Compare July 20, 2024 13:16
@ashnair1 ashnair1 force-pushed the aug-remove-datamodules branch 3 times, most recently from b19be82 to 919bbb2 Compare August 9, 2024 15:20
@adamjstewart
Copy link
Collaborator

Is this waiting on Kornia 0.7.4? Trying to decide if this should be in TorchGeo 0.6.0 next week or if we should bump it to a later release.

@ashnair1 ashnair1 mentioned this pull request Aug 21, 2024
@ashnair1
Copy link
Collaborator Author

Is this waiting on Kornia 0.7.4? Trying to decide if this should be in TorchGeo 0.6.0 next week or if we should bump it to a later release.

Later release for sure. This does depend on Kornia 0.7.4.

@adamjstewart adamjstewart modified the milestones: 0.6.0, 0.7.0 Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

USAVars Augmentation maps to 0
2 participants