Skip to content

Commit

Permalink
Added a random argument to `named_arrays.AbstractArray.cell_centers…
Browse files Browse the repository at this point in the history
…()` which selects a random point within each cell. (#97)
  • Loading branch information
byrdie authored Nov 11, 2024
1 parent 3fa0c95 commit 9dc13dd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
59 changes: 53 additions & 6 deletions named_arrays/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,40 @@ def combine_axes(
Array with the specified axes combined
"""

@classmethod
def _lerp(
cls,
i: float | na.AbstractScalar,
a0: float | na.AbstractScalar,
a1: float | na.AbstractScalar,
) -> na.AbstractScalar:
return a0 * (1 - i) + a1 * i

def _nlerp(
self,
i: dict[str, na.AbstractScalar],
) -> na.AbstractExplicitArray:

if not i:
return self.explicit

axis = next(iter(i))

i_new = {ax: i[ax] for ax in i if ax != axis}

a0 = self[{axis: slice(None, ~0)}]
a1 = self[{axis: slice(1, None)}]

if i_new:
a0 = a0._nlerp(i_new)
a1 = a1._nlerp(i_new)

return self._lerp(i[axis], a0, a1)

def cell_centers(
self,
axis: None | str | Sequence[str] = None,
random: bool = False,
) -> na.AbstractExplicitArray:
"""
Convert an array from cell vertices to cell centers.
Expand All @@ -466,6 +497,9 @@ def cell_centers(
----------
axis
The axes of the array to average over.
random
If true, select a random point within each cell instead of the
geometric center.
"""

if axis is None:
Expand All @@ -477,13 +511,26 @@ def cell_centers(

shape = result.shape

for a in axis:
if a in shape:
lower = {a: slice(None, ~0)}
upper = {a: slice(+1, None)}
result = (result[lower] + result[upper]) / 2
axis = tuple(a for a in axis if a in shape)

return result
shape_centers = {
ax: shape[ax] - 1 if ax in axis
else shape[ax]
for ax in shape
}

if not random:
i = {
a: 0.5
for a in axis
}
else:
i = {
a: na.random.uniform(0, 1, shape_random=shape_centers)
for a in axis
}

return self._nlerp(i)

def volume_cell(self, axis: None | str | Sequence[str]) -> na.AbstractScalar:
"""
Expand Down
5 changes: 3 additions & 2 deletions named_arrays/_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def combine_axes(
def cell_centers(
self,
axis: None | str | Sequence[str] = None,
random: bool = False,
) -> na.AbstractExplicitArray:
return dataclasses.replace(
self,
inputs=self.inputs.cell_centers(axis),
outputs=self.outputs.cell_centers(axis),
inputs=self.inputs.cell_centers(axis, random=random),
outputs=self.outputs.cell_centers(axis, random=random),
)

def to_string_array(
Expand Down
7 changes: 6 additions & 1 deletion named_arrays/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,15 @@ def test_combine_axes(
("x", "y"),
]
)
@pytest.mark.parametrize(
argnames="random",
argvalues=[False, True],
)
def test_cell_centers(
self,
array: na.AbstractArray,
axis: None | str | Sequence[str],
random: bool,
):
if axis is None:
axis_normalized = array.axes
Expand All @@ -278,7 +283,7 @@ def test_cell_centers(
else:
axis_normalized = axis

result = array.cell_centers(axis)
result = array.cell_centers(axis, random=random)

for a in axis_normalized:
if a in array.shape:
Expand Down

0 comments on commit 9dc13dd

Please sign in to comment.