You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been trying to get an implementation of LinearAlgebra.BLAS.gemm! working, and struggling to complete the implementation, and opening an issue seemed like a better option than a discussion on slack or discourse.
I'm not worrying about the mathematical details for now, just trying to get the thing to run.
This is my attempt thus far:
using Enzyme
using.EnzymeRules
using LinearAlgebra
function EnzymeRules.augmented_primal(
config::ConfigWidth{1},
::Const{typeof(BLAS.gemm!)},
::Type{<:Const},
transA::Const{<:AbstractChar},
transB::Const{<:AbstractChar},
alpha::Const,
A::Duplicated{<:AbstractVecOrMat{T}},
B::Duplicated{<:AbstractVecOrMat{T}},
beta::Const,
C::Duplicated{<:AbstractVecOrMat{T}},
) where {T<:Union{Float32, Float64}}
println("in the forwards-pass")
tape = (copy(A.val), copy(B.val), C.dval)
BLAS.gemm!(transA.val, transB.val, alpha.val, A.val, B.val, beta.val, C.val)
primal =needs_primal(config) ? C.val :nothing
shadow =needs_shadow(config) ? C.dval :nothing@showneeds_primal(config), needs_shadow(config)
returnAugmentedReturn(primal, shadow, tape)
endfunction EnzymeRules.reverse(
config::ConfigWidth{1},
::Const{typeof(BLAS.gemm!)},
::Type{<:Const},
tape,
transA::Const{<:AbstractChar},
transB::Const{<:AbstractChar},
alpha::Const,
A::Duplicated{<:AbstractVecOrMat{T}},
B::Duplicated{<:AbstractVecOrMat{T}},
beta::Const,
C::Duplicated{<:AbstractVecOrMat{T}},
) where {T<:Union{Float32, Float64}}
println("In the reverse-pass")
B.dval .=1.0# dummy implementation to see what happensreturn (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
end
When I attempt to compute a pullback:
D =5;
A =Duplicated(randn(D, D), zeros(D, D));
B =Duplicated(randn(D, 2D), zeros(D, 2D));
C =Duplicated(zeros(D, 2D), zeros(D, 2D));
autodiff(Reverse, BLAS.gemm!, Const, 'N', 'N', true, A, B, false, C)
I've been trying to get an implementation of
LinearAlgebra.BLAS.gemm!
working, and struggling to complete the implementation, and opening an issue seemed like a better option than a discussion on slack or discourse.I'm not worrying about the mathematical details for now, just trying to get the thing to run.
This is my attempt thus far:
When I attempt to compute a pullback:
I get the following error:
and, and it doesn't look like the
reverse
implementation is getting hit, because:Any advice on what might be going on would be appreciated!
I'm checked out to the latest commit on
main
.The text was updated successfully, but these errors were encountered: