diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b773130cd03..459d3a94557 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,5 +34,5 @@ repos: hooks: - id: mypy args: [--strict, --ignore-missing-imports, --show-error-codes] - additional_dependencies: [torch>=1.11, torchmetrics>=0.7, pytorch-lightning>=1.7, pytest>=6, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0] + additional_dependencies: [torch>=1.11, torchmetrics>=0.10, pytorch-lightning>=1.7, pytest>=6, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0] exclude: (build|data|dist|logo|logs|output)/ diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index e4dd0c304eb..9e8a86ec9cd 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -3382,7 +3382,7 @@ "\n", " # Reorder and rescale the image\n", " image = sample[\"image\"][rgb_indices].permute(1, 2, 0)\n", - " image = torch.clamp(image / 6000, min=0, max=1)\n", + " image = torch.clamp(image / 6000, min=0, max=1).numpy()\n", "\n", " # Plot the image\n", " fig, ax = plt.subplots()\n", diff --git a/environment.yml b/environment.yml index 084f2edf3c1..c82849aaa35 100644 --- a/environment.yml +++ b/environment.yml @@ -47,5 +47,5 @@ dependencies: - setuptools>=42 - sphinx>=4 - timm>=0.4.12 - - torchmetrics>=0.7 + - torchmetrics>=0.10 - zipfile-deflate64>=0.2 diff --git a/requirements/required.txt b/requirements/required.txt index 9cdcb93ce95..9b4e4aa8f74 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -4,20 +4,20 @@ setuptools==65.4.0 # install einops==0.5.0 fiona==1.9a2 -kornia==0.6.7 +kornia==0.6.8 matplotlib==3.6.1 -numpy==1.23.3;python_version>='3.8' +numpy==1.23.4;python_version>='3.8' omegaconf==2.2.3 packaging==21.3 pillow==9.2.0 pyproj==3.4.0;python_version>='3.8' pytorch-lightning==1.7.7 rasterio==1.3.2;python_version>='3.8' -rtree==1.0.0 +rtree==1.0.1 scikit-image>=0.15.0; scikit-learn==1.1.2;python_version>='3.8' segmentation-models-pytorch==0.3.0 -shapely==1.8.4 +shapely==1.8.5.post1 timm==0.4.12 torch==1.12.1 torchmetrics==0.10.0 diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index db529646836..454c52f980b 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -6,7 +6,7 @@ import abc import os import sys -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, cast import fiona import matplotlib.pyplot as plt @@ -634,7 +634,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = [hit.object for hit in hits] + filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits]) sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 07737f6fc4c..7431bebab2d 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -5,7 +5,7 @@ import os import sys -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, cast import fiona import matplotlib.pyplot as plt @@ -343,7 +343,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = [hit.object for hit in hits] + filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits]) sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query} diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index ff75814c75f..f21c7901a97 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -180,10 +180,7 @@ def __str__(self) -> str: def __getstate__( self, - ) -> Tuple[ - Dict[Any, Any], - List[Tuple[int, Tuple[float, float, float, float, float, float], str]], - ]: + ) -> Tuple[Dict[str, Any], List[Tuple[Any, Any, Optional[Any]]]]: """Define how instances are pickled. Returns: @@ -407,7 +404,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = [hit.object for hit in hits] + filepaths = cast(List[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 270a72a3734..4dd6b8aca07 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, cast import matplotlib.pyplot as plt import torch @@ -176,7 +176,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = [hit.object for hit in hits] + filepaths = cast(List[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 8ccd760e577..f8e377a9c95 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -7,7 +7,7 @@ import json import os import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast import fiona import fiona.transform @@ -307,7 +307,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: IndexError: if query is not found in the index """ hits = self.index.intersection(tuple(query), objects=True) - filepaths = [hit.object for hit in hits] + filepaths = cast(List[str], [hit.object for hit in hits]) if not filepaths: raise IndexError(