From 3c3a306a9f245c1afa5720683a8634276601b919 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 2 Feb 2023 10:12:01 +0400 Subject: [PATCH 01/32] Add VHR10 datamodule --- conf/vhr10.yaml | 28 ++++++++ tests/conf/vhr10.yaml | 14 ++++ tests/trainers/test_detection.py | 8 ++- torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/vhr10.py | 110 +++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 conf/vhr10.yaml create mode 100644 tests/conf/vhr10.yaml create mode 100644 torchgeo/datamodules/vhr10.py diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml new file mode 100644 index 00000000000..c3eac82fbb0 --- /dev/null +++ b/conf/vhr10.yaml @@ -0,0 +1,28 @@ +program: + seed: 0 + overwrite: True + +trainer: + accelerator: gpu + devices: 1 + min_epochs: 5 + max_epochs: 100 + auto_lr_find: False + benchmark: True + +experiment: + task: "vhr10" + name: "vhr10_test" + module: + model: "faster-rcnn" + backbone: "resnet18" + pretrained: True + num_classes: 11 + learning_rate: 1.3e-5 + learning_rate_schedule_patience: 6 + verbose: false + datamodule: + root: "data/vhr10" + batch_size: 3 + num_workers: 12 + patch_size: 512 diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml new file mode 100644 index 00000000000..a9c325eb2f2 --- /dev/null +++ b/tests/conf/vhr10.yaml @@ -0,0 +1,14 @@ +experiment: + task: "vhr10" + module: + detection_model: "faster-rcnn" + backbone: "resnet18" + num_classes: 11 + learning_rate: 1e-4 + learning_rate_schedule_patience: 6 + verbose: false + datamodule: + root: "tests/data/vhr10" + batch_size: 1 + num_workers: 0 + patch_size: 4 \ No newline at end of file diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 51b4ac5c400..681b0056340 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -12,7 +12,11 @@ from pytest import MonkeyPatch from torch.nn.modules import Module -from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule +from torchgeo.datamodules import ( + MisconfigurationException, + NASAMarineDebrisDataModule, + VHR10DataModule, +) from torchgeo.datasets import NASAMarineDebris, RGBBandsMissingError from torchgeo.main import main from torchgeo.trainers import ObjectDetectionTask @@ -67,7 +71,7 @@ def plot(*args: Any, **kwargs: Any) -> None: class TestObjectDetectionTask: - @pytest.mark.parametrize("name", ["nasa_marine_debris"]) + @pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"]) @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 620d8334e44..1f05f31b3b4 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -38,6 +38,7 @@ from .usavars import USAVarsDataModule from .utils import MisconfigurationException from .vaihingen import Vaihingen2DDataModule +from .vhr10 import VHR10DataModule from .xview import XView2DataModule __all__ = ( @@ -79,6 +80,7 @@ "UCMercedDataModule", "USAVarsDataModule", "Vaihingen2DDataModule", + "VHR10DataModule", "XView2DataModule", # Base classes "BaseDataModule", diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py new file mode 100644 index 00000000000..72d8f35f6c8 --- /dev/null +++ b/torchgeo/datamodules/vhr10.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NWPU VHR-10 datamodule.""" + +from typing import Any, Dict, List + +import torch +import torchvision +from torch import Tensor + +from ..datasets import VHR10 +from ..samplers.utils import _to_tuple +from .geo import NonGeoDataModule +from .utils import dataset_split + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable boxes. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + output["labels"] = [sample["labels"] for sample in batch] + if "masks" in batch[0]: + output["masks"] = [sample["masks"] for sample in batch] + return output + + +class VHR10DataModule(NonGeoDataModule): + """LightningDataModule implementation for the VHR10 dataset. + + .. versionadded:: 0.5 + """ + + def __init__( + self, + batch_size: int = 64, + patch_size: int = 512, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a new VHR10DataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.VHR10`. + """ + super().__init__(VHR10, batch_size, num_workers, **kwargs) + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + self.patch_size = _to_tuple(patch_size) + + self.collate_fn = collate_fn + self.kwargs["transforms"] = self.preprocess + self.kwargs["download"] = True + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + + _, h, w = sample["image"].shape + sample["image"] = torchvision.transforms.functional.resize( + sample["image"], size=self.patch_size + ) + box_scale = (self.patch_size[1] / w, self.patch_size[0] / h) + sample["boxes"][:, 0] *= box_scale[0] + sample["boxes"][:, 1] *= box_scale[1] + sample["boxes"][:, 2] *= box_scale[0] + sample["boxes"][:, 3] *= box_scale[1] + sample["boxes"] = torch.round(sample["boxes"]) + + if "masks" in sample: + sample["masks"] = torchvision.transforms.functional.resize( + sample["masks"], size=self.patch_size + ) + + return sample + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.dataset = VHR10(**self.kwargs) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) From 41cf32d5e279f2546a8bac1f560fea1567744817 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 2 Feb 2023 10:15:49 +0400 Subject: [PATCH 02/32] Add newline --- tests/conf/vhr10.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index a9c325eb2f2..e7262e6ce25 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -11,4 +11,4 @@ experiment: root: "tests/data/vhr10" batch_size: 1 num_workers: 0 - patch_size: 4 \ No newline at end of file + patch_size: 4 From 1efc331972190b90dba9240bdaedfc65a2c1690d Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 2 Feb 2023 13:32:24 +0400 Subject: [PATCH 03/32] patch_size accepts int and tuple of ints --- torchgeo/datamodules/vhr10.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 72d8f35f6c8..f5b13128842 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,7 +3,7 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union import torch import torchvision @@ -42,7 +42,7 @@ class VHR10DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: int = 512, + patch_size: Union[Tuple[int, int], int] = 512, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2, From 74c4bd59ae52cd4c194ef36f1b0f9b1e6ce783fb Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 2 Feb 2023 13:39:46 +0400 Subject: [PATCH 04/32] Update conf --- conf/vhr10.yaml | 5 ++--- tests/conf/vhr10.yaml | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index c3eac82fbb0..d4bdf2ac1c7 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -15,10 +15,9 @@ experiment: name: "vhr10_test" module: model: "faster-rcnn" - backbone: "resnet18" - pretrained: True + backbone: "resnet50" num_classes: 11 - learning_rate: 1.3e-5 + learning_rate: 1.6e-5 learning_rate_schedule_patience: 6 verbose: false datamodule: diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index e7262e6ce25..f500d21b9e4 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -1,7 +1,7 @@ experiment: task: "vhr10" module: - detection_model: "faster-rcnn" + model: "faster-rcnn" backbone: "resnet18" num_classes: 11 learning_rate: 1e-4 From b4511eab45ed370a207ca3f0b1a72b01ab0949f0 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 30 Mar 2023 15:24:14 +0400 Subject: [PATCH 05/32] VHR10 Datamodule v2 --- pyproject.toml | 1 + tests/data/vhr10/NWPU VHR-10 dataset.rar | Bin 6529 -> 6527 bytes tests/data/vhr10/annotations.json | 2 +- tests/data/vhr10/data.py | 10 +-- tests/datasets/test_vhr10.py | 4 +- tests/transforms/test_transforms.py | 16 ++-- torchgeo/datamodules/geo.py | 2 +- torchgeo/datamodules/vhr10.py | 93 ++++++++++++++++++++--- torchgeo/datasets/vhr10.py | 7 +- torchgeo/trainers/detection.py | 3 + torchgeo/transforms/transforms.py | 19 +++++ 11 files changed, 124 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9fce633e79b..87df577370b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -266,6 +266,7 @@ filterwarnings = [ # Unexpected warnings, worth investigating # Lightning is having trouble inferring the batch size for ChesapeakeCVPRDataModule and CycloneDataModule for some reason "ignore:Trying to infer the `batch_size` from an ambiguous collection:UserWarning", + "ignore: Tight layout not applied. The bottom and top margins cannot be made large enough:UserWarning", # https://github.com/pytest-dev/pytest/issues/11461 "ignore::pytest.PytestUnraisableExceptionWarning", ] diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.rar b/tests/data/vhr10/NWPU VHR-10 dataset.rar index 6a1b98fed27cc543237a34d1bce597a23af720f1..0f836ac8e17ba85cd1aaf9aa680d2a7f06366da3 100644 GIT binary patch literal 6527 zcmeI1XHZjV8-@dfju8=+COx4^M*&%g1f@#vf=d7i1V|{MgY=Gw)PPD=SUOTwl%fa- z8ma{iDphJi@9hh&yW_Zgv+jTL^UUO(GtXS-zOVBf18a94Hfl1!(s1YuJsBl6fQpPP zI5-&a?2MbwWpWDocpwlYam6YJVX6<2(vGWjjb-^G# zkY0A&NR+j`9rykv35l~}j;{7}6vvkpz}GVC!Bj6ZXvrl!RSkKU_0+Z=HbMw{SvuF>Ui09a%1vD0I^t6Iu$kF5$H-{@jb9&7 zafNf!pB$fAPvZ=|)m=STpmDdmI=dPr=aq979=R#^0NKF{dl@LIO+5Fh4I??UuTxCoke&;x%Ph@5l)z zefT`vx?L=f7KmQiS~!-!s#8d-;gWdty;qifs%EXk$s?CWXg&RRLu#=bfq4QVl@@Zt zOM)%3=4J3Uh`7yXsdMsO-AJ&2P;L4kF%$h2QLy%eusuf+rU-p8($nkqX#1wFY-abQ z;^cDub##+$^+&>*LgUe+aqSvoiyKdOvrQ{L5Z3&6xwx|XXSB3exB7RXL7mxiTYJW) z=2q6WcJ@do=S||`>gFB=HT}`!>|A^TBI44r3QCMQApUs4;H70(Du9M+b5GGM|50Ew ze`{YCF2|?gG+C|~g>Ea#Qpe_ngDIfJUe?X0CvvDvU0QYb;a&mrx*F5Ak&iq*ns z6_-swW`q2(3Xw`nr@e{;BTW#sEJWi5Fi!ClT`skMBoYshcFFnoy8N170f4sbTWAi#WaprwaWXs}-k8^e!mgVo2cEp7>%iSMz zRun%LHpxOxKsn+Uy2Z`fy{? zM_mG;a%LQ(P2u|uld>+sLeIF2AD(dX378*LKO6tpEqMOXa6vk5`x&*sTJ z13zm+@V7I_X`bVSCnxejVIPK*LCV}DuE4ZB+Enz)@Tq8yBy}E^RcN1mE?l65VdaoL>($L@r2_f0{?Y>;h3PSkA7wEwmj!=$>nMg9YgHv zuS@wGJD>!^l6Nz!3X0`lyc@pEYx@NJ4u#(9pI+HgtgimN7u6ORHoml}ztgyv&^@!Z zz1tTwg)EvDb}XKmJ>y(T5Ji;{#n9z*;w}~Q5(|=Um5WmDRZG&<%Q7A{D`!1xhv)#q zZ6Bk$73=|U;4KRRT|o}FD})Rzj1jhMsdZjuFz2Swy-48DLOgf=jV!Xzx~?nu+zwhf zqrJ9a=RqkMn?kQ-^O%@K!D69U(7QhXnNfC#a-XA`YAE#j2&Of`HltWz0w_I}%3$ynM`fqc zQnlu*wnOAEt9bJYYK#KTDz$~G;PvN9SMG-k&CqE2BWLN+UXwfzc2$(R8#5pnrLMqT zYzD+$K1*79)gE%WANSkD!b^B=te}u84o_np&YcpV!e9WC$cqz9rl-+Nc}Za`a8}CJ z+DItG_<5v73#il_)Fzs8e!YN9OKu>&T*lTk@16ekqtOk06kFEb1WWU421tocruj^o znBwxJY4j*CftuAPi9=#o&^*Rb?Bb0LdQ7CcJTR=;Y3Mj_#AQZN%CP&~?i$?T=nefw zUJw{I*+HJt`6{guslOZKJ=n%1Q^)B3UT2_Y)_WpMHjJkHBpY9J$6ehM6r_}&NPZxAZ6zyyUJGj(xpOyjVMCEK+h~szPdT}3+e$K4;-rLH z-`;Hs>L#U=nzpXKAt|CPu6~Osl#3h2+t>ekP)Jz#_`7pTsyZgWBhX$q)3OT6>l)j- z`WG&6$l)KXOrYTmfFyeVT%*xt5Gk$lnV9ll;e@Q}Y>F7!e8BKN^Fhu}Jq(hL(|niK z@#9TY_E{!NYoT|edq2~HsCh7yD$x_l_&XZ)ngW6LhRo&XS?80bP+e{c?njGusn~4? z^`y5h)W2sS1Azt91)nC%YkbDq4#7<6;An!Qeut zC+=$6k6O$`Y(piZV`IwN=!=>`w5Q4Wq6~u#sghV`k3f9MS(iCOygg@nRIE4X2;Noq z%T1lK1F@-KQ+YEx)3=tXBDyk5**7&7O>LrXW2DCRBkPY5QyBKI$V!E%u=!OR3!ekq zd`jFhD`1CDuPmVW*;=a-5X@~I?P@PO_yF*?Z7cOmNk>G z_U+J&0<)=EZ>5mX5;BenMPH1kq&JP!`-awVDkQ>4j)Yd>H?)YWhWlvQvI-fb)%#co z_4g3gud7O@74E!HsC?;w$0}Z6eWI8*nO=x{Vm3f{u2433$D6(5`gTheH5I}sqGiO* zXE0DrAOkmn**wkqxC^UWxEe&tuQfsc2I9`C(1g~}Mc;tWewtQlB+v@}_|}A(x_u=x zt6RHJXg9%XYmda3)qV}4(3>M)|1QE)f7T#WMNN~`gd8Y`;HBFACZufcDGp?!CWkW_ zbeLXxntA1wUdwneay|c!z3(x$G;WDr_Lr~IRa{M8u;v6=?7gRZeC%EmGY1Ivg?6sf zd5f0Nsf1;jofZPiObAwH_4W2IX6db-!50-tWh}=%G$bd0T^awl@3&7^%=gr~(23ao zPjpHBL>KNBC5f)HhEn@gc>SjNWZaaZmZ%;?ScAMcZi*Jg$X%62Sx=Bc(oqW1s2jBi z3OfkJUPTC#GaZW!DSajsrVnFDk}(+X$4oz7-mu^gye<4he2mM=@3wA21qrTr$VPjF z>{5WX;&1(3f$|C$iIiMQo4)grk~Co5;}4mKZth#wDC~srdAVg1&NA~wEmLIT8;#00 zQDzd?KAJy)B$dIi$3b?URm2en3!F2;U_k4On;imx_lYuG;uDuqV5TDYNv4& zQ~zmB_6hM6C?Q6pb<}Mmpvy&mM;4mfKke!FDewhd)_a%7pXQ9cw>i1Q?{rTtzujyI z?DlRNrm7xY+}Kh2;$!t^LGJxkjOBmTqHE>fy=Y4Z7*WeT#u*t1Xr>h$xP1b8BNUJh zOs_)+e-t*xN#DFQc|_&JSx=QHo=i)MBbkGpHmdBgH95G8=ZANrjUOaHSYUuL(ObpH zil@d|O4@uYy3oNGCQjuqZ=E>iIUj|1S4XJnnsjE>0jK$N-eiN>UVYE4e^q$2oxSy; z!XGO9uPZ#t=2hB0UB9C62hw$@@P`Wjv%(7mmpJT0^(zX0AXJA6f2i;WC_JLn+3R~1 z{;R?t$kn03A1eI+Rd}89?v8!BensIAr0Y=O4;B7Eg~z&`A%~I3>LCM0(2!OSi)Z@O zNH2%{2W{c4(P$SBlEk~9F=CDw7xWLi_f`WOw=HSc1UDLalr@l_9m_Ikl5+o{wsTvR z&A=%e>{I))#1C;{Zr($8kb_qcWd^ggJsq7VZ5nSroLss60pk{VWZ)7VfRdVw z3IOo;_a_rt3zXC*r=SZ10+T4qgUHwD#*6(a0jG6M^^IXB8iu0c2$&rTgYvM)h@o9Q z95If!?O~42C$_0~p}%jsub-Ze6_0yRNRbQZj;XKCvAmp!{D4ZdDj*1%Jan8>fMUrZfp?-@A#YXa;lP|J zh)o1j2&`8D(jVs6AsB zRxy=ouyjSiEt0;m*{OVC@FBrw`dxR(^=O2~XSd{VcngR*BrGX=Mx(p8O2#fIM&mqH z-NG3`qq@M}D{{)zd~q~tD-~Q*u&K+ZgTqc~W4v$781pP|>!vgdL=#q9yYH~o^@wC% z`|NdzXq730&opb2p81l6%KGMp+z;}qkF;|&DjZtcJSSgF^sIvuKYRIFDoQ5f`c5uv z`0O1pKIoVe78Tpr(eBdOJK#xeUzz*5cYw>x%6^plI4}Q6K_R~6Y3Z}F@`}o;n%d`e z^$m?p&8;unUbVmO=Vg|Q@Po{BG6}3uATWgl*2&3G$3Fr~ z@&~Zkx03jK5-7+K_EJV?)-R~7B*b4ic3ZP zbu&(T^S9jRXPwj3#bfPQE8=5f-)qVQi#eVWn%V#B+A?GKTBP+gDi#w#{imCBDLf#I zu&P_2dQjvGO%Mc~Afj5w&Z~wEDQn?_1}eRmT8PgMIj=8D8T1I|t_cfv+3Kt31@UVg zJ7tro)1CD4uGzkxRqso~gY9`jN92(6_T8$0|I|iasgzdq?ZadD+NeVWW{pC{UryJ| zxL>?Z;GmUQRS%K>!=S}lGPd}@;!AM}2`!LX7nhq#XLSmigqGndmw{H;ceLa@;8F6L z1*}}mPb~Ads&mJB7Y9otbH>yV(Kf6VaZ%ARv94suej{fA?2vpJ2JSlz;RpyGjJS+f z@)?|0%HVk)jnG89r5qDKR92K=GT)3wuYS3gT?xw%JbPHY=h=J>r{x|ldE3453SDP`W82K zJLJD7B+Hr?jG>qhIbCMbo302j7`lgtWdW04%fG*W*3*gY;GS=P_uLK{B=@T zOyaB@I1>bQ7W7{$M^lk$Qa1J$xHD6eTQeIzG3D>7HEFrf^1JE$>ebZ)6?P`JQQy9~2U3KJMU7?wG*q@ts z9~*Rg#C;q6z@r)&m8nc+%tGP|3j`*U`0C5Ito$Rs#DC(eBbNNoTci!~&|5!jh^v$% z&0zT@aHR;0&WW?BJu56)CXJG9Fg8~8;>G@G%@YjOp=N2yDhg!^@d`N#xC^y48}sm| zyCJu0M-5GdD|#BVUdI)t7GFxtOY`n&`|DmbB-0++P5`&iD4p_(20sc$U1T#5 zU;}5!#!%r6okFOZ3QV7yH3VBVf{Kw%kF-qavpGUUb5K*f+2eKW6>hE3is8&b9>m>$ z&?>lGBp)$*vOa^e4{9Cve)G% z@_`eHyrwUdUKH4}!wd;#?xQ*)ie1aQkZ!ez8rrMVve_qeKfb@V;@M$62}h+>hQ6h(YIqAGlTtnMjYHPVJC(4_ zDR8tN75rr&A zGO8+^`8ln~dr>qK6Z!OzM&`=2CWF-hAy{cYZw1GRV^;Flx7Afvid0LmxtuwZ`M!*T z8g~}u*EcDv&SAhKesG)vQ9E81=b-4P)fq35-ouf)BM~p)IQ!6JYvxmYEo&6Z@QSK< zn39lJ#;bgwqkW2*S_-(x>=`wz{KI{a9pPT-FdH(9QoGq6z?15s*l+^Qsviat?9*&l7 zM&5FXTQ##Hk8+VUNy#Jp&qlj+x}623*Tp;;aeXeH<~0_xU+*?dq8GLjkDs@cvO1aA z5zvXBZ93c|Efn%hi!A=;t$XE8!}J$#-w8kUyuqc-;T=5u&OQ%M;)5Zc@fG{6-Tej} z@#}v7(AdP(KfQeJCy3<97#^`BiCaFu3+n zBmtj`q9t@`L%g7j#1mh!pc}I$3e*@uhI44l_ChF$t>0;Q7kdZPQNtfK{BJdU7f^zX z1lI2~{4c@!vEfOF!=r{jYWTy3=Q&@9BysgS4gX87jvD@`;r~y=Lm3q^NqqfI!~c@6 zqlQ0f_yYrHY4md z$_*`u}3rj#~c}oHO3$?KfYr=Szf+k$D7@li{uQv2%m9Gw@fjHWn-alPuW8= zp7SA*!r(_E!!t8G2vfDubsteEK}K;N21eNg(6AxES&shL*HIipm6D%~*den|!4tl< zy2GYv%4VMVlTh}zK~3o~as{cg3|KT#^(%EjeQheX|6T0%XH>{qUJnP3{hbtEejbX* str: ) ann = 0 - for i, img in enumerate(ANNOTATION_FILE["images"]): + for _, img in enumerate(ANNOTATION_FILE["images"]): annot = { "id": ann, "image_id": img["id"], @@ -57,12 +56,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str: "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0, } - if i != 0: - ANNOTATION_FILE["annotations"].append(annot) - else: - noseg_annot = deepcopy(annot) - del noseg_annot["segmentation"] - ANNOTATION_FILE["annotations"].append(noseg_annot) + ANNOTATION_FILE["annotations"].append(annot) ann += 1 with open(ann_file, "w") as j: diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 805b84a3117..5480acb4ef4 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -35,11 +35,11 @@ def dataset( monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar") monkeypatch.setitem(VHR10.image_meta, "url", url) - md5 = "5fddb0dfd56a80638831df9f90cbf37a" + md5 = "92769845cae6a4e8c74bfa1a0d1d4a80" monkeypatch.setitem(VHR10.image_meta, "md5", md5) url = os.path.join("tests", "data", "vhr10", "annotations.json") monkeypatch.setitem(VHR10.target_meta, "url", url) - md5 = "833899cce369168e0d4ee420dac326dc" + md5 = "567c4cd8c12624864ff04865de504c58" monkeypatch.setitem(VHR10.target_meta, "md5", md5) root = str(tmp_path) split = request.param diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index a9ea200982a..64412c2bc6c 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]: return { "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]: dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]: dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -79,7 +79,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -102,7 +102,7 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -129,7 +129,7 @@ def test_augmentation_sequential_multispectral( dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -156,7 +156,7 @@ def test_augmentation_sequential_image_only( dtype=torch.float, ), "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), + "boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } augs = transforms.AugmentationSequential( @@ -188,7 +188,7 @@ def test_sequential_transforms_augmentations( dtype=torch.float, ), "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), + "boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } train_transforms = transforms.AugmentationSequential( diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 96da7683ba2..7dba392c6a2 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -30,7 +30,7 @@ class BaseDataModule(LightningDataModule): """ mean = torch.tensor(0) - std = torch.tensor(255) + std = torch.tensor(1) def __init__( self, diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index f5b13128842..0dfb5b9a009 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,16 +3,69 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union +import kornia.augmentation as K import torch import torchvision -from torch import Tensor +from einops import rearrange +from torch import Generator, Tensor +from torch.nn.modules import Module +from torch.utils.data import random_split from ..datasets import VHR10 from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import dataset_split + +# kornia DataKeys must accept MASKS + + +class _AugPipe(Module): + """Pipeline for applying augmentations sequentially on select data keys.""" + + def __init__( + self, augs: Callable[[Dict[str, Any]], Dict[str, Any]], batch_size: int + ) -> None: + """Initialize a new _AugPipe instance. + + Args: + augs: Augmentations to apply. + batch_size: Batch size + """ + super().__init__() + self.augs = augs + self.batch_size = batch_size + + def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Apply the augmentation. + + Args: + batch: Input batch. + + Returns: + Augmented batch. + """ + batch_len = len(batch["image"]) + for bs in range(batch_len): + batch_dict = { + "image": batch["image"][bs], + "masks": batch["masks"][bs], + "labels": batch["labels"][bs], + "boxes": batch["boxes"][bs], + } + + batch_dict = self.augs(batch_dict) + + batch["image"][bs] = batch_dict["image"] + batch["masks"][bs] = batch_dict["masks"] + batch["labels"][bs] = batch_dict["labels"] + batch["boxes"][bs] = batch_dict["boxes"] + + # Stack images + batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") + + return batch def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: @@ -25,11 +78,10 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: batch dict output """ output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) + output["image"] = [sample["image"] for sample in batch] output["boxes"] = [sample["boxes"] for sample in batch] output["labels"] = [sample["labels"] for sample in batch] - if "masks" in batch[0]: - output["masks"] = [sample["masks"] for sample in batch] + output["masks"] = [sample["masks"] for sample in batch] return output @@ -66,7 +118,27 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.collate_fn = collate_fn - self.kwargs["transforms"] = self.preprocess + + self.train_aug = _AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(self.patch_size), + K.RandomHorizontalFlip(), + K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7), + K.RandomVerticalFlip(), + data_keys=["image", "boxes", "masks"], + ), + batch_size, + ) + self.aug = _AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(self.patch_size), + data_keys=["image", "boxes", "masks"], + ), + batch_size, + ) + self.kwargs["download"] = True def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -105,6 +177,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.dataset = VHR10(**self.kwargs) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + train_split_pct = 1.0 - (self.val_split_pct + self.test_split_pct) + self.train_dataset, self.val_dataset, self.test_dataset = random_split( + self.dataset, + [train_split_pct, self.val_split_pct, self.test_split_pct], + Generator().manual_seed(torch.initial_seed()), ) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index db0807ee930..fc551e409af 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -91,8 +91,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if "segmentation" in anno[0]: segmentations = [obj["segmentation"] for obj in anno] - else: - segmentations = [] + masks = convert_coco_poly_to_mask(segmentations, h, w) keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) @@ -258,8 +257,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample = self.coco_convert(sample) sample["labels"] = sample["label"]["labels"] sample["boxes"] = sample["label"]["boxes"] - if "masks" in sample["label"]: - sample["masks"] = sample["label"]["masks"] + sample["masks"] = sample["label"]["masks"] del sample["label"] if self.transforms is not None: @@ -298,6 +296,7 @@ def _load_image(self, id_: int) -> Tensor: tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) + tensor = tensor.float() / tensor.max() return tensor def _load_target(self, id_: int) -> dict[str, Any]: diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 0a0d3b6ac8e..5b78af93a10 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -22,6 +22,9 @@ from ..datasets.utils import RGBBandsMissingError, unbind_samples from .base import BaseTask +plt.switch_backend("agg") + + BACKBONE_LAT_DIM_MAP = { "resnet18": 512, "resnet34": 512, diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index f09a52050ff..a4a4218d759 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -10,6 +10,7 @@ from einops import rearrange from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices +from kornia.geometry.boxes import Boxes from torch import Tensor from torch.nn.modules import Module @@ -67,10 +68,20 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: dtype[key] = batch[key].dtype batch[key] = batch[key].float() + # Convert shape of boxes from [N, 4] to [N, 4, 2] + if "boxes" in batch and ( + isinstance(batch["boxes"], list) or batch["boxes"].ndim == 2 + ): + # batch["boxes"] = batch["boxes"].unsqueeze(0) + batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data + # Kornia requires masks to have a channel dimension if "mask" in batch and len(batch["mask"].shape) == 3: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + if "masks" in batch and len(batch["masks"].shape) == 3: + batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w") + inputs = [batch[k] for k in self.data_keys] outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs) outputs_list = ( @@ -85,9 +96,17 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: for key in self.data_keys: batch[key] = batch[key].to(dtype[key]) + # Convert boxes to default [N, 4] + if "boxes" in batch: + batch["boxes"] = Boxes(batch["boxes"]).to_tensor( + mode="xyxy" + ) # type:ignore[assignment] + # Torchmetrics does not support masks with a channel dimension if "mask" in batch and batch["mask"].shape[1] == 1: batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + if "masks" in batch and len(batch["masks"].shape) == 4: + batch["masks"] = rearrange(batch["masks"], "() c h w -> c h w") return batch From f273f03340f3c050c75bbeb24b868b05d6f4fb97 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 10 Apr 2023 10:58:35 +0400 Subject: [PATCH 06/32] Remove auto_lr_find --- conf/vhr10.yaml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index d4bdf2ac1c7..1da780e19d3 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -7,7 +7,6 @@ trainer: devices: 1 min_epochs: 5 max_epochs: 100 - auto_lr_find: False benchmark: True experiment: @@ -22,6 +21,6 @@ experiment: verbose: false datamodule: root: "data/vhr10" - batch_size: 3 - num_workers: 12 + batch_size: 2 + num_workers: 56 patch_size: 512 From f4a96a8799484d8b3c2c9dd6593ad2810cd4a9a1 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 11 Apr 2023 15:53:19 +0400 Subject: [PATCH 07/32] Remove preprocess --- torchgeo/datamodules/vhr10.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 0dfb5b9a009..1db0154e1d1 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -7,10 +7,9 @@ import kornia.augmentation as K import torch -import torchvision from einops import rearrange from torch import Generator, Tensor -from torch.nn.modules import Module +from torch.nn import Module from torch.utils.data import random_split from ..datasets import VHR10 @@ -139,37 +138,6 @@ def __init__( batch_size, ) - self.kwargs["download"] = True - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - - _, h, w = sample["image"].shape - sample["image"] = torchvision.transforms.functional.resize( - sample["image"], size=self.patch_size - ) - box_scale = (self.patch_size[1] / w, self.patch_size[0] / h) - sample["boxes"][:, 0] *= box_scale[0] - sample["boxes"][:, 1] *= box_scale[1] - sample["boxes"][:, 2] *= box_scale[0] - sample["boxes"][:, 3] *= box_scale[1] - sample["boxes"] = torch.round(sample["boxes"]) - - if "masks" in sample: - sample["masks"] = torchvision.transforms.functional.resize( - sample["masks"], size=self.patch_size - ) - - return sample - def setup(self, stage: str) -> None: """Set up datasets. From fec3449dd6f6869566fca6b33b23a742ef83071f Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 4 May 2023 15:46:35 +0400 Subject: [PATCH 08/32] Update config --- conf/vhr10.yaml | 40 +++++++++++++++++++--------------------- tests/conf/vhr10.yaml | 30 ++++++++++++++++-------------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index 1da780e19d3..f6cab82d594 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -1,26 +1,24 @@ -program: - seed: 0 - overwrite: True +module: + _target_: torchgeo.trainers.ObjectDetectionTask + model: "faster-rcnn" + backbone: "resnet50" + num_classes: 11 + learning_rate: 1.6e-5 + learning_rate_schedule_patience: 6 + verbose: false + +datamodule: + _target_: torchgeo.datamodules.VHR10DataModule + root: "data/vhr10" + batch_size: 4 + num_workers: 56 + patch_size: 512 + val_split_pct: 0.2 trainer: + _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 5 - max_epochs: 100 - benchmark: True + min_epochs: 15 + max_epochs: 40 -experiment: - task: "vhr10" - name: "vhr10_test" - module: - model: "faster-rcnn" - backbone: "resnet50" - num_classes: 11 - learning_rate: 1.6e-5 - learning_rate_schedule_patience: 6 - verbose: false - datamodule: - root: "data/vhr10" - batch_size: 2 - num_workers: 56 - patch_size: 512 diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index f500d21b9e4..8d4ef78c6ba 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -1,14 +1,16 @@ -experiment: - task: "vhr10" - module: - model: "faster-rcnn" - backbone: "resnet18" - num_classes: 11 - learning_rate: 1e-4 - learning_rate_schedule_patience: 6 - verbose: false - datamodule: - root: "tests/data/vhr10" - batch_size: 1 - num_workers: 0 - patch_size: 4 +module: + _target_: torchgeo.trainers.ObjectDetectionTask + model: "faster-rcnn" + backbone: "resnet18" + num_classes: 2 + learning_rate: 1.2e-4 + learning_rate_schedule_patience: 6 + verbose: false + +datamodule: + _target_: torchgeo.datamodules.VHR10DataModule + root: "tests/data/vhr10" + download: true + batch_size: 1 + num_workers: 0 + patch_size: 4 From 738d996358e6b1db422ffa44f453848113faf797 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 4 May 2023 15:58:28 +0400 Subject: [PATCH 09/32] Remove setting of matplotlib backend --- torchgeo/trainers/detection.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 5b78af93a10..0a0d3b6ac8e 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -22,9 +22,6 @@ from ..datasets.utils import RGBBandsMissingError, unbind_samples from .base import BaseTask -plt.switch_backend("agg") - - BACKBONE_LAT_DIM_MAP = { "resnet18": 512, "resnet34": 512, From 97bb2347340036e7aac35df29b4a198727dbcac6 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 10 May 2023 16:17:17 +0400 Subject: [PATCH 10/32] Remove import --- tests/trainers/test_detection.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 681b0056340..e4151ac0d29 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -12,11 +12,7 @@ from pytest import MonkeyPatch from torch.nn.modules import Module -from torchgeo.datamodules import ( - MisconfigurationException, - NASAMarineDebrisDataModule, - VHR10DataModule, -) +from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris, RGBBandsMissingError from torchgeo.main import main from torchgeo.trainers import ObjectDetectionTask From 54a31164486c1ce332bf3a31e4c6736017787339 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 10 May 2023 16:25:53 +0400 Subject: [PATCH 11/32] Typing update --- torchgeo/datamodules/vhr10.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 1db0154e1d1..67a59b7ba80 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,7 +3,7 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Union import kornia.augmentation as K import torch @@ -24,7 +24,7 @@ class _AugPipe(Module): """Pipeline for applying augmentations sequentially on select data keys.""" def __init__( - self, augs: Callable[[Dict[str, Any]], Dict[str, Any]], batch_size: int + self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int ) -> None: """Initialize a new _AugPipe instance. @@ -36,7 +36,7 @@ def __init__( self.augs = augs self.batch_size = batch_size - def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Apply the augmentation. Args: @@ -67,7 +67,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: return batch -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: +def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom object detection collate fn to handle variable boxes. Args: @@ -76,7 +76,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: Returns: batch dict output """ - output: Dict[str, Any] = {} + output: dict[str, Any] = {} output["image"] = [sample["image"] for sample in batch] output["boxes"] = [sample["boxes"] for sample in batch] output["labels"] = [sample["labels"] for sample in batch] @@ -93,7 +93,7 @@ class VHR10DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 512, + patch_size: Union[tuple[int, int], int] = 512, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2, From 5bdb57a8602b56191b81cec6755a58510ae00a45 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 29 Jun 2023 12:48:05 +0400 Subject: [PATCH 12/32] Key fix --- torchgeo/transforms/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index a4a4218d759..b49539795a9 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -48,6 +48,8 @@ def __init__( keys.append("input") elif key == "boxes": keys.append("bbox") + elif key == "masks": + keys.append("mask") else: keys.append(key) From bb4fd46a29d48d500f09bceae51d613f583ba8f4 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 29 Jun 2023 15:24:41 +0400 Subject: [PATCH 13/32] Coverage fix --- torchgeo/datamodules/vhr10.py | 2 -- torchgeo/datasets/vhr10.py | 5 +---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 67a59b7ba80..abc2957b8ed 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -17,8 +17,6 @@ from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -# kornia DataKeys must accept MASKS - class _AugPipe(Module): """Pipeline for applying augmentations sequentially on select data keys.""" diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index fc551e409af..5a58f3c4828 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -45,10 +45,7 @@ def convert_coco_poly_to_mask( mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) - if masks: - masks_tensor = torch.stack(masks, dim=0) - else: - masks_tensor = torch.zeros((0, height, width), dtype=torch.uint8) + masks_tensor = torch.stack(masks, dim=0) return masks_tensor From f31c82f83f59b87d5aedd7987597159f2d674e1c Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 24 Oct 2023 14:59:53 +0400 Subject: [PATCH 14/32] Update conf --- conf/vhr10.yaml | 41 +++++++++++++++++++++-------------------- tests/conf/vhr10.yaml | 33 ++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index f6cab82d594..6ae4af3e10b 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -1,24 +1,25 @@ -module: - _target_: torchgeo.trainers.ObjectDetectionTask - model: "faster-rcnn" - backbone: "resnet50" - num_classes: 11 - learning_rate: 1.6e-5 - learning_rate_schedule_patience: 6 - verbose: false - -datamodule: - _target_: torchgeo.datamodules.VHR10DataModule - root: "data/vhr10" - batch_size: 4 - num_workers: 56 - patch_size: 512 - val_split_pct: 0.2 - trainer: - _target_: lightning.pytorch.Trainer accelerator: gpu devices: 1 - min_epochs: 15 - max_epochs: 40 + min_epochs: 15 #2 + max_epochs: 100 #4 + default_root_dir: output/vhr10 + +model: + class_path: ObjectDetectionTask + init_args: + model: "faster-rcnn" + backbone: "resnet50" + num_classes: 11 + lr: 3.0e-5 + patience: 20 +data: + class_path: VHR10DataModule + init_args: + batch_size: 4 + num_workers: 56 #0 + patch_size: 512 + val_split_pct: 0.2 + dict_kwargs: + root: "data/vhr10" \ No newline at end of file diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index 8d4ef78c6ba..f630b26999b 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -1,16 +1,19 @@ -module: - _target_: torchgeo.trainers.ObjectDetectionTask - model: "faster-rcnn" - backbone: "resnet18" - num_classes: 2 - learning_rate: 1.2e-4 - learning_rate_schedule_patience: 6 - verbose: false +model: + class_path: ObjectDetectionTask + init_args: + model: "faster-rcnn" + backbone: "resnet50" + num_classes: 11 + lr: 2.5e-5 + patience: 10 + +data: + class_path: VHR10DataModule + init_args: + batch_size: 1 + num_workers: 0 + patch_size: 4 + dict_kwargs: + root: "tests/data/vhr10" + -datamodule: - _target_: torchgeo.datamodules.VHR10DataModule - root: "tests/data/vhr10" - download: true - batch_size: 1 - num_workers: 0 - patch_size: 4 From bf69bac84b80e71cb420c47a556e4acbc9de38cf Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 24 Oct 2023 15:08:42 +0400 Subject: [PATCH 15/32] Update conf --- conf/vhr10.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index 6ae4af3e10b..c1e0eba0330 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -1,8 +1,8 @@ trainer: accelerator: gpu devices: 1 - min_epochs: 15 #2 - max_epochs: 100 #4 + min_epochs: 15 + max_epochs: 100 default_root_dir: output/vhr10 model: @@ -18,8 +18,8 @@ data: class_path: VHR10DataModule init_args: batch_size: 4 - num_workers: 56 #0 + num_workers: 56 patch_size: 512 val_split_pct: 0.2 dict_kwargs: - root: "data/vhr10" \ No newline at end of file + root: "data/vhr10" From 4cd45689d301df19b2f1da02c08f58a2ff25ce7c Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 24 Oct 2023 16:10:34 +0400 Subject: [PATCH 16/32] Dowload=True --- tests/conf/vhr10.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index f630b26999b..860e4939534 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -15,5 +15,6 @@ data: patch_size: 4 dict_kwargs: root: "tests/data/vhr10" + download: true From 9d48f1ffe19f44319c0d25c7294ee5bcf8380df2 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 25 Oct 2023 16:38:10 +0400 Subject: [PATCH 17/32] Use weights --- conf/vhr10.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml index c1e0eba0330..be836a4a7e0 100644 --- a/conf/vhr10.yaml +++ b/conf/vhr10.yaml @@ -10,6 +10,7 @@ model: init_args: model: "faster-rcnn" backbone: "resnet50" + weights: True num_classes: 11 lr: 3.0e-5 patience: 20 From 2bdbf0fe6744d367b224fb1ca4890886bc62716b Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 30 Oct 2023 10:09:12 +0400 Subject: [PATCH 18/32] Empty commit From 3efa9beeff77cda50a40f6f7811c18b2571379b5 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 6 Nov 2023 14:43:34 +0400 Subject: [PATCH 19/32] Switch to ndim --- torchgeo/transforms/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index b49539795a9..12c707ff984 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -78,10 +78,10 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data # Kornia requires masks to have a channel dimension - if "mask" in batch and len(batch["mask"].shape) == 3: + if "mask" in batch and batch["mask"].ndim == 3: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - if "masks" in batch and len(batch["masks"].shape) == 3: + if "masks" in batch and batch["masks"].ndim == 3: batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w") inputs = [batch[k] for k in self.data_keys] @@ -107,7 +107,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Torchmetrics does not support masks with a channel dimension if "mask" in batch and batch["mask"].shape[1] == 1: batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - if "masks" in batch and len(batch["masks"].shape) == 4: + if "masks" in batch and batch["masks"].ndim == 4: batch["masks"] = rearrange(batch["masks"], "() c h w -> c h w") return batch From e3417e821376977be137549ceca46f0c826f04bf Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 9 Nov 2023 17:34:58 +0400 Subject: [PATCH 20/32] Remove conf, tight_layout and spacing --- conf/vhr10.yaml | 26 -------------------------- pyproject.toml | 1 - tests/conf/vhr10.yaml | 2 -- torchgeo/datasets/vhr10.py | 6 +++--- 4 files changed, 3 insertions(+), 32 deletions(-) delete mode 100644 conf/vhr10.yaml diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml deleted file mode 100644 index be836a4a7e0..00000000000 --- a/conf/vhr10.yaml +++ /dev/null @@ -1,26 +0,0 @@ -trainer: - accelerator: gpu - devices: 1 - min_epochs: 15 - max_epochs: 100 - default_root_dir: output/vhr10 - -model: - class_path: ObjectDetectionTask - init_args: - model: "faster-rcnn" - backbone: "resnet50" - weights: True - num_classes: 11 - lr: 3.0e-5 - patience: 20 - -data: - class_path: VHR10DataModule - init_args: - batch_size: 4 - num_workers: 56 - patch_size: 512 - val_split_pct: 0.2 - dict_kwargs: - root: "data/vhr10" diff --git a/pyproject.toml b/pyproject.toml index 87df577370b..9fce633e79b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -266,7 +266,6 @@ filterwarnings = [ # Unexpected warnings, worth investigating # Lightning is having trouble inferring the batch size for ChesapeakeCVPRDataModule and CycloneDataModule for some reason "ignore:Trying to infer the `batch_size` from an ambiguous collection:UserWarning", - "ignore: Tight layout not applied. The bottom and top margins cannot be made large enough:UserWarning", # https://github.com/pytest-dev/pytest/issues/11461 "ignore::pytest.PytestUnraisableExceptionWarning", ] diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index 860e4939534..6e91f9d6ac5 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -16,5 +16,3 @@ data: dict_kwargs: root: "tests/data/vhr10" download: true - - diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 5a58f3c4828..ada4e3fdab9 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -435,7 +435,9 @@ def plot( ncols += 1 # Display image - fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10)) + fig, axs = plt.subplots( + ncols=ncols, squeeze=False, layout="constrained", figsize=(ncols * 10, 10) + ) axs[0, 0].imshow(image) axs[0, 0].axis("off") @@ -532,8 +534,6 @@ def plot( if show_titles: axs[0, 1].set_title("Prediction") - plt.tight_layout() - if suptitle is not None: plt.suptitle(suptitle) From d4700cc96031d12c1343b47ac63aa2e3b39e8698 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 9 Nov 2023 22:01:11 +0400 Subject: [PATCH 21/32] Set constrained layout via rcParams --- torchgeo/datasets/vhr10.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index ada4e3fdab9..b87627a2acd 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -22,6 +22,8 @@ download_url, ) +plt.rcParams["figure.constrained_layout.use"] = True + def convert_coco_poly_to_mask( segmentations: list[int], height: int, width: int @@ -86,8 +88,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: categories = [obj["category_id"] for obj in anno] classes = torch.tensor(categories, dtype=torch.int64) - if "segmentation" in anno[0]: - segmentations = [obj["segmentation"] for obj in anno] + segmentations = [obj["segmentation"] for obj in anno] masks = convert_coco_poly_to_mask(segmentations, h, w) @@ -435,9 +436,7 @@ def plot( ncols += 1 # Display image - fig, axs = plt.subplots( - ncols=ncols, squeeze=False, layout="constrained", figsize=(ncols * 10, 10) - ) + fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10)) axs[0, 0].imshow(image) axs[0, 0].axis("off") @@ -471,7 +470,7 @@ def plot( # Add masks if show_feats in {"masks", "both"} and "masks" in sample: mask = masks[i] - contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] + contours = find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( @@ -523,7 +522,7 @@ def plot( # Add masks if show_pred_masks: mask = prediction_masks[i] - contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] + contours = find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( From 756143f75e168db0278defcfcc9d09492f9c41d0 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 9 Nov 2023 22:14:14 +0400 Subject: [PATCH 22/32] Revert and bump min matplotlib version --- pyproject.toml | 4 ++-- requirements/min-reqs.old | 2 +- torchgeo/datasets/vhr10.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9fce633e79b..c243c42bf2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ dependencies = [ "lightly>=1.4.4", # lightning 2+ required for LightningCLI args + sys.argv support "lightning[pytorch-extra]>=2", - # matplotlib 3.3.3+ required for Python 3.9 wheels - "matplotlib>=3.3.3", + # matplotlib 3.5.0+ required for setting layout in figure + "matplotlib>=3.5.0", # numpy 1.19.3+ required by Python 3.9 wheels "numpy>=1.19.3", # pandas 1.1.3+ required for Python 3.9 wheels diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 34b22649ceb..9e9197f7c90 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -7,7 +7,7 @@ fiona==1.8.19 kornia==0.6.9 lightly==1.4.4 lightning[pytorch-extra]==2.0.0 -matplotlib==3.3.3 +matplotlib==3.5.0 numpy==1.19.3 pandas==1.1.3 pillow==8.0.0 diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index b87627a2acd..fb69b735b99 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -22,8 +22,6 @@ download_url, ) -plt.rcParams["figure.constrained_layout.use"] = True - def convert_coco_poly_to_mask( segmentations: list[int], height: int, width: int @@ -436,7 +434,9 @@ def plot( ncols += 1 # Display image - fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10)) + fig, axs = plt.subplots( + ncols=ncols, squeeze=False, layout="constrained", figsize=(ncols * 10, 10) + ) axs[0, 0].imshow(image) axs[0, 0].axis("off") @@ -470,7 +470,7 @@ def plot( # Add masks if show_feats in {"masks", "both"} and "masks" in sample: mask = masks[i] - contours = find_contours(mask, 0.5) + contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( @@ -522,7 +522,7 @@ def plot( # Add masks if show_pred_masks: mask = prediction_masks[i] - contours = find_contours(mask, 0.5) + contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( From f750cf42e1e1b42dc6527c35534f3001ca0bfb48 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 10 Nov 2023 17:27:36 +0400 Subject: [PATCH 23/32] Switch back to dataset_split --- requirements/min-reqs.old | 4 ++-- torchgeo/datamodules/vhr10.py | 12 ++++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 9e9197f7c90..c95ba5f0a72 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -17,9 +17,9 @@ rtree==1.0.0 segmentation-models-pytorch==0.2.0 shapely==1.7.1 timm==0.4.12 -torch==1.13.0 +torch==1.12.0 torchmetrics==0.10.0 -torchvision==0.14.0 +torchvision==0.13.0 # datasets h5py==3.0.0 diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index abc2957b8ed..f87f20ba174 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -6,16 +6,15 @@ from typing import Any, Callable, Union import kornia.augmentation as K -import torch from einops import rearrange -from torch import Generator, Tensor +from torch import Tensor from torch.nn import Module -from torch.utils.data import random_split from ..datasets import VHR10 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from .geo import NonGeoDataModule +from .utils import dataset_split class _AugPipe(Module): @@ -143,9 +142,6 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.dataset = VHR10(**self.kwargs) - train_split_pct = 1.0 - (self.val_split_pct + self.test_split_pct) - self.train_dataset, self.val_dataset, self.test_dataset = random_split( - self.dataset, - [train_split_pct, self.val_split_pct, self.test_split_pct], - Generator().manual_seed(torch.initial_seed()), + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + self.dataset, self.val_split_pct, self.test_split_pct ) From b8f01660708a11b13f618332f38823a9ff1f75a6 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 15 Nov 2023 17:52:04 +0400 Subject: [PATCH 24/32] Separate out AugPipe --- torchgeo/datamodules/nasa_marine_debris.py | 15 +++- torchgeo/datamodules/utils.py | 57 ++++++++++++++- torchgeo/datamodules/vhr10.py | 85 +++++++++++----------- 3 files changed, 108 insertions(+), 49 deletions(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index f740df5eb76..a8e68dea50f 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -5,12 +5,14 @@ from typing import Any +import kornia.augmentation as K import torch from torch import Tensor from ..datasets import NASAMarineDebris +from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import dataset_split +from .utils import AugPipe, dataset_split def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: @@ -23,8 +25,8 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: batch dict output """ output: dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] + output["image"] = [sample["image"] for sample in batch] + output["boxes"] = [sample["boxes"].float() for sample in batch] output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] return output @@ -58,6 +60,13 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct + self.aug = AugPipe( + AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"] + ), + batch_size, + ) + self.collate_fn = collate_fn def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d0bb6af9934..106aea07c67 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -5,10 +5,12 @@ import math from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import numpy as np -from torch import Generator +from einops import rearrange +from torch import Generator, Tensor +from torch.nn import Module from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import NonGeoDataset @@ -19,6 +21,57 @@ class MisconfigurationException(Exception): """Exception used to inform users of misuse with Lightning.""" +class AugPipe(Module): + """Pipeline for applying augmentations sequentially on select data keys.""" + + def __init__( + self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int + ) -> None: + """Initialize a new AugPipe instance. + + Args: + augs: Augmentations to apply. + batch_size: Batch size + """ + super().__init__() + self.augs = augs + self.batch_size = batch_size + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply the augmentation. + + Args: + batch: Input batch. + + Returns: + Augmented batch. + """ + batch_len = len(batch["image"]) + for bs in range(batch_len): + batch_dict = { + "image": batch["image"][bs], + "labels": batch["labels"][bs], + "boxes": batch["boxes"][bs], + } + + if "masks" in batch: + batch_dict["masks"] = batch["masks"][bs] + + batch_dict = self.augs(batch_dict) + + batch["image"][bs] = batch_dict["image"] + batch["labels"][bs] = batch_dict["labels"] + batch["boxes"][bs] = batch_dict["boxes"] + + if "masks" in batch: + batch["masks"][bs] = batch_dict["masks"] + + # Stack images + batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") + + return batch + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index f87f20ba174..b0b2e6ab71e 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,65 +3,62 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Callable, Union +from typing import Any, Union import kornia.augmentation as K -from einops import rearrange from torch import Tensor -from torch.nn import Module from ..datasets import VHR10 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import dataset_split +from .utils import AugPipe, dataset_split +# class _AugPipe(Module): +# """Pipeline for applying augmentations sequentially on select data keys.""" -class _AugPipe(Module): - """Pipeline for applying augmentations sequentially on select data keys.""" +# def __init__( +# self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int +# ) -> None: +# """Initialize a new _AugPipe instance. - def __init__( - self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int - ) -> None: - """Initialize a new _AugPipe instance. +# Args: +# augs: Augmentations to apply. +# batch_size: Batch size +# """ +# super().__init__() +# self.augs = augs +# self.batch_size = batch_size - Args: - augs: Augmentations to apply. - batch_size: Batch size - """ - super().__init__() - self.augs = augs - self.batch_size = batch_size +# def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: +# """Apply the augmentation. - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Apply the augmentation. +# Args: +# batch: Input batch. - Args: - batch: Input batch. - - Returns: - Augmented batch. - """ - batch_len = len(batch["image"]) - for bs in range(batch_len): - batch_dict = { - "image": batch["image"][bs], - "masks": batch["masks"][bs], - "labels": batch["labels"][bs], - "boxes": batch["boxes"][bs], - } +# Returns: +# Augmented batch. +# """ +# batch_len = len(batch["image"]) +# for bs in range(batch_len): +# batch_dict = { +# "image": batch["image"][bs], +# "masks": batch["masks"][bs], +# "labels": batch["labels"][bs], +# "boxes": batch["boxes"][bs], +# } - batch_dict = self.augs(batch_dict) +# batch_dict = self.augs(batch_dict) - batch["image"][bs] = batch_dict["image"] - batch["masks"][bs] = batch_dict["masks"] - batch["labels"][bs] = batch_dict["labels"] - batch["boxes"][bs] = batch_dict["boxes"] +# batch["image"][bs] = batch_dict["image"] +# batch["masks"][bs] = batch_dict["masks"] +# batch["labels"][bs] = batch_dict["labels"] +# batch["boxes"][bs] = batch_dict["boxes"] - # Stack images - batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") +# # Stack images +# batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") - return batch +# return batch def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: @@ -84,7 +81,7 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: class VHR10DataModule(NonGeoDataModule): """LightningDataModule implementation for the VHR10 dataset. - .. versionadded:: 0.5 + .. versionadded:: 0.6 """ def __init__( @@ -115,7 +112,7 @@ def __init__( self.collate_fn = collate_fn - self.train_aug = _AugPipe( + self.train_aug = AugPipe( AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(self.patch_size), @@ -126,7 +123,7 @@ def __init__( ), batch_size, ) - self.aug = _AugPipe( + self.aug = AugPipe( AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(self.patch_size), From 0f135145a5969ef1333c427a3ab449d053f34dcb Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 21 Nov 2023 12:03:14 +0400 Subject: [PATCH 25/32] Increase figsize & revert matplotlib --- pyproject.toml | 4 +-- requirements/min-reqs.old | 2 +- tests/conf/vhr10.yaml | 1 - torchgeo/datamodules/vhr10.py | 46 ----------------------------------- torchgeo/datasets/vhr10.py | 6 ++--- 5 files changed, 6 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c243c42bf2c..9fce633e79b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ dependencies = [ "lightly>=1.4.4", # lightning 2+ required for LightningCLI args + sys.argv support "lightning[pytorch-extra]>=2", - # matplotlib 3.5.0+ required for setting layout in figure - "matplotlib>=3.5.0", + # matplotlib 3.3.3+ required for Python 3.9 wheels + "matplotlib>=3.3.3", # numpy 1.19.3+ required by Python 3.9 wheels "numpy>=1.19.3", # pandas 1.1.3+ required for Python 3.9 wheels diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index c95ba5f0a72..c19e625314b 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -7,7 +7,7 @@ fiona==1.8.19 kornia==0.6.9 lightly==1.4.4 lightning[pytorch-extra]==2.0.0 -matplotlib==3.5.0 +matplotlib==3.3.3 numpy==1.19.3 pandas==1.1.3 pillow==8.0.0 diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index 6e91f9d6ac5..0ea0909b971 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -6,7 +6,6 @@ model: num_classes: 11 lr: 2.5e-5 patience: 10 - data: class_path: VHR10DataModule init_args: diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index b0b2e6ab71e..db46e519ce9 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -14,52 +14,6 @@ from .geo import NonGeoDataModule from .utils import AugPipe, dataset_split -# class _AugPipe(Module): -# """Pipeline for applying augmentations sequentially on select data keys.""" - -# def __init__( -# self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int -# ) -> None: -# """Initialize a new _AugPipe instance. - -# Args: -# augs: Augmentations to apply. -# batch_size: Batch size -# """ -# super().__init__() -# self.augs = augs -# self.batch_size = batch_size - -# def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: -# """Apply the augmentation. - -# Args: -# batch: Input batch. - -# Returns: -# Augmented batch. -# """ -# batch_len = len(batch["image"]) -# for bs in range(batch_len): -# batch_dict = { -# "image": batch["image"][bs], -# "masks": batch["masks"][bs], -# "labels": batch["labels"][bs], -# "boxes": batch["boxes"][bs], -# } - -# batch_dict = self.augs(batch_dict) - -# batch["image"][bs] = batch_dict["image"] -# batch["masks"][bs] = batch_dict["masks"] -# batch["labels"][bs] = batch_dict["labels"] -# batch["boxes"][bs] = batch_dict["boxes"] - -# # Stack images -# batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") - -# return batch - def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom object detection collate fn to handle variable boxes. diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index fb69b735b99..d99aa0711a9 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -434,9 +434,7 @@ def plot( ncols += 1 # Display image - fig, axs = plt.subplots( - ncols=ncols, squeeze=False, layout="constrained", figsize=(ncols * 10, 10) - ) + fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13)) axs[0, 0].imshow(image) axs[0, 0].axis("off") @@ -536,4 +534,6 @@ def plot( if suptitle is not None: plt.suptitle(suptitle) + plt.tight_layout() + return fig From d976e5f10167a1d83b6f9c04dac7a8ec04a59c8f Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 21 Nov 2023 17:20:10 +0400 Subject: [PATCH 26/32] Common collate_fn --- torchgeo/datamodules/nasa_marine_debris.py | 23 ++++---------------- torchgeo/datamodules/utils.py | 25 ++++++++++++++++++++++ torchgeo/datamodules/vhr10.py | 25 +++++----------------- torchgeo/datasets/vhr10.py | 3 +-- 4 files changed, 35 insertions(+), 41 deletions(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index a8e68dea50f..9c4d5e157fe 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -7,28 +7,11 @@ import kornia.augmentation as K import torch -from torch import Tensor from ..datasets import NASAMarineDebris from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, dataset_split - - -def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: dict[str, Any] = {} - output["image"] = [sample["image"] for sample in batch] - output["boxes"] = [sample["boxes"].float() for sample in batch] - output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] - return output +from .utils import AugPipe, collate_fn_detection, dataset_split class NASAMarineDebrisDataModule(NonGeoDataModule): @@ -57,6 +40,8 @@ def __init__( """ super().__init__(NASAMarineDebris, batch_size, num_workers, **kwargs) + self.std = torch.tensor(255) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct @@ -67,7 +52,7 @@ def __init__( batch_size, ) - self.collate_fn = collate_fn + self.collate_fn = collate_fn_detection def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 106aea07c67..9e665282409 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Optional, Union import numpy as np +import torch from einops import rearrange from torch import Generator, Tensor from torch.nn import Module @@ -72,6 +73,30 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: return batch +def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: + """Custom collate fn for object detection & instance segmentation. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: dict[str, Any] = {} + output["image"] = [sample["image"] for sample in batch] + output["boxes"] = [sample["boxes"].float() for sample in batch] + if "labels" in batch[0]: + output["labels"] = [sample["labels"] for sample in batch] + else: + output["labels"] = [ + torch.tensor([1] * len(sample["boxes"])) for sample in batch + ] + + if "masks" in batch[0]: + output["masks"] = [sample["masks"] for sample in batch] + return output + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index db46e519ce9..01461e9c876 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -6,30 +6,13 @@ from typing import Any, Union import kornia.augmentation as K -from torch import Tensor +import torch from ..datasets import VHR10 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, dataset_split - - -def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: dict[str, Any] = {} - output["image"] = [sample["image"] for sample in batch] - output["boxes"] = [sample["boxes"] for sample in batch] - output["labels"] = [sample["labels"] for sample in batch] - output["masks"] = [sample["masks"] for sample in batch] - return output +from .utils import AugPipe, collate_fn_detection, dataset_split class VHR10DataModule(NonGeoDataModule): @@ -60,11 +43,13 @@ def __init__( """ super().__init__(VHR10, batch_size, num_workers, **kwargs) + self.std = torch.tensor(255) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct self.patch_size = _to_tuple(patch_size) - self.collate_fn = collate_fn + self.collate_fn = collate_fn_detection self.train_aug = AugPipe( AugmentationSequential( diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index d99aa0711a9..9c3e701c9b7 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -289,10 +289,9 @@ def _load_image(self, id_: int) -> Tensor: ) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) - tensor = tensor.float() / tensor.max() return tensor def _load_target(self, id_: int) -> dict[str, Any]: From 77d3feb0d37a3489478cf76b07786693f831b102 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 21 Nov 2023 17:28:44 +0400 Subject: [PATCH 27/32] Class var std --- torchgeo/datamodules/nasa_marine_debris.py | 4 ++-- torchgeo/datamodules/vhr10.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 9c4d5e157fe..76848bc4e4b 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -20,6 +20,8 @@ class NASAMarineDebrisDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ + std = torch.tensor(255) + def __init__( self, batch_size: int = 64, @@ -40,8 +42,6 @@ def __init__( """ super().__init__(NASAMarineDebris, batch_size, num_workers, **kwargs) - self.std = torch.tensor(255) - self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 01461e9c876..0059d6c71f2 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -21,6 +21,8 @@ class VHR10DataModule(NonGeoDataModule): .. versionadded:: 0.6 """ + std = torch.tensor(255) + def __init__( self, batch_size: int = 64, @@ -43,8 +45,6 @@ def __init__( """ super().__init__(VHR10, batch_size, num_workers, **kwargs) - self.std = torch.tensor(255) - self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct self.patch_size = _to_tuple(patch_size) From 9f21d06129073da15ba497e8341c49cec60d8227 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 23 Nov 2023 16:54:36 +0400 Subject: [PATCH 28/32] Undo std change in BaseDataModule --- torchgeo/datamodules/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 7dba392c6a2..96da7683ba2 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -30,7 +30,7 @@ class BaseDataModule(LightningDataModule): """ mean = torch.tensor(0) - std = torch.tensor(1) + std = torch.tensor(255) def __init__( self, From 61a1d467a444fb0d6d9ce831888b57f464485eec Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 30 Nov 2023 09:53:24 +0400 Subject: [PATCH 29/32] Undo req changes --- requirements/min-reqs.old | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index c19e625314b..34b22649ceb 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -17,9 +17,9 @@ rtree==1.0.0 segmentation-models-pytorch==0.2.0 shapely==1.7.1 timm==0.4.12 -torch==1.12.0 +torch==1.13.0 torchmetrics==0.10.0 -torchvision==0.13.0 +torchvision==0.14.0 # datasets h5py==3.0.0 From 7e88d02828435ed615a89a2a0357da53b038c3a3 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 6 Dec 2023 14:15:59 +0400 Subject: [PATCH 30/32] Remove unused line --- torchgeo/transforms/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 12c707ff984..c3c3d54f56b 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -74,7 +74,6 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: if "boxes" in batch and ( isinstance(batch["boxes"], list) or batch["boxes"].ndim == 2 ): - # batch["boxes"] = batch["boxes"].unsqueeze(0) batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data # Kornia requires masks to have a channel dimension From 30b02d6ef237b97584f0a23f44f8381dc617ab8b Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 12 Dec 2023 12:08:18 +0400 Subject: [PATCH 31/32] Add version strings --- torchgeo/datamodules/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 9e665282409..ae098c9930c 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -23,7 +23,10 @@ class MisconfigurationException(Exception): class AugPipe(Module): - """Pipeline for applying augmentations sequentially on select data keys.""" + """Pipeline for applying augmentations sequentially on select data keys. + + .. versionadded:: 0.6 + """ def __init__( self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int @@ -74,13 +77,15 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: - """Custom collate fn for object detection & instance segmentation. + """Custom collate fn for object detection and instance segmentation. Args: batch: list of sample dicts return by dataset Returns: batch dict output + + .. versionadded:: 0.6 """ output: dict[str, Any] = {} output["image"] = [sample["image"] for sample in batch] From 1f74d9b646bd6eb62195305da24aec0203bc880c Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Tue, 9 Jan 2024 16:20:45 +0400 Subject: [PATCH 32/32] mypy fix --- torchgeo/datasets/vhr10.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 9c3e701c9b7..43756df71a5 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -289,7 +289,8 @@ def _load_image(self, id_: int) -> Tensor: ) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array).float() + tensor = torch.from_numpy(array) + tensor = tensor.float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor