From 3cd2337d8035a324cb38d6793eaf33818066f25c Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 23 Jan 2018 11:33:27 -0500 Subject: [PATCH] fix rasterio chunking with s3 datasets (#1817) * fixes #1816 * new and refactored rasterio tests --- doc/whats-new.rst | 3 + xarray/backends/rasterio_.py | 6 +- xarray/tests/test_backends.py | 291 +++++++++------------------------- 3 files changed, 87 insertions(+), 213 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0c6a590980f..d4822e3d675 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,6 +75,9 @@ Bug fixes :py:meth:`~Dataset.to_netcdf` (:issue:`1763`). By `Mike Neish `_. +- Fixed chunking with non-file-based rasterio datasets (:issue:`1816`) and + refactored rasterio test suite. + By `Ryan Abernathey `_ - Bug fix in open_dataset(engine='pydap') (:issue:`1775`) By `Keisuke Fujii `_. diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index d94bf018857..3acc5551173 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -222,7 +222,11 @@ def open_rasterio(filename, chunks=None, cache=None, lock=None): if chunks is not None: from dask.base import tokenize # augment the token with the file modification time - mtime = os.path.getmtime(filename) + try: + mtime = os.path.getmtime(filename) + except OSError: + # the filename is probably an s3 bucket rather than a regular file + mtime = None token = tokenize(filename, mtime, chunks) name_prefix = 'open_rasterio-%s' % token if lock is None: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 65bf0b5e8d8..013131a1e7b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2149,30 +2149,52 @@ class TestPyNioAutocloseTrue(TestPyNio): autoclose = True +@requires_rasterio +@contextlib.contextmanager +def create_tmp_geotiff(nx=4, ny=3, nz=3, + transform_args=[5000, 80000, 1000, 2000.], + crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', + 'proj': 'utm', 'zone': 18}): + # yields a temporary geotiff file and a corresponding expected DataArray + import rasterio + from rasterio.transform import from_origin + with create_tmp_file(suffix='.tif') as tmp_file: + # allow 2d or 3d shapes + if nz == 1: + data_shape = ny, nx + write_kwargs = {'indexes': 1} + else: + data_shape = nz, ny, nx + write_kwargs = {} + data = np.arange(nz*ny*nx, + dtype=rasterio.float32).reshape(*data_shape) + transform = from_origin(*transform_args) + with rasterio.open( + tmp_file, 'w', + driver='GTiff', height=ny, width=nx, count=nz, + crs=crs, + transform=transform, + dtype=rasterio.float32) as s: + s.write(data, **write_kwargs) + dx, dy = s.res[0], -s.res[1] + + a, b, c, d = transform_args + data = data[np.newaxis, ...] if nz == 1 else data + expected = DataArray(data, dims=('band', 'y', 'x'), + coords={ + 'band': np.arange(nz)+1, + 'y': -np.arange(ny) * d + b + dy/2, + 'x': np.arange(nx) * c + a + dx/2, + }) + yield tmp_file, expected + + @requires_rasterio class TestRasterio(TestCase): @requires_scipy_or_netCDF4 def test_serialization(self): - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in utm proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx * ny * nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(5000, 80000, 1000, 2000.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', - 'proj': 'utm', 'zone': 18}, - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - + with create_tmp_geotiff() as (tmp_file, expected): # Write it to a netcdf and read again (roundtrip) with xr.open_rasterio(tmp_file) as rioda: with create_tmp_file(suffix='.nc') as tmp_nc_file: @@ -2180,95 +2202,8 @@ def test_serialization(self): with xr.open_dataarray(tmp_nc_file) as ncds: assert_identical(rioda, ncds) - @requires_scipy_or_netCDF4 - def test_nodata(self): - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in utm proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx*ny*nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(5000, 80000, 1000, 2000.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', - 'proj': 'utm', 'zone': 18}, - transform=transform, - nodata=-9765, - dtype=rasterio.float32) as s: - s.write(data) - expected_nodatavals = [-9765, -9765, -9765] - with xr.open_rasterio(tmp_file) as rioda: - np.testing.assert_array_equal(rioda.attrs['nodatavals'], - expected_nodatavals) - with create_tmp_file(suffix='.nc') as tmp_nc_file: - rioda.to_netcdf(tmp_nc_file) - with xr.open_dataarray(tmp_nc_file) as ncds: - np.testing.assert_array_equal(ncds.attrs['nodatavals'], - expected_nodatavals) - - @requires_scipy_or_netCDF4 - def test_nodata_missing(self): - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in utm proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx*ny*nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(5000, 80000, 1000, 2000.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', - 'proj': 'utm', 'zone': 18}, - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - - expected_nodatavals = [np.nan, np.nan, np.nan] - with xr.open_rasterio(tmp_file) as rioda: - np.testing.assert_array_equal(rioda.attrs['nodatavals'], - expected_nodatavals) - with create_tmp_file(suffix='.nc') as tmp_nc_file: - rioda.to_netcdf(tmp_nc_file) - with xr.open_dataarray(tmp_nc_file) as ncds: - np.testing.assert_array_equal(ncds.attrs['nodatavals'], - expected_nodatavals) - def test_utm(self): - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in utm proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx * ny * nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(5000, 80000, 1000, 2000.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84', - 'proj': 'utm', 'zone': 18}, - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - dx, dy = s.res[0], -s.res[1] - - # Tests - expected = DataArray(data, dims=('band', 'y', 'x'), coords={ - 'band': [1, 2, 3], - 'y': -np.arange(ny) * 2000 + 80000 + dy / 2, - 'x': np.arange(nx) * 1000 + 5000 + dx / 2, - }) + with create_tmp_geotiff() as (tmp_file, expected): with xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert 'crs' in rioda.attrs @@ -2281,32 +2216,9 @@ def test_utm(self): assert isinstance(rioda.attrs['transform'], tuple) def test_platecarree(self): - - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in latlong proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny = 8, 10 - data = np.arange(80, dtype=rasterio.float32).reshape(ny, nx) - transform = from_origin(1, 2, 0.5, 2.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=1, - crs='+proj=latlong', - transform=transform, - dtype=rasterio.float32) as s: - s.write(data, indexes=1) - dx, dy = s.res[0], -s.res[1] - - # Tests - expected = DataArray(data[np.newaxis, ...], - dims=('band', 'y', 'x'), - coords={'band': [1], - 'y': -np.arange(ny) * 2 + 2 + dy / 2, - 'x': np.arange(nx) * 0.5 + 1 + dx / 2, - }) + with create_tmp_geotiff(8, 10, 1, transform_args=[1, 2, 0.5, 2.], + crs='+proj=latlong') \ + as (tmp_file, expected): with xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert 'crs' in rioda.attrs @@ -2319,32 +2231,8 @@ def test_platecarree(self): assert isinstance(rioda.attrs['transform'], tuple) def test_indexing(self): - - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in latlong proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 8, 10, 3 - data = np.arange(nx * ny * nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(1, 2, 0.5, 2.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs='+proj=latlong', - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - dx, dy = s.res[0], -s.res[1] - - # ref - expected = DataArray(data, dims=('band', 'y', 'x'), coords={ - 'x': (np.arange(nx) * 0.5 + 1) + dx / 2, - 'y': (-np.arange(ny) * 2 + 2) + dy / 2, - 'band': [1, 2, 3]}) - + with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], + crs='+proj=latlong') as (tmp_file, expected): with xr.open_rasterio(tmp_file, cache=False) as actual: # tests @@ -2411,33 +2299,8 @@ def test_indexing(self): assert_allclose(ac, ex) def test_caching(self): - - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in latlong proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 8, 10, 3 - data = np.arange(nx * ny * nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(1, 2, 0.5, 2.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs='+proj=latlong', - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - dx, dy = s.res[0], -s.res[1] - - # ref - expected = DataArray( - data, dims=('band', 'y', 'x'), coords={ - 'x': (np.arange(nx) * 0.5 + 1) + dx / 2, - 'y': (-np.arange(ny) * 2 + 2) + dy / 2, - 'band': [1, 2, 3]}) - + with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], + crs='+proj=latlong') as (tmp_file, expected): # Cache is the default with xr.open_rasterio(tmp_file) as actual: @@ -2456,26 +2319,8 @@ def test_caching(self): @requires_dask def test_chunks(self): - - import rasterio - from rasterio.transform import from_origin - - # Create a geotiff file in latlong proj - with create_tmp_file(suffix='.tif') as tmp_file: - # data - nx, ny, nz = 8, 10, 3 - data = np.arange(nx * ny * nz, - dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(1, 2, 0.5, 2.) - with rasterio.open( - tmp_file, 'w', - driver='GTiff', height=ny, width=nx, count=nz, - crs='+proj=latlong', - transform=transform, - dtype=rasterio.float32) as s: - s.write(data) - dx, dy = s.res[0], -s.res[1] - + with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], + crs='+proj=latlong') as (tmp_file, expected): # Chunk at open time with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: @@ -2483,12 +2328,6 @@ def test_chunks(self): self.assertIsInstance(actual.data, da.Array) assert 'open_rasterio' in actual.data.name - # ref - expected = DataArray(data, dims=('band', 'y', 'x'), coords={ - 'x': np.arange(nx) * 0.5 + 1 + dx / 2, - 'y': -np.arange(ny) * 2 + 2 + dy / 2, - 'band': [1, 2, 3]}) - # do some arithmetic ac = actual.mean() ex = expected.mean() @@ -2503,6 +2342,7 @@ def test_ENVI_tags(self): from rasterio.transform import from_origin # Create an ENVI file with some tags in the ENVI namespace + # this test uses a custom driver, so we can't use create_tmp_geotiff with create_tmp_file(suffix='.dat') as tmp_file: # data nx, ny, nz = 4, 3, 3 @@ -2545,6 +2385,33 @@ def test_ENVI_tags(self): assert isinstance(rioda.attrs['map_info'], basestring) assert isinstance(rioda.attrs['samples'], basestring) + def test_no_mftime(self): + # rasterio can accept "filename" urguments that are actually urls, + # including paths to remote files. + # In issue #1816, we found that these caused dask to break, because + # the modification time was used to determine the dask token. This + # tests ensure we can still chunk such files when reading with + # rasterio. + with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.], + crs='+proj=latlong') as (tmp_file, expected): + with mock.patch('os.path.getmtime', side_effect=OSError): + with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: + import dask.array as da + self.assertIsInstance(actual.data, da.Array) + assert_allclose(actual, expected) + + @network + def test_http_url(self): + # more examples urls here + # http://download.osgeo.org/geotiff/samples/ + url = 'http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif' + with xr.open_rasterio(url) as actual: + assert actual.shape == (1, 512, 512) + # make sure chunking works + with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: + import dask.array as da + self.assertIsInstance(actual.data, da.Array) + class TestEncodingInvalid(TestCase):