Skip to content

Commit

Permalink
Merge pull request #3827 from neutrinoceros/improve_center_sanitizing
Browse files Browse the repository at this point in the history
ENH: uniformize API for "center" argument for plot windows
  • Loading branch information
neutrinoceros authored Feb 20, 2023
2 parents 29fcdad + 3494048 commit a6dc72c
Show file tree
Hide file tree
Showing 19 changed files with 584 additions and 236 deletions.
5 changes: 5 additions & 0 deletions doc/source/visualizing/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,18 @@ If supplied without units, the center is assumed by in code units. There are al
the following alternative options for the ``center`` keyword:

* ``"center"``, ``"c"``: the domain center
* ``"left"``, ``"l"``, ``"right"`` ``"r"``: the domain's left/right edge along the normal direction
(``SlicePlot``'s second argument). Remaining axes use their respective domain center values.
* ``"min"``: the position of the minimum density
* ``"max"``, ``"m"``: the position of the maximum density
* ``"min/max_<field name>"``: the position of the minimum/maximum in the first field matching field name
* ``("min", field)``: the position of the minimum of ``field``
* ``("max", field)``: the position of the maximum of ``field``

where for the last two objects any spatial field, such as ``"density"``,
``"velocity_z"``,
etc., may be used, e.g. ``center=("min", ("gas", "temperature"))``.
``"left"`` and ``"right"`` are not allowed for off-axis slices.

The effective resolution of the plot (i.e. the number of resolution elements
in the image itself) can be controlled with the ``buff_size`` argument:
Expand Down
2 changes: 1 addition & 1 deletion nose_unit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ nologcapture=1
verbosity=2
where=yt
with-timer=1
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_set_zlim\.py|test_add_field\.py|test_glue\.py|test_geometries\.py|test_firefly\.py|test_callable_grids\.py|test_external_frontends\.py|test_stream_stretched\.py)
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_set_zlim\.py|test_add_field\.py|test_glue\.py|test_geometries\.py|test_firefly\.py|test_callable_grids\.py|test_external_frontends\.py|test_stream_stretched\.py|test_sanitize_center\.py)
exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF
1 change: 1 addition & 0 deletions tests/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ other_tests:
- "--ignore-file=test_callable_grids\\.py"
- "--ignore-file=test_external_frontends\\.py"
- "--ignore-file=test_stream_stretched\\.py"
- "--ignore-files=test_sanitize_center\\.py"
- "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF"
- "--exclude-test=yt.frontends.adaptahop.tests.test_outputs"
- "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles"
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _sq_field(field, data, fname: FieldKey):
self.ds.field_info.pop(field)
self.tree = tree

def to_pw(self, fields=None, center="c", width=None, origin="center-window"):
def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.
Expand Down
42 changes: 4 additions & 38 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from yt.data_objects.profiles import create_profile
from yt.fields.field_exceptions import NeedsGridType
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, iter_fields, mylog
from yt.funcs import get_output_filename, iter_fields, mylog, parse_center_array
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.amr_kdtree.api import AMRKDTree
from yt.utilities.exceptions import (
YTCouldNotGenerateField,
Expand Down Expand Up @@ -177,43 +176,10 @@ def _set_center(self, center):
if center is None:
self.center = None
return
elif isinstance(center, YTArray):
self.center = self.ds.arr(center.astype("float64"))
self.center.convert_to_units("code_length")
elif isinstance(center, (list, tuple, np.ndarray)):
if isinstance(center[0], YTQuantity):
self.center = self.ds.arr([c.copy() for c in center], dtype="float64")
self.center.convert_to_units("code_length")
else:
self.center = self.ds.arr(center, "code_length", dtype="float64")
elif isinstance(center, str):
if center.lower() in ("c", "center"):
self.center = self.ds.domain_center
# is this dangerous for race conditions?
elif center.lower() in ("max", "m"):
self.center = self.ds.find_max(("gas", "density"))[1]
elif center.startswith("max_"):
field = self._first_matching_field(center[4:])
self.center = self.ds.find_max(field)[1]
elif center.lower() == "min":
self.center = self.ds.find_min(("gas", "density"))[1]
elif center.startswith("min_"):
field = self._first_matching_field(center[4:])
self.center = self.ds.find_min(field)[1]
else:
self.center = self.ds.arr(center, "code_length", dtype="float64")

if self.center.ndim > 1:
mylog.debug("Removing singleton dimensions from 'center'.")
self.center = np.squeeze(self.center)
if self.center.ndim > 1:
msg = (
"center array must be 1 dimensional, supplied center has "
f"{self.center.ndim} dimensions with shape {self.center.shape}."
)
raise YTException(msg)

self.set_field_parameter("center", self.center)
axis = getattr(self, "axis", None)
self.center = parse_center_array(center, ds=self.ds, axis=axis)
self.set_field_parameter("center", self.center)

def get_field_parameter(self, name, default=None):
"""
Expand Down
6 changes: 3 additions & 3 deletions yt/data_objects/selection_objects/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _generate_container_field(self, field):
def _mrep(self):
return MinimalSliceData(self)

def to_pw(self, fields=None, center="c", width=None, origin="center-window"):
def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.
Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(
validate_object(ds, Dataset)
validate_object(field_parameters, dict)
validate_object(data_source, YTSelectionContainer)
YTSelectionContainer2D.__init__(self, 4, ds, field_parameters, data_source)
YTSelectionContainer2D.__init__(self, None, ds, field_parameters, data_source)
self._set_center(center)
self.set_field_parameter("center", center)
# Let's set up our plane equation
Expand Down Expand Up @@ -275,7 +275,7 @@ def _generate_container_field(self, field):
else:
raise KeyError(field)

def to_pw(self, fields=None, center="c", width=None, axes_unit=None):
def to_pw(self, fields=None, center="center", width=None, axes_unit=None):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.
Expand Down
13 changes: 13 additions & 0 deletions yt/data_objects/static_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import numpy as np
import unyt as un
from more_itertools import unzip
from sympy import Symbol
from unyt import Unit, UnitSystem, unyt_quantity
Expand Down Expand Up @@ -2037,6 +2038,18 @@ def define_unit(self, symbol, value, tex_repr=None, offset=None, prefixable=Fals
registry=self.unit_registry,
)

def _is_within_domain(self, point) -> bool:
assert len(point) == len(self.domain_left_edge)
assert point.units.dimensions == un.dimensions.length
for i, x in enumerate(point):
if self.periodicity[i]:
continue
if x < self.domain_left_edge[i]:
return False
if x > self.domain_right_edge[i]:
return False
return True


def _reconstruct_ds(*args, **kwargs):
datasets = ParameterFileStore()
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/tests/test_cutting_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_cutting_plane():
fi = ds._get_field_info(cut_field)
data = frb[cut_field]
assert_equal(data.info["data_source"], cut.__str__())
assert_equal(data.info["axis"], 4)
assert_equal(data.info["axis"], None)
assert_equal(data.info["field"], str(cut_field))
assert_equal(data.units, Unit(fi.units))
assert_equal(data.info["xlim"], frb.bounds[:2])
Expand Down
138 changes: 136 additions & 2 deletions yt/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import traceback
import urllib
from collections import UserDict
from copy import deepcopy
from functools import lru_cache, wraps
from numbers import Number as numeric_type
from typing import Any, Callable, Type
from typing import Any, Callable, Optional, Type

import numpy as np
from more_itertools import always_iterable, collapse, first
Expand All @@ -26,7 +27,7 @@
from yt._maintenance.deprecation import issue_deprecation_warning
from yt.config import ytcfg
from yt.units import YTArray, YTQuantity
from yt.utilities.exceptions import YTInvalidWidthError
from yt.utilities.exceptions import YTFieldNotFound, YTInvalidWidthError
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _requests as requests

Expand Down Expand Up @@ -1178,6 +1179,15 @@ def validate_field_key(key):
)


def is_valid_field_key(key):
try:
validate_field_key(key)
except TypeError:
return False
else:
return True


def validate_object(obj, data_type):
if obj is not None and not isinstance(obj, data_type):
raise TypeError(
Expand Down Expand Up @@ -1218,6 +1228,130 @@ def validate_center(center):
)


def parse_center_array(center, ds, axis: Optional[int] = None):
known_shortnames = {"m": "max", "c": "center", "l": "left", "r": "right"}
valid_single_str_values = ("center", "left", "right")
valid_field_loc_str_values = ("min", "max")
valid_str_values = valid_single_str_values + valid_field_loc_str_values
default_error_message = (
"Expected any of the following\n"
"- 'c', 'center', 'l', 'left', 'r', 'right', 'm', 'max', or 'min'\n"
"- a 2 element tuple with 'min' or 'max' as the first element, followed by a field identifier\n"
"- a 3 element array-like: for a unyt_array, expects length dimensions, otherwise code_lenght is assumed"
)
# store an unmodified copy of user input to be inserted in error messages
center_input = deepcopy(center)

if isinstance(center, str):
centerl = center.lower()
if centerl in known_shortnames:
centerl = known_shortnames[centerl]

match = re.match(r"^(?P<extremum>(min|max))(_(?P<field>[\w_]+))?", centerl)
if match is not None:
if match["field"] is not None:
for ftype, fname in ds.derived_field_list: # noqa: B007
if fname == match["field"]:
break
else:
raise YTFieldNotFound(match["field"], ds)
else:
ftype, fname = ("gas", "density")

center = (match["extremum"], (ftype, fname))

elif centerl in ("center", "left", "right"):
# domain_left_edge and domain_right_edge might not be
# initialized until we create the index, so create it
ds.index
center = ds.domain_center.copy()
if centerl in ("left", "right") and axis is None:
raise ValueError(f"center={center!r} is not valid with axis=None")
if centerl == "left":
center = ds.domain_center.copy()
center[axis] = ds.domain_left_edge[axis]
elif centerl == "right":
# note that the right edge of a grid is excluded by slice selector
# which is why we offset the region center by the smallest distance possible
center = ds.domain_center.copy()
center[axis] = (
ds.domain_right_edge[axis] - center.uq * np.finfo(center.dtype).eps
)

elif centerl not in valid_str_values:
raise ValueError(
f"Received unknown center single string value {center!r}. "
+ default_error_message
)

if is_sequence(center):
if (
len(center) == 2
and isinstance(center[0], str)
and (is_valid_field_key(center[1]) or isinstance(center[1], str))
):
center0l = center[0].lower()

if center0l not in valid_str_values:
raise ValueError(
f"Received unknown string value {center[0]!r}. "
f"Expected one of {valid_field_loc_str_values} (case insensitive)"
)
field_key = center[1]
if center0l == "min":
v, center = ds.find_min(field_key)
else:
assert center0l == "max"
v, center = ds.find_max(field_key)
center = ds.arr(center, "code_length")
elif len(center) == 2 and is_sequence(center[0]) and isinstance(center[1], str):
center = ds.arr(center[0], center[1])
elif len(center) == 3 and all(isinstance(_, YTQuantity) for _ in center):
center = ds.arr([c.copy() for c in center], dtype="float64")
elif len(center) == 3:
center = ds.arr(center, "code_length")

if isinstance(center, np.ndarray) and center.ndim > 1:
mylog.debug("Removing singleton dimensions from 'center'.")
center = np.squeeze(center)

if not isinstance(center, YTArray):
raise TypeError(
f"Received {center_input!r}, but failed to transform to a unyt_array (obtained {center!r}).\n"
+ default_error_message
+ "\n"
"If you supplied an expected type, consider filing a bug report"
)

if center.shape != (3,):
raise TypeError(
f"Received {center_input!r} and obtained {center!r} after sanitizing.\n"
+ default_error_message
+ "\n"
"If you supplied an expected type, consider filing a bug report"
)

# make sure the return value shares all
# unit symbols with ds.unit_registry
center = ds.arr(center)
# we rely on unyt to invalidate unit dimensionality here
center.convert_to_units("code_length")

if not ds._is_within_domain(center):
mylog.warning(
"Requested center at %s is outside of data domain with "
"left edge = %s, "
"right edge = %s, "
"periodicity = %s",
center,
ds.domain_left_edge,
ds.domain_right_edge,
ds.periodicity,
)

return center.astype("float64")


def sglob(pattern):
"""
Return the results of a glob through the sorted() function.
Expand Down
2 changes: 1 addition & 1 deletion yt/geometry/coordinates/cartesian_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def pixelize(
# re-order the array and squeeze out the dummy dim
return np.squeeze(np.transpose(img, (yax, xax, ax)))

elif self.axis_id.get(dimension, dimension) < 3:
elif self.axis_id.get(dimension, dimension) is not None:
return self._ortho_pixelize(
data_source, field, bounds, size, antialias, dimension, periodic
)
Expand Down
31 changes: 2 additions & 29 deletions yt/geometry/coordinates/coordinate_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from yt._typing import AxisOrder
from yt.funcs import fix_unitary, is_sequence, validate_width_tuple
from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.exceptions import YTCoordinateNotImplemented, YTInvalidWidthError

Expand Down Expand Up @@ -325,34 +325,7 @@ def sanitize_width(self, axis, width, depth):
return width

def sanitize_center(self, center, axis):
if isinstance(center, str):
if center.lower() == "m" or center.lower() == "max":
v, center = self.ds.find_max(("gas", "density"))
center = self.ds.arr(center, "code_length")
elif center.lower() == "c" or center.lower() == "center":
# domain_left_edge and domain_right_edge might not be
# initialized until we create the index, so create it
self.ds.index
center = (self.ds.domain_left_edge + self.ds.domain_right_edge) / 2
else:
raise RuntimeError(f'center keyword "{center}" not recognized')
elif isinstance(center, YTArray):
return self.ds.arr(center), self.convert_to_cartesian(center)
elif is_sequence(center):
if isinstance(center[0], str) and isinstance(center[1], str):
if center[0].lower() == "min":
v, center = self.ds.find_min(center[1])
elif center[0].lower() == "max":
v, center = self.ds.find_max(center[1])
else:
raise RuntimeError(f'center keyword "{center}" not recognized')
center = self.ds.arr(center, "code_length")
elif is_sequence(center[0]) and isinstance(center[1], str):
center = self.ds.arr(center[0], center[1])
else:
center = self.ds.arr(center, "code_length")
else:
raise RuntimeError(f'center keyword "{center}" not recognized')
center = parse_center_array(center, ds=self.ds, axis=axis)
# This has to return both a center and a display_center
display_center = self.convert_to_cartesian(center)
return center, display_center
Expand Down
Loading

0 comments on commit a6dc72c

Please sign in to comment.