Skip to content

Commit

Permalink
Merge branch 'main' into vhr10-plot
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 authored Oct 14, 2022
2 parents e76b764 + 48f516f commit 35f34cb
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ repos:
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --show-error-codes]
additional_dependencies: [torch>=1.11, torchmetrics>=0.7, pytorch-lightning>=1.7, pytest>=6, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0]
additional_dependencies: [torch>=1.11, torchmetrics>=0.10, pytorch-lightning>=1.7, pytest>=6, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0]
exclude: (build|data|dist|logo|logs|output)/
2 changes: 1 addition & 1 deletion docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3382,7 +3382,7 @@
"\n",
" # Reorder and rescale the image\n",
" image = sample[\"image\"][rgb_indices].permute(1, 2, 0)\n",
" image = torch.clamp(image / 6000, min=0, max=1)\n",
" image = torch.clamp(image / 6000, min=0, max=1).numpy()\n",
"\n",
" # Plot the image\n",
" fig, ax = plt.subplots()\n",
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ dependencies:
- setuptools>=42
- sphinx>=4
- timm>=0.4.12
- torchmetrics>=0.7
- torchmetrics>=0.10
- zipfile-deflate64>=0.2
8 changes: 4 additions & 4 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@ setuptools==65.4.0
# install
einops==0.5.0
fiona==1.9a2
kornia==0.6.7
kornia==0.6.8
matplotlib==3.6.1
numpy==1.23.3;python_version>='3.8'
numpy==1.23.4;python_version>='3.8'
omegaconf==2.2.3
packaging==21.3
pillow==9.2.0
pyproj==3.4.0;python_version>='3.8'
pytorch-lightning==1.7.7
rasterio==1.3.2;python_version>='3.8'
rtree==1.0.0
rtree==1.0.1
scikit-image>=0.15.0;
scikit-learn==1.1.2;python_version>='3.8'
segmentation-models-pytorch==0.3.0
shapely==1.8.4
shapely==1.8.5.post1
timm==0.4.12
torch==1.12.1
torchmetrics==0.10.0
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import abc
import os
import sys
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, cast

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -634,7 +634,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits])

sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import sys
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, cast

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -343,7 +343,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = cast(List[Dict[str, str]], [hit.object for hit in hits])

sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}

Expand Down
7 changes: 2 additions & 5 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ def __str__(self) -> str:

def __getstate__(
self,
) -> Tuple[
Dict[Any, Any],
List[Tuple[int, Tuple[float, float, float, float, float, float], str]],
]:
) -> Tuple[Dict[str, Any], List[Tuple[Any, Any, Optional[Any]]]]:
"""Define how instances are pickled.
Returns:
Expand Down Expand Up @@ -407,7 +404,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = cast(List[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, cast

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -176,7 +176,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = cast(List[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
import sys
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast

import fiona
import fiona.transform
Expand Down Expand Up @@ -307,7 +307,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = cast(List[str], [hit.object for hit in hits])

if not filepaths:
raise IndexError(
Expand Down

0 comments on commit 35f34cb

Please sign in to comment.