Skip to content

Commit

Permalink
fix rasterio chunking with s3 datasets (#1817)
Browse files Browse the repository at this point in the history
* fixes #1816

* new and refactored rasterio tests
  • Loading branch information
rabernat authored and shoyer committed Jan 23, 2018
1 parent e31cf43 commit 3cd2337
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 213 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ Bug fixes
:py:meth:`~Dataset.to_netcdf` (:issue:`1763`).
By `Mike Neish <https://github.com/neishm>`_.

- Fixed chunking with non-file-based rasterio datasets (:issue:`1816`) and
refactored rasterio test suite.
By `Ryan Abernathey <https://github.com/rabernat>`_
- Bug fix in open_dataset(engine='pydap') (:issue:`1775`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

Expand Down
6 changes: 5 additions & 1 deletion xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def open_rasterio(filename, chunks=None, cache=None, lock=None):
if chunks is not None:
from dask.base import tokenize
# augment the token with the file modification time
mtime = os.path.getmtime(filename)
try:
mtime = os.path.getmtime(filename)
except OSError:
# the filename is probably an s3 bucket rather than a regular file
mtime = None
token = tokenize(filename, mtime, chunks)
name_prefix = 'open_rasterio-%s' % token
if lock is None:
Expand Down
291 changes: 79 additions & 212 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,126 +2149,61 @@ class TestPyNioAutocloseTrue(TestPyNio):
autoclose = True


@requires_rasterio
@contextlib.contextmanager
def create_tmp_geotiff(nx=4, ny=3, nz=3,
transform_args=[5000, 80000, 1000, 2000.],
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18}):
# yields a temporary geotiff file and a corresponding expected DataArray
import rasterio
from rasterio.transform import from_origin
with create_tmp_file(suffix='.tif') as tmp_file:
# allow 2d or 3d shapes
if nz == 1:
data_shape = ny, nx
write_kwargs = {'indexes': 1}
else:
data_shape = nz, ny, nx
write_kwargs = {}
data = np.arange(nz*ny*nx,
dtype=rasterio.float32).reshape(*data_shape)
transform = from_origin(*transform_args)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs=crs,
transform=transform,
dtype=rasterio.float32) as s:
s.write(data, **write_kwargs)
dx, dy = s.res[0], -s.res[1]

a, b, c, d = transform_args
data = data[np.newaxis, ...] if nz == 1 else data
expected = DataArray(data, dims=('band', 'y', 'x'),
coords={
'band': np.arange(nz)+1,
'y': -np.arange(ny) * d + b + dy/2,
'x': np.arange(nx) * c + a + dx/2,
})
yield tmp_file, expected


@requires_rasterio
class TestRasterio(TestCase):

@requires_scipy_or_netCDF4
def test_serialization(self):
import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in utm proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 4, 3, 3
data = np.arange(nx * ny * nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(5000, 80000, 1000, 2000.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18},
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)

with create_tmp_geotiff() as (tmp_file, expected):
# Write it to a netcdf and read again (roundtrip)
with xr.open_rasterio(tmp_file) as rioda:
with create_tmp_file(suffix='.nc') as tmp_nc_file:
rioda.to_netcdf(tmp_nc_file)
with xr.open_dataarray(tmp_nc_file) as ncds:
assert_identical(rioda, ncds)

@requires_scipy_or_netCDF4
def test_nodata(self):
import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in utm proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 4, 3, 3
data = np.arange(nx*ny*nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(5000, 80000, 1000, 2000.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18},
transform=transform,
nodata=-9765,
dtype=rasterio.float32) as s:
s.write(data)
expected_nodatavals = [-9765, -9765, -9765]
with xr.open_rasterio(tmp_file) as rioda:
np.testing.assert_array_equal(rioda.attrs['nodatavals'],
expected_nodatavals)
with create_tmp_file(suffix='.nc') as tmp_nc_file:
rioda.to_netcdf(tmp_nc_file)
with xr.open_dataarray(tmp_nc_file) as ncds:
np.testing.assert_array_equal(ncds.attrs['nodatavals'],
expected_nodatavals)

@requires_scipy_or_netCDF4
def test_nodata_missing(self):
import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in utm proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 4, 3, 3
data = np.arange(nx*ny*nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(5000, 80000, 1000, 2000.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18},
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)

expected_nodatavals = [np.nan, np.nan, np.nan]
with xr.open_rasterio(tmp_file) as rioda:
np.testing.assert_array_equal(rioda.attrs['nodatavals'],
expected_nodatavals)
with create_tmp_file(suffix='.nc') as tmp_nc_file:
rioda.to_netcdf(tmp_nc_file)
with xr.open_dataarray(tmp_nc_file) as ncds:
np.testing.assert_array_equal(ncds.attrs['nodatavals'],
expected_nodatavals)

def test_utm(self):
import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in utm proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 4, 3, 3
data = np.arange(nx * ny * nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(5000, 80000, 1000, 2000.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
'proj': 'utm', 'zone': 18},
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)
dx, dy = s.res[0], -s.res[1]

# Tests
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
'band': [1, 2, 3],
'y': -np.arange(ny) * 2000 + 80000 + dy / 2,
'x': np.arange(nx) * 1000 + 5000 + dx / 2,
})
with create_tmp_geotiff() as (tmp_file, expected):
with xr.open_rasterio(tmp_file) as rioda:
assert_allclose(rioda, expected)
assert 'crs' in rioda.attrs
Expand All @@ -2281,32 +2216,9 @@ def test_utm(self):
assert isinstance(rioda.attrs['transform'], tuple)

def test_platecarree(self):

import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in latlong proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny = 8, 10
data = np.arange(80, dtype=rasterio.float32).reshape(ny, nx)
transform = from_origin(1, 2, 0.5, 2.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=1,
crs='+proj=latlong',
transform=transform,
dtype=rasterio.float32) as s:
s.write(data, indexes=1)
dx, dy = s.res[0], -s.res[1]

# Tests
expected = DataArray(data[np.newaxis, ...],
dims=('band', 'y', 'x'),
coords={'band': [1],
'y': -np.arange(ny) * 2 + 2 + dy / 2,
'x': np.arange(nx) * 0.5 + 1 + dx / 2,
})
with create_tmp_geotiff(8, 10, 1, transform_args=[1, 2, 0.5, 2.],
crs='+proj=latlong') \
as (tmp_file, expected):
with xr.open_rasterio(tmp_file) as rioda:
assert_allclose(rioda, expected)
assert 'crs' in rioda.attrs
Expand All @@ -2319,32 +2231,8 @@ def test_platecarree(self):
assert isinstance(rioda.attrs['transform'], tuple)

def test_indexing(self):

import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in latlong proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 8, 10, 3
data = np.arange(nx * ny * nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(1, 2, 0.5, 2.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs='+proj=latlong',
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)
dx, dy = s.res[0], -s.res[1]

# ref
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
'x': (np.arange(nx) * 0.5 + 1) + dx / 2,
'y': (-np.arange(ny) * 2 + 2) + dy / 2,
'band': [1, 2, 3]})

with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
crs='+proj=latlong') as (tmp_file, expected):
with xr.open_rasterio(tmp_file, cache=False) as actual:

# tests
Expand Down Expand Up @@ -2411,33 +2299,8 @@ def test_indexing(self):
assert_allclose(ac, ex)

def test_caching(self):

import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in latlong proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 8, 10, 3
data = np.arange(nx * ny * nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(1, 2, 0.5, 2.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs='+proj=latlong',
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)
dx, dy = s.res[0], -s.res[1]

# ref
expected = DataArray(
data, dims=('band', 'y', 'x'), coords={
'x': (np.arange(nx) * 0.5 + 1) + dx / 2,
'y': (-np.arange(ny) * 2 + 2) + dy / 2,
'band': [1, 2, 3]})

with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
crs='+proj=latlong') as (tmp_file, expected):
# Cache is the default
with xr.open_rasterio(tmp_file) as actual:

Expand All @@ -2456,39 +2319,15 @@ def test_caching(self):

@requires_dask
def test_chunks(self):

import rasterio
from rasterio.transform import from_origin

# Create a geotiff file in latlong proj
with create_tmp_file(suffix='.tif') as tmp_file:
# data
nx, ny, nz = 8, 10, 3
data = np.arange(nx * ny * nz,
dtype=rasterio.float32).reshape(nz, ny, nx)
transform = from_origin(1, 2, 0.5, 2.)
with rasterio.open(
tmp_file, 'w',
driver='GTiff', height=ny, width=nx, count=nz,
crs='+proj=latlong',
transform=transform,
dtype=rasterio.float32) as s:
s.write(data)
dx, dy = s.res[0], -s.res[1]

with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
crs='+proj=latlong') as (tmp_file, expected):
# Chunk at open time
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:

import dask.array as da
self.assertIsInstance(actual.data, da.Array)
assert 'open_rasterio' in actual.data.name

# ref
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
'x': np.arange(nx) * 0.5 + 1 + dx / 2,
'y': -np.arange(ny) * 2 + 2 + dy / 2,
'band': [1, 2, 3]})

# do some arithmetic
ac = actual.mean()
ex = expected.mean()
Expand All @@ -2503,6 +2342,7 @@ def test_ENVI_tags(self):
from rasterio.transform import from_origin

# Create an ENVI file with some tags in the ENVI namespace
# this test uses a custom driver, so we can't use create_tmp_geotiff
with create_tmp_file(suffix='.dat') as tmp_file:
# data
nx, ny, nz = 4, 3, 3
Expand Down Expand Up @@ -2545,6 +2385,33 @@ def test_ENVI_tags(self):
assert isinstance(rioda.attrs['map_info'], basestring)
assert isinstance(rioda.attrs['samples'], basestring)

def test_no_mftime(self):
# rasterio can accept "filename" urguments that are actually urls,
# including paths to remote files.
# In issue #1816, we found that these caused dask to break, because
# the modification time was used to determine the dask token. This
# tests ensure we can still chunk such files when reading with
# rasterio.
with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
crs='+proj=latlong') as (tmp_file, expected):
with mock.patch('os.path.getmtime', side_effect=OSError):
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:
import dask.array as da
self.assertIsInstance(actual.data, da.Array)
assert_allclose(actual, expected)

@network
def test_http_url(self):
# more examples urls here
# http://download.osgeo.org/geotiff/samples/
url = 'http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif'
with xr.open_rasterio(url) as actual:
assert actual.shape == (1, 512, 512)
# make sure chunking works
with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual:
import dask.array as da
self.assertIsInstance(actual.data, da.Array)


class TestEncodingInvalid(TestCase):

Expand Down

0 comments on commit 3cd2337

Please sign in to comment.