Skip to content

Commit

Permalink
Add new blas extension and update dpnp.matmul func (IntelPython#1616)
Browse files Browse the repository at this point in the history
* Add new blas extension and update matmul impl

* Add support for N-D array

add N-dimension

* support more special cases + add new tests

* fix random behavior on cpu

* correct dtypes + support more keywords

* add strided support

* check input arrays

* address comments - first round

* address comments - second round

* address comments - third round

* fix pre-commit

* improve test coverage

* address comments

* update _gemm_res_dtype func

* fix a test for result_type

* fix minor issues

* skip tests for matmul

---------

Co-authored-by: Vahid Tavanashad <vahid.tavanashad@intel.com>
Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 13, 2024
1 parent 7e54eb8 commit f95ceb9
Show file tree
Hide file tree
Showing 22 changed files with 1,740 additions and 233 deletions.
1 change: 1 addition & 0 deletions dpnp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ endfunction()

build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.pyx dpnp)
add_subdirectory(backend)
add_subdirectory(backend/extensions/blas)
add_subdirectory(backend/extensions/lapack)
add_subdirectory(backend/extensions/vm)
add_subdirectory(backend/extensions/sycl_ext)
Expand Down
83 changes: 83 additions & 0 deletions dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# *****************************************************************************
# Copyright (c) 2016-2023, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************


set(python_module_name _blas_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})

if (WIN32)
if (${CMAKE_VERSION} VERSION_LESS "3.27")
# this is a work-around for target_link_options inserting option after -link option, cause
# linker to ignore it.
set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel")
endif()
endif()

set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)

target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)

target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

if (WIN32)
target_compile_options(${python_module_name} PRIVATE
/clang:-fno-approx-func
/clang:-fno-finite-math-only
)
else()
target_compile_options(${python_module_name} PRIVATE
-fno-approx-func
-fno-finite-math-only
)
endif()

target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel)
if (UNIX)
# this option is support on Linux only
target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code)
endif()

if (DPNP_GENERATE_COVERAGE)
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping)
endif()

if (MKL_VERSION_2024)
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS)
else()
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP)
endif()

install(TARGETS ${python_module_name}
DESTINATION "dpnp/backend/extensions/blas"
)
66 changes: 66 additions & 0 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************
//
// This file defines functions of dpnp.backend._lapack_impl extensions
//
//*****************************************************************************

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "gemm.hpp"

namespace blas_ext = dpnp::backend::ext::blas;
namespace py = pybind11;

// populate dispatch tables
void init_dispatch_tables(void)
{
blas_ext::init_gemm_batch_dispatch_table();
blas_ext::init_gemm_dispatch_table();
}

PYBIND11_MODULE(_blas_impl, m)
{
init_dispatch_tables();

{
m.def("_gemm", &blas_ext::gemm,
"Call `gemm` from OneMKL LAPACK library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL LAPACK library to return "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
py::arg("strideb"), py::arg("stridec"),
py::arg("depends") = py::list());
}
}
Loading

0 comments on commit f95ceb9

Please sign in to comment.