Skip to content

Commit

Permalink
[oneMKL] gemv_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Johnson committed Nov 1, 2021
1 parent e3e83e0 commit 48480d6
Showing 1 changed file with 196 additions and 23 deletions.
219 changes: 196 additions & 23 deletions source/elements/oneMKL/source/domains/blas/gemv_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,179 @@ The strided API operation is defined as
Y := alpha * op(A) * X + beta * Y
end for

where:

``alpha`` and ``beta`` are scalars,

op(``A``) is one of op(``A``) = ``A``, or op(``A``) = ``A``\ :sup:`T`, or op(``A``) = ``A``\ :sup:`H`,

op(``A``) is ``m`` x ``n``.


For group API, ``a`` array contains the pointers for all the input matrices.
The total number of matrices in ``a`` is given by:

.. math::
total\_batch\_count = \sum_{i=0}^{group\_count-1}group\_size[i]
For strided API, ``a`` array contains all the input matrices. The total number of matrices
in ``a`` is given by the ``batch_size`` parameter.

**Group API**

.. rubric:: Syntax

.. code-block:: cpp
namespace oneapi::mkl::blas::column_major {
sycl::event gemv_batch(sycl::queue &queue,
onemkl::transpose *trans,
std::int64_t *m,
std::int64_t *n,
T *alpha,
const T **a,
std::int64_t *lda,
const T **x,
std::int64_t *incx,
T *beta,
T **y,
std::int64_t *incy,
std::int64_t group_count,
std::int64_t *group_size,
const sycl::vector_class<sycl::event> &dependencies = {})
}
.. code-block:: cpp
namespace oneapi::mkl::blas::row_major {
sycl::event gemv_batch(sycl::queue &queue,
onemkl::transpose *trans,
std::int64_t *m,
std::int64_t *n,
T *alpha,
const T **a,
std::int64_t *lda,
const T **x,
std::int64_t *incx,
T *beta,
T **y,
std::int64_t *incy,
std::int64_t group_count,
std::int64_t *group_size,
const sycl::vector_class<sycl::event> &dependencies = {})
}
.. container:: section

.. rubric:: Input Parameters

queue
The queue where the routine should be executed.

trans
Array of ``group_count`` ``onemkl::transpose`` values. ``trans[i]`` specifies the form of op(``A``) used in
the matrix multiplication in group ``i``. See :ref:`onemkl_datatypes` for more details.

m
Array of ``group_count`` integers. ``m[i]`` specifies the
number of rows of op(``A``) the matrix in group ``i``. All entries must be at least zero.

n
Array of ``group_count`` integers. ``n[i]`` specifies the
number of columns of op(``A``) the matrix in group ``i``. All entries must be at least zero.

alpha
Array of ``group_count`` scalar elements. ``alpha[i]`` specifies the scaling factor for every matrix-vector
product in group ``i``.

a
Array of pointers to input matrices ``A`` with size ``total_batch_count``.

See :ref:`matrix-storage` for more details.

lda
Array of ``group_count`` integers. ``lda[i]`` specifies the
leading dimension of ``A`` for every matrix in group ``i``. All
entries must be positive.

.. list-table::
:header-rows: 1

* -
- ``A`` not transposed
- ``A`` transposed
* - Column major
- ``lda[i]`` must be at least ``m[i]``.
- ``lda[i]`` must be at least ``n[i]``.
* - Row major
- ``lda[i]`` must be at least ``n[i]``.
- ``lda[i]`` must be at least ``m[i]``.

x
Array of pointers to input vectors ``X`` with size ``total_batch_count``.
The size of array allocated for the ``X`` vector of the group ``i`` must be at least (1 + (``n[i]`` – 1)*abs(``incx[i]``))``.
See :ref:`matrix-storage` for more details.

incx
Array of ``group_count`` integers. ``incx[i]`` specifies the stride of vector ``X`` in group ``i``.

beta
Array of ``group_count`` scalar elements. ``beta[i]`` specifies the scaling factor for vector ``X`` in group ``i``.

y
Array of pointers to input/output vectors ``Y`` with size ``total_batch_count``.
The size of array allocated for the ``Y`` vector of the group ``i`` must be at least (1 + (``n[i]`` – 1)*abs(``incy[i]``))``.
See :ref:`matrix-storage` for more details.

incy
Array of ``group_count`` integers. ``incy[i]`` specifies the stride of vector ``Y`` in group ``i``.

group_count
Specifies the number of groups. Must be at least 0.

group_size
Array of ``group_count`` integers. ``group_size[i]`` specifies the
number of ``gemv`` products in group ``i``. All entries must be at least 0.

dependencies
List of events to wait for before starting computation, if any.
If omitted, defaults to no dependencies.

.. container:: section

.. rubric:: Output Parameters

y
The pointer to updated vector ``y``.

.. container:: section

.. rubric:: Return Values

Output event to wait on to ensure computation is complete.

.. container:: section

.. rubric:: Throws

This routine shall throw the following exceptions if the associated condition is detected. An implementation may throw additional implementation-specific exception(s) in case of error conditions not covered here.

:ref:`oneapi::mkl::invalid_argument<onemkl_exception_invalid_argument>`


:ref:`oneapi::mkl::unsupported_device<onemkl_exception_unsupported_device>`


:ref:`oneapi::mkl::host_bad_alloc<onemkl_exception_host_bad_alloc>`


:ref:`oneapi::mkl::device_bad_alloc<onemkl_exception_device_bad_alloc>`


:ref:`oneapi::mkl::unimplemented<onemkl_exception_unimplemented>`


**Strided API**

.. rubric:: Syntax

Expand All @@ -254,11 +426,15 @@ The strided API operation is defined as
T alpha,
const T *a,
std::int64_t lda,
std::int64_t stridea,
const T *x,
std::int64_t incx,
std::int64_t stridex,
T beta,
T *y,
std::int64_t incy,
std::int64_t stridey,
std::int64_t batch_size,
const sycl::vector_class<sycl::event> &dependencies = {})
}
.. code-block:: cpp
Expand All @@ -271,11 +447,15 @@ The strided API operation is defined as
T alpha,
const T *a,
std::int64_t lda,
std::int64_t stridea,
const T *x,
std::int64_t incx,
std::int64_t stridex,
T beta,
T *y,
std::int64_t incy,
std::int64_t stridey,
std::int64_t batch_size,
const sycl::vector_class<sycl::event> &dependencies = {})
}
Expand All @@ -288,9 +468,7 @@ The strided API operation is defined as

trans
Specifies ``op(A)``, the transposition operation applied to
``A``. See
:ref:`onemkl_datatypes` for
more details.
``A``. See :ref:`onemkl_datatypes` for more details.

m
Specifies the number of rows of the matrix ``A``. The value of
Expand All @@ -304,39 +482,36 @@ The strided API operation is defined as
Scaling factor for the matrix-vector product.

a
The pointer to the input matrix ``A``. Must have a size of at
least ``lda``\ \*\ ``n`` if column major layout is used or at
least ``lda``\ \*\ ``m`` if row major layout is used. See
:ref:`matrix-storage` for more details.
Pointer to input matrices ``A`` with size ``stridea`` * ``batch_size``.

lda
Leading dimension of matrix ``A``. Must be positive and at least
``m`` if column major layout is used or at least ``n`` if row
major layout is used.

stridea
Stride between different ``A`` matrices.

x
Pointer to the input vector ``x``. The length ``len`` of vector
``x`` is ``n`` if ``A`` is not transposed, and ``m`` if ``A``
is transposed. The array holding vector ``x`` must be of size
at least (1 + (``len`` - 1)*abs(``incx``)). See :ref:`matrix-storage` for
more details.
Pointer to input vectors ``X`` with size ``stridex`` * ``batchs_size``.

incx
The stride of vector ``x``.

beta
The scaling factor for vector ``y``.
The scaling factor for vector ``Y``.

y
Pointer to input/output vector ``y``. The length ``len`` of
vector ``y`` is ``m``, if ``A`` is not transposed, and ``n`` if
``A`` is transposed. The array holding input/output vector
``y`` must be of size at least (1 + (``len`` -
1)*abs(``incy``)) where ``len`` is this length. See :ref:`matrix-storage` for
more details.
Pointer to input/output vectors ``Y`` with size ``stridey`` * ``batch_size``.

incy
The stride of vector ``y``.
The stride of vector ``Y``.

stridey
Stride between different ``y`` vectors.

batch_size
Specifies the number of ``gemv`` operations to performs.

dependencies
List of events to wait for before starting computation, if any.
Expand All @@ -347,7 +522,7 @@ The strided API operation is defined as
.. rubric:: Output Parameters

y
The pointer to updated vector ``y``.
The pointer to updated array of ``Y`` vectors.

.. container:: section

Expand All @@ -364,7 +539,6 @@ The strided API operation is defined as
:ref:`oneapi::mkl::invalid_argument<onemkl_exception_invalid_argument>`



:ref:`oneapi::mkl::unsupported_device<onemkl_exception_unsupported_device>`


Expand All @@ -376,5 +550,4 @@ The strided API operation is defined as

:ref:`oneapi::mkl::unimplemented<onemkl_exception_unimplemented>`


**Parent topic:** :ref:`blas-level-2-routines`

0 comments on commit 48480d6

Please sign in to comment.