Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding overloading functionality for add, sub and neg #173

Merged
merged 7 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions geoutils/georaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import collections
from numbers import Number
import os
import warnings
from typing import Optional, Union
Expand Down Expand Up @@ -348,6 +349,81 @@ def __eq__(self, other) -> bool:
def __ne__(self, other) -> bool:
return not self.__eq__(other)

def __add__(self, other: Union[Raster, np.ndarray, Number]) -> Raster:
"""
Sum up the data of two rasters or a raster and a numpy array, or a raster and single number.
If other is a Raster, it must have the same data.shape, transform and crs as self.
If other is a np.ndarray, it must have the same shape.
Otherwise, other must be a single number.
"""
# Check that other is of correct type
if not isinstance(other, (Raster, np.ndarray, Number)):
raise ValueError("Addition possible only with a Raster, np.ndarray or single number.")

# Case 1 - other is a Raster
if isinstance(other, Raster):
# Check that both data are loaded
if not (self.is_loaded & other.is_loaded):
raise ValueError("Raster's data must be loaded with self.load().")

# Check that both rasters have the same shape and georeferences
if (self.data.shape == other.data.shape) & (self.transform == other.transform) & (self.crs == other.crs):
pass
else:
raise ValueError("Both rasters must have the same shape, transform and CRS.")

other_data = other.data

# Case 2 - other is a numpy array
elif isinstance(other, np.ndarray):
# Check that both array have the same shape
if (self.data.shape == other.shape):
pass
else:
raise ValueError("Both rasters must have the same shape.")

other_data = other

# Case 3 - other is a single number
else:
other_data = other

# Calculate the sum of arrays
data = self.data + other_data

# Save as a new Raster
out_rst = self.from_array(data, self.transform, self.crs, nodata=self.nodata)

return out_rst

def __neg__(self) -> Raster:
"""Return self with self.data set to -self.data"""
out_rst = self.copy()
out_rst.data = -out_rst.data
return out_rst

def __sub__(self, other: Union[Raster, np.ndarray, Number]) -> Raster:
"""
Subtract two rasters. Both rasters must have the same data.shape, transform and crs.
"""
if isinstance(other, Raster):
# Need to convert both rasters to a common type before doing the negation
ctype = np.find_common_type([*self.dtypes, *other.dtypes], [])
other = other.astype(ctype)

return self + -other

def astype(self, dtype: Union[type, str]) -> Raster:
"""
Converts the data type of a Raster object.

:param dtype: Any numpy dtype or string accepted by numpy.astype

:returns: the output Raster with dtype changed.
"""
out_data = self.data.astype(dtype)
return self.from_array(out_data, self.transform, self.crs)

def _get_rio_attrs(self) -> list[str]:
"""Get the attributes that have the same name in rio.DatasetReader and Raster."""
rio_attrs: list[str] = []
Expand Down
62 changes: 62 additions & 0 deletions tests/test_georaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,68 @@ def test_downsampling(self):
# One pixel right and down
assert r.xy2ij(r.bounds.left + r.res[0], r.bounds.top - r.res[1]) == (1, 1)

def test_add_sub(self):
"""
Test addition, subtraction and negation on a Raster object.
"""
# Create fake rasters with random values in 0-255 and dtype uint8
width = height = 5
transform = rio.transform.from_bounds(0, 0, 1, 1, width, height)
r1 = gr.Raster.from_array(np.random.randint(0, 255, (height, width), dtype='uint8'),
transform=transform, crs=None)
r2 = gr.Raster.from_array(np.random.randint(0, 255, (height, width), dtype='uint8'),
transform=transform, crs=None)

# Test negation
r3 = -r1
assert np.all(r3.data == -r1.data)
assert r3.dtypes == ('uint8',)

# Test addition
r3 = r1 + r2
assert np.all(r3.data == r1.data + r2.data)
assert r3.dtypes == ('uint8',)

# Test subtraction
r3 = r1 - r2
assert np.all(r3.data == r1.data - r2.data)
assert r3.dtypes == ('uint8',)

# Test with dtype Float32
r1 = gr.Raster.from_array(np.random.randint(0, 255, (height, width)).astype('float32'),
transform=transform, crs=None)
r3 = -r1
assert np.all(r3.data == -r1.data)
assert r3.dtypes == ('float32',)

r3 = r1 + r2
assert np.all(r3.data == r1.data + r2.data)
assert r3.dtypes == ('float32',)

r3 = r1 - r2
assert np.all(r3.data == r1.data - r2.data)
assert r3.dtypes == ('float32',)

# Check that errors are properly raised
# different shapes
r1 = gr.Raster.from_array(np.random.randint(0, 255, (height + 1, width)).astype('float32'),
transform=transform, crs=None)
pytest.raises(ValueError, r1.__add__, r2)
pytest.raises(ValueError, r1.__sub__, r2)

# different CRS
r1 = gr.Raster.from_array(np.random.randint(0, 255, (height, width)).astype('float32'),
transform=transform, crs=rio.crs.CRS.from_epsg(4326))
pytest.raises(ValueError, r1.__add__, r2)
pytest.raises(ValueError, r1.__sub__, r2)

# different transform
transform2 = rio.transform.from_bounds(0, 0, 2, 2, width, height)
r1 = gr.Raster.from_array(np.random.randint(0, 255, (height, width)).astype('float32'),
transform=transform2, crs=None)
pytest.raises(ValueError, r1.__add__, r2)
pytest.raises(ValueError, r1.__sub__, r2)

def test_copy(self):
"""
Test that the copy method works as expected for Raster. In particular
Expand Down
23 changes: 23 additions & 0 deletions tests/test_satimg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import geoutils
from io import StringIO
import numpy as np
import rasterio as rio

DO_PLOT = False

Expand Down Expand Up @@ -76,6 +77,28 @@ def __exit__(self, *args):
# check nothing outputs to console
assert len(output2) == 0

def test_add_sub(self):
"""
Test that overloading of addition, subtraction and negation works for child classes as well.
"""
# Create fake rasters with random values in 0-255 and dtype uint8
width = height = 5
transform = rio.transform.from_bounds(0, 0, 1, 1, width, height)
satimg1 = si.SatelliteImage.from_array(np.random.randint(0, 255, (height, width), dtype='uint8'),
transform=transform, crs=None)
satimg2 = si.SatelliteImage.from_array(np.random.randint(0, 255, (height, width), dtype='uint8'),
transform=transform, crs=None)

# Check that output type is same - other tests are in test_georaster.py
sat_out = -satimg1
assert isinstance(sat_out, si.SatelliteImage)

sat_out = satimg1 + satimg2
assert isinstance(sat_out, si.SatelliteImage)

sat_out = satimg1 - satimg2
assert isinstance(sat_out, si.SatelliteImage)

def test_copy(self):
"""
Test that the copy method works as expected for SatelliteImage. In particular
Expand Down