-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rolling window with as_strided
#1837
Changes from 22 commits
789134c
fa4e857
52915f3
b622007
36a1fe9
71fed0f
3960134
4bd38f3
af8362e
76db6b5
87f53af
c23cedb
9547c57
1f71cff
724776f
73862eb
859bb5c
d5fc24e
05c72f0
d55e498
9393eb2
9c71a50
54975b4
e907fdf
6482536
b8def4f
ff31589
6c011cb
684145a
3a7526e
a0968d6
ac4f00e
fbfc262
c757986
8fd5fa3
ade5ba2
2d6897f
6461f84
aece1c4
d5ad4a0
4189d71
081c928
75c1d7d
452b219
c5490c4
ab91394
9fa0812
0c1d49a
9463937
dce4e37
b3050cb
22f6d4a
19e0fca
d3b1e2b
2d06ec9
734da93
1a000b8
27ff67c
a2c7141
35dee9d
137709f
cc82cdc
b246411
b3a2105
b80fbfd
3c010ae
ab82f75
b9f10cd
cc9c3d6
52cc48d
2954cdf
f19e531
a074df3
f6f78a5
0ec8aba
0261cfe
a91c27f
c83d588
3bb4668
d0d89ce
eaba563
aeabdf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,13 @@ Documentation | |
|
||
Enhancements | ||
~~~~~~~~~~~~ | ||
- reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype`` | ||
- Improve :py:func:`~xarray.DataArray.rooling` logic for speed up. | ||
:py:func:`~xarray.DataArrayRolling` object now support ``to_dataarray`` | ||
method that returns a view of the DataArray object with the rolling-window | ||
dimension added to the last position. This enables more flexible operation, | ||
such as strided rolling, windowed rolling, ND-rolling, and convolution. | ||
(:issue:`1831`, :issue:`1142`, :issue:`819`) | ||
- reduce methods such as :py:func:`DataArray.sum()` now accept ``dtype`` | ||
arguments. (:issue:`1838`) | ||
By `Keisuke Fujii <https://github.com/fujiisoup>`_. | ||
- Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a bug fix note for the aggregations of the last element with |
||
|
@@ -68,6 +74,10 @@ Enhancements | |
|
||
Bug fixes | ||
~~~~~~~~~ | ||
- Rolling aggregation with ``center=True`` option now gives the same result | ||
with pandas including the last element (:issue:`1046`). | ||
By `Keisuke Fujii <https://github.com/fujiisoup>`_. | ||
|
||
- Added warning in api.py of a netCDF4 bug that occurs when | ||
the filepath has 88 characters (:issue:`1745`). | ||
By `Liam Brannigan <https://github.com/braaannigan>` _. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
import warnings | ||
from . import npcompat | ||
|
||
|
||
def _validate_axis(data, axis): | ||
|
@@ -133,3 +134,52 @@ def __setitem__(self, key, value): | |
mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) | ||
self._array[key] = np.moveaxis(value, vindex_positions, | ||
mixed_positions) | ||
|
||
|
||
def rolling_window(a, axis, window): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a small point, but can you swap the arguments for this function? That would let you set a default axis. Bottleneck uses default arguments like |
||
""" | ||
Make an ndarray with a rolling window along axis. | ||
|
||
Parameters | ||
---------- | ||
a : array_like | ||
Array to add rolling window to | ||
axis: int | ||
axis position along which rolling window will be applied. | ||
window : int | ||
Size of rolling window | ||
|
||
Returns | ||
------- | ||
Array that is a view of the original array with a added dimension | ||
of size w. | ||
|
||
Examples | ||
-------- | ||
>>> x=np.arange(10).reshape((2,5)) | ||
>>> np.rolling_window(x, 3, axis=-1) | ||
array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], | ||
[[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) | ||
|
||
Calculate rolling mean of last dimension: | ||
>>> np.mean(np.rolling_window(x, 3, axis=-1), -1) | ||
array([[ 1., 2., 3.], | ||
[ 6., 7., 8.]]) | ||
|
||
This function is taken from https://github.com/numpy/numpy/pull/31 | ||
but slightly modified to accept axis option. | ||
""" | ||
axis = _validate_axis(a, axis) | ||
a = np.swapaxes(a, axis, -1) | ||
|
||
if window < 1: | ||
raise ValueError( | ||
"`window` must be at least 1. Given : {}".format(window)) | ||
if window > a.shape[-1]: | ||
raise ValueError("`window` is too long. Given : {}".format(window)) | ||
|
||
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) | ||
strides = a.strides + (a.strides[-1],) | ||
rolling = npcompat.as_strided(a, shape=shape, strides=strides, | ||
writeable=False) | ||
return np.swapaxes(rolling, -2, axis) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,20 +227,8 @@ def func(self, *args, **kwargs): | |
|
||
def rolling_count(rolling): | ||
|
||
not_null = rolling.obj.notnull() | ||
instance_attr_dict = {'center': rolling.center, | ||
'min_periods': rolling.min_periods, | ||
rolling.dim: rolling.window} | ||
rolling_count = not_null.rolling(**instance_attr_dict).sum() | ||
|
||
if rolling.min_periods is None: | ||
return rolling_count | ||
|
||
# otherwise we need to filter out points where there aren't enough periods | ||
# but not_null is False, and so the NaNs don't flow through | ||
# array with points where there are enough values given min_periods | ||
enough_periods = rolling_count >= rolling.min_periods | ||
|
||
rolling_count = rolling._counts() | ||
enough_periods = rolling_count > rolling._min_periods - 0.5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.5 is a little strange to see when the other veggies are integers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other variables, not veggies (I blame autocorrect!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. I will fix. |
||
return rolling_count.where(enough_periods) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we show the example by calling
construct('window_dim', stride=2)
instead?