Skip to content

Commit

Permalink
Move Zero check out of pullback
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Dec 15, 2020
1 parent f67db2e commit d41ef75
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy)
Ω = BLAS.dot(n, X, incx, Y, incy)
function blas_dot_pullback(::Zero)
return (NO_FIELDS, DoesNotExist(), Zero(), DoesNotExist(), Zero(), DoesNotExist())
end
function blas_dot_pullback(ΔΩ)
if ΔΩ isa Zero
∂X = Zero()
∂Y = Zero()
else
ΔΩ = extern(ΔΩ)
∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx)
∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy)
end
ΔΩ = extern(ΔΩ)
∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx)
∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy)
return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist())
end
return Ω, blas_dot_pullback
Expand Down

0 comments on commit d41ef75

Please sign in to comment.