-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rolling window with as_strided
#1837
Changes from 70 commits
789134c
fa4e857
52915f3
b622007
36a1fe9
71fed0f
3960134
4bd38f3
af8362e
76db6b5
87f53af
c23cedb
9547c57
1f71cff
724776f
73862eb
859bb5c
d5fc24e
05c72f0
d55e498
9393eb2
9c71a50
54975b4
e907fdf
6482536
b8def4f
ff31589
6c011cb
684145a
3a7526e
a0968d6
ac4f00e
fbfc262
c757986
8fd5fa3
ade5ba2
2d6897f
6461f84
aece1c4
d5ad4a0
4189d71
081c928
75c1d7d
452b219
c5490c4
ab91394
9fa0812
0c1d49a
9463937
dce4e37
b3050cb
22f6d4a
19e0fca
d3b1e2b
2d06ec9
734da93
1a000b8
27ff67c
a2c7141
35dee9d
137709f
cc82cdc
b246411
b3a2105
b80fbfd
3c010ae
ab82f75
b9f10cd
cc9c3d6
52cc48d
2954cdf
f19e531
a074df3
f6f78a5
0ec8aba
0261cfe
a91c27f
c83d588
3bb4668
d0d89ce
eaba563
aeabdf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
from . import parameterized, randn, requires_dask | ||
|
||
nx = 3000 | ||
ny = 2000 | ||
nt = 1000 | ||
window = 20 | ||
|
||
|
||
class Rolling(object): | ||
def setup(self, *args, **kwargs): | ||
self.ds = xr.Dataset( | ||
{'var1': (('x', 'y'), randn((nx, ny), frac_nan=0.1)), | ||
'var2': (('x', 't'), randn((nx, nt))), | ||
'var3': (('t', ), randn(nt))}, | ||
coords={'x': np.arange(nx), | ||
'y': np.linspace(0, 1, ny), | ||
't': pd.date_range('1970-01-01', periods=nt, freq='D'), | ||
'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) | ||
|
||
@parameterized(['func', 'center'], | ||
(['mean', 'count'], [True, False])) | ||
def time_rolling(self, func, center): | ||
getattr(self.ds.rolling(x=window, center=center), func)() | ||
|
||
@parameterized(['window_', 'min_periods'], | ||
([20, 40], [5, None])) | ||
def time_rolling_np(self, window_, min_periods): | ||
self.ds.rolling(x=window_, center=False, | ||
min_periods=min_periods).reduce(getattr(np, 'nanmean')) | ||
|
||
@parameterized(['center', 'stride'], | ||
([True, False], [1, 200])) | ||
def time_rolling_to_dataset(self, center, stride): | ||
self.ds.rolling(x=window, center=center).to_dataset( | ||
'window_dim', stride=stride).mean(dim='window_dim') | ||
|
||
|
||
class RollingDask(Rolling): | ||
def setup(self, *args, **kwargs): | ||
requires_dask() | ||
super(RollingDask, self).setup(**kwargs) | ||
self.ds = self.ds.chunk({'x': 100, 'y': 50, 't': 50}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
"""Define core operations for xarray objects. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import numpy as np | ||
from . import nputils | ||
|
||
try: | ||
import dask.array as da | ||
|
@@ -24,3 +26,44 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): | |
# trim array | ||
result = da.ghost.trim_internal(out, depth) | ||
return result | ||
|
||
|
||
def rolling_window(a, window, axis=-1): | ||
""" Dask's equivalence to np.utils.rolling_window """ | ||
# inputs for ghost | ||
if axis < 0: | ||
axis = a.ndim + axis | ||
depth = {d: 0 for d in range(a.ndim)} | ||
if window % 2 == 0: | ||
depth[axis] = int((window - 1) / 2 + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. W1619 division w/o future statement |
||
offset = 1 | ||
else: | ||
depth[axis] = int((window - 1) / 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. W1619 division w/o future statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should indeed add these future imports:
|
||
offset = 0 | ||
|
||
if depth[axis] > min(a.chunks[axis]): | ||
raise ValueError( | ||
"The window size %d is larger than your\n" | ||
"smallest chunk size %d + 1. Rechunk your array\n" | ||
"with a larger chunk size or a chunk size that\n" | ||
"more evenly divides the shape of your array." % | ||
(window, min(a.chunks[axis]))) | ||
|
||
boundary = {d: np.nan for d in range(a.ndim)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use this |
||
# create ghosted arrays | ||
ag = da.ghost.ghost(a, depth=depth, boundary=boundary) | ||
|
||
# apply rolling func | ||
def func(x, window, axis=-1): | ||
x = np.asarray(x) | ||
rolling = nputils.rolling_window(x, window, axis) | ||
return rolling[(slice(None), ) * axis + (slice(offset, None), )] | ||
|
||
chunks = list(a.chunks) | ||
chunks.append(window) | ||
out = ag.map_blocks(func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, | ||
window=window, axis=axis) | ||
# crop the edge points | ||
index = (slice(None),) * axis + (slice(depth[axis] - offset, | ||
- depth[axis]),) | ||
return out[index] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
from .computation import apply_ufunc | ||
from .utils import is_scalar | ||
from .npcompat import flip | ||
from . import rolling | ||
|
||
|
||
class BaseInterpolator(object): | ||
|
@@ -329,4 +330,8 @@ def _get_valid_fill_mask(arr, dim, limit): | |
'''helper function to determine values that can be filled when limit is not | ||
None''' | ||
kw = {dim: limit + 1} | ||
return arr.isnull().rolling(min_periods=1, **kw).sum() <= limit | ||
# we explicitly use construct method to avoid copy. | ||
new_dim = rolling._get_new_dimname(arr.dims, '_window') | ||
return (arr.isnull().rolling(min_periods=1, **kw) | ||
.construct(new_dim, fill_value=False) | ||
.sum(new_dim, skipna=False)) <= limit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my environment, this is faster than the original one with bottlneck (especially for small arrays) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!