diff --git a/yt_experiments/tiled_grid/tiled_grid.py b/yt_experiments/tiled_grid/tiled_grid.py index e1cd2db..c6d2eef 100644 --- a/yt_experiments/tiled_grid/tiled_grid.py +++ b/yt_experiments/tiled_grid/tiled_grid.py @@ -2,21 +2,36 @@ from typing import Any import numpy as np +import unyt +from numpy import typing as npt from yt._typing import FieldKey from yt.data_objects.construction_data_containers import YTArbitraryGrid from yt.data_objects.static_output import Dataset +def _validate_edge(edge: npt.ArrayLike, ds: Dataset): + if not isinstance(edge, unyt.unyt_array): + return ds.arr(edge, 'code_length') + return edge + +def _validate_nd_int(nd: int, x: int | npt.ArrayLike) -> npt.NDArray: + if isinstance(x, int): + x = (x,) * nd + x = np.array(x).astype(int) + if len(x) != 3: + raise ValueError("Variable must have a length of 3") + return x + class YTTiledArbitraryGrid: _ndim = 3 def __init__( self, - left_edge, - right_edge, - dims: tuple[int, int, int], - chunks: int | tuple[int, int, int], + left_edge: npt.ArrayLike, + right_edge: npt.ArrayLike, + dims: int | npt.ArrayLike, + chunks: int | npt.ArrayLike, *, ds: Dataset = None, field_parameters=None, @@ -45,18 +60,15 @@ def __init__( """ - self.left_edge = left_edge - self.right_edge = right_edge self.ds = ds + self.left_edge = _validate_edge(left_edge, ds) + self.right_edge = _validate_edge(right_edge, ds) self.data_source = data_source self.field_parameters = field_parameters + self.dims = _validate_nd_int(self._ndim, dims) + self.chunks = _validate_nd_int(self._ndim, chunks) - self.dims = dims - if isinstance(chunks, int): - chunks = (chunks,) * self._ndim - self.chunks = chunks - - nchunks = self._dims / self._chunks + nchunks = self.dims / self.chunks if np.any(np.mod(nchunks, nchunks.astype(int)) != 0): msg = ( "The dimensions and chunks provide result in partially filled " @@ -65,8 +77,7 @@ def __init__( raise NotImplementedError(msg) self.nchunks = nchunks.astype(int) - self.dds = (self.right_edge - self.left_edge) / self._dims - + self.dds = (self.right_edge - self.left_edge) / self.dims self._grids: list[YTArbitraryGrid] = [] self._grid_slc: list[tuple[slice, slice, slice]] = [] self._ngrids = np.prod(self.nchunks) @@ -84,16 +95,8 @@ def __repr__(self): ) return msg - @property - def _chunks(self): - return np.array(self.chunks, dtype=int) - - @property - def _dims(self): - return np.array(self.dims, dtype=int) - def _get_grid_by_ijk(self, ijk_grid): - chunksizes = self._chunks + chunksizes = self.chunks le_index = [] re_index = [] @@ -242,10 +245,10 @@ class YTArbitraryGridPyramid: def __init__( self, - left_edge, - right_edge, - level_dims: [tuple[int, int, int]], - level_chunks, + left_edge: npt.ArrayLike, + right_edge: npt.ArrayLike, + level_dims: [int | npt.ArrayLike], + level_chunks: int | npt.ArrayLike, ds: Dataset = None, field_parameters=None, data_source: Any | None = None, @@ -268,10 +271,9 @@ def __init__( n_levels = len(level_dims) self.n_levels = n_levels - if isinstance(level_chunks, int): - level_chunks = (level_chunks,) * self._ndim + level_chunks = _validate_nd_int(self._ndim, level_chunks) - if isinstance(level_chunks, tuple): + if isinstance(level_chunks, np.ndarray): level_chunks = [level_chunks for _ in range(n_levels)] if len(level_chunks) != n_levels: @@ -279,7 +281,7 @@ def __init__( "length of level_chunks must match the total number of levels." f" Found {len(level_chunks)}, expected {n_levels}" ) - raise ValueError(msg) + raise RuntimeError(msg) for ilev in range(n_levels): if isinstance(level_chunks[ilev], int): @@ -344,10 +346,10 @@ def __getitem__(self, item: int) -> YTTiledArbitraryGrid: class YTArbitraryGridOctPyramid(YTArbitraryGridPyramid): def __init__( self, - left_edge, - right_edge, - dims: tuple[int, int, int], - chunks: int | tuple[int, int, int], + left_edge: npt.ArrayLike, + right_edge: npt.ArrayLike, + dims: int | npt.ArrayLike, + chunks: int | npt.ArrayLike, n_levels: int, factor: int | tuple[int, int, int] = 2, ds: Dataset = None, @@ -355,7 +357,8 @@ def __init__( data_source: Any | None = None, ): - dims_ = np.array(dims, dtype=int) + dims = _validate_nd_int(self._ndim, dims) + if isinstance(chunks, int): chunks = (chunks,) * self._ndim @@ -366,7 +369,7 @@ def __init__( level_dims = [] for lev in range(n_levels): - current_dims = dims_ / factor**lev + current_dims = dims / factor**lev level_dims.append(current_dims) super().__init__(