Skip to content

Commit

Permalink
Final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeR227 committed Jun 12, 2024
1 parent 807961a commit 92d2cff
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,18 +375,20 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

# input_numbers = incident(d, inputs, :name)
# TODO: Pass in state indices instead of names
input_numbers = reduce(vcat, incident(d, inputs, :name))
# visited_Var[collect(flatten(input_numbers))] .= true

visited_Var[input_numbers] .= true
visited_Var[incident(d, :Literal, :type)] .= true

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = []
op_order = AbstractCall[]

for _ in 1:(nparts(d, :Op1) + nparts(d,:Op2) + nparts(d, ))
for op in parts(d, :Op1)
s = d[op, :src]
Expand Down Expand Up @@ -581,10 +583,10 @@ end
A combined `infer_types` and `resolve_overloads` pipeline with default DEC rules.
"""
function infer_overload_compiler!(d::SummationDecapode, dimension::Int)
if(dimension == 1)
if dimension == 1
infer_types!(d, op1_inf_rules_1D, op2_inf_rules_1D)
resolve_overloads!(d, op1_res_rules_1D, op2_res_rules_1D)
elseif(dimension == 2)
elseif dimension == 2
infer_types!(d, op1_inf_rules_2D, op2_inf_rules_2D)
resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D)
end
Expand All @@ -596,16 +598,9 @@ end
Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})

for op1_name in d[:op1]
if(op1_name optimizable_dec_operators)
push!(dec_matrices, op1_name)
end
end

for op2_name in d[:op2]
if(op2_name optimizable_dec_operators)
push!(dec_matrices, op2_name)
for op_name in vcat(d[:op1], d[:op2])
if(op_name in optimizable_dec_operators)
push!(dec_matrices, op_name)
end
end
end
Expand Down Expand Up @@ -696,10 +691,13 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
(stateeltype == Float32 || stateeltype == Float64) ||
throw(UnsupportedStateeltypeException(stateeltype))

recognize_types(user_d)
# Explicit copy for safety
gen_d = deepcopy(user_d)

recognize_types(gen_d)

# Makes copy
gen_d = expand_operators(user_d)
gen_d = expand_operators(gen_d)

dec_matrices = Vector{Symbol}()
alloc_vectors = Vector{AllocVecCall}()
Expand Down

0 comments on commit 92d2cff

Please sign in to comment.