Skip to content

Commit

Permalink
ENH: expose FRB to enable buffer filtering within a PlotWindow
Browse files Browse the repository at this point in the history
The name of the callback has in general no reason to start by `annotate_`. This is
to make space for a smooth callback
  • Loading branch information
cphyc authored and neutrinoceros committed Jan 23, 2022
1 parent 762afa8 commit 1623601
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 99 deletions.
17 changes: 17 additions & 0 deletions doc/source/visualizing/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -889,3 +889,20 @@ Overplot the Path of a Ray
p.annotate_ray(oray)
p.annotate_ray(ray)
p.save()


Applying filters on the final image
-----------------------------------

It is also possible to operate on the plotted image directly by using
one of the fixed resolution buffer filter as described in
:ref:`frb-filters`.

.. python-script::

import yt

ds = yt.load('IsolatedGalaxy/galaxy0030/galaxy0030')
p = yt.SlicePlot(ds, 'z', 'density')
p.frb.apply_gauss_beam(sigma=30)
p.save()
55 changes: 55 additions & 0 deletions yt/visualization/_commons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import warnings
from functools import wraps
from typing import Optional, Type

import matplotlib
Expand Down Expand Up @@ -93,3 +94,57 @@ def get_canvas(figure, filename):
f"without an extension."
)
return get_canvas_class(suffix)(figure)


def invalidate_plot(f):
@wraps(f)
def newfunc(self, *args, **kwargs):
retv = f(self, *args, **kwargs)
self._plot_valid = False
return retv

return newfunc


def invalidate_data(f):
@wraps(f)
def newfunc(self, *args, **kwargs):
retv = f(self, *args, **kwargs)
self._data_valid = False
self._plot_valid = False
return retv

return newfunc


def invalidate_figure(f):
@wraps(f)
def newfunc(self, *args, **kwargs):
retv = f(self, *args, **kwargs)
for field in self.plots.keys():
self.plots[field].figure = None
self.plots[field].axes = None
self.plots[field].cax = None
self._setup_plots()
return retv

return newfunc


def validate_plot(f):
@wraps(f)
def newfunc(self, *args, **kwargs):
# TODO: _profile_valid and _data_valid seem to play very similar roles,
# there's probably room to abstract these into a common operation
if hasattr(self, "_data_valid") and not self._data_valid:
self._recreate_frb()
if hasattr(self, "_profile_valid") and not self._profile_valid:
self._recreate_profile()
if not self._plot_valid:
# it is the responsibility of _setup_plots to
# call plot.run_callbacks()
self._setup_plots()
retv = f(self, *args, **kwargs)
return retv

return newfunc
53 changes: 37 additions & 16 deletions yt/visualization/fixed_resolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import weakref
from functools import wraps
from typing import Dict, List, Optional

import numpy as np

Expand All @@ -14,7 +15,11 @@
from yt.utilities.lib.pixelization_routines import pixelize_cylinder
from yt.utilities.on_demand_imports import _h5py as h5py

from .fixed_resolution_filters import apply_filter, filter_registry
from .fixed_resolution_filters import (
FixedResolutionBufferFilter,
apply_filter,
filter_registry,
)
from .volume_rendering.api import off_axis_projection


Expand Down Expand Up @@ -53,6 +58,8 @@ class FixedResolutionBuffer:
This can be true or false, and governs whether the pixelization
will span the domain boundaries.
filters : list of FixedResolutionBufferFilter objects (optional)
Examples
--------
To make a projection and then several images, you can generate a
Expand Down Expand Up @@ -90,16 +97,27 @@ class FixedResolutionBuffer:
("index", "dtheta"),
)

def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False):
def __init__(
self,
data_source,
bounds,
buff_size,
antialias=True,
periodic=False,
*,
filters: Optional[List[FixedResolutionBufferFilter]] = None,
):
self.data_source = data_source
self.ds = data_source.ds
self.bounds = bounds
self.buff_size = (int(buff_size[0]), int(buff_size[1]))
self.antialias = antialias
self.data = {}
self.data: Dict[str, np.ndarray] = {}
self._filters = []
self.axis = data_source.axis
self.periodic = periodic
self._data_valid = False
self._filters = filters if filters is not None else []

ds = getattr(data_source, "ds", None)
if ds is not None:
Expand All @@ -125,7 +143,7 @@ def __delitem__(self, item):
del self.data[item]

def __getitem__(self, item):
if item in self.data:
if item in self.data and self._data_valid:
return self.data[item]
mylog.info(
"Making a fixed resolution buffer of (%s) %d by %d",
Expand Down Expand Up @@ -165,6 +183,7 @@ def __getitem__(self, item):

ia = ImageArray(buff, units=units, info=self._get_info(item))
self.data[item] = ia
self._data_valid = True
return self.data[item]

def __setitem__(self, item, val):
Expand Down Expand Up @@ -542,14 +561,14 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer):
that supports non-aligned input data objects, primarily cutting planes.
"""

def __init__(self, data_source, radius, buff_size, antialias=True):

def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None):
self.data_source = data_source
self.ds = data_source.ds
self.radius = radius
self.buff_size = buff_size
self.antialias = antialias
self.data = {}
self._filters = filters if filters is not None else []

ds = getattr(data_source, "ds", None)
if ds is not None:
Expand Down Expand Up @@ -579,12 +598,6 @@ class OffAxisProjectionFixedResolutionBuffer(FixedResolutionBuffer):
that supports off axis projections. This calls the volume renderer.
"""

def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False):
self.data = {}
FixedResolutionBuffer.__init__(
self, data_source, bounds, buff_size, antialias, periodic
)

def __getitem__(self, item):
if item in self.data:
return self.data[item]
Expand Down Expand Up @@ -631,10 +644,18 @@ class ParticleImageBuffer(FixedResolutionBuffer):
"""

def __init__(self, data_source, bounds, buff_size, antialias=True, periodic=False):
self.data = {}
FixedResolutionBuffer.__init__(
self, data_source, bounds, buff_size, antialias, periodic
def __init__(
self,
data_source,
bounds,
buff_size,
antialias=True,
periodic=False,
*,
filters=None,
):
super().__init__(
data_source, bounds, buff_size, antialias, periodic, filters=filters
)

# set up the axis field names
Expand Down
26 changes: 10 additions & 16 deletions yt/visualization/fixed_resolution_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

def apply_filter(f):
@wraps(f)
def newfunc(*args, **kwargs):
args[0]._filters.append((f.__name__, (args, kwargs)))
return args[0]
def newfunc(frb, *args, **kwargs):
frb._filters.append((f.__name__, (args, kwargs)))
# Invalidate the data of the frb to force its regeneration
frb._data_valid = False
return frb

return newfunc

Expand Down Expand Up @@ -50,20 +52,12 @@ def __init__(self, nbeam=30, sigma=2.0):
def apply(self, buff):
from yt.utilities.on_demand_imports import _scipy

hnbeam = self.nbeam // 2
sigma = self.sigma

l = np.linspace(-hnbeam, hnbeam, num=self.nbeam + 1)
x, y = np.meshgrid(l, l)
g2d = (1.0 / (sigma * np.sqrt(2.0 * np.pi))) * np.exp(
-((x / sigma) ** 2 + (y / sigma) ** 2) / (2 * sigma ** 2)
spl = _scipy.ndimage.gaussian_filter(
buff,
self.sigma,
truncate=self.nbeam / self.sigma,
)
g2d /= g2d.max()

npm, nqm = np.shape(buff)
spl = _scipy.signal.convolve(buff, g2d)

return spl[hnbeam : npm + hnbeam, hnbeam : nqm + hnbeam]
return spl


class FixedResolutionBufferWhiteNoiseFilter(FixedResolutionBufferFilter):
Expand Down
65 changes: 9 additions & 56 deletions yt/visualization/plot_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,67 +21,20 @@
from yt.utilities.definitions import formatted_length_unit_names
from yt.utilities.exceptions import YTNotInsideNotebook

from ._commons import DEFAULT_FONT_PROPERTIES, validate_image_name
from ._commons import (
DEFAULT_FONT_PROPERTIES,
invalidate_data,
invalidate_figure,
invalidate_plot,
validate_image_name,
validate_plot,
)

latex_prefixes = {
"u": r"\mu",
}


def invalidate_data(f):
@wraps(f)
def newfunc(*args, **kwargs):
rv = f(*args, **kwargs)
args[0]._data_valid = False
args[0]._plot_valid = False
return rv

return newfunc


def invalidate_figure(f):
@wraps(f)
def newfunc(*args, **kwargs):
rv = f(*args, **kwargs)
for field in args[0].plots.keys():
args[0].plots[field].figure = None
args[0].plots[field].axes = None
args[0].plots[field].cax = None
args[0]._setup_plots()
return rv

return newfunc


def invalidate_plot(f):
@wraps(f)
def newfunc(*args, **kwargs):
rv = f(*args, **kwargs)
args[0]._plot_valid = False
return rv

return newfunc


def validate_plot(f):
@wraps(f)
def newfunc(*args, **kwargs):
if hasattr(args[0], "_data_valid"):
if not args[0]._data_valid:
args[0]._recreate_frb()
if hasattr(args[0], "_profile_valid"):
if not args[0]._profile_valid:
args[0]._recreate_profile()
if not args[0]._plot_valid:
# it is the responsibility of _setup_plots to
# call args[0].run_callbacks()
args[0]._setup_plots()
rv = f(*args, **kwargs)
return rv

return newfunc


def apply_callback(f):
@wraps(f)
def newfunc(*args, **kwargs):
Expand Down Expand Up @@ -373,6 +326,7 @@ def _initialize_dataset(self, ts):
ts = DatasetSeries(ts)
return ts

@invalidate_data
def _switch_ds(self, new_ds, data_source=None):
old_object = self.data_source
name = old_object._type_name
Expand All @@ -399,7 +353,6 @@ def _switch_ds(self, new_ds, data_source=None):
new_object = getattr(new_ds, name)(**kwargs)

self.data_source = new_object
self._data_valid = self._plot_valid = False

for d in "xyz":
lim_name = d + "lim"
Expand Down
Loading

0 comments on commit 1623601

Please sign in to comment.