Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fill_value for concat and auto_combine #2964

Merged
merged 5 commits into from
May 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from . import utils
from . import utils, dtypes
from .alignment import align
from .merge import merge
from .variable import IndexVariable, Variable, as_variable
Expand All @@ -14,7 +14,7 @@

def concat(objs, dim=None, data_vars='all', coords='different',
compat='equals', positions=None, indexers=None, mode=None,
concat_over=None):
concat_over=None, fill_value=dtypes.NA):
"""Concatenate xarray objects along a new or existing dimension.

Parameters
Expand Down Expand Up @@ -66,6 +66,8 @@ def concat(objs, dim=None, data_vars='all', coords='different',
List of integer arrays which specifies the integer positions to which
to assign each dataset along the concatenated dimension. If not
supplied, objects are concatenated in the provided order.
fill_value : scalar, optional
Value to use for newly missing values
indexers, mode, concat_over : deprecated

Returns
Expand Down Expand Up @@ -117,7 +119,7 @@ def concat(objs, dim=None, data_vars='all', coords='different',
else:
raise TypeError('can only concatenate xarray Dataset and DataArray '
'objects, got %s' % type(first_obj))
return f(objs, dim, data_vars, coords, compat, positions)
return f(objs, dim, data_vars, coords, compat, positions, fill_value)


def _calc_concat_dim_coord(dim):
Expand Down Expand Up @@ -212,7 +214,8 @@ def process_subset_opt(opt, subset):
return concat_over, equals


def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
def _dataset_concat(datasets, dim, data_vars, coords, compat, positions,
fill_value=dtypes.NA):
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand All @@ -225,7 +228,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
dim, coord = _calc_concat_dim_coord(dim)
# Make sure we're working on a copy (we'll be loading variables)
datasets = [ds.copy() for ds in datasets]
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
datasets = align(*datasets, join='outer', copy=False, exclude=[dim],
fill_value=fill_value)

concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)

Expand Down Expand Up @@ -317,7 +321,7 @@ def ensure_common_dims(vars):


def _dataarray_concat(arrays, dim, data_vars, coords, compat,
positions):
positions, fill_value=dtypes.NA):
arrays = list(arrays)

if data_vars != 'all':
Expand All @@ -336,14 +340,15 @@ def _dataarray_concat(arrays, dim, data_vars, coords, compat,
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(datasets, dim, data_vars, coords, compat,
positions)
positions, fill_value)
result = arrays[0]._from_temp_dataset(ds, name)

result.name = result_name(arrays)
return result


def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
def _auto_concat(datasets, dim=None, data_vars='all', coords='different',
fill_value=dtypes.NA):
if len(datasets) == 1 and dim is None:
# There is nothing more to combine, so kick out early.
return datasets[0]
Expand All @@ -366,7 +371,8 @@ def _auto_concat(datasets, dim=None, data_vars='all', coords='different'):
'supply the ``concat_dim`` argument '
'explicitly')
dim, = concat_dims
return concat(datasets, dim=dim, data_vars=data_vars, coords=coords)
return concat(datasets, dim=dim, data_vars=data_vars,
coords=coords, fill_value=fill_value)


_CONCAT_DIM_DEFAULT = utils.ReprObject('<inferred>')
Expand Down Expand Up @@ -442,7 +448,8 @@ def _check_shape_tile_ids(combined_tile_ids):


def _combine_nd(combined_ids, concat_dims, data_vars='all',
coords='different', compat='no_conflicts'):
coords='different', compat='no_conflicts',
fill_value=dtypes.NA):
"""
Concatenates and merges an N-dimensional structure of datasets.

Expand Down Expand Up @@ -472,13 +479,14 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all',
dim=concat_dim,
data_vars=data_vars,
coords=coords,
compat=compat)
compat=compat,
fill_value=fill_value)
combined_ds = list(combined_ids.values())[0]
return combined_ds


def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
coords, compat):
coords, compat, fill_value=dtypes.NA):
# Group into lines of datasets which must be combined along dim
# need to sort by _new_tile_id first for groupby to work
# TODO remove all these sorted OrderedDicts once python >= 3.6 only
Expand All @@ -490,7 +498,8 @@ def _auto_combine_all_along_first_dim(combined_ids, dim, data_vars,
combined_ids = OrderedDict(sorted(group))
datasets = combined_ids.values()
new_combined_ids[new_id] = _auto_combine_1d(datasets, dim, compat,
data_vars, coords)
data_vars, coords,
fill_value)
return new_combined_ids


Expand All @@ -500,18 +509,20 @@ def vars_as_keys(ds):

def _auto_combine_1d(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
compat='no_conflicts',
data_vars='all', coords='different'):
data_vars='all', coords='different',
fill_value=dtypes.NA):
# This is just the old auto_combine function (which only worked along 1D)
if concat_dim is not None:
dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim
sorted_datasets = sorted(datasets, key=vars_as_keys)
grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys)
concatenated = [_auto_concat(list(ds_group), dim=dim,
data_vars=data_vars, coords=coords)
data_vars=data_vars, coords=coords,
fill_value=fill_value)
for id, ds_group in grouped_by_vars]
else:
concatenated = datasets
merged = merge(concatenated, compat=compat)
merged = merge(concatenated, compat=compat, fill_value=fill_value)
return merged


Expand All @@ -521,7 +532,7 @@ def _new_tile_id(single_id_ds_pair):


def _auto_combine(datasets, concat_dims, compat, data_vars, coords,
infer_order_from_coords, ids):
infer_order_from_coords, ids, fill_value=dtypes.NA):
"""
Calls logic to decide concatenation order before concatenating.
"""
Expand Down Expand Up @@ -550,12 +561,14 @@ def _auto_combine(datasets, concat_dims, compat, data_vars, coords,

# Repeatedly concatenate then merge along each dimension
combined = _combine_nd(combined_ids, concat_dims, compat=compat,
data_vars=data_vars, coords=coords)
data_vars=data_vars, coords=coords,
fill_value=fill_value)
return combined


def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
compat='no_conflicts', data_vars='all', coords='different'):
compat='no_conflicts', data_vars='all', coords='different',
fill_value=dtypes.NA):
"""Attempt to auto-magically combine the given datasets into one.
This method attempts to combine a list of datasets into a single entity by
inspecting metadata and using a combination of concat and merge.
Expand Down Expand Up @@ -596,6 +609,8 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
Details are in the documentation of concat
coords : {'minimal', 'different', 'all' or list of str}, optional
Details are in the documentation of conca
fill_value : scalar, optional
Value to use for newly missing values

Returns
-------
Expand All @@ -622,4 +637,4 @@ def auto_combine(datasets, concat_dim=_CONCAT_DIM_DEFAULT,
return _auto_combine(datasets, concat_dims=concat_dims, compat=compat,
data_vars=data_vars, coords=coords,
infer_order_from_coords=infer_order_from_coords,
ids=False)
ids=False, fill_value=fill_value)
42 changes: 42 additions & 0 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from xarray import DataArray, Dataset, Variable, auto_combine, concat
from xarray.core import dtypes
from xarray.core.combine import (
_auto_combine, _auto_combine_1d, _auto_combine_all_along_first_dim,
_check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions,
Expand Down Expand Up @@ -237,6 +238,20 @@ def test_concat_multiindex(self):
assert expected.equals(actual)
assert isinstance(actual.x.to_index(), pd.MultiIndex)

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_concat_fill_value(self, fill_value):
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = Dataset({'a': (('t', 'x'),
[[fill_value, 2, 3], [1, 2, fill_value]])},
{'x': [0, 1, 2]})
actual = concat(datasets, dim='t', fill_value=fill_value)
assert_identical(actual, expected)


class TestConcatDataArray:
def test_concat(self):
Expand Down Expand Up @@ -306,6 +321,19 @@ def test_concat_lazy(self):
assert combined.shape == (2, 3, 3)
assert combined.dims == ('z', 'x', 'y')

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_concat_fill_value(self, fill_value):
foo = DataArray([1, 2], coords=[('x', [1, 2])])
bar = DataArray([1, 2], coords=[('x', [1, 3])])
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = DataArray([[1, 2, fill_value], [1, fill_value, 2]],
dims=['y', 'x'], coords={'x': [1, 2, 3]})
actual = concat((foo, bar), dim='y', fill_value=fill_value)
assert_identical(actual, expected)


class TestAutoCombine:

Expand Down Expand Up @@ -417,6 +445,20 @@ def test_auto_combine_no_concat(self):
{'baz': [100]})
assert_identical(expected, actual)

@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_auto_combine_fill_value(self, fill_value):
datasets = [Dataset({'a': ('x', [2, 3]), 'x': [1, 2]}),
Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})]
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = Dataset({'a': (('t', 'x'),
[[fill_value, 2, 3], [1, 2, fill_value]])},
{'x': [0, 1, 2]})
actual = auto_combine(datasets, concat_dim='t', fill_value=fill_value)
assert_identical(expected, actual)


def assert_combined_tile_ids_equal(dict1, dict2):
assert len(dict1) == len(dict2)
Expand Down