Skip to content

Commit

Permalink
ENH: enable plotting curvilinear geometries in native coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Apr 19, 2023
1 parent d73939e commit e5fc788
Show file tree
Hide file tree
Showing 16 changed files with 893 additions and 209 deletions.
17 changes: 15 additions & 2 deletions yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial, wraps
from re import finditer
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Optional

import numpy as np
from more_itertools import always_iterable
Expand All @@ -30,6 +31,7 @@
validate_moment,
)
from yt.geometry import particle_deposit as particle_deposit
from yt.geometry.coordinates._axes_transforms import parse_axes_transform
from yt.geometry.coordinates.cartesian_coordinates import all_data
from yt.loaders import load_uniform_grid
from yt.units._numpy_wrapper_functions import uconcatenate
Expand Down Expand Up @@ -350,15 +352,26 @@ def _sq_field(field, data, fname: FieldKey):
self.ds.field_info.pop(field)
self.tree = tree

def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
def to_pw(
self,
fields=None,
center="center",
width=None,
origin="center-window",
*,
axes_transform: Optional[str] = None,
):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.
This is a bare-bones mechanism of creating a plot window from this
object, which can then be moved around, zoomed, and on and on. All
behavior of the plot window is relegated to that routine.
"""
pw = self._get_pw(fields, center, width, origin, "Projection")
_axt = parse_axes_transform(axes_transform)
pw = self._get_pw(
fields, center, width, origin, "Projection", axes_transform=_axt
)
return pw

def plot(self, fields=None):
Expand Down
15 changes: 12 additions & 3 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from yt.fields.field_exceptions import NeedsGridType
from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple
from yt.geometry.api import Geometry
from yt.geometry.coordinates._axes_transforms import AxesTransform
from yt.geometry.selection_routines import compose_selector
from yt.units import YTArray
from yt.utilities.exceptions import (
Expand Down Expand Up @@ -530,20 +531,28 @@ def __init__(self, axis, ds, field_parameters=None, data_source=None):
def _convert_field_name(self, field):
return field

def _get_pw(self, fields, center, width, origin, plot_type):
def _get_pw(
self, fields, center, width, origin, plot_type, *, axes_transform: AxesTransform
):
from yt.visualization.fixed_resolution import FixedResolutionBuffer as frb
from yt.visualization.plot_window import PWViewerMPL, get_window_parameters

axis = self.axis
skip = self._key_fields
skip += list(set(frb._exclude_fields).difference(set(self._key_fields)))
# this line works, but mypy incorrectly flags it, so turning it off locally
skip += list(set(frb._exclude_fields).difference(set(self._key_fields))) # type: ignore [arg-type]
self.fields = [k for k in self.field_data if k not in skip]
if fields is not None:
self.fields = list(iter_fields(fields)) + self.fields
if len(self.fields) == 0:
raise ValueError("No fields found to plot in get_pw")

(bounds, center, display_center) = get_window_parameters(
axis, center, width, self.ds
axis,
center,
width,
self.ds,
axes_transform=axes_transform,
)
pw = PWViewerMPL(
self,
Expand Down
16 changes: 14 additions & 2 deletions yt/data_objects/selection_objects/slices.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np

from yt.data_objects.selection_objects.data_selection_objects import (
Expand All @@ -15,6 +17,7 @@
validate_object,
validate_width_tuple,
)
from yt.geometry.coordinates._axes_transforms import parse_axes_transform
from yt.utilities.exceptions import YTNotInsideNotebook
from yt.utilities.minimal_representation import MinimalSliceData
from yt.utilities.orientation import Orientation
Expand Down Expand Up @@ -104,15 +107,24 @@ def _generate_container_field(self, field):
def _mrep(self):
return MinimalSliceData(self)

def to_pw(self, fields=None, center="center", width=None, origin="center-window"):
def to_pw(
self,
fields=None,
center="center",
width=None,
origin="center-window",
*,
axes_transform: Optional[str] = None,
):
r"""Create a :class:`~yt.visualization.plot_window.PWViewerMPL` from this
object.
This is a bare-bones mechanism of creating a plot window from this
object, which can then be moved around, zoomed, and on and on. All
behavior of the plot window is relegated to that routine.
"""
pw = self._get_pw(fields, center, width, origin, "Slice")
_axt = parse_axes_transform(axes_transform)
pw = self._get_pw(fields, center, width, origin, "Slice", axes_transform=_axt)
return pw

def plot(self, fields=None):
Expand Down
14 changes: 4 additions & 10 deletions yt/frontends/nc4_cm1/data_structures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import weakref
from collections import OrderedDict
from typing import Optional

import numpy as np

from yt._typing import AxisOrder
from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.static_output import Dataset
from yt.geometry.grid_geometry_handler import GridIndex
Expand Down Expand Up @@ -89,16 +87,12 @@ def __init__(
)
self.storage_filename = storage_filename

def _setup_coordinate_handler(self, axis_order: Optional[AxisOrder]) -> None:
def _setup_coordinate_handler(self):
# ensure correct ordering of axes so plots aren't rotated (z should always be
# on the vertical axis).
super()._setup_coordinate_handler(axis_order)

# type checking is deactivated in the following two lines because changing them is not
# within the scope of the PR that _enabled_ typechecking here (#4244), but it'd be worth
# having a careful look at *why* these warnings appear, as they may point to rotten code
self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x")) # type: ignore [union-attr]
self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y")) # type: ignore [union-attr]
super()._setup_coordinate_handler()
self.coordinates._x_pairs = (("x", "y"), ("y", "x"), ("z", "x"))
self.coordinates._y_pairs = (("x", "z"), ("y", "z"), ("z", "y"))

def _set_code_unit_attributes(self):
# This is where quantities are created that represent the various
Expand Down
23 changes: 23 additions & 0 deletions yt/geometry/coordinates/_axes_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from enum import Enum, auto
from typing import Optional


class AxesTransform(Enum):
DEFAULT = auto()
GEOMETRY_NATIVE = auto()
POLAR = auto()
AITOFF_HAMMER = auto()


def parse_axes_transform(axes_transform: Optional[str]) -> AxesTransform:
if axes_transform is None:
# pass the responsability to ds.coordinates
return AxesTransform.DEFAULT
elif axes_transform == "geometry_native":
return AxesTransform.GEOMETRY_NATIVE
elif axes_transform == "polar":
return AxesTransform.POLAR
elif axes_transform == "aitoff_hammer":
return AxesTransform.AITOFF_HAMMER
else:
raise ValueError(f"Unknown axes transform {axes_transform!r}")
63 changes: 62 additions & 1 deletion yt/geometry/coordinates/cartesian_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from yt.utilities.math_utils import compute_stddev_image
from yt.utilities.nodal_data_utils import get_nodal_data

from ._axes_transforms import AxesTransform
from .coordinate_handler import (
CoordinateHandler,
DefaultProperties,
_get_coord_fields,
_get_vert_fields,
cartesian_to_cylindrical,
Expand Down Expand Up @@ -161,13 +163,29 @@ def _check_fields(self, registry):
)

def pixelize(
self, dimension, data_source, field, bounds, size, antialias=True, periodic=True
self,
dimension,
data_source,
field,
bounds,
size,
antialias=True,
periodic=True,
*,
axes_transform=AxesTransform.DEFAULT,
):
"""
Method for pixelizing datasets in preparation for
two-dimensional image plots. Relies on several sampling
routines written in cython
"""
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"cartesian coordinates don't implement {axes_transform} yet"
)
index = data_source.ds.index
if hasattr(index, "meshes") and not isinstance(
index.meshes[0], SemiStructuredMesh
Expand Down Expand Up @@ -624,3 +642,46 @@ def convert_from_spherical(self, coord):
@property
def period(self):
return self.ds.domain_width

@classmethod
def _get_plot_axes_default_properties(
cls, normal_axis_name: str, axes_transform: AxesTransform
) -> DefaultProperties:
if axes_transform is AxesTransform.DEFAULT:
axes_transform = AxesTransform.GEOMETRY_NATIVE

if axes_transform is not AxesTransform.GEOMETRY_NATIVE:
raise NotImplementedError(
f"cartesian coordinates don't implement {axes_transform} yet"
)

if normal_axis_name == "x":
return dict(
x_axis_label="y",
y_axis_label="z",
x_axis_units=None,
y_axis_units=None,
)
elif normal_axis_name == "y":
return dict(
x_axis_label="z",
y_axis_label="x",
x_axis_units=None,
y_axis_units=None,
)
elif normal_axis_name == "z":
return dict(
x_axis_label="x",
y_axis_label="y",
x_axis_units=None,
y_axis_units=None,
)
elif normal_axis_name == "oblique":
return dict(
x_axis_label="Image x",
y_axis_label="Image y",
x_axis_units=None,
y_axis_units=None,
)
else:
raise ValueError(f"Unknown axis {normal_axis_name!r}")
Loading

0 comments on commit e5fc788

Please sign in to comment.