From b5291a654cbb8964615630a8dd38a64fe98610cb Mon Sep 17 00:00:00 2001 From: Anton Eriksson Date: Tue, 8 Oct 2024 19:05:52 +0200 Subject: [PATCH] ensure no breaking change in kwargs --- torchgeo/datamodules/geo.py | 8 ++++---- torchgeo/datamodules/utils.py | 23 ++++++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 507940d7c1..cd21d8ac54 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -21,7 +21,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential -from .utils import MisconfigurationException, get_prefixed_kwargs +from .utils import MisconfigurationException, split_kwargs class BaseDataModule(LightningDataModule): @@ -53,7 +53,7 @@ def __init__( self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers - self.kwargs = kwargs + self.dataloader_kwargs, self.kwargs = split_kwargs('dataloader_', **kwargs) # Datasets self.dataset: Dataset[dict[str, Tensor]] | None = None @@ -287,7 +287,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - **get_prefixed_kwargs('dataloader_', **self.kwargs), + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -432,7 +432,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, - **get_prefixed_kwargs('dataloader_', **self.kwargs), + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 6a8070c627..5ec9519cc1 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -171,14 +171,27 @@ def group_shuffle_split( return train_idxs, test_idxs -def get_prefixed_kwargs(prefix: str, **kwargs: Any) -> dict[str, Any]: - """Get kwargs with a specific prefix. +def split_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]: + """Split kwargs into prefixed and other kwargs. Args: - prefix: Prefix to filter kwargs by. + *prefixes: Prefixes to filter kwargs by. **kwargs: Keyword arguments to filter. Returns: - Dictionary of kwargs with the specified prefix. + Tuple of prefixed kwargs and other kwargs. """ - return {k.replace(prefix, ''): v for k, v in kwargs.items() if k.startswith(prefix)} + prefixed_kwargs: list[dict[str, Any]] = [{} for _ in prefixes] + other_kwargs: dict[str, Any] = {} + + for key, value in kwargs.items(): + matched = False + for i, prefix in enumerate(prefixes): + if key.startswith(prefix): + prefixed_kwargs[i][key[len(prefix) :]] = value + matched = True + break + if not matched: + other_kwargs[key] = value + + return *prefixed_kwargs, other_kwargs