Skip to content

Commit

Permalink
Remove NamedDecapode constructor and compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored and mehalter committed Apr 24, 2023
1 parent d2d5cfb commit 3b3ab55
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 74 deletions.
4 changes: 2 additions & 2 deletions examples/sw/sw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end


diffExpr = parse_decapode(DiffusionExprBody)
ddp = NamedDecapode(diffExpr)
ddp = SummationDecapode(diffExpr)
gensim(expand_operators(ddp), [:C])
f = eval(gensim(expand_operators(ddp), [:C]))

Expand Down Expand Up @@ -110,7 +110,7 @@ AdvDiff = quote
end

advdiff = parse_decapode(AdvDiff)
advdiffdp = NamedDecapode(advdiff)
advdiffdp = SummationDecapode(advdiff)
gensim(expand_operators(advdiffdp), [:C, :V])
sim = eval(gensim(expand_operators(advdiffdp), [:C, :V]))

Expand Down
18 changes: 0 additions & 18 deletions src/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,6 @@ function Decapode(e::DecaExpr)
return d
end

function NamedDecapode(e::DecaExpr)
d = NamedDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()
for judgement in e.judgements
var_id = add_part!(d, :Var, name=judgement._1._1, type=judgement._2)
symbol_table[judgement._1._1] = var_id
end
deletions = Vector{Int64}()
for eq in e.equations
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))
fill_names!(d)
d[:name] = map(normalize_unicode,d[:name])
recognize_types(d)
return d
end

function SummationDecapode(e::DecaExpr)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()
Expand Down
54 changes: 0 additions & 54 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,60 +155,6 @@ end

compile(d::AbstractNamedDecapode) = compile(d, infer_state_names(d))

function compile(d::NamedDecapode, inputs::Vector)
input_numbers = incident(d, inputs, :name)
visited = falses(nparts(d, :Var))
visited[collect(flatten(input_numbers))] .= true
consumed1 = falses(nparts(d, :Op1))
consumed2 = falses(nparts(d, :Op2))
# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = []
for iter in 1:(nparts(d, :Op1) + nparts(d,:Op2))
for op in parts(d, :Op1)
s = d[op, :src]
if !consumed1[op] && visited[s]
# skip the derivative edges
operator = d[op, :op1]
t = d[op, :tgt]
if operator == DerivOp
continue
end
consumed1[op] = true
visited[t] = true
sname = d[s, :name]
tname = d[t, :name]
c = UnaryCall(operator, sname, tname)
push!(op_order, c)
end
end

for op in parts(d, :Op2)
arg1 = d[op, :proj1]
arg2 = d[op, :proj2]
if !consumed2[op] && visited[arg1] && visited[arg2]
r = d[op, :res]
a1name = d[arg1, :name]
a2name = d[arg2, :name]
rname = d[r, :name]
operator = d[op, :op2]
consumed2[op] = true
visited[r] = true
c = BinaryCall(operator, a1name, a2name, rname)
push!(op_order, c)
end
end
end
assigns = map(Expr, op_order)
ret = :(return)
ret.args = d[d[:,:incl], :name]
return quote f(du, u, p, t) = begin
$(get_vars_code(d, inputs))
$(assigns...)
du .= 0.0
$(set_tanvars_code(d, inputs))
end; end
end

function compile(d::SummationDecapode, inputs::Vector)
# Get the Vars of the inputs (probably state Vars).
input_numbers = incident(d, inputs, :name)
Expand Down

0 comments on commit 3b3ab55

Please sign in to comment.