diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index 016ca47f246..b693b182ef3 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -6,6 +6,7 @@ from functools import partial, wraps from re import finditer from tempfile import NamedTemporaryFile, TemporaryFile +from typing import Optional import numpy as np from more_itertools import always_iterable @@ -30,6 +31,7 @@ validate_moment, ) from yt.geometry import particle_deposit as particle_deposit +from yt.geometry.coordinates._axes_transforms import parse_axes_transform from yt.geometry.coordinates.cartesian_coordinates import all_data from yt.loaders import load_uniform_grid from yt.units._numpy_wrapper_functions import uconcatenate @@ -350,7 +352,15 @@ def _sq_field(field, data, fname: FieldKey): self.ds.field_info.pop(field) self.tree = tree - def to_pw(self, fields=None, center="center", width=None, origin="center-window"): + def to_pw( + self, + fields=None, + center="center", + width=None, + origin="center-window", + *, + axes_transform: Optional[str] = None, + ): r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this object. @@ -358,7 +368,10 @@ def to_pw(self, fields=None, center="center", width=None, origin="center-window" object, which can then be moved around, zoomed, and on and on. All behavior of the plot window is relegated to that routine. """ - pw = self._get_pw(fields, center, width, origin, "Projection") + _axt = parse_axes_transform(axes_transform) + pw = self._get_pw( + fields, center, width, origin, "Projection", axes_transform=_axt + ) return pw def plot(self, fields=None): diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index c965c34762d..b5dca97b559 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -18,6 +18,7 @@ from yt.fields.field_exceptions import NeedsGridType from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple from yt.geometry.api import Geometry +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.geometry.selection_routines import compose_selector from yt.units import YTArray from yt.utilities.exceptions import ( @@ -530,20 +531,28 @@ def __init__(self, axis, ds, field_parameters=None, data_source=None): def _convert_field_name(self, field): return field - def _get_pw(self, fields, center, width, origin, plot_type): + def _get_pw( + self, fields, center, width, origin, plot_type, *, axes_transform: AxesTransform + ): from yt.visualization.fixed_resolution import FixedResolutionBuffer as frb from yt.visualization.plot_window import PWViewerMPL, get_window_parameters axis = self.axis skip = self._key_fields - skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) + # this line works, but mypy incorrectly flags it, so turning it off locally + skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) # type: ignore [arg-type] self.fields = [k for k in self.field_data if k not in skip] if fields is not None: self.fields = list(iter_fields(fields)) + self.fields if len(self.fields) == 0: raise ValueError("No fields found to plot in get_pw") + (bounds, center, display_center) = get_window_parameters( - axis, center, width, self.ds + axis, + center, + width, + self.ds, + axes_transform=axes_transform, ) pw = PWViewerMPL( self, diff --git a/yt/data_objects/selection_objects/slices.py b/yt/data_objects/selection_objects/slices.py index 69547f1c295..dc790115399 100644 --- a/yt/data_objects/selection_objects/slices.py +++ b/yt/data_objects/selection_objects/slices.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from yt.data_objects.selection_objects.data_selection_objects import ( @@ -15,6 +17,7 @@ validate_object, validate_width_tuple, ) +from yt.geometry.coordinates._axes_transforms import parse_axes_transform from yt.utilities.exceptions import YTNotInsideNotebook from yt.utilities.minimal_representation import MinimalSliceData from yt.utilities.orientation import Orientation @@ -104,7 +107,15 @@ def _generate_container_field(self, field): def _mrep(self): return MinimalSliceData(self) - def to_pw(self, fields=None, center="center", width=None, origin="center-window"): + def to_pw( + self, + fields=None, + center="center", + width=None, + origin="center-window", + *, + axes_transform: Optional[str] = None, + ): r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this object. @@ -112,7 +123,8 @@ def to_pw(self, fields=None, center="center", width=None, origin="center-window" object, which can then be moved around, zoomed, and on and on. All behavior of the plot window is relegated to that routine. """ - pw = self._get_pw(fields, center, width, origin, "Slice") + _axt = parse_axes_transform(axes_transform) + pw = self._get_pw(fields, center, width, origin, "Slice", axes_transform=_axt) return pw def plot(self, fields=None): diff --git a/yt/frontends/nc4_cm1/data_structures.py b/yt/frontends/nc4_cm1/data_structures.py index 2036f2c2e2d..9cb31c237d3 100644 --- a/yt/frontends/nc4_cm1/data_structures.py +++ b/yt/frontends/nc4_cm1/data_structures.py @@ -1,11 +1,9 @@ import os import weakref from collections import OrderedDict -from typing import Optional import numpy as np -from yt._typing import AxisOrder from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch from yt.data_objects.static_output import Dataset from yt.geometry.grid_geometry_handler import GridIndex @@ -89,16 +87,12 @@ def __init__( ) self.storage_filename = storage_filename - def _setup_coordinate_handler(self, axis_order: Optional[AxisOrder]) -> None: + def _setup_coordinate_handler(self): # ensure correct ordering of axes so plots aren't rotated (z should always be # on the vertical axis). - super()._setup_coordinate_handler(axis_order) - - # type checking is deactivated in the following two lines because changing them is not - # within the scope of the PR that _enabled_ typechecking here (#4244), but it'd be worth - # having a careful look at *why* these warnings appear, as they may point to rotten code - self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x")) # type: ignore [union-attr] - self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y")) # type: ignore [union-attr] + super()._setup_coordinate_handler() + self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x")) + self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y")) def _set_code_unit_attributes(self): # This is where quantities are created that represent the various diff --git a/yt/geometry/coordinates/_axes_transforms.py b/yt/geometry/coordinates/_axes_transforms.py new file mode 100644 index 00000000000..89c581d1e7e --- /dev/null +++ b/yt/geometry/coordinates/_axes_transforms.py @@ -0,0 +1,23 @@ +from enum import Enum, auto +from typing import Optional + + +class AxesTransform(Enum): + DEFAULT = auto() + GEOMETRY_NATIVE = auto() + POLAR = auto() + AITOFF_HAMMER = auto() + + +def parse_axes_transform(axes_transform: Optional[str]) -> AxesTransform: + if axes_transform is None: + # pass the responsability to ds.coordinates + return AxesTransform.DEFAULT + elif axes_transform == "geometry_native": + return AxesTransform.GEOMETRY_NATIVE + elif axes_transform == "polar": + return AxesTransform.POLAR + elif axes_transform == "aitoff_hammer": + return AxesTransform.AITOFF_HAMMER + else: + raise ValueError(f"Unknown axes transform {axes_transform!r}") diff --git a/yt/geometry/coordinates/cartesian_coordinates.py b/yt/geometry/coordinates/cartesian_coordinates.py index 992e9addefe..405b8b05fc8 100644 --- a/yt/geometry/coordinates/cartesian_coordinates.py +++ b/yt/geometry/coordinates/cartesian_coordinates.py @@ -18,8 +18,10 @@ from yt.utilities.math_utils import compute_stddev_image from yt.utilities.nodal_data_utils import get_nodal_data +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_vert_fields, cartesian_to_cylindrical, @@ -161,13 +163,29 @@ def _check_fields(self, registry): ) def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): """ Method for pixelizing datasets in preparation for two-dimensional image plots. Relies on several sampling routines written in cython """ + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"cartesian coordinates don't implement {axes_transform} yet" + ) index = data_source.ds.index if hasattr(index, "meshes") and not isinstance( index.meshes[0], SemiStructuredMesh @@ -624,3 +642,46 @@ def convert_from_spherical(self, coord): @property def period(self): return self.ds.domain_width + + @classmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"cartesian coordinates don't implement {axes_transform} yet" + ) + + if normal_axis_name == "x": + return dict( + x_axis_label="y", + y_axis_label="z", + x_axis_units=None, + y_axis_units=None, + ) + elif normal_axis_name == "y": + return dict( + x_axis_label="z", + y_axis_label="x", + x_axis_units=None, + y_axis_units=None, + ) + elif normal_axis_name == "z": + return dict( + x_axis_label="x", + y_axis_label="y", + x_axis_units=None, + y_axis_units=None, + ) + elif normal_axis_name == "oblique": + return dict( + x_axis_label="Image x", + y_axis_label="Image y", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/coordinate_handler.py b/yt/geometry/coordinates/coordinate_handler.py index 17a0b42c21f..fe927e33f74 100644 --- a/yt/geometry/coordinates/coordinate_handler.py +++ b/yt/geometry/coordinates/coordinate_handler.py @@ -1,14 +1,27 @@ import abc import weakref from numbers import Number -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple, TypedDict import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt._typing import AxisOrder from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.exceptions import YTCoordinateNotImplemented, YTInvalidWidthError +from yt.utilities.lib.pixelization_routines import pixelize_cartesian + +from ._axes_transforms import AxesTransform + + +class DefaultProperties(TypedDict): + x_axis_label: str + y_axis_label: str + # note that an empty string maps to "dimensionless", + # while None means "figure it out yourself" + x_axis_units: Optional[str] + y_axis_units: Optional[str] def _unknown_coord(field, data): @@ -132,6 +145,7 @@ def validate_sequence_width(width, ds, unit=None): class CoordinateHandler(abc.ABC): name: str _default_axis_order: AxisOrder + _default_axes_transforms: Dict[str, AxesTransform] def __init__(self, ds, ordering: Optional[AxisOrder] = None): self.ds = weakref.proxy(ds) @@ -146,7 +160,17 @@ def setup_fields(self): pass @abc.abstractmethod - def pixelize(self, dimension, data_source, field, bounds, size, antialias=True): + def pixelize( + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + *, + axes_transform=AxesTransform.DEFAULT, + ): # This should *actually* be a pixelize call, not just returning the # pixelizer pass @@ -184,10 +208,18 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): pass + @classmethod + @abc.abstractmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + ... + _data_projection = None @property def data_projection(self): + # see https://github.com/yt-project/yt/issues/4182 if self._data_projection is not None: return self._data_projection dpj = {} @@ -234,10 +266,16 @@ def axis_id(self): self._axis_id = ai return ai - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + stacklevel=3, + ) # Default if self._image_axis_name is not None: return self._image_axis_name @@ -293,7 +331,16 @@ def sanitize_depth(self, depth): raise YTInvalidWidthError(depth) return depth - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"generic coordinate handler doesn't implement {axes_transform}" + ) if width is None: # initialize the index if it is not already initialized self.ds.index @@ -324,12 +371,46 @@ def sanitize_width(self, axis, width, depth): return width + depth return width - def sanitize_center(self, center, axis): + def _get_display_center(self, center, axes_transform: AxesTransform): + # default implementation + return self.convert_to_cartesian(center) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"generic coordinate handler doesn't implement {axes_transform}" + ) center = parse_center_array(center, ds=self.ds, axis=axis) # This has to return both a center and a display_center - display_center = self.convert_to_cartesian(center) + display_center = self._get_display_center(center, axes_transform) return center, display_center + def _ortho_pixelize( + self, data_source, field, bounds, size, antialias, dim, periodic + ): + period = self.period[:2].copy() # dummy here + period[0] = self.period[self.x_axis[dim]] + period[1] = self.period[self.y_axis[dim]] + if hasattr(period, "in_units"): + period = period.in_units("code_length").d + buff = np.full(size, np.nan, dtype="float64") + pixelize_cartesian( + buff, + data_source["px"], + data_source["py"], + data_source["pdx"], + data_source["pdy"], + data_source[field], + bounds, + int(antialias), + period, + int(periodic), + ) + return buff + def cartesian_to_cylindrical(coord, center=(0, 0, 0)): c2 = np.zeros_like(coord) diff --git a/yt/geometry/coordinates/cylindrical_coordinates.py b/yt/geometry/coordinates/cylindrical_coordinates.py index 32fd0ffff93..cc3928991d7 100644 --- a/yt/geometry/coordinates/cylindrical_coordinates.py +++ b/yt/geometry/coordinates/cylindrical_coordinates.py @@ -1,11 +1,15 @@ from functools import cached_property +from typing import Dict import numpy as np -from yt.utilities.lib.pixelization_routines import pixelize_cartesian, pixelize_cylinder +from yt._maintenance.deprecation import issue_deprecation_warning +from yt.utilities.lib.pixelization_routines import pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_polar_bounds, _setup_dummy_cartesian_coords_and_widths, @@ -22,13 +26,14 @@ class CylindricalCoordinateHandler(CoordinateHandler): name = "cylindrical" _default_axis_order = ("r", "z", "theta") + _default_axes_transforms: Dict[str, AxesTransform] = { + "r": AxesTransform.GEOMETRY_NATIVE, + "theta": AxesTransform.GEOMETRY_NATIVE, + "z": AxesTransform.POLAR, + } def __init__(self, ds, ordering=None): super().__init__(ds, ordering) - self.image_units = {} - self.image_units[self.axis_id["r"]] = ("rad", None) - self.image_units[self.axis_id["theta"]] = (None, None) - self.image_units[self.axis_id["z"]] = (None, None) def setup_fields(self, registry): # Missing implementation for x and y coordinates. @@ -87,17 +92,35 @@ def pixelize( size, antialias=True, periodic=False, + *, + axes_transform=AxesTransform.DEFAULT, ): # Note that above, we set periodic by default to be *false*. This is # because our pixelizers, at present, do not handle periodicity # correctly, and if you change the "width" of a cylindrical plot, it # double-counts in the edge buffers. See, for instance, issue 1669. - ax_name = self.axis_name[dimension] - if ax_name in ("r", "theta"): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: return self._ortho_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) - elif ax_name == "z": + + name = self.axis_name[dimension] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + + if name in ("r", "theta"): + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + return self._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # This is admittedly a very hacky way to resolve a bug # it's very likely that the *right* fix would have to # be applied upstream of this function, *but* this case @@ -113,29 +136,6 @@ def pixelize( def pixelize_line(self, field, start_point, end_point, npoints): raise NotImplementedError - def _ortho_pixelize( - self, data_source, field, bounds, size, antialias, dim, periodic - ): - period = self.period[:2].copy() # dummy here - period[0] = self.period[self.x_axis[dim]] - period[1] = self.period[self.y_axis[dim]] - if hasattr(period, "in_units"): - period = period.in_units("code_length").d - buff = np.full(size, np.nan, dtype="float64") - pixelize_cartesian( - buff, - data_source["px"], - data_source["py"], - data_source["pdx"], - data_source["pdy"], - data_source[field], - bounds, - int(antialias), - period, - int(periodic), - ) - return buff - def _cyl_pixelize(self, data_source, field, bounds, size, antialias): buff = np.full((size[1], size[0]), np.nan, dtype="f8") pixelize_cylinder( @@ -149,13 +149,18 @@ def _cyl_pixelize(self, data_source, field, bounds, size, antialias): ) return buff - _x_pairs = (("r", "theta"), ("z", "r"), ("theta", "r")) - _y_pairs = (("r", "z"), ("z", "theta"), ("theta", "z")) + _x_pairs = (("r", "theta"), ("z", "r"), ("theta", "r")) # deprecated + _y_pairs = (("r", "z"), ("z", "theta"), ("theta", "z")) # deprecated - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -198,47 +203,145 @@ def period(self): def _polar_bounds(self): return _get_polar_bounds(self, axes=("r", "theta")) - def sanitize_center(self, center, axis): - center, display_center = super().sanitize_center(center, axis) + def _get_display_center(self, center, axes_transform): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return center + else: + return super()._get_display_center(center, axes_transform) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + center, display_center = super().sanitize_center( + center, axis, axes_transform=AxesTransform.GEOMETRY_NATIVE + ) display_center = [ 0.0 * display_center[0], 0.0 * display_center[1], 0.0 * display_center[2], ] - ax_name = self.axis_name[axis] r_ax = self.axis_id["r"] theta_ax = self.axis_id["theta"] z_ax = self.axis_id["z"] - if ax_name == "r": + if name == "r": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) display_center[theta_ax] = self.ds.domain_center[theta_ax] display_center[z_ax] = self.ds.domain_center[z_ax] - elif ax_name == "theta": + elif name == "theta": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # use existing center value for idx in (r_ax, z_ax): display_center[idx] = center[idx] - elif ax_name == "z": + elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._polar_bounds xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (xc, yc, 0 * xc) + else: + RuntimeError(f"Unknown axis name {name!r} for cylindrical coordinates") return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + r_ax, theta_ax, z_ax = ( self.ds.coordinates.axis_id[ax] for ax in ("r", "theta", "z") ) if width is not None: - width = super().sanitize_width(axis, width, depth) + width = super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) # Note: regardless of axes, these are set up to give consistent plots # when plotted, which is not strictly a "right hand rule" for axes. elif name == "r": # soup can label + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) width = [self.ds.domain_width[theta_ax], self.ds.domain_width[z_ax]] elif name == "theta": + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) width = [self.ds.domain_width[r_ax], self.ds.domain_width[z_ax]] elif name == "z": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._polar_bounds xw = xxmax - xxmin yw = yymax - yymin width = [xw, yw] return width + + @classmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = cls._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "r": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label=r"\theta", + y_axis_label="z", + x_axis_units="rad", + y_axis_units=None, + ) + else: + raise NotImplementedError + elif normal_axis_name == "theta": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="R", + y_axis_label="z", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + elif normal_axis_name == "z": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="R", + y_axis_label=r"\theta", + x_axis_units=None, + y_axis_units="rad", + ) + elif axes_transform is AxesTransform.POLAR: + return dict( + x_axis_label="x", + y_axis_label="y", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis name {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/geographic_coordinates.py b/yt/geometry/coordinates/geographic_coordinates.py index e8c30389adc..d09640e6601 100644 --- a/yt/geometry/coordinates/geographic_coordinates.py +++ b/yt/geometry/coordinates/geographic_coordinates.py @@ -1,9 +1,14 @@ +from typing import Dict + import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt.utilities.lib.pixelization_routines import pixelize_cartesian, pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _setup_dummy_cartesian_coords_and_widths, ) @@ -12,15 +17,16 @@ class GeographicCoordinateHandler(CoordinateHandler): radial_axis = "altitude" name = "geographic" + _default_axes_transforms: Dict[str, AxesTransform] = { + "latitude": AxesTransform.POLAR, + "longitude": AxesTransform.POLAR, + "altitude": AxesTransform.AITOFF_HAMMER, + } def __init__(self, ds, ordering=None): if ordering is None: ordering = ("latitude", "longitude", self.radial_axis) super().__init__(ds, ordering) - self.image_units = {} - self.image_units[self.axis_id["latitude"]] = (None, None) - self.image_units[self.axis_id["longitude"]] = (None, None) - self.image_units[self.axis_id[self.radial_axis]] = ("deg", "deg") def setup_fields(self, registry): # Missing implementation for x, y and z coordinates. @@ -214,13 +220,29 @@ def _retrieve_radial_offset(self, data_source=None): return surface_height, 1.0 def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): - if self.axis_name[dimension] in ("latitude", "longitude"): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super()._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[dimension] + if name in ("latitude", "longitude"): return self._cyl_pixelize( data_source, field, bounds, size, antialias, dimension ) - elif self.axis_name[dimension] == self.radial_axis: + elif name == self.radial_axis: return self._ortho_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) @@ -326,10 +348,15 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): raise NotImplementedError - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -349,13 +376,13 @@ def image_axis_name(self): self._image_axis_name = rv return rv - _x_pairs = ( + _x_pairs = ( # deprecated ("latitude", "longitude"), ("longitude", "latitude"), ("altitude", "longitude"), ) - _y_pairs = ( + _y_pairs = ( # deprecated ("latitude", "altitude"), ("longitude", "altitude"), ("altitude", "latitude"), @@ -365,6 +392,7 @@ def image_axis_name(self): @property def data_projection(self): + # see https://github.com/yt-project/yt/issues/4182 # this will control the default projection to use when displaying data if self._data_projection is not None: return self._data_projection @@ -397,7 +425,12 @@ def data_transform(self): def period(self): return self.ds.domain_width - def sanitize_center(self, center, axis): + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + center, display_center = super().sanitize_center(center, axis) name = self.axis_name[axis] if name == self.radial_axis: @@ -419,7 +452,16 @@ def sanitize_center(self, center, axis): display_center[self.axis_id["latitude"]] = c return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[axis] if width is not None: width = super().sanitize_width(axis, width, depth) @@ -439,10 +481,81 @@ def sanitize_width(self, axis, width, depth): width = [self.ds.domain_width[ri], 2.0 * self.ds.domain_width[ri]] return width + @classmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = cls._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "latitude": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label=r"longitude", + y_axis_label=r"R", + x_axis_units="deg", + y_axis_units=None, + ) + elif axes_transform is AxesTransform.POLAR: + return dict( + x_axis_label=r"x / \sin(\mathrm{latitude})", + y_axis_label=r"y / \sin(\mathrm{latitude})", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + elif normal_axis_name == "longitude": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="latitude", + y_axis_label="R", + x_axis_units="deg", + y_axis_units=None, + ) + elif axes_transform is AxesTransform.POLAR: + return dict( + x_axis_label="R", + y_axis_label="z", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + elif normal_axis_name == cls.radial_axis: + # TODO(4179): either clean this or refactor something elsewhere, + # because it is currently never used: + # this case is deleguated to cartopy, + # though we have everything needed to do it ourselves ! + # see https://github.com/yt-project/yt/issues/4182 + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="longitude", + y_axis_label="latitude", + x_axis_units="deg", + y_axis_units="deg", + ) + elif axes_transform is AxesTransform.AITOFF_HAMMER: + return dict( + x_axis_label=r"\frac{2\cos(\mathrm{\mathrm{latitude}})\sin(\mathrm{longitude}/2)}{\sqrt{1 + \cos(\mathrm{latitude}) \cos(\mathrm{longitude}/2)}}", + y_axis_label=r"\frac{sin(\mathrm{latitude})}{\sqrt{1 + \cos(\mathrm{latitude}) \cos(\mathrm{longitude}/2)}}", + x_axis_units="dimensionless", + y_axis_units="dimensionless", + ) + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") + class InternalGeographicCoordinateHandler(GeographicCoordinateHandler): radial_axis = "depth" name = "internal_geographic" + _default_axes_transforms: Dict[str, AxesTransform] = { + "latitude": AxesTransform.POLAR, + "longitude": AxesTransform.POLAR, + "depth": AxesTransform.AITOFF_HAMMER, + } def _setup_radial_fields(self, registry): # Altitude is the radius from the central zone minus the radius of the @@ -488,9 +601,18 @@ def _retrieve_radial_offset(self, data_source=None): ("depth", "longitude"), ) - _y_pairs = (("latitude", "depth"), ("longitude", "depth"), ("depth", "latitude")) + _y_pairs = ( + ("latitude", "depth"), + ("longitude", "depth"), + ("depth", "latitude"), + ) + + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_center(center, axis, axes_transform=axes_transform) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError - def sanitize_center(self, center, axis): center, display_center = super( GeographicCoordinateHandler, self ).sanitize_center(center, axis) @@ -515,12 +637,19 @@ def sanitize_center(self, center, axis): display_center[self.axis_id["latitude"]] = outermost / 2.0 return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + if axes_transform is not AxesTransform.DEFAULT: + raise NotImplementedError + name = self.axis_name[axis] if width is not None: - width = super(GeographicCoordinateHandler, self).sanitize_width( - axis, width, depth - ) + width = super().sanitize_width(axis, width, depth) elif name == self.radial_axis: rax = self.radial_axis width = [ diff --git a/yt/geometry/coordinates/spec_cube_coordinates.py b/yt/geometry/coordinates/spec_cube_coordinates.py index 7275b60f736..46ab4906951 100644 --- a/yt/geometry/coordinates/spec_cube_coordinates.py +++ b/yt/geometry/coordinates/spec_cube_coordinates.py @@ -1,9 +1,22 @@ -from .cartesian_coordinates import CartesianCoordinateHandler +from typing import Dict + +from yt._maintenance.deprecation import issue_deprecation_warning + +from ._axes_transforms import AxesTransform +from .cartesian_coordinates import ( + CartesianCoordinateHandler, + DefaultProperties, +) from .coordinate_handler import _get_coord_fields class SpectralCubeCoordinateHandler(CartesianCoordinateHandler): name = "spectral_cube" + _default_axes_transforms: Dict[str, AxesTransform] = { + "x": AxesTransform.GEOMETRY_NATIVE, + "y": AxesTransform.GEOMETRY_NATIVE, + "z": AxesTransform.GEOMETRY_NATIVE, + } def __init__(self, ds, ordering=None): if ordering is None: @@ -12,7 +25,8 @@ def __init__(self, ds, ordering=None): ) super().__init__(ds, ordering) - self.default_unit_label = {} + # TODO(4179): migrate this + self.default_unit_label = {} # deprecated names = {} if ds.lon_name != "X" or ds.lat_name != "Y": names["x"] = r"Image\ x" @@ -24,7 +38,7 @@ def __init__(self, ds, ordering=None): # Again, can use spec_axis here self.default_unit_label[ds.spec_axis] = ds.spec_unit - self._image_axis_name = ian = {} + self._image_axis_name = ian = {} # deprecated for ax in "xyz": axi = self.axis_id[ax] xax = self.axis_name[self.x_axis[ax]] @@ -94,4 +108,46 @@ def convert_from_cylindrical(self, coord): @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) return self._image_axis_name + + @classmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = AxesTransform.GEOMETRY_NATIVE + if axes_transform is not AxesTransform.GEOMETRY_NATIVE: + raise NotImplementedError( + f"spectral cube coordinates don't implement {axes_transform} yet" + ) + + if normal_axis_name == "x": + return dict( + x_axis_label="y", + y_axis_label="z", + x_axis_units=None, + # https://github.com/yt-project/yt/issues/4350 + y_axis_units="dimensionless", + ) + elif normal_axis_name == "y": + return dict( + x_axis_label="x", + y_axis_label="z", + x_axis_units=None, + # https://github.com/yt-project/yt/issues/4350 + y_axis_units="dimensionless", + ) + elif normal_axis_name == "z": + return dict( + x_axis_label="x", + y_axis_label="y", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise ValueError(f"Unknown axis {normal_axis_name!r}") diff --git a/yt/geometry/coordinates/spherical_coordinates.py b/yt/geometry/coordinates/spherical_coordinates.py index ee98e48b2f6..085d7d80135 100644 --- a/yt/geometry/coordinates/spherical_coordinates.py +++ b/yt/geometry/coordinates/spherical_coordinates.py @@ -1,11 +1,15 @@ from functools import cached_property +from typing import Dict import numpy as np +from yt._maintenance.deprecation import issue_deprecation_warning from yt.utilities.lib.pixelization_routines import pixelize_aitoff, pixelize_cylinder +from ._axes_transforms import AxesTransform from .coordinate_handler import ( CoordinateHandler, + DefaultProperties, _get_coord_fields, _get_polar_bounds, _setup_dummy_cartesian_coords_and_widths, @@ -16,14 +20,14 @@ class SphericalCoordinateHandler(CoordinateHandler): name = "spherical" _default_axis_order = ("r", "theta", "phi") + _default_axes_transforms: Dict[str, AxesTransform] = { + "r": AxesTransform.AITOFF_HAMMER, + "theta": AxesTransform.POLAR, + "phi": AxesTransform.POLAR, + } def __init__(self, ds, ordering=None): super().__init__(ds, ordering) - # Generate - self.image_units = {} - self.image_units[self.axis_id["r"]] = (1, 1) - self.image_units[self.axis_id["theta"]] = (None, None) - self.image_units[self.axis_id["phi"]] = (None, None) def setup_fields(self, registry): # Missing implementation for x, y and z coordinates. @@ -82,15 +86,40 @@ def _path_phi(field, data): ) def pixelize( - self, dimension, data_source, field, bounds, size, antialias=True, periodic=True + self, + dimension, + data_source, + field, + bounds, + size, + antialias=True, + periodic=True, + *, + axes_transform=AxesTransform.DEFAULT, ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return self._ortho_pixelize( + data_source, field, bounds, size, antialias, dimension, periodic + ) + self.period name = self.axis_name[dimension] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if name == "r": - return self._ortho_pixelize( + if axes_transform is not AxesTransform.AITOFF_HAMMER: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + return self._aitoff_hammer_pixelize( data_source, field, bounds, size, antialias, dimension, periodic ) elif name in ("theta", "phi"): + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) if name == "theta": # This is admittedly a very hacky way to resolve a bug # it's very likely that the *right* fix would have to @@ -108,7 +137,7 @@ def pixelize( def pixelize_line(self, field, start_point, end_point, npoints): raise NotImplementedError - def _ortho_pixelize( + def _aitoff_hammer_pixelize( self, data_source, field, bounds, size, antialias, dim, periodic ): # use Aitoff projection @@ -193,10 +222,15 @@ def convert_to_spherical(self, coord): def convert_from_spherical(self, coord): raise NotImplementedError - _image_axis_name = None + _image_axis_name = None # deprecated @property def image_axis_name(self): + issue_deprecation_warning( + "The image_axis_name property isn't used " + "internally in yt anymore and is deprecated", + since="4.2.0", + ) if self._image_axis_name is not None: return self._image_axis_name # This is the x and y axes labels that get displayed. For @@ -226,18 +260,74 @@ def image_axis_name(self): _x_pairs = (("r", "theta"), ("theta", "r"), ("phi", "r")) _y_pairs = (("r", "phi"), ("theta", "phi"), ("phi", "theta")) + @classmethod + def _get_plot_axes_default_properties( + cls, normal_axis_name: str, axes_transform: AxesTransform + ) -> DefaultProperties: + if axes_transform is AxesTransform.DEFAULT: + axes_transform = cls._default_axes_transforms[normal_axis_name] + + if normal_axis_name == "r": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label=r"\theta", + y_axis_label=r"\phi", + x_axis_units="rad", + y_axis_units="rad", + ) + elif axes_transform is AxesTransform.AITOFF_HAMMER: + return dict( + x_axis_label=r"\frac{2\cos(\mathrm{\bar{\theta}})\sin(\lambda/2)}{\sqrt{1 + \cos(\bar{\theta}) \cos(\lambda/2)}}", + y_axis_label=r"\frac{sin(\bar{\theta})}{\sqrt{1 + \cos(\bar{\theta}) \cos(\lambda/2)}}", + x_axis_units="dimensionless", + y_axis_units="dimensionless", + ) + else: + raise NotImplementedError + elif normal_axis_name == "theta": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="r", + y_axis_label=r"\phi", + x_axis_units=None, + y_axis_units="rad", + ) + elif axes_transform is AxesTransform.POLAR: + return dict( + x_axis_label=r"x / \sin(\theta)", + y_axis_label=r"y / \sin(\theta)", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + elif normal_axis_name == "phi": + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return dict( + x_axis_label="r", + y_axis_label=r"\theta", + x_axis_units=None, + y_axis_units="rad", + ) + elif axes_transform is AxesTransform.POLAR: + return dict( + x_axis_label="R", + y_axis_label="z", + x_axis_units=None, + y_axis_units=None, + ) + else: + raise NotImplementedError + else: + raise ValueError(f"Unknown axis name {normal_axis_name!r}") + @property def period(self): return self.ds.domain_width @cached_property def _poloidal_bounds(self): - ri = self.axis_id["r"] - ti = self.axis_id["theta"] - rmin = self.ds.domain_left_edge[ri] - rmax = self.ds.domain_right_edge[ri] - thetamin = self.ds.domain_left_edge[ti] - thetamax = self.ds.domain_right_edge[ti] + rmin, rmax, thetamin, thetamax, _, _ = self._r_theta_phi_bounds corners = [ (rmin, thetamin), (rmin, thetamax), @@ -273,6 +363,25 @@ def to_poloidal_plane(r, theta): def _conic_bounds(self): return _get_polar_bounds(self, axes=("r", "phi")) + @cached_property + def _r_theta_phi_bounds(self): + # radius + ri = self.axis_id["r"] + rmin = self.ds.domain_left_edge[ri] + rmax = self.ds.domain_right_edge[ri] + + # colatitude + ti = self.axis_id["theta"] + thetamin = self.ds.domain_left_edge[ti] + thetamax = self.ds.domain_right_edge[ti] + + # azimuth + pi = self.axis_id["phi"] + phimin = self.ds.domain_left_edge[pi] + phimax = self.ds.domain_right_edge[pi] + + return rmin, rmax, thetamin, thetamax, phimin, phimax + @cached_property def _aitoff_bounds(self): # at the time of writing this function, yt's support for curvilinear @@ -281,18 +390,12 @@ def _aitoff_bounds(self): # this is not needed but calls for a large refactor. ONE = self.ds.quan(1, "code_length") - # colatitude - ti = self.axis_id["theta"] - thetamin = self.ds.domain_left_edge[ti] - thetamax = self.ds.domain_right_edge[ti] + _, _, thetamin, thetamax, phimin, phimax = self._r_theta_phi_bounds + # latitude latmin = ONE * np.pi / 2 - thetamax latmax = ONE * np.pi / 2 - thetamin - # azimuth - pi = self.axis_id["phi"] - phimin = self.ds.domain_left_edge[pi] - phimax = self.ds.domain_right_edge[pi] # longitude lonmin = phimin - ONE * np.pi lonmax = phimax - ONE * np.pi @@ -341,36 +444,79 @@ def to_aitoff_plane(latitude, longitude): return xmin, xmax, ymin, ymax - def sanitize_center(self, center, axis): + def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT): center, display_center = super().sanitize_center(center, axis) name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if name == "r": - xxmin, xxmax, yymin, yymax = self._aitoff_bounds + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + _, _, xxmin, xxmax, yymin, yymax = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.AITOFF_HAMMER: + xxmin, xxmax, yymin, yymax = self._aitoff_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (0 * xc, xc, yc) + elif name == "theta": - xxmin, xxmax, yymin, yymax = self._conic_bounds + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + xxmin, xxmax, _, _, yymin, yymax = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.POLAR: + xxmin, xxmax, yymin, yymax = self._conic_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xc = (xxmin + xxmax) / 2 yc = (yymin + yymax) / 2 display_center = (xc, 0 * xc, yc) elif name == "phi": - Rmin, Rmax, zmin, zmax = self._poloidal_bounds - xc = (Rmin + Rmax) / 2 - yc = (zmin + zmax) / 2 + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + xxmin, xxmax, yymin, yymax, _, _ = self._r_theta_phi_bounds + elif axes_transform is AxesTransform.POLAR: + xxmin, xxmax, yymin, yymax = self._poloidal_bounds + else: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) + xc = (xxmin + xxmax) / 2 + yc = (yymin + yymax) / 2 display_center = (xc, yc) return center, display_center - def sanitize_width(self, axis, width, depth): + def sanitize_width( + self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT + ): + if axes_transform is AxesTransform.GEOMETRY_NATIVE: + return super().sanitize_width( + axis, width, depth, axes_transform=axes_transform + ) + name = self.axis_name[axis] + if axes_transform is AxesTransform.DEFAULT: + axes_transform = self._default_axes_transforms[name] + if width is not None: width = super().sanitize_width(axis, width, depth) elif name == "r": + if axes_transform is not AxesTransform.AITOFF_HAMMER: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) xxmin, xxmax, yymin, yymax = self._aitoff_bounds xw = xxmax - xxmin yw = yymax - yymin width = [xw, yw] elif name == "theta": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) # Remember, in spherical coordinates when we cut in theta, # we create a conic section xxmin, xxmax, yymin, yymax = self._conic_bounds @@ -378,6 +524,10 @@ def sanitize_width(self, axis, width, depth): yw = yymax - yymin width = [xw, yw] elif name == "phi": + if axes_transform is not AxesTransform.POLAR: + raise NotImplementedError( + f"{axes_transform} is not implemented for normal axis {name!r}" + ) Rmin, Rmax, zmin, zmax = self._poloidal_bounds xw = Rmax - Rmin yw = zmax - zmin diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 131f07eec8c..8dd38e3e880 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -9,6 +9,7 @@ from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, iter_fields, mylog +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.loaders import load_uniform_grid from yt.utilities.lib.api import ( # type: ignore CICDeposit_2, @@ -107,6 +108,7 @@ def __init__( periodic=False, *, filters: Optional[List["FixedResolutionBufferFilter"]] = None, + axes_transform: AxesTransform = AxesTransform.DEFAULT, ): self.data_source = data_source self.ds = data_source.ds @@ -117,6 +119,7 @@ def __init__( self.axis = data_source.axis self.periodic = periodic self._data_valid = False + self._axes_transform = axes_transform # import type here to avoid import cycles # note that this import statement is actually crucial at runtime: @@ -174,6 +177,7 @@ def __getitem__(self, item): bounds, self.buff_size, int(self.antialias), + axes_transform=self._axes_transform, ) buff = self._apply_filters(buff) @@ -680,9 +684,16 @@ def __init__( periodic=False, *, filters=None, + axes_transform=AxesTransform.DEFAULT, ): super().__init__( - data_source, bounds, buff_size, antialias, periodic, filters=filters + data_source, + bounds, + buff_size, + antialias, + periodic, + filters=filters, + axes_transform=axes_transform, ) # set up the axis field names diff --git a/yt/visualization/line_plot.py b/yt/visualization/line_plot.py index 2d44176138c..97907234e95 100644 --- a/yt/visualization/line_plot.py +++ b/yt/visualization/line_plot.py @@ -334,12 +334,12 @@ def _setup_plots(self): # set x and y axis labels axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) - if self._xlabel is not None: + if self._xlabel: x_label = self._xlabel else: x_label = r"$\rm{Path\ Length" + axes_unit_labels[0] + "}$" - if self._ylabel is not None: + if self._ylabel: y_label = self._ylabel else: finfo = self.ds.field_info[field] diff --git a/yt/visualization/particle_plots.py b/yt/visualization/particle_plots.py index dc2fee71c49..c137ec1c59c 100644 --- a/yt/visualization/particle_plots.py +++ b/yt/visualization/particle_plots.py @@ -5,6 +5,7 @@ from yt.data_objects.profiles import create_profile from yt.data_objects.static_output import Dataset from yt.funcs import fix_axis, iter_fields +from yt.geometry.coordinates._axes_transforms import AxesTransform from yt.units.yt_array import YTArray from yt.visualization.fixed_resolution import ParticleImageBuffer from yt.visualization.profile_plotter import PhasePlot @@ -250,7 +251,11 @@ def __init__( ds = self.ds = ts[0] axis = fix_axis(axis, ds) (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=AxesTransform.DEFAULT, ) if field_parameters is None: field_parameters = {} diff --git a/yt/visualization/plot_container.py b/yt/visualization/plot_container.py index 6b5198ba28a..a240e945e64 100644 --- a/yt/visualization/plot_container.py +++ b/yt/visualization/plot_container.py @@ -137,8 +137,8 @@ def __init__(self, data_source, figure_size=None, fontsize: Optional[float] = No self._font_properties = FontProperties(**font_dict) self._font_color = None - self._xlabel = None - self._ylabel = None + self._xlabel: str = "" + self._ylabel: str = "" self._minorticks: Dict[FieldKey, bool] = {} @accepts_all_fields @@ -665,7 +665,7 @@ def _repr_html_(self): return ret @invalidate_plot - def set_xlabel(self, label): + def set_xlabel(self, label: str): r""" Allow the user to modify the X-axis title Defaults to the global value. Fontsize defaults @@ -683,7 +683,7 @@ def set_xlabel(self, label): return self @invalidate_plot - def set_ylabel(self, label): + def set_ylabel(self, label: str): r""" Allow the user to modify the Y-axis title Defaults to the global value. @@ -699,26 +699,15 @@ def set_ylabel(self, label): self._ylabel = label return self - def _get_axes_unit_labels(self, unit_x, unit_y): + def _get_axes_unit_labels(self, unit_x: str, unit_y: str) -> Tuple[str, str]: axes_unit_labels = ["", ""] comoving = False hinv = False for i, un in enumerate((unit_x, unit_y)): - unn = None - if hasattr(self.data_source, "axis"): - if hasattr(self.ds.coordinates, "image_units"): - # This *forces* an override - unn = self.ds.coordinates.image_units[self.data_source.axis][i] - elif hasattr(self.ds.coordinates, "default_unit_label"): - axax = getattr(self.ds.coordinates, f"{'xy'[i]}_axis")[ - self.data_source.axis - ] - unn = self.ds.coordinates.default_unit_label.get(axax, None) - if unn in (1, "1", "dimensionless"): - axes_unit_labels[i] = "" + if un == "dimensionless": continue - if unn is not None: - axes_unit_labels[i] = r"\ \ \left(" + unn + r"\right)" + if un in ("rad", "deg"): + axes_unit_labels[i] = r"\ \ \left(" + un + r"\right)" continue # Use sympy to factor h out of the unit. In this context 'un' # is a string, so we call the Unit constructor. @@ -744,8 +733,7 @@ def _get_axes_unit_labels(self, unit_x, unit_y): if un in formatted_length_unit_names: un = formatted_length_unit_names[un] else: - un = Unit(un, registry=self.ds.unit_registry) - un = un.latex_representation() + un = Unit(un, registry=self.ds.unit_registry).latex_representation() if hinv: un = un + r"\,h^{-1}" if comoving: @@ -759,7 +747,7 @@ def _get_axes_unit_labels(self, unit_x, unit_y): axes_unit_labels[i] = r"\ \ \left(" + un + r"\right)" else: axes_unit_labels[i] = r"\ \ (" + un + r")" - return axes_unit_labels + return axes_unit_labels[0], axes_unit_labels[1] def hide_colorbar(self, field=None): """ @@ -895,7 +883,7 @@ class ImagePlotContainer(PlotContainer, abc.ABC): def __init__(self, data_source, figure_size, fontsize): super().__init__(data_source, figure_size, fontsize) self._callbacks = [] - self._colorbar_label = PlotDictionary(self.data_source, lambda: None) + self._colorbar_label = PlotDictionary(self.data_source, str) def _get_default_handlers( self, field, default_display_units: Unit @@ -1071,7 +1059,7 @@ def set_colorbar_minorticks(self, field, state): return self @invalidate_plot - def set_colorbar_label(self, field, label): + def set_colorbar_label(self, field: Union[str, Tuple[str, str]], label: str): r""" Sets the colorbar label. @@ -1091,7 +1079,7 @@ def set_colorbar_label(self, field, label): self._colorbar_label[field] = label return self - def _get_axes_labels(self, field): + def _get_axes_labels(self, field) -> Tuple[str, str, str]: return (self._xlabel, self._ylabel, self._colorbar_label[field]) diff --git a/yt/visualization/plot_window.py b/yt/visualization/plot_window.py index 78ec20e0454..88781457fad 100644 --- a/yt/visualization/plot_window.py +++ b/yt/visualization/plot_window.py @@ -22,6 +22,7 @@ validate_moment, ) from yt.geometry.api import Geometry +from yt.geometry.coordinates._axes_transforms import AxesTransform, parse_axes_transform from yt.units.unit_object import Unit # type: ignore from yt.units.unit_registry import UnitParseError # type: ignore from yt.units.yt_array import YTArray, YTQuantity @@ -63,9 +64,15 @@ from yt._maintenance.backports import zip -def get_window_parameters(axis, center, width, ds): - width = ds.coordinates.sanitize_width(axis, width, None) - center, display_center = ds.coordinates.sanitize_center(center, axis) +def get_window_parameters( + axis, center, width, ds, axes_transform: AxesTransform = AxesTransform.DEFAULT +): + width = ds.coordinates.sanitize_width( + axis, width, None, axes_transform=axes_transform + ) + center, display_center = ds.coordinates.sanitize_center( + center, axis, axes_transform=axes_transform + ) xax = ds.coordinates.x_axis[axis] yax = ds.coordinates.y_axis[axis] bounds = ( @@ -98,6 +105,8 @@ def get_axes_unit(width, ds): r""" Infers the axes unit names from the input width specification """ + if width is None: + return None if ds.no_cgs_equiv_length: return ("code_length",) * 2 if is_sequence(width): @@ -187,6 +196,7 @@ def __init__( setup=False, *, geometry: Geometry = Geometry.CARTESIAN, + axes_transform: AxesTransform = AxesTransform.DEFAULT, ) -> None: # axis manipulation operations are callback-only: self._swap_axes_input = False @@ -200,8 +210,12 @@ def __init__( self.buff_size = buff_size self.antialias = antialias self._axes_unit_names = None + + # TODO(4179): handle compat with _transform and _projection ? + # see https://github.com/yt-project/yt/issues/4182 self._transform = None self._projection = None + self._axes_transform = axes_transform self.aspect = aspect skip = list(FixedResolutionBuffer._exclude_fields) + data_source._key_fields @@ -231,11 +245,12 @@ def __init__( self.origin = origin if self.data_source.center is not None and not oblique: + # see https://github.com/yt-project/yt/issues/4182 ax = self.data_source.axis xax = self.ds.coordinates.x_axis[ax] yax = self.ds.coordinates.y_axis[ax] center, display_center = self.ds.coordinates.sanitize_center( - self.data_source.center, ax + self.data_source.center, ax, axes_transform=self._axes_transform ) center = [display_center[xax], display_center[yax]] self.set_center(center) @@ -325,6 +340,7 @@ def _recreate_frb(self): self.antialias, periodic=self._periodic, filters=old_filters, + axes_transform=self._axes_transform, ) # At this point the frb has the valid bounds, size, aliasing, etc. @@ -687,8 +703,9 @@ def set_width(self, width, unit=None): axes_unit = get_axes_unit(width, self.ds) - width = self.ds.coordinates.sanitize_width(self.frb.axis, width, None) - + width = self.ds.coordinates.sanitize_width( + self.frb.axis, width, None, axes_transform=self._axes_transform + ) centerx = (self.xlim[1] + self.xlim[0]) / 2.0 centery = (self.ylim[1] + self.ylim[0]) / 2.0 @@ -862,7 +879,8 @@ def __init__(self, *args, **kwargs) -> None: if self._plot_type is None: self._plot_type = kwargs.pop("plot_type") self._splat_color = kwargs.pop("splat_color", None) - PlotWindow.__init__(self, *args, **kwargs) + self._frb: Optional[FixedResolutionBuffer] = None + super().__init__(*args, **kwargs) # import type here to avoid import cycles # note that this import statement is actually crucial at runtime: @@ -1008,62 +1026,73 @@ def _setup_plots(self): self._recreate_frb() self._colorbar_valid = True field_list = list(set(self.data_source._determine_fields(self.fields))) - for f in field_list: - axis_index = self.data_source.axis + coordinates = self.ds.coordinates + normal_axis_index = self.data_source.axis + + if self.oblique: + normal_axis_name = "oblique" + else: + normal_axis_name = coordinates.axis_name[normal_axis_index] + + default_plot_properties = coordinates._get_plot_axes_default_properties( + normal_axis_name, self._axes_transform + ) + + for f in field_list: xc, yc = self._setup_origin() - if self.ds._uses_code_length_unit: - # this should happen only if the dataset was initialized with - # argument unit_system="code" or if it's set to have no CGS - # equivalent. This only needs to happen here in the specific - # case that we're doing a computationally intense operation - # like using cartopy, but it prevents crashes in that case. - (unit_x, unit_y) = ("code_length", "code_length") - elif self._axes_unit_names is None: - unit = self.ds.get_smallest_appropriate_unit( - self.xlim[1] - self.xlim[0] - ) - unit_x = unit_y = unit - coords = self.ds.coordinates - if hasattr(coords, "image_units"): - # check for special cases defined in - # non cartesian CoordinateHandler subclasses - image_units = coords.image_units[coords.axis_id[axis_index]] - if image_units[0] in ("deg", "rad"): + if self._axes_unit_names is None: + unit_x = default_plot_properties["x_axis_units"] + unit_y = default_plot_properties["y_axis_units"] + if unit_x is None: + if self.ds._uses_code_length_unit: unit_x = "code_length" - elif image_units[0] == 1: - unit_x = "dimensionless" - if image_units[1] in ("deg", "rad"): + else: + unit_x = self.ds.get_smallest_appropriate_unit( + self.xlim[1] - self.xlim[0] + ) + if unit_y is None: + if self.ds._uses_code_length_unit: unit_y = "code_length" - elif image_units[1] == 1: - unit_y = "dimensionless" + else: + unit_y = self.ds.get_smallest_appropriate_unit( + self.ylim[1] - self.ylim[0] + ) else: (unit_x, unit_y) = self._axes_unit_names - # For some plots we may set aspect by hand, such as for spectral cube data. - # This will likely be replaced at some point by the coordinate handler - # setting plot aspect. - if self.aspect is None: - self.aspect = float( - (self.ds.quan(1.0, unit_y) / self.ds.quan(1.0, unit_x)).in_cgs() - ) extentx = (self.xlim - xc)[:2] extenty = (self.ylim - yc)[:2] # extentx/y arrays inherit units from xlim and ylim attributes # and these attributes are always length even for angular and # dimensionless axes so we need to strip out units for consistency - if unit_x == "dimensionless": + if unit_x in ("dimensionless", "rad", "deg"): extentx = extentx / extentx.units else: extentx.convert_to_units(unit_x) - if unit_y == "dimensionless": + if unit_y in ("dimensionless", "rad", "deg"): extenty = extenty / extenty.units else: extenty.convert_to_units(unit_y) extent = [*extentx, *extenty] + # For some plots we may set aspect by hand, such as for spectral cube data. + # This will likely be replaced at some point by the coordinate handler + # setting plot aspect. + if self.aspect is None: + ratio = (self.ds.quan(1.0, unit_y) / self.ds.quan(1.0, unit_x)).in_cgs() + if ratio.units.is_dimensionless: + self.aspect = float(ratio) + else: + # maybe we have length on the x axis and radians on the y axis + # in that case, it doesn't make much sense to impose a 1 to 1 ratio + # so instead we set the image to be a square by default + self.aspect = float( + ((extentx[1] - extentx[0]) / (extenty[1] - extenty[0])).value + ) + image = self.frb[f] font_size = self._font_properties.get_size() @@ -1130,30 +1159,31 @@ def _setup_plots(self): colorbar_handler=cbh, ) - axes_unit_labels = self._get_axes_unit_labels(unit_x, unit_y) + # construct default axes labels + x_unit_label, y_unit_label = self._get_axes_unit_labels(unit_x, unit_y) if self.oblique: labels = [ - r"$\rm{Image\ x" + axes_unit_labels[0] + "}$", - r"$\rm{Image\ y" + axes_unit_labels[1] + "}$", + r"$\rm{Image\ x" + x_unit_label + "}$", + r"$\rm{Image\ y" + y_unit_label + "}$", ] else: - coordinates = self.ds.coordinates - axis_names = coordinates.image_axis_name[axis_index] - xax = coordinates.x_axis[axis_index] - yax = coordinates.y_axis[axis_index] - - if hasattr(coordinates, "axis_default_unit_name"): - axes_unit_labels = [ - coordinates.axis_default_unit_name[xax], - coordinates.axis_default_unit_name[yax], - ] labels = [ - r"$\rm{" + axis_names[0] + axes_unit_labels[0] + r"}$", - r"$\rm{" + axis_names[1] + axes_unit_labels[1] + r"}$", + r"$\rm{" + + default_plot_properties["x_axis_label"] + + x_unit_label + + r"}$", + r"$\rm{" + + default_plot_properties["y_axis_label"] + + y_unit_label + + r"}$", ] if hasattr(coordinates, "axis_field"): + # this is exclusive to spectral_cube geometries + xax = coordinates.x_axis[normal_axis_index] + yax = coordinates.y_axis[normal_axis_index] + if xax in coordinates.axis_field: xmin, xmax = coordinates.axis_field[xax]( 0, self.xlim, self.ylim @@ -1174,9 +1204,9 @@ def _setup_plots(self): x_label, y_label, colorbar_label = self._get_axes_labels(f) - if x_label is not None: + if x_label: labels[0] = x_label - if y_label is not None: + if y_label: labels[1] = y_label if swap_axes: @@ -1189,7 +1219,7 @@ def _setup_plots(self): units = Unit(self.frb[f].units, registry=self.ds.unit_registry) units = units.latex_representation() - if colorbar_label is None: + if not colorbar_label: colorbar_label = image.info["label"] if getattr(self, "moment", 1) == 2: colorbar_label = "%s \\rm{Standard Deviation}" % colorbar_label @@ -1802,6 +1832,7 @@ def __init__( buff_size=(800, 800), *, north_vector=None, + axes_transform: Optional[str] = None, ): if north_vector is not None: # this kwarg exists only for symmetry reasons with OffAxisSlicePlot @@ -1812,10 +1843,17 @@ def __init__( del north_vector normal = self.sanitize_normal_vector(ds, normal) + + _axt = parse_axes_transform(axes_transform) + # this will handle time series data and controllers axis = fix_axis(normal, ds) (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=_axt, ) if field_parameters is None: field_parameters = {} @@ -1835,8 +1873,7 @@ def __init__( ) slc.get_data(fields) validate_mesh_fields(slc, fields) - PWViewerMPL.__init__( - self, + super().__init__( slc, bounds, origin=origin, @@ -1846,6 +1883,7 @@ def __init__( aspect=aspect, buff_size=buff_size, geometry=ds.geometry, + axes_transform=_axt, ) if axes_unit is None: axes_unit = get_axes_unit(width, ds) @@ -2031,6 +2069,7 @@ def __init__( aspect=None, *, moment=1, + axes_transform: Optional[str] = None, ): if method == "mip": issue_deprecation_warning( @@ -2041,12 +2080,18 @@ def __init__( method = "max" normal = self.sanitize_normal_vector(ds, normal) + _axt = parse_axes_transform(axes_transform) + axis = fix_axis(normal, ds) # If a non-weighted integral projection, assure field-label reflects that if weight_field is None and method == "integrate": self.projected = True (bounds, center, display_center) = get_window_parameters( - axis, center, width, ds + axis, + center, + width, + ds, + axes_transform=_axt, ) if field_parameters is None: field_parameters = {} @@ -2092,6 +2137,7 @@ def __init__( aspect=aspect, buff_size=buff_size, geometry=ds.geometry, + axes_transform=_axt, ) if axes_unit is None: axes_unit = get_axes_unit(width, ds) @@ -2558,6 +2604,8 @@ def plot_2d( window_size=8.0, aspect=None, data_source=None, + *, + axes_transform: Optional[str] = None, ) -> AxisAlignedSlicePlot: r"""Creates a plot of a 2D dataset @@ -2708,4 +2756,5 @@ def plot_2d( window_size=window_size, aspect=aspect, data_source=data_source, + axes_transform=axes_transform, )