Skip to content

Commit

Permalink
Stabilize MulAddMul strategically (#52439)
Browse files Browse the repository at this point in the history
Co-authored-by: Ashley Milsted <ashmilsted@gmail.com>
  • Loading branch information
dkarrasch and amilsted authored May 8, 2024
1 parent 999dde7 commit 29ced9e
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 105 deletions.
15 changes: 10 additions & 5 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,16 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const TriSym = Union{Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
Expand Down
68 changes: 68 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,74 @@ end
end
end

"""
@stable_muladdmul
Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an
argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)`
and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch.
For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into
```
if isone(alpha)
if iszero(beta)
f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
else
if iszero(beta)
f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta))
else
f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta))
end
end
```
This avoids the type instability of the `MulAddMul(alpha, beta)` constructor,
which causes runtime dispatch in case alpha and zero are not constants.
"""
macro stable_muladdmul(expr)
expr.head == :call || throw(ArgumentError("Can only handle function calls."))
for (i, e) in enumerate(expr.args)
e isa Expr || continue
if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3
e.args[2] isa Symbol || continue
e.args[3] isa Symbol || continue
local asym = e.args[2]
local bsym = e.args[3]

local e_sub11 = copy(expr)
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub10 = copy(expr)
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub01 = copy(expr)
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub00 = copy(expr)
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_out = quote
if isone($asym)
if iszero($bsym)
$e_sub11
else
$e_sub10
end
else
if iszero($bsym)
$e_sub01
else
$e_sub00
end
end
end
return esc(e_out)
end
end
throw(ArgumentError("No valid MulAddMul expression found."))
end

MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)

@inline (::MulAddMul{true})(x) = x
Expand Down
Loading

0 comments on commit 29ced9e

Please sign in to comment.