Skip to content

Commit

Permalink
REF: Use dst_path memory file for loading data
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Oct 31, 2024
1 parent a6d4a9c commit cd0021d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 53 deletions.
77 changes: 28 additions & 49 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset

from rioxarray._io import open_rasterio
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords


Expand All @@ -35,9 +36,10 @@ def __init__(self, xds: DataArray):
res = xds.rio.resolution(recalc=True)
self.res = (abs(res[0]), abs(res[1]))
self.transform = xds.rio.transform(recalc=True)
# profile is only used for writing to a file.
# This never happens with rioxarray merge.
self.profile: dict = {}
self.profile: dict = {
"crs": self.crs,
"nodata": self.nodatavals[0],
}

def colormap(self, *args, **kwargs) -> None:
"""
Expand All @@ -52,21 +54,8 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
This method is meant to be used by the rasterio.merge.merge function.
"""
with MemoryFile() as memfile:
with memfile.open(
driver="GTiff",
height=int(self._xds.rio.height),
width=int(self._xds.rio.width),
count=self.count,
dtype=self.dtypes[0],
crs=self.crs,
transform=self.transform,
nodata=self.nodatavals[0],
) as dataset:
data = self._xds.values
if data.ndim == 2:
dataset.write(data, 1)
else:
dataset.write(data)
self._xds.rio.to_raster(memfile.name)
with memfile.open() as dataset:
return dataset.read(*args, **kwargs)


Expand Down Expand Up @@ -145,40 +134,30 @@ def merge_arrays(
rioduckarrays.append(RasterioDatasetDuck(dataarray))

# use rasterio to merge
merged_data, merged_transform = _rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
)
# generate merged data array
representative_array = rioduckarrays[0]._xds
if parse_coordinates:
coords = _make_coords(
src_data_array=representative_array,
dst_affine=merged_transform,
dst_width=merged_data.shape[-1],
dst_height=merged_data.shape[-2],
with MemoryFile() as memfile:
_rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
dst_path=memfile.name,
)
mask_and_scale = bool(
set(representative_array.encoding.keys())
& {"_FillValue", "_Unsigned", "scale_factor", "add_offset"}
)
xda = open_rasterio(
memfile.name,
parse_coordinates=parse_coordinates,
mask_and_scale=mask_and_scale,
).load()
xda.coords.update(
{
coord: value
for coord, value in _get_nonspatial_coords(representative_array).items()
if coord not in xda.coords
}
)
else:
coords = _get_nonspatial_coords(representative_array)

# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze()

xda = DataArray(
name=representative_array.name,
data=merged_data,
coords=coords,
dims=tuple(representative_array.dims),
attrs=representative_array.attrs,
)
xda.rio.write_nodata(
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
)
xda.rio.write_crs(representative_array.rio.crs, inplace=True)
xda.rio.write_transform(merged_transform, inplace=True)
return xda


Expand Down
15 changes: 11 additions & 4 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def test_merge_arrays(squeeze):
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"scale_factor": 1.0,
"add_offset": 0.0,
**rds.attrs,
}
assert merged.encoding["grid_mapping"] == "spatial_ref"


Expand Down Expand Up @@ -106,6 +111,7 @@ def test_merge__different_crs(dataset):
assert merged.rio.crs == rds.rio.crs
if not dataset:
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"_FillValue": -28672,
"add_offset": 0.0,
"scale_factor": 1.0,
Expand Down Expand Up @@ -151,9 +157,10 @@ def test_merge_arrays__res():
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert_almost_equal(merged.attrs.pop("_FillValue"), rds.attrs.pop("_FillValue"))
compare_attrs = dict(rds.attrs)
assert merged.attrs == compare_attrs
assert_almost_equal(
merged.encoding.pop("_FillValue"), rds.attrs.pop("_FillValue")
)
assert merged.attrs == {"AREA_OR_POINT": "Area", **rds.attrs}
assert merged.encoding["grid_mapping"] == "spatial_ref"
assert_almost_equal(nansum(merged), 13760565)

Expand Down

0 comments on commit cd0021d

Please sign in to comment.