Skip to content

Commit

Permalink
Generate sum([...]) for large vectors (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored Jul 18, 2024
1 parent a9ec2eb commit f9ee003
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
12 changes: 7 additions & 5 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ Base.Expr(c::BinaryCall) = begin
return Expr(c.equality, c.output, Expr(:call, c.operator, c.input1, c.input2))
end

struct VarargsCall <: AbstractCall
operator::Union{Symbol, Expr}
struct SummationCall <: AbstractCall
equality::Symbol
inputs::Vector{Symbol}
output::Symbol
end

Base.Expr(c::VarargsCall) = begin
return Expr(c.equality, c.output, Expr(:call, c.operator, c.inputs...))
# The output of @code_llvm (.+) of more than 32 variables is inefficient.
Base.Expr(c::SummationCall) = begin
length(c.inputs) 32 ?
Expr(c.equality, c.output, Expr(:call, Expr(:., :+), c.inputs...)) : # (.+)(a,b,c)
Expr(c.equality, c.output, Expr(:call, :sum, Expr(:vect, c.inputs...))) # sum([a,b,c])
end

struct AllocVecCall <: AbstractCall
Expand Down Expand Up @@ -508,7 +510,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve

visited_Σ[op] = true
visited_Var[r] = true
c = VarargsCall(operator, equality, argnames, rname)
c = SummationCall(equality, argnames, rname)
push!(op_order, c)
end
end
Expand Down
41 changes: 41 additions & 0 deletions test/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,44 @@ for prealloc in [false, true]
end

end

@testset "Large Summations" begin
# Elementwise summations of more than 32 variables are not pre-compiled by our
# host language.

# Test that (.+)(...) is generated for small sums.
SmallSum = @decapode begin
(A00, A01, A02, A03, A04, A05, A06, A07, A08, A09,
A10, A11, A12, A13, A14, A15, A16, A17, A18, A19,
A20, A21, A22, A23, A24, A25, A26, A27, A28, A29,
A30, A31, A32)::Form0

∂ₜ(A00) ==
A01 + A02 + A03 + A04 + A05 + A06 + A07 + A08 + A09 +
A10 + A11 + A12 + A13 + A14 + A15 + A16 + A17 + A18 + A19 +
A20 + A21 + A22 + A23 + A24 + A25 + A26 + A27 + A28 + A29 +
A30 + A31 + A32
end
needle = "A00̇ .= (.+)(A01, A02, A03, A04, A05, A06, A07, A08, A09, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, A31, A32)"
haystack = string(gensim(SmallSum))
@test occursin(needle, haystack)

# Test that sum([...]) is generated for large sums.
LargeSum = @decapode begin
(A00, A01, A02, A03, A04, A05, A06, A07, A08, A09,
A10, A11, A12, A13, A14, A15, A16, A17, A18, A19,
A20, A21, A22, A23, A24, A25, A26, A27, A28, A29,
A30, A31, A32, A33)::Form0

∂ₜ(A00) ==
A01 + A02 + A03 + A04 + A05 + A06 + A07 + A08 + A09 +
A10 + A11 + A12 + A13 + A14 + A15 + A16 + A17 + A18 + A19 +
A20 + A21 + A22 + A23 + A24 + A25 + A26 + A27 + A28 + A29 +
A30 + A31 + A32 + A33
end
needle = "A00̇ .= sum([A01, A02, A03, A04, A05, A06, A07, A08, A09, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, A31, A32, A33])"
haystack = string(gensim(LargeSum))
@test occursin(needle, haystack)

end

26 changes: 16 additions & 10 deletions test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,38 @@ import Decapodes: BinaryCall
end
end

import Decapodes: VarargsCall
import Decapodes: SummationCall

@testset "Test VarargsCall" begin
@testset "Test SummationCall" begin
# Test equality, 2 inputs
@test Expr(VarargsCall(:F, EQUALS, [:x, :y], :z)) == :(z = F(x, y))
@test Expr(SummationCall(EQUALS, [:x, :y], :z)) == :(z = (.+)(x, y))

# Test equality, 3 inputs
@test Expr(VarargsCall(:F, EQUALS, [:x, :y, :w], :z)) == :(z = F(x, y, w))
@test Expr(SummationCall(EQUALS, [:x, :y, :w], :z)) == :(z = (.+)(x, y, w))

# Test equality, 1 input
@test Expr(VarargsCall(:F, EQUALS, [:x], :z)) == :(z = F(x))
@test Expr(SummationCall(EQUALS, [:x], :z)) == :(z = (.+)(x))

# Test equality, 0 inputs
@test Expr(VarargsCall(:F, EQUALS, [], :z)) == :(z = F())
@test Expr(SummationCall(EQUALS, [], :z)) == :(z = (.+)())

# Test broadcast equality, 33 inputs
@test Expr(SummationCall(EQUALS, fill(:x, 33), :z)) == Meta.parse("z = sum([" * foldl(*, fill("x, ", 32)) * "x])")

# Test broadcast equality, 2 inputs
@test Expr(VarargsCall(:F, DOT_EQUALS, [:x, :y], :z)) == :(z .= F(x, y))
@test Expr(SummationCall(DOT_EQUALS, [:x, :y], :z)) == :(z .= (.+)(x, y))

# Test broadcast equality, 3 inputs
@test Expr(VarargsCall(:F, DOT_EQUALS, [:x, :y, :w], :z)) == :(z .= F(x, y, w))
@test Expr(SummationCall(DOT_EQUALS, [:x, :y, :w], :z)) == :(z .= (.+)(x, y, w))

# Test broadcast equality, 1 input
@test Expr(VarargsCall(:F, DOT_EQUALS, [:x], :z)) == :(z .= F(x))
@test Expr(SummationCall(DOT_EQUALS, [:x], :z)) == :(z .= (.+)(x))

# Test broadcast equality, 0 inputs
@test Expr(VarargsCall(:F, DOT_EQUALS, [], :z)) == :(z .= F())
@test Expr(SummationCall(DOT_EQUALS, [], :z)) == :(z .= (.+)())

# Test broadcast equality, 33 inputs
@test Expr(SummationCall(DOT_EQUALS, fill(:x, 33), :z)) == Meta.parse("z .= sum([" * foldl(*, fill("x, ", 32)) * "x])")
end

#######################
Expand Down

0 comments on commit f9ee003

Please sign in to comment.