Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argmax and argmin #84

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
Changelog
=========

0.9.3 (unreleased)
adityagoel4512 marked this conversation as resolved.
Show resolved Hide resolved
------------------

- Reduced the number of unnecessary casts in :func:`ndonnx.argmax` and :func:`ndonnx.argmin`.


0.9.2 (2024-10-03)
------------------

Expand Down
Binary file modified docs/_static/classify_iris.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 30 additions & 34 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,43 +355,39 @@ def matrix_transpose(self, x) -> ndx.Array:

@validate_core
def argmax(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_max(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def argmin(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_min(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def nonzero(self, x) -> tuple[Array, ...]:
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
Expand Down
46 changes: 46 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import re
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -992,3 +993,48 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
],
)
def test_argmaxmin(func, x, keepdims):
np_result = func(x, keepdims=keepdims)
ndx_result = getattr(ndx, func.__name__)(
ndx.asarray(x), keepdims=keepdims
).to_numpy()
assert_array_equal(np_result, ndx_result)


# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
],
)
def test_argmaxmin_unsupported_kernels(func, x):
import onnxruntime as ort

if ort.__version__.startswith("19"):
warnings.warn(
"Please remove this test and update `argmax` and `argmin` to reflect expanded kernel support.",
Warning,
)

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down