diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 972ff3e7ae..ecf3eade35 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -26,6 +26,7 @@ import dpctl.utils from dpctl.tensor._data_types import _get_dtype from dpctl.tensor._device import normalize_queue_device +from dpctl.tensor._type_utils import _dtype_supported_by_device_impl __doc__ = ( "Implementation module for copy- and cast- operations on " @@ -121,7 +122,7 @@ def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None): output array is created. Device can be specified by a a filter selector string, an instance of :class:`dpctl.SyclDevice`, an instance of - :class:`dpctl.SyclQueue`, an instance of + :class:`dpctl.SyclQueue`, or an instance of :class:`dpctl.tensor.Device`. If the value is `None`, returned array is created on the default-selected device. Default: `None`. @@ -564,9 +565,11 @@ def copy(usm_ary, order="K"): return R -def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): +def astype( + usm_ary, newdtype, /, order="K", casting="unsafe", *, copy=True, device=None +): """ astype(array, new_dtype, order="K", casting="unsafe", \ - copy=True) + copy=True, device=None) Returns a copy of the :class:`dpctl.tensor.usm_ndarray`, cast to a specified type. @@ -576,7 +579,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): An input array. new_dtype (dtype): The data type of the resulting array. If `None`, gives default - floating point type supported by device where `array` is allocated. + floating point type supported by device where the resulting array + will be located. order ({"C", "F", "A", "K"}, optional): Controls memory layout of the resulting array if a copy is returned. @@ -587,6 +591,14 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): By default, `astype` always returns a newly allocated array. If this keyword is set to `False`, a view of the input array may be returned when possible. + device (object): array API specification of device where the + output array is created. Device can be specified by a + a filter selector string, an instance of + :class:`dpctl.SyclDevice`, an instance of + :class:`dpctl.SyclQueue`, or an instance of + :class:`dpctl.tensor.Device`. If the value is `None`, + returned array is created on the same device as `array`. + Default: `None`. Returns: usm_ndarray: @@ -604,7 +616,25 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): ) order = order[0].upper() ary_dtype = usm_ary.dtype - target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) + if device is not None: + if not isinstance(device, dpctl.SyclQueue): + if isinstance(device, dpt.Device): + device = device.sycl_queue + else: + device = dpt.Device.create_device(device).sycl_queue + d = device.sycl_device + target_dtype = _get_dtype(newdtype, device) + if not _dtype_supported_by_device_impl( + target_dtype, d.has_aspect_fp16, d.has_aspect_fp64 + ): + raise ValueError( + f"Requested dtype `{target_dtype}` is not supported by the " + "target device" + ) + usm_ary = usm_ary.to_device(device) + else: + target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) + if not dpt.can_cast(ary_dtype, target_dtype, casting=casting): raise TypeError( f"Can not cast from {ary_dtype} to {newdtype} " diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 00e47db97a..144215e2d6 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import numpy as np diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index cf2c6f0331..7c5765332b 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1313,6 +1313,20 @@ def test_astype_invalid_order(): dpt.astype(X, "i4", order="WRONG") +def test_astype_device(): + get_queue_or_skip() + q1 = dpctl.SyclQueue() + q2 = dpctl.SyclQueue() + + x = dpt.arange(5, dtype="i4", sycl_queue=q1) + r = dpt.astype(x, "f4") + assert r.sycl_queue == x.sycl_queue + assert r.sycl_device == x.sycl_device + + r = dpt.astype(x, "f4", device=q2) + assert r.sycl_queue == q2 + + def test_copy(): try: X = dpt.usm_ndarray((5, 5), "i4")[2:4, 1:4]