Skip to content

Commit

Permalink
Allow and handle non-parameters in rasterize.instance() (#5811)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Høxbro Hansen <simon.hansen@me.com>
  • Loading branch information
ahuang11 and hoxbro authored Jul 21, 2023
1 parent c717c6d commit 8d6676c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
15 changes: 10 additions & 5 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,20 +1375,25 @@ class rasterize(AggregationOperation):
]

__instance_params = set()
__instance_kwargs = {}

@bothmethod
def instance(self_or_cls, **params):
inst = super().instance(**params)
inst.__instance_params = set(params)
kwargs = set(params) - set(self_or_cls.param)
inst_params = {k: v for k, v in params.items() if k in self_or_cls.param}
inst = super().instance(**inst_params)
inst.__instance_params = set(inst_params)
inst.__instance_kwargs = {k: v for k, v in params.items() if k in kwargs}
return inst

def _process(self, element, key=None):
# Potentially needs traverse to find element types first?
all_allowed_kws = set()
all_supplied_kws = set()
instance_params = {
k: getattr(self, k) for k in self.__instance_params
}
instance_params = dict(
self.__instance_kwargs,
**{k: getattr(self, k) for k in self.__instance_params}
)
for predicate, transform in self._transforms:
merged_param_values = dict(instance_params, **self.p)

Expand Down
24 changes: 23 additions & 1 deletion holoviews/tests/operation/test_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from holoviews import (
Dimension, Curve, Points, Image, Dataset, RGB, Path, Graph, TriMesh,
QuadMesh, NdOverlay, Contours, Spikes, Spread, Area, Rectangles,
Segments, Polygons, Nodes
Segments, Polygons, Nodes, DynamicMap, Overlay
)
from holoviews.util import render
from holoviews.streams import Tap
from holoviews.element.comparison import ComparisonTestCase
from numpy import nan
from holoviews.operation import apply_when
from packaging.version import Version

try:
Expand Down Expand Up @@ -1169,6 +1171,26 @@ def test_rasterize_image_expand_default(self):
output = img.data["z"].to_numpy()
assert np.isnan(output).any()

def test_rasterize_apply_when_instance_with_line_width(self):
df = pd.DataFrame(
np.random.multivariate_normal(
(0, 0), [[0.1, 0.1], [0.1, 1.0]], (100,))
)
df.columns = ["a", "b"]

curve = Curve(df, kdims=["a"], vdims=["b"])
# line_width is not a parameter
custom_rasterize = rasterize.instance(line_width=2)
assert {'line_width': 2} == custom_rasterize._rasterize__instance_kwargs
output = apply_when(
curve, operation=custom_rasterize, predicate=lambda x: len(x) > 10
)
render(output, "bokeh")
assert isinstance(output, DynamicMap)
overlay = output.items()[0][1]
assert isinstance(overlay, Overlay)
assert len(overlay) == 2

class DatashaderSpreadTests(ComparisonTestCase):

def test_spread_rgb_1px(self):
Expand Down

0 comments on commit 8d6676c

Please sign in to comment.