diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index e70217f1a2b..661c3fdbbba 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -6,6 +6,7 @@ from functools import partial, wraps from re import finditer from tempfile import NamedTemporaryFile, TemporaryFile +from typing import Optional import numpy as np from more_itertools import always_iterable @@ -30,6 +31,7 @@ validate_moment, ) from yt.geometry import particle_deposit as particle_deposit +from yt.geometry.coordinates._axes_transforms import parse_axes_transform from yt.geometry.coordinates.cartesian_coordinates import all_data from yt.loaders import load_uniform_grid from yt.units._numpy_wrapper_functions import uconcatenate @@ -350,7 +352,15 @@ def _sq_field(field, data, fname: FieldKey): self.ds.field_info.pop(field) self.tree = tree - def to_pw(self, fields=None, center="center", width=None, origin="center-window"): + def to_pw( + self, + fields=None, + center="center", + width=None, + origin="center-window", + *, + axes_transform: Optional[str] = None, + ): r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this object. @@ -358,7 +368,10 @@ def to_pw(self, fields=None, center="center", width=None, origin="center-window" object, which can then be moved around, zoomed, and on and on. All behavior of the plot window is relegated to that routine. """ - pw = self._get_pw(fields, center, width, origin, "Projection") + _axt = parse_axes_transform(axes_transform) + pw = self._get_pw( + fields, center, width, origin, "Projection", axes_transform=_axt + ) return pw def plot(self, fields=None): diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index c965c34762d..b5dca97b559 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -18,6 +18,7 @@ from yt.fields.field_exceptions import NeedsGridType from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple from yt.geometry.api import Geometry +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.geometry.selection_routines import compose_selector from yt.units import YTArray from yt.utilities.exceptions import ( @@ -530,20 +531,28 @@ def __init__(self, axis, ds, field_parameters=None, data_source=None): def _convert_field_name(self, field): return field - def _get_pw(self, fields, center, width, origin, plot_type): + def _get_pw( + self, fields, center, width, origin, plot_type, *, axes_transform: AxesTransform + ): from yt.visualization.fixed_resolution import FixedResolutionBuffer as frb from yt.visualization.plot_window import PWViewerMPL, get_window_parameters axis = self.axis skip = self._key_fields - skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) + # this line works, but mypy incorrectly flags it, so turning it off locally + skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) # type: ignore [arg-type] self.fields = [k for k in self.field_data if k not in skip] if fields is not None: self.fields = list(iter_fields(fields)) + self.fields if len(self.fields) == 0: raise ValueError("No fields found to plot in get_pw") + (bounds, center, display_center) = get_window_parameters( - axis, center, width, self.ds + axis, + center, + width, + self.ds, + axes_transform=axes_transform, ) pw = PWViewerMPL( self, diff --git a/yt/data_objects/selection_objects/slices.py b/yt/data_objects/selection_objects/slices.py index 69547f1c295..dc790115399 100644 --- a/yt/data_objects/selection_objects/slices.py +++ b/yt/data_objects/selection_objects/slices.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from yt.data_objects.selection_objects.data_selection_objects import ( @@ -15,6 +17,7 @@ validate_object, validate_width_tuple, ) +from yt.geometry.coordinates._axes_transforms import parse_axes_transform from yt.utilities.exceptions import YTNotInsideNotebook from yt.utilities.minimal_representation import MinimalSliceData from yt.utilities.orientation import Orientation @@ -104,7 +107,15 @@ def _generate_container_field(self, field): def _mrep(self): return MinimalSliceData(self) - def to_pw(self, fields=None, center="center", width=None, origin="center-window"): + def to_pw( + self, + fields=None, + center="center", + width=None, + origin="center-window", + *, + axes_transform: Optional[str] = None, + ): r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this object. @@ -112,7 +123,8 @@ def to_pw(self, fields=None, center="center", width=None, origin="center-window" object, which can then be moved around, zoomed, and on and on. All behavior of the plot window is relegated to that routine. """ - pw = self._get_pw(fields, center, width, origin, "Slice") + _axt = parse_axes_transform(axes_transform) + pw = self._get_pw(fields, center, width, origin, "Slice", axes_transform=_axt) return pw def plot(self, fields=None): diff --git a/yt/frontends/nc4_cm1/data_structures.py b/yt/frontends/nc4_cm1/data_structures.py index 2036f2c2e2d..92c83b8e767 100644 --- a/yt/frontends/nc4_cm1/data_structures.py +++ b/yt/frontends/nc4_cm1/data_structures.py @@ -97,6 +97,7 @@ def _setup_coordinate_handler(self, axis_order: Optional[AxisOrder]) -> None: # type checking is deactivated in the following two lines because changing them is not # within the scope of the PR that _enabled_ typechecking here (#4244), but it'd be worth # having a careful look at *why* these warnings appear, as they may point to rotten code + # TODO(4179): refactor this out self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x")) # type: ignore [union-attr] self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y")) # type: ignore [union-attr] diff --git a/yt/geometry/coordinates/_axes_transforms.py b/yt/geometry/coordinates/_axes_transforms.py new file mode 100644 index 00000000000..219f6ca4f81 --- /dev/null +++ b/yt/geometry/coordinates/_axes_transforms.py @@ -0,0 +1,22 @@ +import sys +from enum import auto +from typing import Optional + +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from yt._maintenance.backports import StrEnum + + +class AxesTransform(StrEnum): + DEFAULT = auto() + GEOMETRY_NATIVE = auto() + POLAR = auto() + AITOFF_HAMMER = auto() + + +def parse_axes_transform(axes_transform: Optional[str]) -> AxesTransform: + if axes_transform is None: + # pass the responsability to ds.coordinates + axes_transform = "default" + return AxesTransform(axes_transform) diff --git a/yt/geometry/coordinates/cartesian_coordinates.py b/yt/geometry/coordinates/cartesian_coordinates.py index 992e9addefe..8dd0048247c 100644 --- a/yt/geometry/coordinates/cartesian_coordinates.py +++ b/yt/geometry/coordinates/cartesian_coordinates.py @@ -18,8 +18,10 @@ from yt.utilities.math_utils import compute_stddev_image from yt.utilities.nodal_data_utils import get_nodal_data +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_vert_fields, cartesian_to_cylindrical, @@ -161,13 +163,29 @@ def _check_fields(self, registry): ) def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): """ Method for pixelizing datasets in preparation for two-dimensional image plots. Relies on several sampling routines written in cython """ + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"cartesian coordinates don't implement {axes_transform} yet" + ) index = data_source.ds.index if hasattr(index, "meshes") and not isinstance( index.meshes[0], SemiStructuredMesh @@ -624,3 +642,45 @@ def convert_from_spherical(self, coord): @property def period(self): return self.ds.domain_width + + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"cartesian coordinates don't implement {axes_transform} yet" + ) + + if normal_axis_name == "x": + return { + "x_axis_label": "y", + "y_axis_label": "z", + "x_axis_units": None, + "y_axis_units": None, + } + elif normal_axis_name == "y": + return { + "x_axis_label": "z", + "y_axis_label": "x", + "x_axis_units": None, + "y_axis_units": None, + } + elif normal_axis_name == "z": + return { + "x_axis_label": "x", + "y_axis_label": "y", + "x_axis_units": None, + "y_axis_units": None, + } + elif normal_axis_name == "oblique": + return { + "x_axis_label": "Image x", + "y_axis_label": "Image y", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/coordinate_handler.py b/yt/geometry/coordinates/coordinate_handler.py index d579a6209e8..e97a8b059e1 100644 --- a/yt/geometry/coordinates/coordinate_handler.py +++ b/yt/geometry/coordinates/coordinate_handler.py @@ -2,14 +2,27 @@ import weakref from functools import cached_property from numbers import Number -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple, TypedDict import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt._typing import AxisOrder from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.exceptions import YTCoordinateNotImplemented, YTInvalidWidthError +from yt.utilities.lib.pixelization_routines import pixelize_cartesian + +from ._axes_transforms import AxesTransform + + +class DefaultProperties(TypedDict): + x_axis_label: str + y_axis_label: str + # note that an empty string maps to "dimensionless", + # while None means "figure it out yourself" + x_axis_units: Optional[str] + y_axis_units: Optional[str] def _unknown_coord(field, data): @@ -133,6 +146,7 @@ def validate_sequence_width(width, ds, unit=None): class CoordinateHandler(abc.ABC): name: str _default_axis_order: AxisOrder + _default_axes_transforms: Dict[str, AxesTransform] def __init__(self, ds, ordering: Optional[AxisOrder] = None): self.ds = weakref.proxy(ds) @@ -147,7 +161,17 @@ def setup_fields(self): pass @abc.abstractmethod - def pixelize(self, dimension, data_source, field, bounds, size, antialias=True): + def pixelize( + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + *, + axes_transform=AxesTransform.DEFAULT, + ): # This should *actually* be a pixelize call, not just returning the # pixelizer pass @@ -185,8 +209,15 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): pass + @abc.abstractmethod + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + ... + @cached_property def data_projection(self): + # see https://github.com/yt-project/yt/issues/4182 return {ax: None for ax in self.axis_order} @cached_property @@ -211,6 +242,12 @@ def axis_id(self): @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + stacklevel=3, + ) rv = {} for i in range(3): rv[i] = (self.axis_name[self.x_axis[i]], self.axis_name[self.y_axis[i]]) @@ -253,7 +290,16 @@ def sanitize_depth(self, depth): raise YTInvalidWidthError(depth) return depth - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"generic coordinate handler doesn't implement {axes_transform}" + ) if width is None: # initialize the index if it is not already initialized self.ds.index @@ -284,12 +330,46 @@ def sanitize_width(self, axis, width, depth): return width + depth return width - def sanitize_center(self, center, axis): + def _get_display_center(self, center, axes_transform: AxesTransform): + # default implementation + return self.convert_to_cartesian(center) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"generic coordinate handler doesn't implement {axes_transform}" + ) center = parse_center_array(center, ds=self.ds, axis=axis) # This has to return both a center and a display_center - display_center = self.convert_to_cartesian(center) + display_center = self._get_display_center(center, axes_transform) return center, display_center + def _ortho_pixelize( + self, data_source, field, bounds, size, antialias, dim, periodic + ): + period = self.period[:2].copy() # dummy here + period[0] = self.period[self.x_axis[dim]] + period[1] = self.period[self.y_axis[dim]] + if hasattr(period, "in_units"): + period = period.in_units("code_length").d + buff = np.full(size, np.nan, dtype="float64") + pixelize_cartesian( + buff, + data_source["px"], + data_source["py"], + data_source["pdx"], + data_source["pdy"], + data_source[field], + bounds, + int(antialias), + period, + int(periodic), + ) + return buff + def cartesian_to_cylindrical(coord, center=(0, 0, 0)): c2 = np.zeros_like(coord) diff --git a/yt/geometry/coordinates/cylindrical_coordinates.py b/yt/geometry/coordinates/cylindrical_coordinates.py index 32fd0ffff93..fe1e2576ce7 100644 --- a/yt/geometry/coordinates/cylindrical_coordinates.py +++ b/yt/geometry/coordinates/cylindrical_coordinates.py @@ -1,11 +1,15 @@ from functools import cached_property +from typing import Dict import numpy as np -from yt.utilities.lib.pixelization_routines import pixelize_cartesian, pixelize_cylinder +from yt._maintenance.deprecation import issue_deprecation_warning +from yt.utilities.lib.pixelization_routines import pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_polar_bounds, _setup_dummy_cartesian_coords_and_widths, @@ -22,13 +26,14 @@ class CylindricalCoordinateHandler(CoordinateHandler): name = "cylindrical" _default_axis_order = ("r", "z", "theta") + _default_axes_transforms: Dict[str, AxesTransform] = { + "r": AxesTransform.GEOMETRY_NATIVE, + "theta": AxesTransform.GEOMETRY_NATIVE, + "z": AxesTransform.POLAR, + } def __init__(self, ds, ordering=None): super().__init__(ds, ordering) - self.image_units = {} - self.image_units[self.axis_id["r"]] = ("rad", None) - self.image_units[self.axis_id["theta"]] = (None, None) - self.image_units[self.axis_id["z"]] = (None, None) def setup_fields(self, registry): # Missing implementation for x and y coordinates. @@ -87,17 +92,35 @@ def pixelize( size, antialias=True, periodic=False, + *, + axes_transform=AxesTransform.DEFAULT, ): # Note that above, we set periodic by default to be *false*. This is # because our pixelizers, at present, do not handle periodicity # correctly, and if you change the "width" of a cylindrical plot, it # double-counts in the edge buffers. See, for instance, issue 1669. - ax_name = self.axis_name[dimension] - if ax_name in ("r", "theta"): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: return self._ortho_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) - elif ax_name == "z": + + name = self.axis_name[dimension] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + + if name in ("r", "theta"): + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + return self._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # This is admittedly a very hacky way to resolve a bug # it's very likely that the *right* fix would have to # be applied upstream of this function, *but* this case @@ -113,29 +136,6 @@ def pixelize( def pixelize_line(self, field, start_point, end_point, npoints): raise NotImplementedError - def _ortho_pixelize( - self, data_source, field, bounds, size, antialias, dim, periodic - ): - period = self.period[:2].copy() # dummy here - period[0] = self.period[self.x_axis[dim]] - period[1] = self.period[self.y_axis[dim]] - if hasattr(period, "in_units"): - period = period.in_units("code_length").d - buff = np.full(size, np.nan, dtype="float64") - pixelize_cartesian( - buff, - data_source["px"], - data_source["py"], - data_source["pdx"], - data_source["pdy"], - data_source[field], - bounds, - int(antialias), - period, - int(periodic), - ) - return buff - def _cyl_pixelize(self, data_source, field, bounds, size, antialias): buff = np.full((size[1], size[0]), np.nan, dtype="f8") pixelize_cylinder( @@ -149,13 +149,18 @@ def _cyl_pixelize(self, data_source, field, bounds, size, antialias): ) return buff - _x_pairs = (("r", "theta"), ("z", "r"), ("theta", "r")) - _y_pairs = (("r", "z"), ("z", "theta"), ("theta", "z")) + _x_pairs = (("r", "theta"), ("z", "r"), ("theta", "r")) # deprecated + _y_pairs = (("r", "z"), ("z", "theta"), ("theta", "z")) # deprecated - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -198,47 +203,144 @@ def period(self): def _polar_bounds(self): return _get_polar_bounds(self, axes=("r", "theta")) - def sanitize_center(self, center, axis): - center, display_center = super().sanitize_center(center, axis) + def _get_display_center(self, center, axes_transform): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return center + else: + return super()._get_display_center(center, axes_transform) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + center, display_center = super().sanitize_center( + center, axis, axes_transform=AxesTransform.GEOMETRY_NATIVE + ) display_center = [ 0.0 * display_center[0], 0.0 * display_center[1], 0.0 * display_center[2], ] - ax_name = self.axis_name[axis] r_ax = self.axis_id["r"] theta_ax = self.axis_id["theta"] z_ax = self.axis_id["z"] - if ax_name == "r": + if name == "r": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) display_center[theta_ax] = self.ds.domain_center[theta_ax] display_center[z_ax] = self.ds.domain_center[z_ax] - elif ax_name == "theta": + elif name == "theta": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # use existing center value for idx in (r_ax, z_ax): display_center[idx] = center[idx] - elif ax_name == "z": + elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._polar_bounds xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (xc, yc, 0 * xc) + else: + RuntimeError(f"Unknown axis name {name!r} for cylindrical coordinates") return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + r_ax, theta_ax, z_ax = ( self.ds.coordinates.axis_id[ax] for ax in ("r", "theta", "z") ) if width is not None: - width = super().sanitize_width(axis, width, depth) + width = super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) # Note: regardless of axes, these are set up to give consistent plots # when plotted, which is not strictly a "right hand rule" for axes. elif name == "r": # soup can label + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) width = [self.ds.domain_width[theta_ax], self.ds.domain_width[z_ax]] elif name == "theta": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) width = [self.ds.domain_width[r_ax], self.ds.domain_width[z_ax]] elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._polar_bounds xw = xxmax - xxmin yw = yymax - yymin width = [xw, yw] return width + + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "r": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": r"\theta", + "y_axis_label": "z", + "x_axis_units": "rad", + "y_axis_units": None, + } + else: + raise NotImplementedError + elif normal_axis_name == "theta": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "R", + "y_axis_label": "z", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + elif normal_axis_name == "z": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "R", + "y_axis_label": r"\theta", + "x_axis_units": None, + "y_axis_units": "rad", + } + elif axes_transform is AxesTransform.POLAR: + return { + "x_axis_label": "x", + "y_axis_label": "y", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis name {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/geographic_coordinates.py b/yt/geometry/coordinates/geographic_coordinates.py index e8c30389adc..ec029c72b70 100644 --- a/yt/geometry/coordinates/geographic_coordinates.py +++ b/yt/geometry/coordinates/geographic_coordinates.py @@ -1,9 +1,14 @@ +from typing import Dict + import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt.utilities.lib.pixelization_routines import pixelize_cartesian, pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _setup_dummy_cartesian_coords_and_widths, ) @@ -12,15 +17,16 @@ class GeographicCoordinateHandler(CoordinateHandler): radial_axis = "altitude" name = "geographic" + _default_axes_transforms: Dict[str, AxesTransform] = { + "latitude": AxesTransform.POLAR, + "longitude": AxesTransform.POLAR, + "altitude": AxesTransform.AITOFF_HAMMER, + } def __init__(self, ds, ordering=None): if ordering is None: ordering = ("latitude", "longitude", self.radial_axis) super().__init__(ds, ordering) - self.image_units = {} - self.image_units[self.axis_id["latitude"]] = (None, None) - self.image_units[self.axis_id["longitude"]] = (None, None) - self.image_units[self.axis_id[self.radial_axis]] = ("deg", "deg") def setup_fields(self, registry): # Missing implementation for x, y and z coordinates. @@ -214,13 +220,29 @@ def _retrieve_radial_offset(self, data_source=None): return surface_height, 1.0 def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): - if self.axis_name[dimension] in ("latitude", "longitude"): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super()._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[dimension] + if name in ("latitude", "longitude"): return self._cyl_pixelize( data_source, field, bounds, size, antialias, dimension ) - elif self.axis_name[dimension] == self.radial_axis: + elif name == self.radial_axis: return self._ortho_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) @@ -326,10 +348,15 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): raise NotImplementedError - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -349,13 +376,13 @@ def image_axis_name(self): self._image_axis_name = rv return rv - _x_pairs = ( + _x_pairs = ( # deprecated ("latitude", "longitude"), ("longitude", "latitude"), ("altitude", "longitude"), ) - _y_pairs = ( + _y_pairs = ( # deprecated ("latitude", "altitude"), ("longitude", "altitude"), ("altitude", "latitude"), @@ -365,6 +392,7 @@ def image_axis_name(self): @property def data_projection(self): + # see https://github.com/yt-project/yt/issues/4182 # this will control the default projection to use when displaying data if self._data_projection is not None: return self._data_projection @@ -397,7 +425,12 @@ def data_transform(self): def period(self): return self.ds.domain_width - def sanitize_center(self, center, axis): + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + center, display_center = super().sanitize_center(center, axis) name = self.axis_name[axis] if name == self.radial_axis: @@ -419,7 +452,16 @@ def sanitize_center(self, center, axis): display_center[self.axis_id["latitude"]] = c return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[axis] if width is not None: width = super().sanitize_width(axis, width, depth) @@ -439,10 +481,86 @@ def sanitize_width(self, axis, width, depth): width = [self.ds.domain_width[ri], 2.0 * self.ds.domain_width[ri]] return width + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "latitude": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": r"longitude", + "y_axis_label": r"R", + "x_axis_units": "deg", + "y_axis_units": None, + } + elif axes_transform is AxesTransform.POLAR: + return { + "x_axis_label": r"x / \sin(\mathrm{latitude})", + "y_axis_label": r"y / \sin(\mathrm{latitude})", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + elif normal_axis_name == "longitude": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "latitude", + "y_axis_label": "R", + "x_axis_units": "deg", + "y_axis_units": None, + } + elif axes_transform is AxesTransform.POLAR: + return { + "x_axis_label": "R", + "y_axis_label": "z", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + elif normal_axis_name == self.radial_axis: + # TODO(4179): either clean this or refactor something elsewhere, + # because it is currently never used: + # this case is deleguated to cartopy, + # though we have everything needed to do it ourselves ! + # see https://github.com/yt-project/yt/issues/4182 + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "longitude", + "y_axis_label": "latitude", + "x_axis_units": "deg", + "y_axis_units": "deg", + } + elif axes_transform is AxesTransform.AITOFF_HAMMER: + return { + "x_axis_label": ( + r"\frac{2\cos(\mathrm{\mathrm{latitude}})\sin(\mathrm{longitude}/2)}" + r"{\sqrt{1 + \cos(\mathrm{latitude}) \cos(\mathrm{longitude}/2)}}" + ), + "y_axis_label": ( + r"\frac{sin(\mathrm{latitude})}" + r"{\sqrt{1 + \cos(\mathrm{latitude}) \cos(\mathrm{longitude}/2)}}" + ), + "x_axis_units": "dimensionless", + "y_axis_units": "dimensionless", + } + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") + class InternalGeographicCoordinateHandler(GeographicCoordinateHandler): radial_axis = "depth" name = "internal_geographic" + _default_axes_transforms: Dict[str, AxesTransform] = { + "latitude": AxesTransform.POLAR, + "longitude": AxesTransform.POLAR, + "depth": AxesTransform.AITOFF_HAMMER, + } def _setup_radial_fields(self, registry): # Altitude is the radius from the central zone minus the radius of the @@ -488,9 +606,18 @@ def _retrieve_radial_offset(self, data_source=None): ("depth", "longitude"), ) - _y_pairs = (("latitude", "depth"), ("longitude", "depth"), ("depth", "latitude")) + _y_pairs = ( + ("latitude", "depth"), + ("longitude", "depth"), + ("depth", "latitude"), + ) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError - def sanitize_center(self, center, axis): center, display_center = super( GeographicCoordinateHandler, self ).sanitize_center(center, axis) @@ -515,12 +642,19 @@ def sanitize_center(self, center, axis): display_center[self.axis_id["latitude"]] = outermost / 2.0 return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[axis] if width is not None: - width = super(GeographicCoordinateHandler, self).sanitize_width( - axis, width, depth - ) + width = super().sanitize_width(axis, width, depth) elif name == self.radial_axis: rax = self.radial_axis width = [ diff --git a/yt/geometry/coordinates/spec_cube_coordinates.py b/yt/geometry/coordinates/spec_cube_coordinates.py index 7275b60f736..b1eb2f197c7 100644 --- a/yt/geometry/coordinates/spec_cube_coordinates.py +++ b/yt/geometry/coordinates/spec_cube_coordinates.py @@ -1,9 +1,22 @@ -from .cartesian_coordinates import CartesianCoordinateHandler +from typing import Dict + +from yt._maintenance.deprecation import issue_deprecation_warning + +from ._axes_transforms import AxesTransform +from .cartesian_coordinates import ( + CartesianCoordinateHandler, + DefaultProperties, +) from .coordinate_handler import _get_coord_fields class SpectralCubeCoordinateHandler(CartesianCoordinateHandler): name = "spectral_cube" + _default_axes_transforms: Dict[str, AxesTransform] = { + "x": AxesTransform.GEOMETRY_NATIVE, + "y": AxesTransform.GEOMETRY_NATIVE, + "z": AxesTransform.GEOMETRY_NATIVE, + } def __init__(self, ds, ordering=None): if ordering is None: @@ -12,28 +25,6 @@ def __init__(self, ds, ordering=None): ) super().__init__(ds, ordering) - self.default_unit_label = {} - names = {} - if ds.lon_name != "X" or ds.lat_name != "Y": - names["x"] = r"Image\ x" - names["y"] = r"Image\ y" - # We can just use ds.lon_axis here - self.default_unit_label[ds.lon_axis] = "pixel" - self.default_unit_label[ds.lat_axis] = "pixel" - names["z"] = ds.spec_name - # Again, can use spec_axis here - self.default_unit_label[ds.spec_axis] = ds.spec_unit - - self._image_axis_name = ian = {} - for ax in "xyz": - axi = self.axis_id[ax] - xax = self.axis_name[self.x_axis[ax]] - yax = self.axis_name[self.y_axis[ax]] - ian[axi] = ian[ax] = ian[ax.upper()] = ( - names.get(xax, xax), - names.get(yax, yax), - ) - def _spec_axis(ax, x, y): p = (x, y)[ax] return [self.ds.pixel2spec(pp).v for pp in p] @@ -94,4 +85,73 @@ def convert_from_cylindrical(self, coord): @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) return self._image_axis_name + + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"spectral cube coordinates don't implement {axes_transform} yet" + ) + + if normal_axis_name == "x": + if self.ds.lon_name == "X" and self.ds.lat_name == "Y": + return { + "x_axis_label": "x", + "y_axis_label": self.ds.spec_name, + "x_axis_units": None, + "y_axis_units": None, + } + else: + return { + "x_axis_label": r"Image\ x", + "y_axis_label": self.ds.spec_name, + "x_axis_units": "pixel", + "y_axis_units": None, + } + elif normal_axis_name == "y": + { + "x_axis_label": "x", + "y_axis_label": self.ds.spec_name, + "x_axis_units": None, + "y_axis_units": None, + } + if self.ds.lon_name == "X" and self.ds.lat_name == "Y": + return { + "x_axis_label": "x", + "y_axis_label": self.ds.spec_name, + "x_axis_units": None, + "y_axis_units": None, + } + else: + return { + "x_axis_label": r"Image\ x", + "y_axis_label": self.ds.spec_name, + "x_axis_units": "pixel", + "y_axis_units": None, + } + elif normal_axis_name == "z": + if self.ds.lon_name == "X" and self.ds.lat_name == "Y": + return { + "x_axis_label": "x", + "y_axis_label": "y", + "x_axis_units": None, + "y_axis_units": None, + } + else: + return { + "x_axis_label": r"Image\ x", + "y_axis_label": r"Image\ y", + "x_axis_units": "pixel", + "y_axis_units": "pixel", + } + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/spherical_coordinates.py b/yt/geometry/coordinates/spherical_coordinates.py index ee98e48b2f6..2de5aedcbff 100644 --- a/yt/geometry/coordinates/spherical_coordinates.py +++ b/yt/geometry/coordinates/spherical_coordinates.py @@ -1,11 +1,15 @@ from functools import cached_property +from typing import Dict import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt.utilities.lib.pixelization_routines import pixelize_aitoff, pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_polar_bounds, _setup_dummy_cartesian_coords_and_widths, @@ -16,14 +20,14 @@ class SphericalCoordinateHandler(CoordinateHandler): name = "spherical" _default_axis_order = ("r", "theta", "phi") + _default_axes_transforms: Dict[str, AxesTransform] = { + "r": AxesTransform.AITOFF_HAMMER, + "theta": AxesTransform.POLAR, + "phi": AxesTransform.POLAR, + } def __init__(self, ds, ordering=None): super().__init__(ds, ordering) - # Generate - self.image_units = {} - self.image_units[self.axis_id["r"]] = (1, 1) - self.image_units[self.axis_id["theta"]] = (None, None) - self.image_units[self.axis_id["phi"]] = (None, None) def setup_fields(self, registry): # Missing implementation for x, y and z coordinates. @@ -82,15 +86,40 @@ def _path_phi(field, data): ) def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return self._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + self.period name = self.axis_name[dimension] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if name == "r": - return self._ortho_pixelize( + if axes_transform is not AxesTransform.AITOFF_HAMMER: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + return self._aitoff_hammer_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) elif name in ("theta", "phi"): + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) if name == "theta": # This is admittedly a very hacky way to resolve a bug # it's very likely that the *right* fix would have to @@ -108,7 +137,7 @@ def pixelize( def pixelize_line(self, field, start_point, end_point, npoints): raise NotImplementedError - def _ortho_pixelize( + def _aitoff_hammer_pixelize( self, data_source, field, bounds, size, antialias, dim, periodic ): # use Aitoff projection @@ -193,10 +222,15 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): raise NotImplementedError - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -226,18 +260,73 @@ def image_axis_name(self): _x_pairs = (("r", "theta"), ("theta", "r"), ("phi", "r")) _y_pairs = (("r", "phi"), ("theta", "phi"), ("phi", "theta")) + def _get_plot_axes_default_properties( + self, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "r": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": r"\theta", + "y_axis_label": r"\phi", + "x_axis_units": "rad", + "y_axis_units": "rad", + } + elif axes_transform is AxesTransform.AITOFF_HAMMER: + return { + "x_axis_label": r"\frac{2\cos(\mathrm{\bar{\theta}})\sin(\lambda/2)}{\sqrt{1 + \cos(\bar{\theta}) \cos(\lambda/2)}}", + "y_axis_label": r"\frac{sin(\bar{\theta})}{\sqrt{1 + \cos(\bar{\theta}) \cos(\lambda/2)}}", + "x_axis_units": "dimensionless", + "y_axis_units": "dimensionless", + } + else: + raise NotImplementedError + elif normal_axis_name == "theta": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "r", + "y_axis_label": r"\phi", + "x_axis_units": None, + "y_axis_units": "rad", + } + elif axes_transform is AxesTransform.POLAR: + return { + "x_axis_label": r"x / \sin(\theta)", + "y_axis_label": r"y / \sin(\theta)", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + elif normal_axis_name == "phi": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return { + "x_axis_label": "r", + "y_axis_label": r"\theta", + "x_axis_units": None, + "y_axis_units": "rad", + } + elif axes_transform is AxesTransform.POLAR: + return { + "x_axis_label": "R", + "y_axis_label": "z", + "x_axis_units": None, + "y_axis_units": None, + } + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis name {normal_axis_name!r}") + @property def period(self): return self.ds.domain_width @cached_property def _poloidal_bounds(self): - ri = self.axis_id["r"] - ti = self.axis_id["theta"] - rmin = self.ds.domain_left_edge[ri] - rmax = self.ds.domain_right_edge[ri] - thetamin = self.ds.domain_left_edge[ti] - thetamax = self.ds.domain_right_edge[ti] + rmin, rmax, thetamin, thetamax, _, _ = self._r_theta_phi_bounds corners = [ (rmin, thetamin), (rmin, thetamax), @@ -273,6 +362,25 @@ def to_poloidal_plane(r, theta): def _conic_bounds(self): return _get_polar_bounds(self, axes=("r", "phi")) + @cached_property + def _r_theta_phi_bounds(self): + # radius + ri = self.axis_id["r"] + rmin = self.ds.domain_left_edge[ri] + rmax = self.ds.domain_right_edge[ri] + + # colatitude + ti = self.axis_id["theta"] + thetamin = self.ds.domain_left_edge[ti] + thetamax = self.ds.domain_right_edge[ti] + + # azimuth + pi = self.axis_id["phi"] + phimin = self.ds.domain_left_edge[pi] + phimax = self.ds.domain_right_edge[pi] + + return rmin, rmax, thetamin, thetamax, phimin, phimax + @cached_property def _aitoff_bounds(self): # at the time of writing this function, yt's support for curvilinear @@ -281,18 +389,12 @@ def _aitoff_bounds(self): # this is not needed but calls for a large refactor. ONE = self.ds.quan(1, "code_length") - # colatitude - ti = self.axis_id["theta"] - thetamin = self.ds.domain_left_edge[ti] - thetamax = self.ds.domain_right_edge[ti] + _, _, thetamin, thetamax, phimin, phimax = self._r_theta_phi_bounds + # latitude latmin = ONE * np.pi / 2 - thetamax latmax = ONE * np.pi / 2 - thetamin - # azimuth - pi = self.axis_id["phi"] - phimin = self.ds.domain_left_edge[pi] - phimax = self.ds.domain_right_edge[pi] # longitude lonmin = phimin - ONE * np.pi lonmax = phimax - ONE * np.pi @@ -341,36 +443,79 @@ def to_aitoff_plane(latitude, longitude): return xmin, xmax, ymin, ymax - def sanitize_center(self, center, axis): + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): center, display_center = super().sanitize_center(center, axis) name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if name == "r": - xxmin, xxmax, yymin, yymax = self._aitoff_bounds + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + _, _, xxmin, xxmax, yymin, yymax = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.AITOFF_HAMMER: + xxmin, xxmax, yymin, yymax = self._aitoff_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (0 * xc, xc, yc) + elif name == "theta": - xxmin, xxmax, yymin, yymax = self._conic_bounds + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + xxmin, xxmax, _, _, yymin, yymax = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.POLAR: + xxmin, xxmax, yymin, yymax = self._conic_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (xc, 0 * xc, yc) elif name == "phi": - Rmin, Rmax, zmin, zmax = self._poloidal_bounds - xc = (Rmin + Rmax) / 2 - yc = (zmin + zmax) / 2 + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + xxmin, xxmax, yymin, yymax, _, _ = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.POLAR: + xxmin, xxmax, yymin, yymax = self._poloidal_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + xc = (xxmin + xxmax) / 2 + yc = (yymin + yymax) / 2 display_center = (xc, yc) return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if width is not None: width = super().sanitize_width(axis, width, depth) elif name == "r": + if axes_transform is not AxesTransform.AITOFF_HAMMER: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._aitoff_bounds xw = xxmax - xxmin yw = yymax - yymin width = [xw, yw] elif name == "theta": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # Remember, in spherical coordinates when we cut in theta, # we create a conic section xxmin, xxmax, yymin, yymax = self._conic_bounds @@ -378,6 +523,10 @@ def sanitize_width(self, axis, width, depth): yw = yymax - yymin width = [xw, yw] elif name == "phi": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) Rmin, Rmax, zmin, zmax = self._poloidal_bounds xw = Rmax - Rmin yw = zmax - zmin diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 83326e3daf0..b07a4bc6e90 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -9,6 +9,7 @@ from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, iter_fields, mylog +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.loaders import load_uniform_grid from yt.utilities.lib.api import ( # type: ignore CICDeposit_2, @@ -107,6 +108,7 @@ def __init__( periodic=False, *, filters: Optional[List["FixedResolutionBufferFilter"]] = None, + axes_transform: AxesTransform = AxesTransform.DEFAULT, ): self.data_source = data_source self.ds = data_source.ds @@ -117,6 +119,7 @@ def __init__( self.axis = data_source.axis self.periodic = periodic self._data_valid = False + self._axes_transform = axes_transform # import type here to avoid import cycles # note that this import statement is actually crucial at runtime: @@ -174,6 +177,7 @@ def __getitem__(self, item): bounds, self.buff_size, int(self.antialias), + axes_transform=self._axes_transform, ) buff = self._apply_filters(buff) @@ -680,9 +684,16 @@ def __init__( periodic=False, *, filters=None, + axes_transform=AxesTransform.DEFAULT, ): super().__init__( - data_source, bounds, buff_size, antialias, periodic, filters=filters + data_source, + bounds, + buff_size, + antialias, + periodic, + filters=filters, + axes_transform=axes_transform, ) # set up the axis field names diff --git a/yt/visualization/line_plot.py b/yt/visualization/line_plot.py index 2d44176138c..97907234e95 100644 --- a/yt/visualization/line_plot.py +++ b/yt/visualization/line_plot.py @@ -334,12 +334,12 @@ def _setup_plots(self): # set x and y axis labels axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) - if self._xlabel is not None: + if self._xlabel: x_label = self._xlabel else: x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$" - if self._ylabel is not None: + if self._ylabel: y_label = self._ylabel else: finfo = self.ds.field_info[field] diff --git a/yt/visualization/particle_plots.py b/yt/visualization/particle_plots.py index dc2fee71c49..c137ec1c59c 100644 --- a/yt/visualization/particle_plots.py +++ b/yt/visualization/particle_plots.py @@ -5,6 +5,7 @@ from yt.data_objects.profiles import create_profile from yt.data_objects.static_output import Dataset from yt.funcs import fix_axis, iter_fields +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.units.yt_array import YTArray from yt.visualization.fixed_resolution import ParticleImageBuffer from yt.visualization.profile_plotter import PhasePlot @@ -250,7 +251,11 @@ def __init__( ds = self.ds = ts[0] axis = fix_axis(axis, ds) (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=AxesTransform.DEFAULT, ) if field_parameters is None: field_parameters = {} diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index f2220a7cb3f..1b0435129c5 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -705,22 +705,6 @@ def _get_axes_unit_labels(self, unit_x, unit_y): comoving = False hinv = False for i, un in enumerate((unit_x, unit_y)): - unn = None - if hasattr(self.data_source, "axis"): - if hasattr(self.ds.coordinates, "image_units"): - # This *forces* an override - unn = self.ds.coordinates.image_units[self.data_source.axis][i] - elif hasattr(self.ds.coordinates, "default_unit_label"): - axax = getattr(self.ds.coordinates, f"{'xy'[i]}_axis")[ - self.data_source.axis - ] - unn = self.ds.coordinates.default_unit_label.get(axax, None) - if unn in (1, "1", "dimensionless"): - axes_unit_labels[i] = "" - continue - if unn is not None: - axes_unit_labels[i] = _get_units_label(unn).strip("$") - continue # Use sympy to factor h out of the unit. In this context 'un' # is a string, so we call the Unit constructor. expr = Unit(un, registry=self.ds.unit_registry).expr @@ -741,7 +725,7 @@ def _get_axes_unit_labels(self, unit_x, unit_y): # It doesn't make sense to scale a position by anything # other than h**-1 raise RuntimeError - if un not in ["1", "u", "unitary"]: + if un not in ["dimensionless", "1", "u", "unitary"]: if un in formatted_length_unit_names: un = formatted_length_unit_names[un] else: diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 5c33be9c107..1da5e901b78 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -22,6 +22,7 @@ validate_moment, ) from yt.geometry.api import Geometry +from yt.geometry.coordinates._axes_transforms import AxesTransform, parse_axes_transform from yt.units.unit_object import Unit # type: ignore from yt.units.unit_registry import UnitParseError # type: ignore from yt.units.yt_array import YTArray, YTQuantity @@ -67,9 +68,15 @@ from yt._maintenance.backports import zip -def get_window_parameters(axis, center, width, ds): - width = ds.coordinates.sanitize_width(axis, width, None) - center, display_center = ds.coordinates.sanitize_center(center, axis) +def get_window_parameters( + axis, center, width, ds, axes_transform: AxesTransform = AxesTransform.DEFAULT +): + width = ds.coordinates.sanitize_width( + axis, width, None, axes_transform=axes_transform + ) + center, display_center = ds.coordinates.sanitize_center( + center, axis, axes_transform=axes_transform + ) xax = ds.coordinates.x_axis[axis] yax = ds.coordinates.y_axis[axis] bounds = ( @@ -102,6 +109,8 @@ def get_axes_unit(width, ds): r""" Infers the axes unit names from the input width specification """ + if width is None: + return None if ds.no_cgs_equiv_length: return ("code_length",) * 2 if is_sequence(width): @@ -191,6 +200,7 @@ def __init__( setup=False, *, geometry: Geometry = Geometry.CARTESIAN, + axes_transform: AxesTransform = AxesTransform.DEFAULT, ) -> None: # axis manipulation operations are callback-only: self._swap_axes_input = False @@ -204,8 +214,12 @@ def __init__( self.buff_size = buff_size self.antialias = antialias self._axes_unit_names = None + + # TODO(4179): handle compat with _transform and _projection ? + # see https://github.com/yt-project/yt/issues/4182 self._transform = None self._projection = None + self._axes_transform = axes_transform self.aspect = aspect skip = list(FixedResolutionBuffer._exclude_fields) + data_source._key_fields @@ -235,11 +249,12 @@ def __init__( self.origin = origin if self.data_source.center is not None and not oblique: + # see https://github.com/yt-project/yt/issues/4182 ax = self.data_source.axis xax = self.ds.coordinates.x_axis[ax] yax = self.ds.coordinates.y_axis[ax] center, display_center = self.ds.coordinates.sanitize_center( - self.data_source.center, ax + self.data_source.center, ax, axes_transform=self._axes_transform ) center = [display_center[xax], display_center[yax]] self.set_center(center) @@ -329,6 +344,7 @@ def _recreate_frb(self): self.antialias, periodic=self._periodic, filters=old_filters, + axes_transform=self._axes_transform, ) # At this point the frb has the valid bounds, size, aliasing, etc. @@ -691,8 +707,9 @@ def set_width(self, width, unit=None): axes_unit = get_axes_unit(width, self.ds) - width = self.ds.coordinates.sanitize_width(self.frb.axis, width, None) - + width = self.ds.coordinates.sanitize_width( + self.frb.axis, width, None, axes_transform=self._axes_transform + ) centerx = (self.xlim[1] + self.xlim[0]) / 2.0 centery = (self.ylim[1] + self.ylim[0]) / 2.0 @@ -866,7 +883,8 @@ def __init__(self, *args, **kwargs) -> None: if self._plot_type is None: self._plot_type = kwargs.pop("plot_type") self._splat_color = kwargs.pop("splat_color", None) - PlotWindow.__init__(self, *args, **kwargs) + self._frb: Optional[FixedResolutionBuffer] = None + super().__init__(*args, **kwargs) # import type here to avoid import cycles # note that this import statement is actually crucial at runtime: @@ -1012,62 +1030,73 @@ def _setup_plots(self): self._recreate_frb() self._colorbar_valid = True field_list = list(set(self.data_source._determine_fields(self.fields))) - for f in field_list: - axis_index = self.data_source.axis + coordinates = self.ds.coordinates + normal_axis_index = self.data_source.axis + + if self.oblique: + normal_axis_name = "oblique" + else: + normal_axis_name = coordinates.axis_name[normal_axis_index] + + default_plot_properties = coordinates._get_plot_axes_default_properties( + normal_axis_name, self._axes_transform + ) + + for f in field_list: xc, yc = self._setup_origin() - if self.ds._uses_code_length_unit: - # this should happen only if the dataset was initialized with - # argument unit_system="code" or if it's set to have no CGS - # equivalent. This only needs to happen here in the specific - # case that we're doing a computationally intense operation - # like using cartopy, but it prevents crashes in that case. - (unit_x, unit_y) = ("code_length", "code_length") - elif self._axes_unit_names is None: - unit = self.ds.get_smallest_appropriate_unit( - self.xlim[1] - self.xlim[0] - ) - unit_x = unit_y = unit - coords = self.ds.coordinates - if hasattr(coords, "image_units"): - # check for special cases defined in - # non cartesian CoordinateHandler subclasses - image_units = coords.image_units[coords.axis_id[axis_index]] - if image_units[0] in ("deg", "rad"): + if self._axes_unit_names is None: + unit_x = default_plot_properties["x_axis_units"] + unit_y = default_plot_properties["y_axis_units"] + if unit_x is None: + if self.ds._uses_code_length_unit: unit_x = "code_length" - elif image_units[0] == 1: - unit_x = "dimensionless" - if image_units[1] in ("deg", "rad"): + else: + unit_x = self.ds.get_smallest_appropriate_unit( + self.xlim[1] - self.xlim[0] + ) + if unit_y is None: + if self.ds._uses_code_length_unit: unit_y = "code_length" - elif image_units[1] == 1: - unit_y = "dimensionless" + else: + unit_y = self.ds.get_smallest_appropriate_unit( + self.ylim[1] - self.ylim[0] + ) else: (unit_x, unit_y) = self._axes_unit_names - # For some plots we may set aspect by hand, such as for spectral cube data. - # This will likely be replaced at some point by the coordinate handler - # setting plot aspect. - if self.aspect is None: - self.aspect = float( - (self.ds.quan(1.0, unit_y) / self.ds.quan(1.0, unit_x)).in_cgs() - ) extentx = (self.xlim - xc)[:2] extenty = (self.ylim - yc)[:2] # extentx/y arrays inherit units from xlim and ylim attributes # and these attributes are always length even for angular and # dimensionless axes so we need to strip out units for consistency - if unit_x == "dimensionless": + if unit_x in ("dimensionless", "rad", "deg"): extentx = extentx / extentx.units else: extentx.convert_to_units(unit_x) - if unit_y == "dimensionless": + if unit_y in ("dimensionless", "rad", "deg"): extenty = extenty / extenty.units else: extenty.convert_to_units(unit_y) extent = [*extentx, *extenty] + # For some plots we may set aspect by hand, such as for spectral cube data. + # This will likely be replaced at some point by the coordinate handler + # setting plot aspect. + if self.aspect is None: + ratio = (self.ds.quan(1.0, unit_y) / self.ds.quan(1.0, unit_x)).in_cgs() + if ratio.units.is_dimensionless: + self.aspect = float(ratio) + else: + # maybe we have length on the x axis and radians on the y axis + # in that case, it doesn't make much sense to impose a 1 to 1 ratio + # so instead we set the image to be a square by default + self.aspect = float( + ((extentx[1] - extentx[0]) / (extenty[1] - extenty[0])).value + ) + image = self.frb[f] font_size = self._font_properties.get_size() @@ -1134,30 +1163,31 @@ def _setup_plots(self): colorbar_handler=cbh, ) - axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) + # construct default axes labels + x_unit_label, y_unit_label = self._get_axes_unit_labels(unit_x, unit_y) if self.oblique: labels = [ - r"$\rm{Image\ x" + axes_unit_labels[0] + "}$", - r"$\rm{Image\ y" + axes_unit_labels[1] + "}$", + r"$\rm{Image\ x" + x_unit_label + "}$", + r"$\rm{Image\ y" + y_unit_label + "}$", ] else: - coordinates = self.ds.coordinates - axis_names = coordinates.image_axis_name[axis_index] - xax = coordinates.x_axis[axis_index] - yax = coordinates.y_axis[axis_index] - - if hasattr(coordinates, "axis_default_unit_name"): - axes_unit_labels = [ - coordinates.axis_default_unit_name[xax], - coordinates.axis_default_unit_name[yax], - ] labels = [ - r"$\rm{" + axis_names[0] + axes_unit_labels[0] + r"}$", - r"$\rm{" + axis_names[1] + axes_unit_labels[1] + r"}$", + r"$\rm{" + + default_plot_properties["x_axis_label"] + + x_unit_label + + r"}$", + r"$\rm{" + + default_plot_properties["y_axis_label"] + + y_unit_label + + r"}$", ] if hasattr(coordinates, "axis_field"): + # this is exclusive to spectral_cube geometries + xax = coordinates.x_axis[normal_axis_index] + yax = coordinates.y_axis[normal_axis_index] + if xax in coordinates.axis_field: xmin, xmax = coordinates.axis_field[xax]( 0, self.xlim, self.ylim @@ -1178,9 +1208,9 @@ def _setup_plots(self): x_label, y_label, colorbar_label = self._get_axes_labels(f) - if x_label is not None: + if x_label: labels[0] = x_label - if y_label is not None: + if y_label: labels[1] = y_label if swap_axes: @@ -1193,7 +1223,7 @@ def _setup_plots(self): units = Unit(self.frb[f].units, registry=self.ds.unit_registry) units = units.latex_representation() - if colorbar_label is None: + if not colorbar_label: colorbar_label = image.info["label"] if getattr(self, "moment", 1) == 2: colorbar_label = "%s \\rm{Standard Deviation}" % colorbar_label @@ -1804,6 +1834,7 @@ def __init__( buff_size=(800, 800), *, north_vector=None, + axes_transform: Optional[str] = None, ): if north_vector is not None: # this kwarg exists only for symmetry reasons with OffAxisSlicePlot @@ -1814,10 +1845,17 @@ def __init__( del north_vector normal = self.sanitize_normal_vector(ds, normal) + + _axt = parse_axes_transform(axes_transform) + # this will handle time series data and controllers axis = fix_axis(normal, ds) (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=_axt, ) if field_parameters is None: field_parameters = {} @@ -1837,8 +1875,7 @@ def __init__( ) slc.get_data(fields) validate_mesh_fields(slc, fields) - PWViewerMPL.__init__( - self, + super().__init__( slc, bounds, origin=origin, @@ -1848,6 +1885,7 @@ def __init__( aspect=aspect, buff_size=buff_size, geometry=ds.geometry, + axes_transform=_axt, ) if axes_unit is None: axes_unit = get_axes_unit(width, ds) @@ -2033,6 +2071,7 @@ def __init__( aspect=None, *, moment=1, + axes_transform: Optional[str] = None, ): if method == "mip": issue_deprecation_warning( @@ -2043,12 +2082,18 @@ def __init__( method = "max" normal = self.sanitize_normal_vector(ds, normal) + _axt = parse_axes_transform(axes_transform) + axis = fix_axis(normal, ds) # If a non-weighted integral projection, assure field-label reflects that if weight_field is None and method == "integrate": self.projected = True (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=_axt, ) if field_parameters is None: field_parameters = {} @@ -2094,6 +2139,7 @@ def __init__( aspect=aspect, buff_size=buff_size, geometry=ds.geometry, + axes_transform=_axt, ) if axes_unit is None: axes_unit = get_axes_unit(width, ds) @@ -2560,6 +2606,8 @@ def plot_2d( window_size=8.0, aspect=None, data_source=None, + *, + axes_transform: Optional[str] = None, ) -> AxisAlignedSlicePlot: r"""Creates a plot of a 2D dataset @@ -2710,4 +2758,5 @@ def plot_2d( window_size=window_size, aspect=aspect, data_source=data_source, + axes_transform=axes_transform, )