From dd930d6f424ef59199a715dd990a7f1b43a44d68 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 4 Jun 2017 05:39:31 -0500 Subject: [PATCH] PERF: vectorize _interp_limit (#16592) * PERF: vectorize _interp_limit * CLN: remove old implementation * fixup! CLN: remove old implementation --- pandas/core/missing.py | 77 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 10 deletions(-) diff --git a/pandas/core/missing.py b/pandas/core/missing.py index 51778684d68f5c..5aabc9d8730dd5 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -143,12 +143,6 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None, 'DatetimeIndex') method = 'values' - def _interp_limit(invalid, fw_limit, bw_limit): - "Get idx of values that won't be filled b/c they exceed the limits." - for x in np.where(invalid)[0]: - if invalid[max(0, x - fw_limit):x + bw_limit + 1].all(): - yield x - valid_limit_directions = ['forward', 'backward', 'both'] limit_direction = limit_direction.lower() if limit_direction not in valid_limit_directions: @@ -180,21 +174,29 @@ def _interp_limit(invalid, fw_limit, bw_limit): # default limit is unlimited GH #16282 if limit is None: - limit = len(xvalues) + # limit = len(xvalues) + pass elif not is_integer(limit): raise ValueError('Limit must be an integer') elif limit < 1: raise ValueError('Limit must be greater than 0') # each possible limit_direction - if limit_direction == 'forward': + # TODO: do we need sorted? + if limit_direction == 'forward' and limit is not None: violate_limit = sorted(start_nans | set(_interp_limit(invalid, limit, 0))) - elif limit_direction == 'backward': + elif limit_direction == 'forward': + violate_limit = sorted(start_nans) + elif limit_direction == 'backward' and limit is not None: violate_limit = sorted(end_nans | set(_interp_limit(invalid, 0, limit))) - elif limit_direction == 'both': + elif limit_direction == 'backward': + violate_limit = sorted(end_nans) + elif limit_direction == 'both' and limit is not None: violate_limit = sorted(_interp_limit(invalid, limit, limit)) + else: + violate_limit = [] xvalues = getattr(xvalues, 'values', xvalues) yvalues = getattr(yvalues, 'values', yvalues) @@ -630,3 +632,58 @@ def fill_zeros(result, x, y, name, fill): result = result.reshape(shape) return result + + +def _interp_limit(invalid, fw_limit, bw_limit): + """Get idx of values that won't be filled b/c they exceed the limits. + + This is equivalent to the more readable, but slower + + .. code-block:: python + + for x in np.where(invalid)[0]: + if invalid[max(0, x - fw_limit):x + bw_limit + 1].all(): + yield x + """ + # handle forward first; the backward direction is the same except + # 1. operate on the reversed array + # 2. subtract the returned indicies from N - 1 + N = len(invalid) + + def inner(invalid, limit): + limit = min(limit, N) + windowed = _rolling_window(invalid, limit + 1).all(1) + idx = (set(np.where(windowed)[0] + limit) | + set(np.where((~invalid[:limit + 1]).cumsum() == 0)[0])) + return idx + + if fw_limit == 0: + f_idx = set(np.where(invalid)[0]) + else: + f_idx = inner(invalid, fw_limit) + + if bw_limit == 0: + # then we don't even need to care about backwards, just use forwards + return f_idx + else: + b_idx = set(N - 1 - np.asarray(list(inner(invalid[::-1], bw_limit)))) + if fw_limit == 0: + return b_idx + return f_idx & b_idx + + +def _rolling_window(a, window): + """ + [True, True, False, True, False], 2 -> + + [ + [True, True], + [True, False], + [False, True], + [True, False], + ] + """ + # https://stackoverflow.com/a/6811241 + shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) + strides = a.strides + (a.strides[-1],) + return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)