Skip to content

Commit

Permalink
Add new where reduction (#1155)
Browse files Browse the repository at this point in the history
* Basics of where reduction

* Where with min or max

* Add bool returns to all append functions

* Tidy up

* where cannot use same column as its contained reduction

* Dask combine for cpu where reduction

* More sensible arg names

* Add new tests

* Dask where.combine for multiple-stage combines

* Support antialiasing, including dask cpu

* cuda tests

* Temporarily disable test_dask.py::test_line_antialias_where

* Fix antialiased where reduction tests

* Exclude elements that may vary from dask line antialias where tests

* where docstring

* Exclude docstring example from doctest

* Improved exception message
  • Loading branch information
ianthomas23 authored Jan 16, 2023
1 parent 645ae07 commit 2e0f8e0
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 27 deletions.
42 changes: 34 additions & 8 deletions datashader/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import xarray as xr

from .reductions import by, category_codes, summary
from .reductions import by, category_codes, summary, where
from .utils import ngjit


Expand Down Expand Up @@ -76,18 +76,20 @@ def compile_components(agg, schema, glyph, *, antialias=False, cuda=False):
self_intersect = False
antialias_stage_2 = False

# List of tuples of (append, base, input columns, temps)
# List of tuples of (append, base, input columns, temps, combine temps)
calls = [_get_call_tuples(b, d, schema, cuda, antialias, self_intersect)
for (b, d) in zip(bases, dshapes)]

# List of unique column names needed
cols = list(unique(concat(pluck(2, calls))))
# List of temps needed
temps = list(pluck(3, calls))
combine_temps = list(pluck(4, calls))

create = make_create(bases, dshapes, cuda)
info = make_info(cols)
append = make_append(bases, cols, calls, glyph, isinstance(agg, by), antialias)
combine = make_combine(bases, dshapes, temps, antialias)
combine = make_combine(bases, dshapes, temps, combine_temps, antialias)
finalize = make_finalize(bases, agg, schema, cuda)

return create, info, append, combine, finalize, antialias_stage_2
Expand All @@ -110,6 +112,7 @@ def _get_call_tuples(base, dshape, schema, cuda, antialias, self_intersect):
(base,), # bases
base.inputs, # cols
base._build_temps(cuda), # temps
base._build_combine_temps(cuda), # combine temps
)


Expand Down Expand Up @@ -141,7 +144,7 @@ def make_append(bases, cols, calls, glyph, categorical, antialias):
else:
subscript = None

for func, bases, cols, temps in calls:
for func, bases, cols, temps, _ in calls:
local_lk.update(zip(temps, (next(names) for i in temps)))
func_name = next(names)
namespace[func_name] = func
Expand All @@ -163,6 +166,13 @@ def make_append(bases, cols, calls, glyph, categorical, antialias):

body.append('{0}(x, y, {1})'.format(func_name, ', '.join(args)))

where_reduction = len(bases) == 1 and isinstance(bases[0], where)
if where_reduction:
# where reduction needs access to the return of the contained
# reduction, which is the preceding one here.
body[-2] = 'if ' + body[-2] + ':'
body[-1] = ' ' + body[-1]

body = ['{0} = {1}[y, x]'.format(name, arg_lk[agg])
for agg, name in local_lk.items()] + body

Expand All @@ -187,14 +197,30 @@ def make_append(bases, cols, calls, glyph, categorical, antialias):
return ngjit(namespace['append'])


def make_combine(bases, dshapes, temps, antialias):
def make_combine(bases, dshapes, temps, combine_temps, antialias):
arg_lk = dict((k, v) for (v, k) in enumerate(bases))
calls = [(b._build_combine(d, antialias), [arg_lk[i] for i in (b,) + t])
for (b, d, t) in zip(bases, dshapes, temps)]

# where._combine() deals with combine of preceding reduction so exclude
# it from explicit combine calls.
base_is_where = [isinstance(b, where) for b in bases]
next_base_is_where = base_is_where[1:] + [False]
calls = [(None if n else b._build_combine(d, antialias), [arg_lk[i] for i in (b,) + t + ct])
for (b, d, t, ct, n) in zip(bases, dshapes, temps, combine_temps, next_base_is_where)]

def combine(base_tuples):
bases = tuple(np.stack(bs) for bs in zip(*base_tuples))
return tuple(f(*get(inds, bases)) for (f, inds) in calls)
ret = []
for is_where, (func, inds) in zip(base_is_where, calls):
if func is None:
continue
call = func(*get(inds, bases))
if is_where:
# Separate aggs of where reduction and its selector,
# selector's goes first to match order of bases.
ret.extend(call[::-1])
else:
ret.append(call)
return tuple(ret)

return combine

Expand Down
3 changes: 2 additions & 1 deletion datashader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def line(self, source, x=None, y=None, agg=None, axis=0, geometry=None,

if not isinstance(non_cat_agg, (
rd.any, rd.count, rd.max, rd.min, rd.sum, rd.summary, rd._sum_zero,
rd.first, rd.last, rd.mean
rd.first, rd.last, rd.mean, rd.where
)):
raise NotImplementedError(
f"{type(non_cat_agg)} reduction not implemented for antialiased lines")
Expand Down Expand Up @@ -1276,6 +1276,7 @@ def _bypixel_sanitise(source, glyph, agg):
# by only retaining the necessary columns:
# https://github.com/bokeh/datashader/issues/396
# Preserve column ordering without duplicates

cols_to_keep = _cols_to_keep(source.columns, glyph, agg)
if len(cols_to_keep) < len(source.columns):
# If _sindex is set, ensure it is not dropped
Expand Down
Loading

0 comments on commit 2e0f8e0

Please sign in to comment.