From 1623601f7218ae2c6ff55f1dd492adfd435383a7 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Tue, 3 Jul 2018 17:52:50 +0200 Subject: [PATCH] ENH: expose FRB to enable buffer filtering within a PlotWindow The name of the callback has in general no reason to start by `annotate_`. This is to make space for a smooth callback --- doc/source/visualizing/callbacks.rst | 17 +++++ yt/visualization/_commons.py | 55 +++++++++++++++++ yt/visualization/fixed_resolution.py | 53 +++++++++++----- yt/visualization/fixed_resolution_filters.py | 26 +++----- yt/visualization/plot_container.py | 65 +++----------------- yt/visualization/plot_window.py | 34 +++++++--- yt/visualization/tests/test_filters.py | 42 ++++++++++++- 7 files changed, 193 insertions(+), 99 deletions(-) diff --git a/doc/source/visualizing/callbacks.rst b/doc/source/visualizing/callbacks.rst index 4d657c848d8..b0241c8a2a0 100644 --- a/doc/source/visualizing/callbacks.rst +++ b/doc/source/visualizing/callbacks.rst @@ -889,3 +889,20 @@ Overplot the Path of a Ray p.annotate_ray(oray) p.annotate_ray(ray) p.save() + + +Applying filters on the final image +----------------------------------- + +It is also possible to operate on the plotted image directly by using +one of the fixed resolution buffer filter as described in +:ref:`frb-filters`. + +.. python-script:: + + import yt + + ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030') + p = yt.SlicePlot(ds, 'z', 'density') + p.frb.apply_gauss_beam(sigma=30) + p.save() diff --git a/yt/visualization/_commons.py b/yt/visualization/_commons.py index a890ad642e7..40fc6eb7d61 100644 --- a/yt/visualization/_commons.py +++ b/yt/visualization/_commons.py @@ -1,6 +1,7 @@ import os import sys import warnings +from functools import wraps from typing import Optional, Type import matplotlib @@ -93,3 +94,57 @@ def get_canvas(figure, filename): f"without an extension." ) return get_canvas_class(suffix)(figure) + + +def invalidate_plot(f): + @wraps(f) + def newfunc(self, *args, **kwargs): + retv = f(self, *args, **kwargs) + self._plot_valid = False + return retv + + return newfunc + + +def invalidate_data(f): + @wraps(f) + def newfunc(self, *args, **kwargs): + retv = f(self, *args, **kwargs) + self._data_valid = False + self._plot_valid = False + return retv + + return newfunc + + +def invalidate_figure(f): + @wraps(f) + def newfunc(self, *args, **kwargs): + retv = f(self, *args, **kwargs) + for field in self.plots.keys(): + self.plots[field].figure = None + self.plots[field].axes = None + self.plots[field].cax = None + self._setup_plots() + return retv + + return newfunc + + +def validate_plot(f): + @wraps(f) + def newfunc(self, *args, **kwargs): + # TODO: _profile_valid and _data_valid seem to play very similar roles, + # there's probably room to abstract these into a common operation + if hasattr(self, "_data_valid") and not self._data_valid: + self._recreate_frb() + if hasattr(self, "_profile_valid") and not self._profile_valid: + self._recreate_profile() + if not self._plot_valid: + # it is the responsibility of _setup_plots to + # call plot.run_callbacks() + self._setup_plots() + retv = f(self, *args, **kwargs) + return retv + + return newfunc diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index db365e340f9..8dadaf8a999 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -1,5 +1,6 @@ import weakref from functools import wraps +from typing import Dict, List, Optional import numpy as np @@ -14,7 +15,11 @@ from yt.utilities.lib.pixelization_routines import pixelize_cylinder from yt.utilities.on_demand_imports import _h5py as h5py -from .fixed_resolution_filters import apply_filter, filter_registry +from .fixed_resolution_filters import ( + FixedResolutionBufferFilter, + apply_filter, + filter_registry, +) from .volume_rendering.api import off_axis_projection @@ -53,6 +58,8 @@ class FixedResolutionBuffer: This can be true or false, and governs whether the pixelization will span the domain boundaries. + filters : list of FixedResolutionBufferFilter objects (optional) + Examples -------- To make a projection and then several images, you can generate a @@ -90,16 +97,27 @@ class FixedResolutionBuffer: ("index", "dtheta"), ) - def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False): + def __init__( + self, + data_source, + bounds, + buff_size, + antialias=True, + periodic=False, + *, + filters: Optional[List[FixedResolutionBufferFilter]] = None, + ): self.data_source = data_source self.ds = data_source.ds self.bounds = bounds self.buff_size = (int(buff_size[0]), int(buff_size[1])) self.antialias = antialias - self.data = {} + self.data: Dict[str, np.ndarray] = {} self._filters = [] self.axis = data_source.axis self.periodic = periodic + self._data_valid = False + self._filters = filters if filters is not None else [] ds = getattr(data_source, "ds", None) if ds is not None: @@ -125,7 +143,7 @@ def __delitem__(self, item): del self.data[item] def __getitem__(self, item): - if item in self.data: + if item in self.data and self._data_valid: return self.data[item] mylog.info( "Making a fixed resolution buffer of (%s) %d by %d", @@ -165,6 +183,7 @@ def __getitem__(self, item): ia = ImageArray(buff, units=units, info=self._get_info(item)) self.data[item] = ia + self._data_valid = True return self.data[item] def __setitem__(self, item, val): @@ -542,14 +561,14 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer): that supports non-aligned input data objects, primarily cutting planes. """ - def __init__(self, data_source, radius, buff_size, antialias=True): - + def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None): self.data_source = data_source self.ds = data_source.ds self.radius = radius self.buff_size = buff_size self.antialias = antialias self.data = {} + self._filters = filters if filters is not None else [] ds = getattr(data_source, "ds", None) if ds is not None: @@ -579,12 +598,6 @@ class OffAxisProjectionFixedResolutionBuffer(FixedResolutionBuffer): that supports off axis projections. This calls the volume renderer. """ - def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False): - self.data = {} - FixedResolutionBuffer.__init__( - self, data_source, bounds, buff_size, antialias, periodic - ) - def __getitem__(self, item): if item in self.data: return self.data[item] @@ -631,10 +644,18 @@ class ParticleImageBuffer(FixedResolutionBuffer): """ - def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False): - self.data = {} - FixedResolutionBuffer.__init__( - self, data_source, bounds, buff_size, antialias, periodic + def __init__( + self, + data_source, + bounds, + buff_size, + antialias=True, + periodic=False, + *, + filters=None, + ): + super().__init__( + data_source, bounds, buff_size, antialias, periodic, filters=filters ) # set up the axis field names diff --git a/yt/visualization/fixed_resolution_filters.py b/yt/visualization/fixed_resolution_filters.py index d69dc8fafdd..e532b4aed45 100644 --- a/yt/visualization/fixed_resolution_filters.py +++ b/yt/visualization/fixed_resolution_filters.py @@ -7,9 +7,11 @@ def apply_filter(f): @wraps(f) - def newfunc(*args, **kwargs): - args[0]._filters.append((f.__name__, (args, kwargs))) - return args[0] + def newfunc(frb, *args, **kwargs): + frb._filters.append((f.__name__, (args, kwargs))) + # Invalidate the data of the frb to force its regeneration + frb._data_valid = False + return frb return newfunc @@ -50,20 +52,12 @@ def __init__(self, nbeam=30, sigma=2.0): def apply(self, buff): from yt.utilities.on_demand_imports import _scipy - hnbeam = self.nbeam // 2 - sigma = self.sigma - - l = np.linspace(-hnbeam, hnbeam, num=self.nbeam + 1) - x, y = np.meshgrid(l, l) - g2d = (1.0 / (sigma * np.sqrt(2.0 * np.pi))) * np.exp( - -((x / sigma) ** 2 + (y / sigma) ** 2) / (2 * sigma ** 2) + spl = _scipy.ndimage.gaussian_filter( + buff, + self.sigma, + truncate=self.nbeam / self.sigma, ) - g2d /= g2d.max() - - npm, nqm = np.shape(buff) - spl = _scipy.signal.convolve(buff, g2d) - - return spl[hnbeam : npm + hnbeam, hnbeam : nqm + hnbeam] + return spl class FixedResolutionBufferWhiteNoiseFilter(FixedResolutionBufferFilter): diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 4536afac031..6f36e09a4e3 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -21,67 +21,20 @@ from yt.utilities.definitions import formatted_length_unit_names from yt.utilities.exceptions import YTNotInsideNotebook -from ._commons import DEFAULT_FONT_PROPERTIES, validate_image_name +from ._commons import ( + DEFAULT_FONT_PROPERTIES, + invalidate_data, + invalidate_figure, + invalidate_plot, + validate_image_name, + validate_plot, +) latex_prefixes = { "u": r"\mu", } -def invalidate_data(f): - @wraps(f) - def newfunc(*args, **kwargs): - rv = f(*args, **kwargs) - args[0]._data_valid = False - args[0]._plot_valid = False - return rv - - return newfunc - - -def invalidate_figure(f): - @wraps(f) - def newfunc(*args, **kwargs): - rv = f(*args, **kwargs) - for field in args[0].plots.keys(): - args[0].plots[field].figure = None - args[0].plots[field].axes = None - args[0].plots[field].cax = None - args[0]._setup_plots() - return rv - - return newfunc - - -def invalidate_plot(f): - @wraps(f) - def newfunc(*args, **kwargs): - rv = f(*args, **kwargs) - args[0]._plot_valid = False - return rv - - return newfunc - - -def validate_plot(f): - @wraps(f) - def newfunc(*args, **kwargs): - if hasattr(args[0], "_data_valid"): - if not args[0]._data_valid: - args[0]._recreate_frb() - if hasattr(args[0], "_profile_valid"): - if not args[0]._profile_valid: - args[0]._recreate_profile() - if not args[0]._plot_valid: - # it is the responsibility of _setup_plots to - # call args[0].run_callbacks() - args[0]._setup_plots() - rv = f(*args, **kwargs) - return rv - - return newfunc - - def apply_callback(f): @wraps(f) def newfunc(*args, **kwargs): @@ -373,6 +326,7 @@ def _initialize_dataset(self, ts): ts = DatasetSeries(ts) return ts + @invalidate_data def _switch_ds(self, new_ds, data_source=None): old_object = self.data_source name = old_object._type_name @@ -399,7 +353,6 @@ def _switch_ds(self, new_ds, data_source=None): new_object = getattr(new_ds, name)(**kwargs) self.data_source = new_object - self._data_valid = self._plot_valid = False for d in "xyz": lim_name = d + "lim" diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 315c094a9dc..44b54b432bc 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -273,7 +273,11 @@ def piter(self, *args, **kwargs): @property def frb(self): - if self._frb is None or not self._data_valid: + # Force the regeneration of the fixed resolution buffer + # * if there's none + # * if the data has been invalidated + # * if the frb has been inalidated + if not self._data_valid: self._recreate_frb() return self._frb @@ -286,15 +290,15 @@ def frb(self, value): def frb(self): del self._frb self._frb = None - self._data_valid = False def _recreate_frb(self): old_fields = None + old_filters = [] # If we are regenerating an frb, we want to know what fields we had before if self._frb is not None: - old_fields = list(self._frb.keys()) - old_units = [str(self._frb[of].units) for of in old_fields] - + old_fields = list(self._frb.data.keys()) + old_units = [_.units for _ in self._frb.data.values()] + old_filters = self._frb._filters # Set the bounds if hasattr(self, "zlim"): bounds = self.xlim + self.ylim + self.zlim @@ -308,6 +312,7 @@ def _recreate_frb(self): self.buff_size, self.antialias, periodic=self._periodic, + filters=old_filters, ) # At this point the frb has the valid bounds, size, aliasing, etc. @@ -474,9 +479,6 @@ def set_unit(self, field, new_unit, equivalency=None, equivalency_kwargs=None): Keyword arguments to be passed to the equivalency. Only used if ``equivalency`` is set. """ - if equivalency_kwargs is None: - equivalency_kwargs = {} - field = self.data_source._determine_fields(field)[0] for f, u in zip_equal(iter_fields(field), always_iterable(new_unit)): self.frb.set_unit(f, u, equivalency, equivalency_kwargs) self._equivalencies[f] = (equivalency, equivalency_kwargs) @@ -848,7 +850,6 @@ class PWViewerMPL(PlotWindow): _current_field = None _frb_generator: Optional[Type[FixedResolutionBuffer]] = None _plot_type: Optional[str] = None - _data_valid = False def __init__(self, *args, **kwargs): if self._frb_generator is None: @@ -856,8 +857,22 @@ def __init__(self, *args, **kwargs): if self._plot_type is None: self._plot_type = kwargs.pop("plot_type") self._splat_color = kwargs.pop("splat_color", None) + self._frb: Optional[FixedResolutionBuffer] = None PlotWindow.__init__(self, *args, **kwargs) + @property + def _data_valid(self) -> bool: + return self._frb is not None and self._frb._data_valid + + @_data_valid.setter + def _data_valid(self, value): + if self._frb is None: + # we delegate the (in)validation responsability to the FRB + # if we don't have one yet, we can exit without doing anything + return + else: + self._frb._data_valid = value + def _setup_origin(self): origin = self.origin axis_index = self.data_source.axis @@ -978,7 +993,6 @@ def _setup_plots(self): return if not self._data_valid: self._recreate_frb() - self._data_valid = True self._colorbar_valid = True for f in list(set(self.data_source._determine_fields(self.fields))): axis_index = self.data_source.axis diff --git a/yt/visualization/tests/test_filters.py b/yt/visualization/tests/test_filters.py index e256ad96e89..10b6099fce1 100644 --- a/yt/visualization/tests/test_filters.py +++ b/yt/visualization/tests/test_filters.py @@ -3,6 +3,9 @@ """ +import numpy as np + +import yt from yt.testing import fake_amr_ds, requires_module @@ -21,5 +24,42 @@ def test_gauss_beam_filter(): ds = fake_amr_ds(fields=("density",), units=("g/cm**3",)) p = ds.proj(("gas", "density"), "z") frb = p.to_frb((1, "unitary"), 64) - frb.apply_gauss_beam(nbeam=15, sigma=1.0) + frb.apply_gauss_beam(sigma=1.0) frb[("gas", "density")] + + +@requires_module("scipy") +def test_filter_wiring(): + from scipy.ndimage import gaussian_filter + + ds = fake_amr_ds(fields=[("gas", "density")], units=["g/cm**3"]) + p = yt.SlicePlot(ds, "x", "density") + + # Note: frb is a FixedResolutionBuffer object + frb1 = p.frb + data_orig = frb1["density"].value + + sigma = 2 + nbeam = 30 + p.frb.apply_gauss_beam(nbeam=nbeam, sigma=sigma) + frb2 = p.frb + data_gauss = frb2["density"].value + + p.frb.apply_white_noise() + frb3 = p.frb + data_white = frb3["density"].value + + # We check the frb objects are different + assert frb1 is not frb2 + assert frb1 is not frb3 + assert frb2 is not frb3 + + # We check the resulting image are different each time + assert not np.allclose(data_orig, data_gauss) + assert not np.allclose(data_orig, data_white) + assert not np.allclose(data_gauss, data_white) + + # Check the gaussian filtering is ok + assert np.allclose( + gaussian_filter(data_orig, sigma, truncate=nbeam / sigma), data_gauss + )