Skip to content

Commit

Permalink
Implements dpctl.tensor.clip (#1444)
Browse files Browse the repository at this point in the history
* Implements dpctl.tensor.clip

* Clip now consistently yields max where max < min

sycl::clamp would yield max or min depending on the platform

A test has been added for this behavior

* Adds more tests for clip

* Removed redundant branches in clip and elementwise function calls

As the result dtype of the out array is already checked when overlap is checked, checking again later is superfluous

* Removed more redundant logic from clip

* Fixed order logic in clip

Now properly accounts for all three arrays in all branches

* Adds more compute follows data tests for clip

* Tests to increase coverage of _clip.py (#1451)

* Clip raises ValueError when types cannot be resolved

---------

Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
  • Loading branch information
ndgrigorian and oleksandr-pavlyk authored Oct 25, 2023
1 parent 442e46f commit 2eba93e
Show file tree
Hide file tree
Showing 9 changed files with 2,115 additions and 6 deletions.
2 changes: 2 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ set(_tensor_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
)
list(APPEND _tensor_impl_sources
${_elementwise_sources}
Expand All @@ -138,6 +139,7 @@ set(_no_fast_math_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
)
list(APPEND _no_fast_math_sources
${_elementwise_sources}
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from dpctl.tensor._usmarray import usm_ndarray
from dpctl.tensor._utility_functions import all, any

from ._clip import clip
from ._constants import e, inf, nan, newaxis, pi
from ._elementwise_funcs import (
abs,
Expand Down Expand Up @@ -322,4 +323,5 @@
"exp2",
"copysign",
"rsqrt",
"clip",
]
Loading

0 comments on commit 2eba93e

Please sign in to comment.