Skip to content

Commit

Permalink
Use oneMKL with Float64 matmul. (#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Apr 11, 2024
1 parent fadcd8d commit f20d471
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
end

if all(in(('N', 'T', 'C')), (tA, tB))
if T <: onemklFloat && eltype(A) == eltype(B) == T
if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T
return gemm!(tA, tB, alpha, A, B, beta, C)
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdia

using SparseArrays

# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
# Exclude Float16 for now, since many oneMKL functions do not take Float16
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
const onemklComplex = Union{ComplexF32,ComplexF64}
const onemklHalf = Union{Float16,ComplexF16}
const onemklHalf = Float16

include("array.jl")
include("utils.jl")
Expand Down
4 changes: 2 additions & 2 deletions lib/mkl/wrappers_blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ for (fname, elty, cty, sty, supty) in ((:onemklSrot,:Float32,:Float32,:Float32,:
end
end

function axpy!(n::Integer,
function axpy!(n::Integer,
alpha::Number,
x::oneStridedArray{ComplexF16},
y::oneStridedArray{ComplexF16})
Expand Down Expand Up @@ -1260,7 +1260,7 @@ function dgmm(mode::Char, A::oneStridedMatrix{T}, X::oneStridedVector{T}) where
dgmm!( mode, A, X, similar(A, (m,n) ) )
end

for (fname, elty) in
for (fname, elty) in
((:onemklSgemmBatchStrided, Float32),
(:onemklDgemmBatchStrided, Float64),
(:onemklCgemmBatchStrided, ComplexF32),
Expand Down

0 comments on commit f20d471

Please sign in to comment.