Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Julia 1.11 #415

Merged
merged 10 commits into from
Apr 13, 2024
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ steps:
- "1.8"
- "1.9"
- "1.10"
- "1.11"
- "nightly"
adjustments:
- with:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ KernelAbstractions = "0.9.1"
LLVM = "6"
NEO_jll = "=24.09.28717"
Preferences = "1"
SPIRV_LLVM_Translator_unified_jll = "0.3"
SPIRV_LLVM_Translator_unified_jll = "0.4"
SpecialFunctions = "1.3, 2"
StaticArrays = "1"
julia = "1.8"
Expand Down
59 changes: 31 additions & 28 deletions deps/onemkl_prologue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,54 +554,55 @@ extern "C" int onemklZtrsmBatched(syclQueue_t device_queue, onemklSide left_righ

extern "C" int onemklHgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
uint16_t alpha, const short *a, int64_t lda, int64_t stridea,
const short *b, int64_t ldb, int64_t strideb, uint16_t beta,
uint16_t *alpha, const short *a, int64_t lda, int64_t stridea,
const short *b, int64_t ldb, int64_t strideb, uint16_t *beta,
short *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
convert(transb), m, n, k, sycl::bit_cast<sycl::half>(alpha),
convert(transb), m, n, k,
*reinterpret_cast<const sycl::half *>(alpha),
reinterpret_cast<const sycl::half *>(a), lda, stridea,
reinterpret_cast<const sycl::half *>(b), ldb, strideb,
sycl::bit_cast<sycl::half>(beta),
*reinterpret_cast<const sycl::half *>(beta),
reinterpret_cast<sycl::half *>(c), ldc, stridec, batch_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklSgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a, int64_t lda, int64_t stridea,
const float *b, int64_t ldb, int64_t strideb, float beta,
float *alpha, const float *a, int64_t lda, int64_t stridea,
const float *b, int64_t ldb, int64_t strideb, float *beta,
float *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
convert(transb), m, n, k, alpha, a, lda, stridea,
b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
convert(transb), m, n, k, *alpha, a, lda, stridea,
b, ldb, strideb, *beta, c, ldc, stridec, batch_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklDgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a, int64_t lda, int64_t stridea,
const double *b, int64_t ldb, int64_t strideb, double beta,
double *alpha, const double *a, int64_t lda, int64_t stridea,
const double *b, int64_t ldb, int64_t strideb, double *beta,
double *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
convert(transb), m, n, k, alpha, a, lda, stridea,
b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
convert(transb), m, n, k, *alpha, a, lda, stridea,
b, ldb, strideb, *beta, c, ldc, stridec, batch_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklCgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
float _Complex alpha, const float _Complex *a, int64_t lda, int64_t stridea,
const float _Complex *b, int64_t ldb, int64_t strideb, float _Complex beta,
float _Complex *alpha, const float _Complex *a, int64_t lda, int64_t stridea,
const float _Complex *b, int64_t ldb, int64_t strideb, float _Complex *beta,
float _Complex *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
convert(transb), m, n, k, alpha,
convert(transb), m, n, k, *alpha,
reinterpret_cast<const std::complex<float> *>(a),
lda, stridea,
reinterpret_cast<const std::complex<float> *>(b),
ldb, strideb, beta,
ldb, strideb, *beta,
reinterpret_cast<std::complex<float> *>(c),
ldc, stridec, batch_size, {});
__FORCE_MKL_FLUSH__(status);
Expand All @@ -610,15 +611,15 @@ extern "C" int onemklCgemmBatchStrided(syclQueue_t device_queue, onemklTranspose

extern "C" int onemklZgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
double _Complex alpha, const double _Complex *a, int64_t lda, int64_t stridea,
const double _Complex *b, int64_t ldb, int64_t strideb, double _Complex beta,
double _Complex *alpha, const double _Complex *a, int64_t lda, int64_t stridea,
const double _Complex *b, int64_t ldb, int64_t strideb, double _Complex *beta,
double _Complex *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
convert(transb), m, n, k, alpha,
convert(transb), m, n, k, *alpha,
reinterpret_cast<const std::complex<double> *>(a),
lda, stridea,
reinterpret_cast<const std::complex<double> *>(b),
ldb, strideb, beta,
ldb, strideb, *beta,
reinterpret_cast<std::complex<double> *>(c),
ldc, stridec, batch_size, {});
__FORCE_MKL_FLUSH__(status);
Expand All @@ -627,14 +628,15 @@ extern "C" int onemklZgemmBatchStrided(syclQueue_t device_queue, onemklTranspose

extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
onemklTranspose transB, int64_t m, int64_t n,
int64_t k, uint16_t alpha, const short *A, int64_t lda,
const short *B, int64_t ldb, uint16_t beta, short *C,
int64_t k, uint16_t *alpha, const short *A, int64_t lda,
const short *B, int64_t ldb, uint16_t *beta, short *C,
int64_t ldc) {
auto status = oneapi::mkl::blas::column_major::gemm(device_queue->val, convert(transA),
convert(transB), m, n, k, sycl::bit_cast<sycl::half>(alpha),
convert(transB), m, n, k,
*reinterpret_cast<const sycl::half *>(alpha),
reinterpret_cast<const sycl::half *>(A), lda,
reinterpret_cast<const sycl::half *>(B), ldb,
sycl::bit_cast<sycl::half>(beta),
*reinterpret_cast<const sycl::half *>(beta),
reinterpret_cast<sycl::half *>(C), ldc, {});
__FORCE_MKL_FLUSH__(status);
return 0;
Expand All @@ -651,19 +653,20 @@ extern "C" int onemklHdot(syclQueue_t device_queue, int64_t n,
return 0;
}

extern "C" int onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha,
extern "C" int onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t *alpha,
const short *x, std::int64_t incx, short *y, int64_t incy) {
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n,
sycl::bit_cast<sycl::half>(alpha),
*reinterpret_cast<const sycl::half *>(alpha),
reinterpret_cast<const sycl::half *>(x),
incx, reinterpret_cast<sycl::half *>(y), incy, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklHscal(syclQueue_t device_queue, int64_t n, uint16_t alpha,
extern "C" int onemklHscal(syclQueue_t device_queue, int64_t n, uint16_t *alpha,
short *x, int64_t incx) {
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, sycl::bit_cast<sycl::half>(alpha),
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n,
*reinterpret_cast<const sycl::half *>(alpha),
reinterpret_cast<sycl::half *>(x), incx, {});
__FORCE_MKL_FLUSH__(status);
return 0;
Expand Down
28 changes: 14 additions & 14 deletions deps/onemkl_prologue.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,46 +195,46 @@ int onemklZtrsmBatched(syclQueue_t device_queue, onemklSide left_right,

int onemklHgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
uint16_t alpha, const short *a, int64_t lda, int64_t stridea,
const short *b, int64_t ldb, int64_t strideb, uint16_t beta,
uint16_t *alpha, const short *a, int64_t lda, int64_t stridea,
const short *b, int64_t ldb, int64_t strideb, uint16_t *beta,
short *c, int64_t ldc, int64_t stridec, int64_t batch_size);

int onemklSgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a, int64_t lda, int64_t stridea,
const float *b, int64_t ldb, int64_t strideb, float beta,
float *alpha, const float *a, int64_t lda, int64_t stridea,
const float *b, int64_t ldb, int64_t strideb, float *beta,
float *c, int64_t ldc, int64_t stridec, int64_t batch_size);

int onemklDgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a, int64_t lda, int64_t stridea,
const double *b, int64_t ldb, int64_t strideb, double beta,
double *alpha, const double *a, int64_t lda, int64_t stridea,
const double *b, int64_t ldb, int64_t strideb, double *beta,
double *c, int64_t ldc, int64_t stridec, int64_t batch_size);

int onemklCgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
float _Complex alpha, const float _Complex *a, int64_t lda,
float _Complex *alpha, const float _Complex *a, int64_t lda,
int64_t stridea, const float _Complex *b, int64_t ldb,
int64_t strideb, float _Complex beta, float _Complex *c,
int64_t strideb, float _Complex *beta, float _Complex *c,
int64_t ldc, int64_t stridec, int64_t batch_size);

int onemklZgemmBatchStrided(syclQueue_t device_queue, onemklTranspose transa,
onemklTranspose transb, int64_t m, int64_t n, int64_t k,
double _Complex alpha, const double _Complex *a, int64_t lda,
double _Complex *alpha, const double _Complex *a, int64_t lda,
int64_t stridea, const double _Complex *b, int64_t ldb,
int64_t strideb, double _Complex beta, double _Complex *c,
int64_t strideb, double _Complex *beta, double _Complex *c,
int64_t ldc, int64_t stridec, int64_t batch_size);

int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
onemklTranspose transB, int64_t m, int64_t n,
int64_t k, uint16_t alpha, const short *A, int64_t lda,
const short *B, int64_t ldb, uint16_t beta, short *C,
int64_t k, uint16_t *alpha, const short *A, int64_t lda,
const short *B, int64_t ldb, uint16_t *beta, short *C,
int64_t ldc);

int onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha, const short *x,
int onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t *alpha, const short *x,
int64_t incx, short *y, int64_t incy);

int onemklHscal(syclQueue_t device_queue, int64_t n, uint16_t alpha,
int onemklHscal(syclQueue_t device_queue, int64_t n, uint16_t *alpha,
short *x, int64_t incx);

int onemklHnrm2(syclQueue_t device_queue, int64_t n, const short *x,
Expand Down
Loading