Skip to content

Commit

Permalink
Add test for OverrideBandDataSource fix
Browse files Browse the repository at this point in the history
Closes #224
  • Loading branch information
andrewdhicks authored and omad committed May 12, 2017
1 parent 58a8fcf commit f2e201f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
14 changes: 14 additions & 0 deletions datacube/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def copyto_fuser(dest, src):


class BandDataSource(object):
"""Wrapper for a rasterio.Band object
:param source: rasterio.Band
"""
def __init__(self, source, nodata=None):
self.source = source
if nodata is None:
Expand All @@ -234,6 +238,8 @@ def shape(self):
return self.source.shape

def read(self, window=None, out_shape=None):
"""Read data in the native format, returning a native array
"""
return self.source.ds.read(indexes=self.source.bidx, window=window, out_shape=out_shape)

def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwargs):
Expand Down Expand Up @@ -312,6 +318,12 @@ def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwar


class OverrideBandDataSource(object):
"""Wrapper for a rasterio.Band object that overrides nodata, crs and transform
This is useful for files with malformed or missing properties
:param source: rasterio.Band
"""
def __init__(self, source, nodata, crs, transform):
self.source = source
self.nodata = nodata
Expand All @@ -327,6 +339,8 @@ def shape(self):
return self.source.shape

def read(self, window=None, out_shape=None):
"""Read data in the native format, returning a native array
"""
return self.source.ds.read(indexes=self.source.bidx, window=window, out_shape=out_shape)

def reproject(self, dest, dst_transform, dst_crs, dst_nodata, resampling, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@


@pytest.fixture
def example_gdal_path(request):
def example_gdal_path(data_folder):
"""Return the pathname of a sample geotiff file
Use this fixture by specifiying an argument named 'example_gdal_path' in your
test method.
"""
return str(request.fspath.dirpath('data/sample_tile_151_-29.tif'))
return str(os.path.join(data_folder, 'sample_tile_151_-29.tif'))


@pytest.fixture
def data_folder(request):
def data_folder():
return os.path.join(os.path.split(os.path.realpath(__file__))[0], 'data')


Expand Down
27 changes: 26 additions & 1 deletion tests/storage/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import datacube
from datacube.utils import geometry
from datacube.storage.storage import write_dataset_to_netcdf, reproject_and_fuse, read_from_source, Resampling
from datacube.storage.storage import NetCDFDataSource
from datacube.storage.storage import NetCDFDataSource, OverrideBandDataSource

GEO_PROJ = 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],' \
'AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433],' \
Expand Down Expand Up @@ -399,3 +399,28 @@ def fake_open():
resampling=Resampling.cubic)

# TODO: crs change


def test_read_raster_with_custom_crs_and_transform(example_gdal_path):
import numpy as np

with rasterio.open(example_gdal_path) as src:
band = rasterio.band(src, 1)
crs = geometry.CRS('EPSG:3577')
nodata = -999
transform = Affine(25.0, 0.0, 1000000.0,
0.0, -25.0, -900000.0)

# Read all raw data from source file
band_data_source = OverrideBandDataSource(band, nodata, crs, transform)
dest1 = band_data_source.read()
assert dest1.shape

# Attempt to read with the same transform parameters
dest2 = np.full(shape=(4000, 4000), fill_value=nodata, dtype=np.float32)
dst_transform = transform
dst_crs = crs
dst_nodata = nodata
resampling = datacube.storage.storage.RESAMPLING_METHODS['nearest']
band_data_source.reproject(dest2, dst_transform, dst_crs, dst_nodata, resampling)
assert (dest1 == dest2).all()

0 comments on commit f2e201f

Please sign in to comment.