Skip to content

Commit

Permalink
improve typing, validation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Oct 4, 2024
1 parent 562955d commit 44777ca
Showing 1 changed file with 40 additions and 37 deletions.
77 changes: 40 additions & 37 deletions yt_experiments/tiled_grid/tiled_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,36 @@
from typing import Any

import numpy as np
import unyt
from numpy import typing as npt
from yt._typing import FieldKey
from yt.data_objects.construction_data_containers import YTArbitraryGrid
from yt.data_objects.static_output import Dataset


def _validate_edge(edge: npt.ArrayLike, ds: Dataset):
if not isinstance(edge, unyt.unyt_array):
return ds.arr(edge, 'code_length')
return edge

def _validate_nd_int(nd: int, x: int | npt.ArrayLike) -> npt.NDArray:
if isinstance(x, int):
x = (x,) * nd
x = np.array(x).astype(int)
if len(x) != 3:
raise ValueError("Variable must have a length of 3")
return x

class YTTiledArbitraryGrid:

_ndim = 3

def __init__(
self,
left_edge,
right_edge,
dims: tuple[int, int, int],
chunks: int | tuple[int, int, int],
left_edge: npt.ArrayLike,
right_edge: npt.ArrayLike,
dims: int | npt.ArrayLike,
chunks: int | npt.ArrayLike,
*,
ds: Dataset = None,
field_parameters=None,
Expand Down Expand Up @@ -45,18 +60,15 @@ def __init__(
"""

self.left_edge = left_edge
self.right_edge = right_edge
self.ds = ds
self.left_edge = _validate_edge(left_edge, ds)
self.right_edge = _validate_edge(right_edge, ds)
self.data_source = data_source
self.field_parameters = field_parameters
self.dims = _validate_nd_int(self._ndim, dims)
self.chunks = _validate_nd_int(self._ndim, chunks)

self.dims = dims
if isinstance(chunks, int):
chunks = (chunks,) * self._ndim
self.chunks = chunks

nchunks = self._dims / self._chunks
nchunks = self.dims / self.chunks
if np.any(np.mod(nchunks, nchunks.astype(int)) != 0):
msg = (
"The dimensions and chunks provide result in partially filled "
Expand All @@ -65,8 +77,7 @@ def __init__(
raise NotImplementedError(msg)
self.nchunks = nchunks.astype(int)

self.dds = (self.right_edge - self.left_edge) / self._dims

self.dds = (self.right_edge - self.left_edge) / self.dims
self._grids: list[YTArbitraryGrid] = []
self._grid_slc: list[tuple[slice, slice, slice]] = []
self._ngrids = np.prod(self.nchunks)
Expand All @@ -84,16 +95,8 @@ def __repr__(self):
)
return msg

@property
def _chunks(self):
return np.array(self.chunks, dtype=int)

@property
def _dims(self):
return np.array(self.dims, dtype=int)

def _get_grid_by_ijk(self, ijk_grid):
chunksizes = self._chunks
chunksizes = self.chunks

le_index = []
re_index = []
Expand Down Expand Up @@ -242,10 +245,10 @@ class YTArbitraryGridPyramid:

def __init__(
self,
left_edge,
right_edge,
level_dims: [tuple[int, int, int]],
level_chunks,
left_edge: npt.ArrayLike,
right_edge: npt.ArrayLike,
level_dims: [int | npt.ArrayLike],
level_chunks: int | npt.ArrayLike,
ds: Dataset = None,
field_parameters=None,
data_source: Any | None = None,
Expand All @@ -268,18 +271,17 @@ def __init__(
n_levels = len(level_dims)
self.n_levels = n_levels

if isinstance(level_chunks, int):
level_chunks = (level_chunks,) * self._ndim
level_chunks = _validate_nd_int(self._ndim, level_chunks)

if isinstance(level_chunks, tuple):
if isinstance(level_chunks, np.ndarray):
level_chunks = [level_chunks for _ in range(n_levels)]

if len(level_chunks) != n_levels:
msg = (
"length of level_chunks must match the total number of levels."
f" Found {len(level_chunks)}, expected {n_levels}"
)
raise ValueError(msg)
raise RuntimeError(msg)

for ilev in range(n_levels):
if isinstance(level_chunks[ilev], int):
Expand Down Expand Up @@ -344,18 +346,19 @@ def __getitem__(self, item: int) -> YTTiledArbitraryGrid:
class YTArbitraryGridOctPyramid(YTArbitraryGridPyramid):
def __init__(
self,
left_edge,
right_edge,
dims: tuple[int, int, int],
chunks: int | tuple[int, int, int],
left_edge: npt.ArrayLike,
right_edge: npt.ArrayLike,
dims: int | npt.ArrayLike,
chunks: int | npt.ArrayLike,
n_levels: int,
factor: int | tuple[int, int, int] = 2,
ds: Dataset = None,
field_parameters=None,
data_source: Any | None = None,
):

dims_ = np.array(dims, dtype=int)
dims = _validate_nd_int(self._ndim, dims)

if isinstance(chunks, int):
chunks = (chunks,) * self._ndim

Expand All @@ -366,7 +369,7 @@ def __init__(

level_dims = []
for lev in range(n_levels):
current_dims = dims_ / factor**lev
current_dims = dims / factor**lev
level_dims.append(current_dims)

super().__init__(
Expand Down

0 comments on commit 44777ca

Please sign in to comment.