Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve gh-1871 #1872

Merged
merged 5 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,6 @@ def matmul(

"""

dpnp.check_supported_arrays_type(x1, x2)
if subok is False:
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
Expand Down
101 changes: 73 additions & 28 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import dpctl.tensor._tensor_elementwise_impl as tei
import dpctl.tensor._tensor_impl as ti
import numpy
from dpctl.utils import ExecutionPlacementError
from numpy.core.numeric import normalize_axis_tuple

import dpnp
Expand Down Expand Up @@ -218,7 +219,9 @@ def _compute_size(start, shape):
return ret


def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
def _copy_array(
x, dep_events, host_events, copy_flag=False, dtype=None, order="C"
):
"""
Creating a copy of input array if needed.

Expand All @@ -236,7 +239,7 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
copy = x.dtype != dtype if dtype is not None else False

if copy:
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
x_copy = dpnp.empty_like(x, dtype=dtype, order=order)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpnp.get_usm_ndarray(x),
dst=x_copy.get_array(),
Expand All @@ -248,7 +251,9 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
return x


def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
def _create_result_array(
x1, x2, out, shape, dtype, usm_type, sycl_queue, order="C"
):
"""
Create the result array.

Expand All @@ -263,13 +268,12 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2)
out_usm = dpnp.get_usm_ndarray(out)
contig_flag = _define_contig_flag(out)
contig_flag, _, _ = _define_contig_flag(out)

if (
out.dtype == dtype
and out.shape == shape
and out.usm_type == usm_type
and out.sycl_queue == sycl_queue
and contig_flag
and not ti._array_overlap(x1_usm, out_usm)
and not ti._array_overlap(x2_usm, out_usm)
Expand All @@ -279,6 +283,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
return dpnp.empty(
shape,
dtype=dtype,
order=order,
usm_type=usm_type,
sycl_queue=sycl_queue,
)
Expand All @@ -295,14 +300,14 @@ def _define_contig_flag(x):
x_strides = x.strides
x_shape = x.shape
if x.ndim < 2:
return True
return True, True, True

x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
if x_is_c_contiguous or x_is_f_contiguous:
flag = True
return flag
return flag, x_is_c_contiguous, x_is_f_contiguous


def _define_dim_flags(x, pos):
Expand Down Expand Up @@ -746,17 +751,26 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, dev_tasks_list):
)
ht_tasks_list.append(ht_blas_ev)
dpctl.SyclEvent.wait_for(ht_tasks_list)

res_shape = res.shape
if not row_major:
res = dpnp.reshape(
res.ravel(), (batch_size, res_shape[2], res_shape[1])
).transpose(0, 2, 1)
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
if row_major:
if res_is_f_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="F"),
(res_shape[1], res_shape[2], batch_size),
).transpose(2, 0, 1)
else:
if res_is_c_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="C"),
(batch_size, res_shape[2], res_shape[1]),
).transpose(0, 2, 1)

if res_shape != orig_shape:
res = res.reshape(orig_shape)

res = dpnp.ascontiguousarray(res)
return res
return dpnp.ascontiguousarray(res)


def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
Expand All @@ -769,14 +783,16 @@ def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
)
ht_blas_ev.wait()

if not row_major:
# TODO: investigate the possibility of defining result
# array with "F" order for this case
res = dpnp.ascontiguousarray(
dpnp.reshape(res.ravel(), res.shape, order="F")
)
if row_major:
if res.flags.f_contiguous is True:
# read data in "F" order and write it in "C" order
res = dpnp.reshape(dpnp.ravel(res, order="F"), res.shape, order="C")
else:
if res.flags.c_contiguous is True:
# read data in "C" order and write it in "F" order
res = dpnp.reshape(dpnp.ravel(res, order="C"), res.shape, order="F")

return res
return dpnp.ascontiguousarray(res)


def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
Expand Down Expand Up @@ -1746,6 +1762,13 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
)

res_usm_type, exec_q = get_usm_allocations([a, b])
if (
out is not None
and dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None
):
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

# Determine the appropriate data types
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
Expand Down Expand Up @@ -1812,6 +1835,12 @@ def dpnp_einsum(
arrays.append(a)

res_usm_type, exec_q = get_usm_allocations(arrays)
if out is not None:
dpnp.check_supported_arrays_type(out)
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
raise ExecutionPlacementError(
vtavana marked this conversation as resolved.
Show resolved Hide resolved
"Input and output allocation queues are not compatible"
)
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
for id, a in enumerate(operands):
if dpnp.isscalar(a):
Expand Down Expand Up @@ -2056,10 +2085,17 @@ def dpnp_matmul(

"""

x1_ndim = x1.ndim
x2_ndim = x2.ndim
dpnp.check_supported_arrays_type(x1, x2)
res_usm_type, exec_q = get_usm_allocations([x1, x2])
if out is not None:
dpnp.check_supported_arrays_type(out)
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

x1_ndim = x1.ndim
x2_ndim = x2.ndim
if axes is not None:
axes = _validate_axes(x1, x2, axes)

Expand All @@ -2072,7 +2108,6 @@ def dpnp_matmul(
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2_ndim != 1 else x2
out_orig = out
if out is not None:
dpnp.check_supported_arrays_type(out)
# out that is passed to the backend should have the correct shape
if len(axes_res) == 2:
out = dpnp.moveaxis(out, axes_res, (-2, -1))
Expand Down Expand Up @@ -2161,8 +2196,18 @@ def dpnp_matmul(
res = dpnp_dot(x1, x2, out=out)
res_shape = res.shape
else:
x1_contig_flag, _, x1_f = _define_contig_flag(x1)
x2_contig_flag, _, x2_f = _define_contig_flag(x2)
res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C"
res = _create_result_array(
x1, x2, out, res_shape, compute_dtype, res_usm_type, exec_q
x1,
x2,
out,
res_shape,
compute_dtype,
res_usm_type,
exec_q,
res_order,
)

# calculate result
Expand All @@ -2175,21 +2220,21 @@ def dpnp_matmul(
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
dep_events_list = []
host_tasks_list = []
contig_flag = _define_contig_flag(x1)
x1 = _copy_array(
x1,
dep_events_list,
host_tasks_list,
copy_flag=not contig_flag,
copy_flag=not x1_contig_flag,
dtype=compute_dtype,
order=res_order,
)
contig_flag = _define_contig_flag(x2)
x2 = _copy_array(
x2,
dep_events_list,
host_tasks_list,
copy_flag=not contig_flag,
copy_flag=not x2_contig_flag,
dtype=compute_dtype,
order=res_order,
)

if call_flag == "gemv":
Expand Down
16 changes: 16 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,28 @@ def test_einsum_trivial_cases(self):
expected = numpy.einsum("i,i,i", b_np, b_np, b_np, optimize="greedy")
assert_dtype_allclose(result, expected)

def test_einsum_out(self):
a = inp.ones((5, 5))
a_np = a.asnumpy()
out = inp.empty((5,))
out_np = out.asnumpy()
result = inp.einsum("ii->i", a, out=out)
assert result is out
expected = numpy.einsum("ii->i", a_np, out=out_np)
assert_dtype_allclose(result, expected)

def test_einsum_error(self):
a = inp.ones((5, 5))
# unknown keyword argument
with pytest.raises(TypeError):
inp.einsum("ii->i", a, copy=False)

a = inp.ones((5, 5))
out = inp.empty((5,), sycl_queue=dpctl.SyclQueue())
# inconsistent sycl_queue
with pytest.raises(ExecutionPlacementError):
inp.einsum("ii->i", a, out=out)

# unknown value for optimize keyword
with pytest.raises(TypeError):
inp.einsum("ii->i", a, optimize="average")
Expand Down
Loading
Loading