Skip to content

Commit

Permalink
TYP: eliminate some typing issues (#59)
Browse files Browse the repository at this point in the history
* eliminate some mypy errors

* REF: eliminate some typing issues

* avoid xarray internals

* mypy in ci

* more types

* try different mypy command

* fix errors

* coverage fix
  • Loading branch information
martinfleis authored Jan 11, 2024
1 parent 17a1308 commit cd33d6d
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 93 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ jobs:
id: status
run: pytest -v . --cov=xvec --cov-append --cov-report term-missing --cov-report xml --color=yes --report-log pytest-log.jsonl

- name: run mypy
if: contains(matrix.environment-file, 'ci/312.yaml') && contains(matrix.os, 'ubuntu')
run: mypy xvec/ --install-types --ignore-missing-imports --non-interactive

- uses: codecov/codecov-action@v3

- name: Generate and publish the report
Expand Down
1 change: 1 addition & 0 deletions ci/312.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ dependencies:
- geopandas-base
- geodatasets
- pyogrio
- mypy

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ omit = ["xvec/tests/*"]
exclude_lines = [
"except ImportError",
"except PackageNotFoundError",
"if TYPE_CHECKING:"
]

[tool.ruff]
Expand Down
4 changes: 2 additions & 2 deletions xvec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.metadata import PackageNotFoundError, version

from .accessor import XvecAccessor # noqa
from .index import GeometryIndex # noqa
from .accessor import XvecAccessor # noqa: F401
from .index import GeometryIndex # noqa: F401

try:
__version__ = version("xvec")
Expand Down
125 changes: 69 additions & 56 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import Hashable, Mapping, Sequence
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, cast

import numpy as np
import pandas as pd
Expand All @@ -13,6 +13,9 @@
from .index import GeometryIndex
from .zonal import _zonal_stats_iterative, _zonal_stats_rasterize

if TYPE_CHECKING:
from geopandas import GeoDataFrame


@xr.register_dataarray_accessor("xvec")
@xr.register_dataset_accessor("xvec")
Expand All @@ -22,7 +25,7 @@ class XvecAccessor:
Currently works on coordinates with :class:`xvec.GeometryIndex`.
"""

def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> None:
"""xvec init, nothing to be done here."""
self._obj = xarray_obj
self._geom_coords_all = [
Expand All @@ -36,7 +39,9 @@ def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
if self.is_geom_variable(name, has_index=True)
]

def is_geom_variable(self, name: Hashable, has_index: bool = True):
def is_geom_variable(
self, name: Hashable, has_index: bool = True
) -> bool | np.bool_:
"""Check if coordinate variable is composed of :class:`shapely.Geometry`.
Can return all such variables or only those using :class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -208,7 +213,7 @@ def to_crs(
self,
variable_crs: Mapping[Any, Any] | None = None,
**variable_crs_kwargs: Any,
):
) -> xr.DataArray | xr.Dataset:
"""
Transform :class:`shapely.Geometry` objects of a variable to a new coordinate
reference system.
Expand Down Expand Up @@ -313,20 +318,15 @@ def to_crs(
currently wraps :meth:`Dataset.assign_coords <xarray.Dataset.assign_coords>`
or :meth:`DataArray.assign_coords <xarray.DataArray.assign_coords>`.
"""
if variable_crs and variable_crs_kwargs:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
"'.xvec.to_crs'."
)
variable_crs_solved = _resolve_input(
variable_crs, variable_crs_kwargs, "to_crs"
)

_obj = self._obj.copy(deep=False)

if variable_crs_kwargs:
variable_crs = variable_crs_kwargs

transformed = {}

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if not isinstance(self._obj.xindexes[key], GeometryIndex):
raise ValueError(
f"The index '{key}' is not an xvec.GeometryIndex. "
Expand All @@ -335,7 +335,7 @@ def to_crs(
)

data = _obj[key]
data_crs = self._obj.xindexes[key].crs
data_crs = self._obj.xindexes[key].crs # type: ignore

# transformation code taken from geopandas (BSD 3-clause license)
if data_crs is None:
Expand Down Expand Up @@ -374,21 +374,21 @@ def to_crs(
for key, (result, _crs) in transformed.items():
_obj = _obj.assign_coords({key: result})

_obj = _obj.drop_indexes(variable_crs.keys())
_obj = _obj.drop_indexes(variable_crs_solved.keys())

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if crs:
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)

return _obj

def set_crs(
self,
variable_crs: Mapping[Any, Any] | None = None,
allow_override=False,
allow_override: bool = False,
**variable_crs_kwargs: Any,
):
) -> xr.DataArray | xr.Dataset:
"""Set the Coordinate Reference System (CRS) of coordinates backed by
:class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -480,27 +480,21 @@ def set_crs(
transform the geometries to a new CRS, use the :meth:`to_crs`
method.
"""

if variable_crs and variable_crs_kwargs:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
".xvec.set_crs."
)
variable_crs_solved = _resolve_input(
variable_crs, variable_crs_kwargs, "set_crs"
)

_obj = self._obj.copy(deep=False)

if variable_crs_kwargs:
variable_crs = variable_crs_kwargs

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if not isinstance(self._obj.xindexes[key], GeometryIndex):
raise ValueError(
f"The index '{key}' is not an xvec.GeometryIndex. "
"Set the xvec.GeometryIndex using '.xvec.set_geom_indexes' before "
"handling projection information."
)

data_crs = self._obj.xindexes[key].crs
data_crs = self._obj.xindexes[key].crs # type: ignore

if not allow_override and data_crs is not None and not data_crs == crs:
raise ValueError(
Expand All @@ -510,23 +504,23 @@ def set_crs(
"want to transform the geometries, use '.xvec.to_crs' instead."
)

_obj = _obj.drop_indexes(variable_crs.keys())
_obj = _obj.drop_indexes(variable_crs_solved.keys())

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if crs:
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)

return _obj

def query(
self,
coord_name: str,
geometry: shapely.Geometry | Sequence[shapely.Geometry],
predicate: str = None,
distance: float | Sequence[float] = None,
unique=False,
):
predicate: str | None = None,
distance: float | Sequence[float] | None = None,
unique: bool = False,
) -> xr.DataArray | xr.Dataset:
"""Return a subset of a DataArray/Dataset filtered using a spatial query on
:class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -619,12 +613,12 @@ def query(
"""
if isinstance(geometry, shapely.Geometry):
ilocs = self._obj.xindexes[coord_name].sindex.query(
ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
geometry, predicate=predicate, distance=distance
)

else:
_, ilocs = self._obj.xindexes[coord_name].sindex.query(
_, ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
geometry, predicate=predicate, distance=distance
)
if unique:
Expand All @@ -634,11 +628,11 @@ def query(

def set_geom_indexes(
self,
coord_names: str | Sequence[Hashable],
coord_names: str | Sequence[str],
crs: Any = None,
allow_override: bool = False,
**kwargs,
):
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Set a new :class:`~xvec.GeometryIndex` for one or more existing
coordinate(s). One :class:`~xvec.GeometryIndex` is set per coordinate. Only
1-dimensional coordinates are supported.
Expand Down Expand Up @@ -691,7 +685,7 @@ def set_geom_indexes(

for coord in coord_names:
if isinstance(self._obj.xindexes[coord], GeometryIndex):
data_crs = self._obj.xindexes[coord].crs
data_crs = self._obj.xindexes[coord].crs # type: ignore

if not allow_override and data_crs is not None and not data_crs == crs:
raise ValueError(
Expand All @@ -710,7 +704,7 @@ def set_geom_indexes(

return _obj

def to_geopandas(self):
def to_geopandas(self) -> GeoDataFrame | pd.DataFrame:
"""Convert this array into a GeoPandas :class:`~geopandas.GeoDataFrame`
Returns a :class:`~geopandas.GeoDataFrame` with coordinates based on a
Expand Down Expand Up @@ -762,11 +756,11 @@ def to_geopandas(self):
if len(self._geom_indexes):
if self._obj.ndim == 1:
gdf = self._obj.to_pandas()
elif self._obj.ndim == 2:
else:
gdf = self._obj.to_pandas()
if gdf.columns.name == self._geom_indexes[0]:
gdf = gdf.T
return gdf.reset_index().set_geometry(
return gdf.reset_index().set_geometry( # type: ignore
self._geom_indexes[0],
crs=self._obj.xindexes[self._geom_indexes[0]].crs,
)
Expand All @@ -790,7 +784,7 @@ def to_geopandas(self):
if index_name in self._geom_coords_all:
return gdf.reset_index().set_geometry(
index_name, crs=self._obj[index_name].attrs.get("crs", None)
)
) # type: ignore

warnings.warn(
"No active geometry column to be set. The resulting object "
Expand All @@ -810,7 +804,7 @@ def to_geodataframe(
dim_order: Sequence[Hashable] | None = None,
geometry: Hashable | None = None,
long: bool = True,
):
) -> GeoDataFrame | pd.DataFrame:
"""Convert this array and its coordinates into a tidy geopandas.GeoDataFrame.
The GeoDataFrame is indexed by the Cartesian product of index coordinates
Expand Down Expand Up @@ -884,7 +878,7 @@ def to_geodataframe(
level
for level in df.index.names
if level not in self._geom_coords_all
]
] # type: ignore
)

if isinstance(df.index, pd.MultiIndex):
Expand All @@ -907,7 +901,7 @@ def to_geodataframe(
if geometry is not None:
return df.set_geometry(
geometry, crs=self._obj[geometry].attrs.get("crs", None)
)
) # type: ignore

warnings.warn(
"No active geometry column to be set. The resulting object "
Expand All @@ -926,12 +920,12 @@ def zonal_stats(
y_coords: Hashable,
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: Hashable = "geometry",
index: bool = None,
index: bool | None = None,
method: str = "rasterize",
all_touched: bool = False,
n_jobs: int = -1,
**kwargs,
):
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Extract the values from a dataset indexed by a set of geometries
Given an object indexed by x and y coordinates (or latitude and longitude), such
Expand Down Expand Up @@ -1121,9 +1115,9 @@ def extract_points(
y_coords: Hashable,
tolerance: float | None = None,
name: str = "geometry",
crs: Any = None,
index: bool = None,
):
crs: Any | None = None,
index: bool | None = None,
) -> xr.DataArray | xr.Dataset:
"""Extract points from a DataArray or a Dataset indexed by spatial coordinates
Given an object indexed by x and y coordinates (or latitude and longitude), such
Expand Down Expand Up @@ -1263,3 +1257,22 @@ def extract_points(
}
)
return result


def _resolve_input(
positional: Mapping[Any, Any] | None,
keyword: Mapping[str, Any],
func_name: str,
) -> Mapping[Hashable, Any]:
"""Resolve combination of positional and keyword arguments.
Based on xarray's ``either_dict_or_kwargs``.
"""
if positional and keyword:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
f"'.xvec.{func_name}'."
)
if positional is None or positional == {}:
return cast(Mapping[Hashable, Any], keyword)
return positional
Loading

0 comments on commit cd33d6d

Please sign in to comment.