Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: enable plotting curvilinear geometries in native coordinates #4179

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial, wraps
from re import finditer
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Optional

import numpy as np
from more_itertools import always_iterable
Expand All @@ -30,6 +31,7 @@
validate_moment,
)
from yt.geometry import particle_deposit as particle_deposit
from yt.geometry.coordinates._axes_transforms import parse_axes_transform
from yt.geometry.coordinates.cartesian_coordinates import all_data
from yt.loaders import load_uniform_grid
from yt.units._numpy_wrapper_functions import uconcatenate
Expand Down Expand Up @@ -350,15 +352,26 @@ def _sq_field(field, data, fname: FieldKey):
self.ds.field_info.pop(field)
self.tree = tree

def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
def to_pw(
self,
fields=None,
center="center",
width=None,
origin="center-window",
*,
axes_transform: Optional[str] = None,
):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.

This is a bare-bones mechanism of creating a plot window from this
object, which can then be moved around, zoomed, and on and on. All
behavior of the plot window is relegated to that routine.
"""
pw = self._get_pw(fields, center, width, origin, "Projection")
_axt = parse_axes_transform(axes_transform)
pw = self._get_pw(
fields, center, width, origin, "Projection", axes_transform=_axt
)
return pw

def plot(self, fields=None):
Expand Down
15 changes: 12 additions & 3 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from yt.fields.field_exceptions import NeedsGridType
from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple
from yt.geometry.api import Geometry
from yt.geometry.coordinates._axes_transforms import AxesTransform
from yt.geometry.selection_routines import compose_selector
from yt.units import YTArray
from yt.utilities.exceptions import (
Expand Down Expand Up @@ -530,20 +531,28 @@ def __init__(self, axis, ds, field_parameters=None, data_source=None):
def _convert_field_name(self, field):
return field

def _get_pw(self, fields, center, width, origin, plot_type):
def _get_pw(
self, fields, center, width, origin, plot_type, *, axes_transform: AxesTransform
):
from yt.visualization.fixed_resolution import FixedResolutionBuffer as frb
from yt.visualization.plot_window import PWViewerMPL, get_window_parameters

axis = self.axis
skip = self._key_fields
skip += list(set(frb._exclude_fields).difference(set(self._key_fields)))
# this line works, but mypy incorrectly flags it, so turning it off locally
skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) # type: ignore [arg-type]
self.fields = [k for k in self.field_data if k not in skip]
if fields is not None:
self.fields = list(iter_fields(fields)) + self.fields
if len(self.fields) == 0:
raise ValueError("No fields found to plot in get_pw")

(bounds, center, display_center) = get_window_parameters(
axis, center, width, self.ds
axis,
center,
width,
self.ds,
axes_transform=axes_transform,
)
pw = PWViewerMPL(
self,
Expand Down
16 changes: 14 additions & 2 deletions yt/data_objects/selection_objects/slices.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np

from yt.data_objects.selection_objects.data_selection_objects import (
Expand All @@ -15,6 +17,7 @@
validate_object,
validate_width_tuple,
)
from yt.geometry.coordinates._axes_transforms import parse_axes_transform
from yt.utilities.exceptions import YTNotInsideNotebook
from yt.utilities.minimal_representation import MinimalSliceData
from yt.utilities.orientation import Orientation
Expand Down Expand Up @@ -104,15 +107,24 @@ def _generate_container_field(self, field):
def _mrep(self):
return MinimalSliceData(self)

def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
def to_pw(
self,
fields=None,
center="center",
width=None,
origin="center-window",
*,
axes_transform: Optional[str] = None,
):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.

This is a bare-bones mechanism of creating a plot window from this
object, which can then be moved around, zoomed, and on and on. All
behavior of the plot window is relegated to that routine.
"""
pw = self._get_pw(fields, center, width, origin, "Slice")
_axt = parse_axes_transform(axes_transform)
pw = self._get_pw(fields, center, width, origin, "Slice", axes_transform=_axt)
return pw

def plot(self, fields=None):
Expand Down
1 change: 1 addition & 0 deletions yt/frontends/nc4_cm1/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _setup_coordinate_handler(self, axis_order: Optional[AxisOrder]) -> None:
# type checking is deactivated in the following two lines because changing them is not
# within the scope of the PR that _enabled_ typechecking here (#4244), but it'd be worth
# having a careful look at *why* these warnings appear, as they may point to rotten code
# TODO(4179): refactor this out
self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x")) # type: ignore [union-attr]
self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y")) # type: ignore [union-attr]

Expand Down
22 changes: 22 additions & 0 deletions yt/geometry/coordinates/_axes_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
from enum import auto
from typing import Optional

if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from yt._maintenance.backports import StrEnum


class AxesTransform(StrEnum):
DEFAULT = auto()
GEOMETRY_NATIVE = auto()
POLAR = auto()
AITOFF_HAMMER = auto()


def parse_axes_transform(axes_transform: Optional[str]) -> AxesTransform:
if axes_transform is None:
# pass the responsability to ds.coordinates
axes_transform = "default"
return AxesTransform(axes_transform)
62 changes: 61 additions & 1 deletion yt/geometry/coordinates/cartesian_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from yt.utilities.math_utils import compute_stddev_image
from yt.utilities.nodal_data_utils import get_nodal_data

from ._axes_transforms import AxesTransform
from .coordinate_handler import (
CoordinateHandler,
DefaultProperties,
_get_coord_fields,
_get_vert_fields,
cartesian_to_cylindrical,
Expand Down Expand Up @@ -161,13 +163,29 @@ def _check_fields(self, registry):
)

def pixelize(
self, dimension, data_source, field, bounds, size, antialias=True, periodic=True
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
periodic=True,
*,
axes_transform=AxesTransform.DEFAULT,
):
"""
Method for pixelizing datasets in preparation for
two-dimensional image plots. Relies on several sampling
routines written in cython
"""
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"cartesian coordinates don't implement {axes_transform} yet"
)
index = data_source.ds.index
if hasattr(index, "meshes") and not isinstance(
index.meshes[0], SemiStructuredMesh
Expand Down Expand Up @@ -624,3 +642,45 @@ def convert_from_spherical(self, coord):
@property
def period(self):
return self.ds.domain_width

def _get_plot_axes_default_properties(
self, normal_axis_name: str, axes_transform: AxesTransform
) -> DefaultProperties:
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"cartesian coordinates don't implement {axes_transform} yet"
)

if normal_axis_name == "x":
return {
"x_axis_label": "y",
"y_axis_label": "z",
"x_axis_units": None,
"y_axis_units": None,
}
elif normal_axis_name == "y":
return {
"x_axis_label": "z",
"y_axis_label": "x",
"x_axis_units": None,
"y_axis_units": None,
}
elif normal_axis_name == "z":
return {
"x_axis_label": "x",
"y_axis_label": "y",
"x_axis_units": None,
"y_axis_units": None,
}
elif normal_axis_name == "oblique":
return {
"x_axis_label": "Image x",
"y_axis_label": "Image y",
"x_axis_units": None,
"y_axis_units": None,
}
else:
raise ValueError(f"Unknown axis {normal_axis_name!r}")
90 changes: 85 additions & 5 deletions yt/geometry/coordinates/coordinate_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,27 @@
import weakref
from functools import cached_property
from numbers import Number
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple, TypedDict

import numpy as np

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import AxisOrder
from yt.funcs import fix_unitary, is_sequence, parse_center_array, validate_width_tuple
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.exceptions import YTCoordinateNotImplemented, YTInvalidWidthError
from yt.utilities.lib.pixelization_routines import pixelize_cartesian

from ._axes_transforms import AxesTransform


class DefaultProperties(TypedDict):
x_axis_label: str
y_axis_label: str
# note that an empty string maps to "dimensionless",
# while None means "figure it out yourself"
x_axis_units: Optional[str]
y_axis_units: Optional[str]


def _unknown_coord(field, data):
Expand Down Expand Up @@ -133,6 +146,7 @@ def validate_sequence_width(width, ds, unit=None):
class CoordinateHandler(abc.ABC):
name: str
_default_axis_order: AxisOrder
_default_axes_transforms: Dict[str, AxesTransform]

def __init__(self, ds, ordering: Optional[AxisOrder] = None):
self.ds = weakref.proxy(ds)
Expand All @@ -147,7 +161,17 @@ def setup_fields(self):
pass

@abc.abstractmethod
def pixelize(self, dimension, data_source, field, bounds, size, antialias=True):
def pixelize(
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
*,
axes_transform=AxesTransform.DEFAULT,
):
# This should *actually* be a pixelize call, not just returning the
# pixelizer
pass
Expand Down Expand Up @@ -185,8 +209,15 @@ def convert_to_spherical(self, coord):
def convert_from_spherical(self, coord):
pass

@abc.abstractmethod
def _get_plot_axes_default_properties(
self, normal_axis_name: str, axes_transform: AxesTransform
) -> DefaultProperties:
...

@cached_property
def data_projection(self):
# see https://github.com/yt-project/yt/issues/4182
return {ax: None for ax in self.axis_order}

@cached_property
Expand All @@ -211,6 +242,12 @@ def axis_id(self):

@property
def image_axis_name(self):
issue_deprecation_warning(
"The image_axis_name property isn't used "
"internally in yt anymore and is deprecated",
since="4.2.0",
stacklevel=3,
)
rv = {}
for i in range(3):
rv[i] = (self.axis_name[self.x_axis[i]], self.axis_name[self.y_axis[i]])
Expand Down Expand Up @@ -253,7 +290,16 @@ def sanitize_depth(self, depth):
raise YTInvalidWidthError(depth)
return depth

def sanitize_width(self, axis, width, depth):
def sanitize_width(
self, axis, width, depth, *, axes_transform=AxesTransform.DEFAULT
):
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"generic coordinate handler doesn't implement {axes_transform}"
)
if width is None:
# initialize the index if it is not already initialized
self.ds.index
Expand Down Expand Up @@ -284,12 +330,46 @@ def sanitize_width(self, axis, width, depth):
return width + depth
return width

def sanitize_center(self, center, axis):
def _get_display_center(self, center, axes_transform: AxesTransform):
# default implementation
return self.convert_to_cartesian(center)

def sanitize_center(self, center, axis, *, axes_transform=AxesTransform.DEFAULT):
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"generic coordinate handler doesn't implement {axes_transform}"
)
center = parse_center_array(center, ds=self.ds, axis=axis)
# This has to return both a center and a display_center
display_center = self.convert_to_cartesian(center)
display_center = self._get_display_center(center, axes_transform)
return center, display_center

def _ortho_pixelize(
self, data_source, field, bounds, size, antialias, dim, periodic
):
period = self.period[:2].copy() # dummy here
period[0] = self.period[self.x_axis[dim]]
period[1] = self.period[self.y_axis[dim]]
if hasattr(period, "in_units"):
period = period.in_units("code_length").d
buff = np.full(size, np.nan, dtype="float64")
pixelize_cartesian(
buff,
data_source["px"],
data_source["py"],
data_source["pdx"],
data_source["pdy"],
data_source[field],
bounds,
int(antialias),
period,
int(periodic),
)
return buff


def cartesian_to_cylindrical(coord, center=(0, 0, 0)):
c2 = np.zeros_like(coord)
Expand Down
Loading