diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 1af4561ca..1535ba825 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -17,15 +17,13 @@ rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) Ω = BLAS.dot(n, X, incx, Y, incy) + function blas_dot_pullback(::Zero) + return (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist(), Zero(), DoesNotExist()) + end function blas_dot_pullback(ΔΩ) - if ΔΩ isa Zero - ∂X = Zero() - ∂Y = Zero() - else - ΔΩ = extern(ΔΩ) - ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) - ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - end + ΔΩ = extern(ΔΩ) + ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) + ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) end return Ω, blas_dot_pullback