diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index 66db0123..a44b79c7 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -220,7 +220,7 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr end if all(in(('N', 'T', 'C')), (tA, tB)) - if T <: onemklFloat && eltype(A) == eltype(B) == T + if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T return gemm!(tA, tB, alpha, A, B, beta, C) end end diff --git a/lib/mkl/oneMKL.jl b/lib/mkl/oneMKL.jl index 730694bd..58734a7e 100644 --- a/lib/mkl/oneMKL.jl +++ b/lib/mkl/oneMKL.jl @@ -17,10 +17,10 @@ using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdia using SparseArrays -# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16 +# Exclude Float16 for now, since many oneMKL functions do not take Float16 const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32} const onemklComplex = Union{ComplexF32,ComplexF64} -const onemklHalf = Union{Float16,ComplexF16} +const onemklHalf = Float16 include("array.jl") include("utils.jl") diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index f72ea81c..23fa4fa5 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -537,7 +537,7 @@ for (fname, elty, cty, sty, supty) in ((:onemklSrot,:Float32,:Float32,:Float32,: end end -function axpy!(n::Integer, +function axpy!(n::Integer, alpha::Number, x::oneStridedArray{ComplexF16}, y::oneStridedArray{ComplexF16}) @@ -1260,7 +1260,7 @@ function dgmm(mode::Char, A::oneStridedMatrix{T}, X::oneStridedVector{T}) where dgmm!( mode, A, X, similar(A, (m,n) ) ) end -for (fname, elty) in +for (fname, elty) in ((:onemklSgemmBatchStrided, Float32), (:onemklDgemmBatchStrided, Float64), (:onemklCgemmBatchStrided, ComplexF32),