Skip to content

Commit

Permalink
Fix maybe_promote (#1953)
Browse files Browse the repository at this point in the history
* Fix maybe_promote

With tests for every possible dtype:

(numpy docs say `biufcmMOSUV` only)

```
for letter in string.ascii_letters:
    try:
        print(letter, np.dtype(letter))
    except TypeError as exc:
        pass
```

* Check issubdtype of floating before timedelta64

In order to hit this branch more often

* Improve maybe_promote test
  • Loading branch information
NotSqrt authored and shoyer committed Aug 20, 2018
1 parent 8378d3a commit 69086b3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
7 changes: 5 additions & 2 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def maybe_promote(dtype):
# N.B. these casting rules should match pandas
if np.issubdtype(dtype, np.floating):
fill_value = np.nan
elif np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64('NaT')
elif np.issubdtype(dtype, np.integer):
if dtype.itemsize <= 2:
dtype = np.float32
Expand All @@ -90,8 +95,6 @@ def maybe_promote(dtype):
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
fill_value = np.datetime64('NaT')
elif np.issubdtype(dtype, np.timedelta64):
fill_value = np.timedelta64('NaT')
else:
dtype = object
fill_value = np.nan
Expand Down
36 changes: 36 additions & 0 deletions xarray/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,39 @@ def error():
def test_inf(obj):
assert dtypes.INF > obj
assert dtypes.NINF < obj


@pytest.mark.parametrize("kind, expected", [
('a', (np.dtype('O'), 'nan')), # dtype('S')
('b', (np.float32, 'nan')), # dtype('int8')
('B', (np.float32, 'nan')), # dtype('uint8')
('c', (np.dtype('O'), 'nan')), # dtype('S1')
('D', (np.complex128, '(nan+nanj)')), # dtype('complex128')
('d', (np.float64, 'nan')), # dtype('float64')
('e', (np.float16, 'nan')), # dtype('float16')
('F', (np.complex64, '(nan+nanj)')), # dtype('complex64')
('f', (np.float32, 'nan')), # dtype('float32')
('h', (np.float32, 'nan')), # dtype('int16')
('H', (np.float32, 'nan')), # dtype('uint16')
('i', (np.float64, 'nan')), # dtype('int32')
('I', (np.float64, 'nan')), # dtype('uint32')
('l', (np.float64, 'nan')), # dtype('int64')
('L', (np.float64, 'nan')), # dtype('uint64')
('m', (np.timedelta64, 'NaT')), # dtype('<m8')
('M', (np.datetime64, 'NaT')), # dtype('<M8')
('O', (np.dtype('O'), 'nan')), # dtype('O')
('p', (np.float64, 'nan')), # dtype('int64')
('P', (np.float64, 'nan')), # dtype('uint64')
('q', (np.float64, 'nan')), # dtype('int64')
('Q', (np.float64, 'nan')), # dtype('uint64')
('S', (np.dtype('O'), 'nan')), # dtype('S')
('U', (np.dtype('O'), 'nan')), # dtype('<U')
('V', (np.dtype('O'), 'nan')), # dtype('V')
])
def test_maybe_promote(kind, expected):
# 'g': np.float128 is not tested : not available on all platforms
# 'G': np.complex256 is not tested : not available on all platforms

actual = dtypes.maybe_promote(np.dtype(kind))
assert actual[0] == expected[0]
assert str(actual[1]) == expected[1]

0 comments on commit 69086b3

Please sign in to comment.