Skip to content

Commit

Permalink
Partial 1.10 enablement (JuliaGPU#330)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <tim.besard@gmail.com>
  • Loading branch information
2 people authored and amontoison committed Oct 18, 2023
1 parent b07fd5f commit d5514e3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
6 changes: 5 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ steps:
- "1.6"
- "1.7"
- "1.8"
- "1.9-nightly"
- "1.9"
- "1.10-nightly"
- "nightly"
adjustments:
- with:
julia: "nightly"
soft_fail: true
- with:
julia: "1.10-nightly"
soft_fail: true

# Special tests
- group: ":eyes: Special"
Expand Down
36 changes: 28 additions & 8 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ if VERSION < v"1.10.0-DEV.1365"
end

# triangular
if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ"
# multiplication
LinearAlgebra.generic_trimatmul!(c::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, b::AbstractVector{T}) where {T<:onemklFloat} =
trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b))
# division
LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractVector{T}) where {T<:onemklFloat} =
trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
else
## direct multiplication/division
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
(:UnitLowerTriangular, 'L', 'U'),
Expand Down Expand Up @@ -183,6 +191,7 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'),
trsv!($uploc, 'C', $isunitc, parent(parent(A)), B)
end
end
end # VERSION


#
Expand Down Expand Up @@ -254,23 +263,34 @@ end
end # VERSION

# triangular
if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ"
LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
else
## direct multiplication/division
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
(:UnitLowerTriangular, 'L', 'U'),
(:UpperTriangular, 'U', 'N'),
(:UnitUpperTriangular, 'U', 'U'))
@eval begin
# Multiplication
LinearAlgebra.lmul!(A::$t{T,<:oneStridedVecOrMat},
B::oneStridedVecOrMat{T}) where {T<:onemklFloat} =
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, B)
LinearAlgebra.rmul!(A::oneStridedVecOrMat{T},
B::$t{T,<:oneStridedVecOrMat}) where {T<:onemklFloat} =
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, A)
LinearAlgebra.lmul!(A::$t{T,<:oneStridedMatrix},
B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
LinearAlgebra.rmul!(A::oneStridedMatrix{T},
B::$t{T,<:oneStridedMatrix}) where {T<:onemklFloat} =
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)

# Left division
LinearAlgebra.ldiv!(A::$t{T,<:oneStridedVecOrMat},
B::oneStridedVecOrMat{T}) where {T<:onemklFloat} =
LinearAlgebra.ldiv!(A::$t{T,<:oneStridedMatrix},
B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
end
end
end # VERSION
2 changes: 1 addition & 1 deletion lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ function trmm(side::Char,
alpha::Number,
A::oneStridedMatrix{T},
B::oneStridedMatrix{T}) where T
trmm!(side, uplo, transa, diag, alpha, A, B)
trmm!(side, uplo, transa, diag, alpha, A, copy(B))
end
function trsm(side::Char,
uplo::Char,
Expand Down
4 changes: 2 additions & 2 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ end
dA = oneArray(A)
dB = oneArray(B)
C = alpha*A*B
oneMKL.trmm('L','U','N','N',alpha,dA,dB)
dC = oneMKL.trmm('L','U','N','N',alpha,dA,dB)
# move to host and compare
h_C = Array(dB)
h_C = Array(dC)
@test C ≈ h_C
end
Expand Down

0 comments on commit d5514e3

Please sign in to comment.