From 8d6676ca36b70e6fd008bbec91b83bebc013930b Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Fri, 21 Jul 2023 02:44:14 -0400 Subject: [PATCH] Allow and handle non-parameters in rasterize.instance() (#5811) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Simon Høxbro Hansen --- holoviews/operation/datashader.py | 15 ++++++++---- holoviews/tests/operation/test_datashader.py | 24 +++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/holoviews/operation/datashader.py b/holoviews/operation/datashader.py index 79eaf46b1a..647efddfc6 100644 --- a/holoviews/operation/datashader.py +++ b/holoviews/operation/datashader.py @@ -1375,20 +1375,25 @@ class rasterize(AggregationOperation): ] __instance_params = set() + __instance_kwargs = {} @bothmethod def instance(self_or_cls, **params): - inst = super().instance(**params) - inst.__instance_params = set(params) + kwargs = set(params) - set(self_or_cls.param) + inst_params = {k: v for k, v in params.items() if k in self_or_cls.param} + inst = super().instance(**inst_params) + inst.__instance_params = set(inst_params) + inst.__instance_kwargs = {k: v for k, v in params.items() if k in kwargs} return inst def _process(self, element, key=None): # Potentially needs traverse to find element types first? all_allowed_kws = set() all_supplied_kws = set() - instance_params = { - k: getattr(self, k) for k in self.__instance_params - } + instance_params = dict( + self.__instance_kwargs, + **{k: getattr(self, k) for k in self.__instance_params} + ) for predicate, transform in self._transforms: merged_param_values = dict(instance_params, **self.p) diff --git a/holoviews/tests/operation/test_datashader.py b/holoviews/tests/operation/test_datashader.py index 795c6f8035..7d85920ebf 100644 --- a/holoviews/tests/operation/test_datashader.py +++ b/holoviews/tests/operation/test_datashader.py @@ -9,11 +9,13 @@ from holoviews import ( Dimension, Curve, Points, Image, Dataset, RGB, Path, Graph, TriMesh, QuadMesh, NdOverlay, Contours, Spikes, Spread, Area, Rectangles, - Segments, Polygons, Nodes + Segments, Polygons, Nodes, DynamicMap, Overlay ) +from holoviews.util import render from holoviews.streams import Tap from holoviews.element.comparison import ComparisonTestCase from numpy import nan +from holoviews.operation import apply_when from packaging.version import Version try: @@ -1169,6 +1171,26 @@ def test_rasterize_image_expand_default(self): output = img.data["z"].to_numpy() assert np.isnan(output).any() + def test_rasterize_apply_when_instance_with_line_width(self): + df = pd.DataFrame( + np.random.multivariate_normal( + (0, 0), [[0.1, 0.1], [0.1, 1.0]], (100,)) + ) + df.columns = ["a", "b"] + + curve = Curve(df, kdims=["a"], vdims=["b"]) + # line_width is not a parameter + custom_rasterize = rasterize.instance(line_width=2) + assert {'line_width': 2} == custom_rasterize._rasterize__instance_kwargs + output = apply_when( + curve, operation=custom_rasterize, predicate=lambda x: len(x) > 10 + ) + render(output, "bokeh") + assert isinstance(output, DynamicMap) + overlay = output.items()[0][1] + assert isinstance(overlay, Overlay) + assert len(overlay) == 2 + class DatashaderSpreadTests(ComparisonTestCase): def test_spread_rgb_1px(self):