Skip to content

Commit

Permalink
Set keepdim=True
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 14, 2024
1 parent 36442fd commit 5fb90fb
Show file tree
Hide file tree
Showing 27 changed files with 131 additions and 29 deletions.
3 changes: 3 additions & 0 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ def __init__(
self.layers = ['naip-new', 'lc']

self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/fire_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def __init__(
K.RandomErasing(p=0.1),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def __init__(
# Data augmentation
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
self.aug: Transform = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]
self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
Expand Down
6 changes: 6 additions & 0 deletions torchgeo/datamodules/gid15.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,19 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
self.predict_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.predict_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
8 changes: 8 additions & 0 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,26 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
self.predict_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
self.predict_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/l8biome.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def __init__(
K.RandomSharpness(p=0.5),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=None,
keepdim=True,
)
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]
20 changes: 16 additions & 4 deletions torchgeo/datamodules/levircd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
self.val_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
self.test_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.val_aug.keepdim = True # type: ignore[attr-defined]
self.test_aug.keepdim = True # type: ignore[attr-defined]


class LEVIRCDPlusDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the LEVIR-CD+ dataset.
Expand Down Expand Up @@ -92,14 +98,20 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
self.val_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
self.test_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.val_aug.keepdim = True # type: ignore[attr-defined]
self.test_aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datamodules/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def __init__(
)

self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def __init__(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/quakeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,7 @@ def __init__(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image'],
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
3 changes: 3 additions & 0 deletions torchgeo/datamodules/resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ def __init__(
K.RandomErasing(p=0.1),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
3 changes: 3 additions & 0 deletions torchgeo/datamodules/seco.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=None,
keepdim=True,
)
# https://github.com/kornia/kornia/issues/2848
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datamodules/sentinel2_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)

self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
14 changes: 9 additions & 5 deletions torchgeo/datamodules/sentinel2_eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


Expand Down Expand Up @@ -64,21 +63,26 @@ def __init__(
**self.eurocrops_kwargs,
)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datamodules/sentinel2_nccm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)

self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datamodules/sentinel2_south_america_soybean.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,20 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)

self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Expand Down
14 changes: 9 additions & 5 deletions torchgeo/datamodules/southafricacroptype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..datasets import SouthAfricaCropType, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


Expand Down Expand Up @@ -49,21 +48,26 @@ def __init__(
**kwargs,
)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
extra_args={
DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

# https://github.com/kornia/kornia/issues/2848
self.train_aug.keepdim = True # type: ignore[attr-defined]
self.aug.keepdim = True # type: ignore[attr-defined]

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
Loading

0 comments on commit 5fb90fb

Please sign in to comment.