-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix usm_ndarray ctor when shape is integral numpy scalar #1467
Fix usm_ndarray ctor when shape is integral numpy scalar #1467
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1467/index.html |
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_4 ran successfully. |
dpctl/tensor/_usmarray.pyx
Outdated
try: | ||
<Py_ssize_t> shape | ||
shape = [shape, ] | ||
except Exception: | ||
raise TypeError( | ||
"Argument shape must be a list or a tuple." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the logic of this exception and message could be improved a bit. For instance:
In [5]: x = dpt.ones(np.prod((2, 3, 4), dtype="f4"), dtype="i8")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/repos/dpctl/dpctl/tensor/_usmarray.pyx:192, in dpctl.tensor._usmarray.usm_ndarray.__cinit__()
191 try:
--> 192 <Py_ssize_t> shape
193 shape = [shape, ]
TypeError: 'float' object cannot be interpreted as an integer
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 x = dpt.ones(np.prod((2, 3, 4), dtype="f4"), dtype="i8")
File ~/repos/dpctl/dpctl/tensor/_ctors.py:968, in ones(shape, dtype, order, device, usm_type, sycl_queue)
966 sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
967 dtype = _get_dtype(dtype, sycl_queue)
--> 968 res = dpt.usm_ndarray(
969 shape,
970 dtype=dtype,
971 buffer=usm_type,
972 order=order,
973 buffer_ctor_kwargs={"queue": sycl_queue},
974 )
975 hev, _ = ti._full_usm_ndarray(1, res, sycl_queue)
976 hev.wait()
File ~/repos/dpctl/dpctl/tensor/_usmarray.pyx:195, in dpctl.tensor._usmarray.usm_ndarray.__cinit__()
193 shape = [shape, ]
194 except Exception:
--> 195 raise TypeError(
196 "Argument shape must be a list or a tuple."
197 )
TypeError: Argument shape must be a list or a tuple.
It seems a bit misleading at first, because it would work for np.prod((2, 3, 4))
.
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_5 ran successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This enables
dpt.usm_ndarray(np.prod((2,3,4)))
.