Skip to content

Commit

Permalink
♻️ Refactor DatashaderRasterizer to be up front about datapipe lengths (
Browse files Browse the repository at this point in the history
#39)

Check during initialization of DatashaderRasterizerIterDataPipe on whether the input canvas and vector datapipes have compatible lengths. This is better than finding out that the zip function doesn't work when the datapipe is being iterated over. Added a unit test to cover the 3:2 ratio case and documented why the ValueError is raised on unmatched lengths. Also renamed the previous ValueError on unsupported geometry types to NotImplementedError to avoid confusion.
  • Loading branch information
weiji14 authored Aug 19, 2022
1 parent ce0f4da commit 23c6ac4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
27 changes: 21 additions & 6 deletions zen3geo/datapipes/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class DatashaderRasterizerIterDataPipe(IterDataPipe):
If ``spatialpandas`` is not installed. Please install it (e.g. via
``pip install spatialpandas``) before using this class.
ValueError
If either the length of the ``vector_datapipe`` is not 1, or if the
length of the ``vector_datapipe`` is not equal to the length of the
``source_datapipe``. I.e. the ratio of vector:canvas must be 1:N or
be exactly N:N.
AttributeError
If either the canvas in ``source_datapipe`` or vector geometry in
``vector_datapipe`` is missing a ``.crs`` attribute. Please set the
Expand All @@ -84,7 +90,7 @@ class DatashaderRasterizerIterDataPipe(IterDataPipe):
:py:class:`geopandas.GeoSeries` or :py:class:`geopandas.GeoDataFrame`
input) before passing them into the datapipe.
ValueError
NotImplementedError
If the input vector geometry type to ``vector_datapipe`` is not
supported, typically when a
:py:class:`shapely.geometry.GeometryCollection` is used. Supported
Expand Down Expand Up @@ -167,14 +173,23 @@ def __init__(
self.agg: Optional = agg # Datashader Aggregation/Reduction function
self.kwargs = kwargs

len_vector_datapipe: int = len(self.vector_datapipe)
len_canvas_datapipe: int = len(self.source_datapipe)
if len_vector_datapipe != 1 or len_vector_datapipe != len_canvas_datapipe:
raise ValueError(
f"Unmatched lengths for the canvas datapipe ({self.source_datapipe}) "
f"and vector datapipe ({self.vector_datapipe}). \n"
f"The vector datapipe's length ({len_vector_datapipe}) should either "
f"be (1) to allow for broadcasting, or match the canvas datapipe's "
f"length of ({len_canvas_datapipe})."
)

def __iter__(self) -> Iterator[xr.DataArray]:
# Broadcast vector iterator to match length of raster iterator
fill_value: Optional = (
list(self.vector_datapipe).pop() if len(self.vector_datapipe) == 1 else None
)
for canvas, vector in self.source_datapipe.zip_longest(
self.vector_datapipe, fill_value=fill_value
self.vector_datapipe, fill_value=list(self.vector_datapipe).pop()
):
# print(canvas, vector)
# If canvas has no CRS attribute, set one to prevent AttributeError
canvas.crs = getattr(canvas, "crs", None)
if canvas.crs is None:
Expand Down Expand Up @@ -202,7 +217,7 @@ def __iter__(self) -> Iterator[xr.DataArray]:
columns = ["geometry"] if not hasattr(vector, "columns") else None
_vector = spatialpandas.GeoDataFrame(data=vector, columns=columns)
except ValueError as e:
raise ValueError(
raise NotImplementedError(
f"Unsupported geometry type(s) {set(vector.geom_type)} detected, "
"only point, line or polygon vector geometry types are supported."
) from e
Expand Down
19 changes: 16 additions & 3 deletions zen3geo/tests/test_datapipes_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,23 @@ def test_datashader_rasterize_vector_missing_crs(canvas, geodataframe):
raster = next(it)


def test_datashader_rasterize_unmatched_lengths(canvas, geodataframe):
"""
Ensure that DatashaderRasterizer raises a ValueError when the length of the
canvas datapipe is unmatched with the length of the vector datapipe.
"""
# Canvas:Vector ratio of 3:2
dp_canvas = IterableWrapper(iterable=[canvas, canvas, canvas])
dp_vector = IterableWrapper(iterable=[geodataframe, geodataframe])

with pytest.raises(ValueError, match="Unmatched lengths for the"):
dp_datashader = dp_canvas.rasterize_with_datashader(vector_datapipe=dp_vector)


def test_datashader_rasterize_vector_geometrycollection(canvas, geodataframe):
"""
Ensure that DatashaderRasterizer raises a ValueError when an unsupported
vector type like GeometryCollection is used.
Ensure that DatashaderRasterizer raises a NotImplementedError when an
unsupported vector type like GeometryCollection is used.
"""
gpd = pytest.importorskip("geopandas")

Expand All @@ -156,5 +169,5 @@ def test_datashader_rasterize_vector_geometrycollection(canvas, geodataframe):

assert len(dp_datashader) == 1
it = iter(dp_datashader)
with pytest.raises(ValueError, match="Unsupported geometry type"):
with pytest.raises(NotImplementedError, match="Unsupported geometry type"):
raster = next(it)

0 comments on commit 23c6ac4

Please sign in to comment.