-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataModules: run all data augmentation on the GPU #992
Conversation
""" | ||
dataset = self.val_dataset or self.dataset | ||
if dataset is not None: | ||
if hasattr(dataset, "plot"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we enforce that all datasets must have a plot method, we could remove this check. Currently the only ones lacking are VHR-10 (WIP) and our point datasets (GBIF, iNaturalist, EDDMapS).
self.train_aug: Optional[Transform] = None | ||
self.val_aug: Optional[Transform] = None | ||
self.test_aug: Optional[Transform] = None | ||
self.predict_aug: Optional[Transform] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea for all of these is that you can either define a single attribute (self.foo
) or a different attribute for each stage (self.train_foo
, self.val_foo
, etc.)
stage: Either 'fit', 'validate', 'test', or 'predict'. | ||
""" | ||
if stage in ["fit"]: | ||
self.train_dataset = self.dataset_class( # type: ignore[call-arg] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to ignore type warnings because not all datasets accept a split
argument
MisconfigurationException: If :meth:`setup` does not define a | ||
'train_dataset'. | ||
""" | ||
dataset = self.train_dataset or self.dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that datasets/samplers with length 0 also evaluate to False, so this may lead to red herrings
""" | ||
# Non-Tensor values cannot be moved to a device | ||
del batch["crs"] | ||
del batch["bbox"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may decide to remove these completely but didn't do that for this PR since it requires a ton of unrelated changes. Will reconsider when working on #985.
I'm having trouble understanding the failing BYOL tests. The same datamodules work fine with SemanticSegmentationTask. If anyone can figure out how to fix these, we can see how much coverage we lack and fix that. |
The last remaining test failure is due to discrepancies between Chesapeake and all other datasets. All of our datasets return a mask of shape "h w" but Chesapeake returns "c h w". This breaks everything. Can't wait until PEP 646 is supported in PyTorch... How should we handle this? It doesn't seem like Chesapeake can be changed to "h w" since the prior labels are c=4. We could change every other dataset to be "c h w" like Kornia expects, but that sounds like a lot more work. We could also write a custom AugmentationSequential just for Chesapeake that handles things properly. Our AugmentationSequential only works for batches, not samples. |
bf2bbb7
to
3d27e5d
Compare
@ashnair1 @nilsleh how would you propose to solve the ExtractTensorPatches OOM issue? Easiest fix is to use RandomNCrop instead (this is what we do during train). I'm happy to do this. However, it doesn't really make sense during val/test, and isn't useful at all during predict. Alternative fix would be to introduce a Any other ideas for how to handle this? |
A trivial solution to OOM would be to run inference in an iterative manner whenever we use For example, in the def validation_step(self, *args: Any, **kwargs: Any) -> None:
"""Compute validation loss and log example predictions.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
"""
batch = args[0]
batch_idx = args[1]
x = batch["image"]
y = batch["mask"]
- y_hat = self(x)
+ from einops import rearrange
+ y_hat_list = []
+ for i in range(x.shape[0]):
+ out = rearrange(x[i], 'c h w -> 1 c h w')
+ out = self(out)
+ y_hat_list.append(out)
+ y_hat = rearrange(y_hat_list, 'b 1 c h w -> b c h w')
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y) |
I don't think there's an easy way to tell whether or not we are using |
How about making this the default behaviour for val, test and predict steps? X is dependent on the hardware RAM which can only determined by the user. So I don't think we can go down that route. |
I think replacing all val/test/predict steps from batch processing to a for-loop over the batch size would be extremely bad for speed. |
Don't know if it is a desired approach, but could you wrap |
If I recall, it isn't possible to check for OOM, the program crashes without raising an error. But correct me if I'm wrong. Also, it would be difficult to get test coverage of that branch. |
On the pytorch docs, they seem to suggest a possible way as a RuntimeError but as you said difficult to get test coverage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently training fails during the sanity checking (validation) step fig = datamodule.plot(sample)
as fig
will be of type None because self.val_dataset
(of type torch.utils.data.Subset
) does not have a plot method.
21dc56f
to
f6a3061
Compare
""" | ||
output: Dict[str, Any] = {} | ||
output["image"] = torch.stack([sample["image"] for sample in batch]) | ||
output["boxes"] = [sample["boxes"] for sample in batch] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@calebrob6 starting a thread here so we can discuss how to handle NASA Marine Debris.
I think the solution in this collate_fn
is actually correct, we want a list of tensors, not a single tensor: kornia/kornia#1497
Note that the name is wrong, we should be used bbox_xyxy or bbox_xywh instead of bbox (boxes is translated to bbox internally in our AugmentationSequential
wrapper) but we can fix that when working on #985.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, see if the last commit fixes this.
@calebrob6 I believe I've addressed all review comments, let me know if you find anything else to fix. |
Yep that did it, nice job! |
* DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efc. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This PR overhauls all of our data modules to improve uniformity. This includes the following changes:
torchgeo.transforms.AugmentationSequential
(usekornia.augmentation.AugmentationSequential
instead)torchgeo.datamodules.utils.dataset_split
(usetorch.utils.data.random_split
instead)In a future PR, I'm planning on extending this to the rest of our transforms:
Fixes #619
Fixes #337
Fixes #336