Skip to content

Commit

Permalink
fix(imagetool): allow coords of any dtype coercible to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Oct 13, 2024
1 parent bdec5ff commit 4342ebc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/erlab/interactive/imagetool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def itool(
Every DataArray in the Dataset will be displayed across multiple ImageTool
windows. Data variables that have less than 2 dimensions or more than 4
dimensions are ignored.
dimensions are ignored. Dimensions with length 1 are automatically squeezed.
link
Whether to enable linking between multiple ImageTool windows when `data` is a
sequence or a `xarray.Dataset`, by default `False`.
Expand Down
6 changes: 3 additions & 3 deletions src/erlab/interactive/imagetool/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ def convert_index(
return round(
index * source.array_slicer.incs[axis] / target.array_slicer.incs[axis]
)
value: np.float32 = source.array_slicer.value_of_index(
axis, index, uniform=False
value = source.array_slicer.value_of_index(axis, index, uniform=False)
new_index: int = target.array_slicer.index_of_value(

Check warning on line 338 in src/erlab/interactive/imagetool/core.py

View check run for this annotation

Codecov / codecov/patch

src/erlab/interactive/imagetool/core.py#L337-L338

Added lines #L337 - L338 were not covered by tests
axis, float(value), uniform=False
)
new_index: int = target.array_slicer.index_of_value(axis, value, uniform=False)
return new_index


Expand Down
87 changes: 45 additions & 42 deletions src/erlab/interactive/imagetool/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ class ArraySlicerState(TypedDict):
VALID_NDIM = (2, 3, 4)

_signature_array_rect = [
numba.types.UniTuple(numba.float32, 4)(
numba.types.UniTuple(ftype, 4)(
numba.int64,
numba.int64,
numba.types.UniTuple(numba.types.UniTuple(numba.float32, 2), i),
numba.types.UniTuple(numba.float32, i),
numba.types.UniTuple(numba.types.UniTuple(ftype, 2), i),
numba.types.UniTuple(ftype, i),
)
for ftype in (numba.float32, numba.float64)
for i in VALID_NDIM
]
_signature_index_of_value = [
Expand All @@ -58,15 +59,15 @@ class ArraySlicerState(TypedDict):
def _array_rect(
i: int,
j: int,
lims: tuple[tuple[np.float32, np.float32], ...],
incs: tuple[np.float32, ...],
) -> tuple[np.float32, np.float32, np.float32, np.float32]:
lims: tuple[tuple[np.floating, np.floating], ...],
incs: tuple[np.floating, ...],
) -> tuple[np.floating, np.floating, np.floating, np.floating]:
x = lims[i][0] - incs[i]
y = lims[j][0] - incs[j]
w = lims[i][-1] - x
h = lims[j][-1] - y
x += np.float32(0.5 * incs[i])
y += np.float32(0.5 * incs[j])
x += 0.5 * incs[i]
y += 0.5 * incs[j]
return x, y, w, h


Expand Down Expand Up @@ -96,10 +97,8 @@ def _transposed(arr: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
return arr.transpose(1, 2, 3, 0)


@numba.njit(
[numba.boolean(numba.float32[::1]), numba.boolean(numba.int64[::1])], cache=True
)
def _is_uniform(arr: npt.NDArray[np.float32]) -> bool:
@numba.njit(numba.boolean(numba.float64[::1]), cache=True)
def _is_uniform(arr: npt.NDArray[np.float64]) -> bool:
dif = np.diff(arr)
return np.allclose(dif, dif[0], rtol=3e-05, atol=3e-05, equal_nan=True)

Expand All @@ -112,7 +111,7 @@ def _is_uniform(arr: npt.NDArray[np.float32]) -> bool:
cache=True,
)
def _index_of_value_nonuniform(
arr: npt.NDArray[np.float32], val: np.float32
arr: npt.NDArray[np.floating], val: np.floating
) -> np.int_:
return np.searchsorted((arr[:-1] + arr[1:]) / 2, val)

Expand Down Expand Up @@ -192,13 +191,13 @@ def set_array(
self._indices: list[list[int]] = [
[s // 2 - (1 if s % 2 == 0 else 0) for s in self._obj.shape]
]
self._values: list[list[np.float32]] = [
self._values: list[list[np.floating]] = [
[c[i] for c, i in zip(self.coords, self._indices[0], strict=True)]
]
self.snap_to_data: bool = False

@functools.cached_property
def coords(self) -> tuple[npt.NDArray[np.float32], ...]:
def coords(self) -> tuple[npt.NDArray[np.floating], ...]:
if self._nonuniform_axes:
return tuple(
(
Expand All @@ -211,19 +210,19 @@ def coords(self) -> tuple[npt.NDArray[np.float32], ...]:
return self.coords_uniform

@functools.cached_property
def coords_uniform(self) -> tuple[npt.NDArray[np.float32], ...]:
def coords_uniform(self) -> tuple[npt.NDArray[np.floating], ...]:
return tuple(self.values_of_dim(dim) for dim in self._obj.dims)

@functools.cached_property
def incs(self) -> tuple[np.float32, ...]:
def incs(self) -> tuple[np.floating, ...]:
return tuple(coord[1] - coord[0] for coord in self.coords)

@functools.cached_property
def incs_uniform(self) -> tuple[np.float32, ...]:
def incs_uniform(self) -> tuple[np.floating, ...]:
return tuple(coord[1] - coord[0] for coord in self.coords_uniform)

@functools.cached_property
def lims(self) -> tuple[tuple[np.float32, np.float32], ...]:
def lims(self) -> tuple[tuple[np.floating, np.floating], ...]:
if self._nonuniform_axes:
return tuple(
(
Expand All @@ -236,7 +235,7 @@ def lims(self) -> tuple[tuple[np.float32, np.float32], ...]:
return tuple((coord[0], coord[-1]) for coord in self.coords)

@functools.cached_property
def lims_uniform(self) -> tuple[tuple[np.float32, np.float32], ...]:
def lims_uniform(self) -> tuple[tuple[np.floating, np.floating], ...]:
return tuple((coord[0], coord[-1]) for coord in self.coords_uniform)

@functools.cached_property
Expand Down Expand Up @@ -309,7 +308,7 @@ def validate_array(data: xr.DataArray) -> xr.DataArray:
If data has two momentum axes (``kx`` and ``ky``), set them (and ``eV`` if
exists) as the first two (or three) dimensions. Then, checks the data for
non-uniform coordinates, which are converted to indices. Finally, converts the
coordinates to C-contiguous float32.
coordinates to C-contiguous arrays.
If input data values are neither float32 nor float64, a conversion to float64 is
attempted.
Expand All @@ -325,15 +324,17 @@ def validate_array(data: xr.DataArray) -> xr.DataArray:
The converted data.
"""
data = data.squeeze()

if data.ndim < 2:
raise ValueError("Data must have at least two dimensions.")

if data.ndim > 4:
raise ValueError("Data must have at most four dimensions.")

# convert coords to C-contiguous float32
# Convert coords to C-contiguous array
data = data.assign_coords(
{d: data[d].astype(np.float32, order="C") for d in data.dims}
{d: data[d].astype(data[d].dtype, order="C") for d in data.dims}
)

if data.dtype not in (np.float32, np.float64):
Expand All @@ -348,7 +349,9 @@ def validate_array(data: xr.DataArray) -> xr.DataArray:
data = data.transpose(*new_dims)

nonuniform_dims: list[str] = [
str(d) for d in data.dims if not _is_uniform(data[d].values)
str(d)
for d in data.dims
if not _is_uniform(data[d].values.astype(np.float64))
]
for d in nonuniform_dims:
data = data.assign_coords(
Expand Down Expand Up @@ -407,7 +410,7 @@ def clear_cache(self) -> None:
self.clear_dim_cache()
self.clear_val_cache(include_vals=True)

def values_of_dim(self, dim: Hashable) -> npt.NDArray[np.float32]:
def values_of_dim(self, dim: Hashable) -> npt.NDArray[np.floating]:
"""Fast equivalent of :code:`self._obj[dim].values`.
Returns the cached pointer of the underlying coordinate array, achieving a ~80x
Expand Down Expand Up @@ -556,22 +559,22 @@ def step_index(
self.sigIndexChanged.emit(cursor, (axis,))

@QtCore.Slot(int, bool, result=list)
def get_values(self, cursor: int, uniform: bool = False) -> list[np.float32]:
def get_values(self, cursor: int, uniform: bool = False) -> list[np.floating]:
if uniform and self._nonuniform_axes:
val = list(self._values[cursor])
for ax in self._nonuniform_axes:
val[ax] = np.float32(self._indices[cursor][ax])
return val
return self._values[cursor]

@QtCore.Slot(int, int, bool, result=np.float32)
def get_value(self, cursor: int, axis: int, uniform: bool = False) -> np.float32:
@QtCore.Slot(int, int, bool, result=float)
def get_value(self, cursor: int, axis: int, uniform: bool = False) -> float:
if uniform and axis in self._nonuniform_axes:
return np.float32(self._indices[cursor][axis])
return self._values[cursor][axis]
return float(self._indices[cursor][axis])

Check warning on line 573 in src/erlab/interactive/imagetool/slicer.py

View check run for this annotation

Codecov / codecov/patch

src/erlab/interactive/imagetool/slicer.py#L573

Added line #L573 was not covered by tests
return float(self._values[cursor][axis])

def set_values(
self, cursor: int, value: list[np.float32], update: bool = True
self, cursor: int, value: list[np.floating], update: bool = True
) -> None:
if not len(value) == self._obj.ndim:
raise ValueError(
Expand All @@ -583,12 +586,12 @@ def set_values(
if update:
self.sigIndexChanged.emit(cursor, tuple(axes))

@QtCore.Slot(int, int, np.float32, bool, bool, result=list)
@QtCore.Slot(int, int, float, bool, bool, result=list)
def set_value(
self,
cursor: int,
axis: int,
value: np.float32,
value: float,
update: bool = True,
uniform: bool = False,
) -> list[int | None]:
Expand All @@ -601,7 +604,7 @@ def set_value(
return []
self._values[cursor][axis] = new
else:
self._values[cursor][axis] = np.float32(value)
self._values[cursor][axis] = np.float64(value)
if update:
self.sigIndexChanged.emit(cursor, (axis,))
return []
Expand Down Expand Up @@ -640,7 +643,8 @@ def swap_axes(self, ax1: int, ax2: int) -> None:
def array_rect(
self, i: int | None = None, j: int | None = None
) -> (
tuple[np.float32, np.float32, np.float32, np.float32] | npt.NDArray[np.float32]
tuple[np.floating, np.floating, np.floating, np.floating]
| npt.NDArray[np.floating]
):
if i is None:
i = 0
Expand All @@ -650,14 +654,12 @@ def array_rect(

def value_of_index(
self, axis: int, value: int, uniform: bool = False
) -> np.float32:
) -> np.floating:
if uniform or (axis not in self._nonuniform_axes):
return self.coords_uniform[axis][value]
return self.coords[axis][value]

def index_of_value(
self, axis: int, value: np.float32, uniform: bool = False
) -> int:
def index_of_value(self, axis: int, value: float, uniform: bool = False) -> int:
if uniform or (axis not in self._nonuniform_axes):
return _index_of_value(
axis, value, self.lims_uniform, self.incs_uniform, self._obj.shape
Expand Down Expand Up @@ -746,8 +748,9 @@ def xslice(self, cursor: int, disp: Sequence[int]) -> xr.DataArray:
def slice_with_coord(
self, cursor: int, disp: Sequence[int]
) -> tuple[
tuple[np.float32, np.float32, np.float32, np.float32] | npt.NDArray[np.float32],
npt.NDArray[np.float32] | np.float32,
tuple[np.floating, np.floating, np.floating, np.floating]
| npt.NDArray[np.floating],
npt.NDArray[np.floating] | np.floating,
]:
axis = sorted(set(range(self._obj.ndim)) - set(disp))
return self.array_rect(*disp), self.extract_avg_slice(cursor, axis)
Expand All @@ -761,7 +764,7 @@ def extract_avg_slice(
return self._bin_along_axis(cursor, axis[0])
return self._bin_along_multiaxis(cursor, axis)

def span_bounds(self, cursor: int, axis: int) -> npt.NDArray[np.float32]:
def span_bounds(self, cursor: int, axis: int) -> npt.NDArray[np.floating]:
slc = self._bin_slice(cursor, axis)
if isinstance(slc, int):
return self.coords_uniform[axis][slc : slc + 1]
Expand Down
23 changes: 23 additions & 0 deletions tests/interactive/test_imagetool.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,29 @@ def move_and_compare_values(qtbot, win, expected, cursor=0, target_win=None):
assert_almost_equal(win.array_slicer.point_value(cursor), expected[0])


@pytest.mark.parametrize("val_dtype", [np.float32, np.float64, np.int32, np.int64])
@pytest.mark.parametrize("coord_dtype", [np.float32, np.float64, np.int32, np.int64])
def test_itool_dtypes(qtbot, val_dtype, coord_dtype):
data = xr.DataArray(
np.arange(25).reshape((5, 5)).astype(val_dtype),
dims=["x", "y"],
coords={
"x": np.arange(5, dtype=coord_dtype),
"y": np.arange(5, dtype=coord_dtype),
},
)
win = itool(data, execute=False)
qtbot.addWidget(win)

with qtbot.waitExposed(win):
win.show()
win.activateWindow()
win.raise_()

move_and_compare_values(qtbot, win, [12.0, 7.0, 6.0, 11.0])
win.close()


def test_itool(qtbot):
data = xr.DataArray(np.arange(25).reshape((5, 5)), dims=["x", "y"])
win = itool(data, execute=False)
Expand Down

0 comments on commit 4342ebc

Please sign in to comment.