Skip to content

Commit

Permalink
Seperate vars_code checking
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeR227 committed Jun 10, 2024
1 parent ee7cc5e commit 7a232c8
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ end

is_form(d::SummationDecapode, var_name::Symbol) = is_form(d, first(incident(d, var_name, :name)))

function getgeneric_type(type::Symbol)
function getgeneric_type(type::Symbol)
if (type == :Form0 || type == :Form1 || type == :Form2 ||
type == :DualForm0 || type == :DualForm1 || type == :DualForm2)
return :Form
Expand Down Expand Up @@ -259,29 +259,14 @@ This initalizes all input variables according to their Decapodes type.
"""
function get_vars_code(d::AbstractNamedDecapode, vars::Vector{Symbol}, ::Type{stateeltype}, code_target::GenerationTarget = CPUTarget()) where stateeltype
stmts = quote end

map(vars) do s
# If name is not unique (or not just literals) then error
found_names_idxs = incident(d, s, :name)
# TODO: we should handle the case of same literals better
is_singular = length(found_names_idxs) == 1
is_all_literals = all(d[found_names_idxs, :type] .== :Literal)

if !is_singular && !is_all_literals
throw(AmbiguousNameException(s, found_names_idxs))
end

if is_all_literals
push!(stmts.args, :($s = $(parse(stateeltype, String(s)))))
return
end

s_type = getgeneric_type(d[only(found_names_idxs), :type])
map(vars) do s
s_type = getgeneric_type(d[only(incident(d, s, :name)), :type])

# Literals don't need assignments, because they are literals, but we stored them as Symbols.
# TODO: we should fix that upstream so that we don't need this.
line = @match s_type begin
# :Literal => :($s = $(parse(stateeltype, String(s))))
:Literal => :($s = $(parse(stateeltype, String(s))))
:Constant => :($s = p.$s)
:Parameter => :($s = (p.$s)(t))
_ => hook_GVC_get_form(s, s_type, code_target) # ! WARNING: This assumes a form
Expand All @@ -297,6 +282,34 @@ function hook_GVC_get_form(var_name::Symbol, var_type::Symbol, code_target::Unio
return :($var_name = u.$var_name)
end

"""
validate_var_names!(d::AbstractNamedDecapode)
This function checks for any ambiguous names, meaning names that are shared by two or
more different indexed variables. Literals are currently an exception to this so in the
case of a repeat, these are collapsed into a single variable.
"""
function validate_var_names!(d::AbstractNamedDecapode)
remove_cache = Int[]

for name in Set(d[:name])
found_names_idxs = incident(d, name, :name)
is_singular = length(found_names_idxs) == 1
is_all_literals = all(d[found_names_idxs, :type] .== :Literal)

if !is_singular && !is_all_literals
throw(AmbiguousNameException(s, found_names_idxs))
end

if is_all_literals && !is_singular
append!(remove_cache, found_names_idxs[2:end])
end

end

rem_parts!(d, :Var, sort(remove_cache))
end

"""
set_tanvars_code(d::AbstractNamedDecapode)
Expand Down Expand Up @@ -611,6 +624,7 @@ function gensim(user_d::AbstractNamedDecapode, input_vars::Vector{Symbol}; dimen
dec_matrices = Vector{Symbol}();
alloc_vectors = Vector{AllocVecCall}();

validate_var_names!(gen_d)
vars = get_vars_code(gen_d, input_vars, stateeltype, code_target)
tars = set_tanvars_code(gen_d, code_target)

Expand Down

0 comments on commit 7a232c8

Please sign in to comment.