Skip to content

Commit

Permalink
Gr/parsing overhaul (#95)
Browse files Browse the repository at this point in the history
* Basic literal support

* Added Lit data term

* Fix composed op2s and add test

* Replaced Var for term in \circ decomp

* Quick patch for arbitrary multiplication support.

* Fixed Tvar's unable to parse on rhs

* Added inferred spaces for variables

* Do not infer literals as states

* Check types are in recognized set

* Fixed strange TVar behavior

* Fix cache invalidation error and add tests

* Add new parsing features tests

* Updated visualization for support new parsing.

Also cleaned up the parsing code.

* Support Literals in simulations

* Do not namespace Literals, and do not vcat Nothing

* Move CombinatorialSpaces using

* Remove erroneous for loop

---------

Co-authored-by: Luke Morris <lukelukemorrismorris@gmail.com>
Co-authored-by: Luke Morris <70283489+lukem12345@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 21, 2023
1 parent ead3181 commit 69d63b5
Show file tree
Hide file tree
Showing 11 changed files with 521 additions and 283 deletions.
6 changes: 4 additions & 2 deletions src/Decapodes2/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ function oapply_rename(relation::RelationDiagram, decapodes::Vector{D}) where D<
for b boxes(r)
box_name = r[b, :name]
for v parts(decapodes_vars[b], :Var)
var_name = decapodes_vars[b][v, :name]
decapodes_vars[b][v, :name] = Symbol(box_name, '_', var_name)
if decapodes_vars[b][v, :type] != :Literal
var_name = decapodes_vars[b][v, :name]
decapodes_vars[b][v, :name] = Symbol(box_name, '_', var_name)
end
end
end

Expand Down
47 changes: 30 additions & 17 deletions src/Decapodes2/decapodeacset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,45 @@ end
add new variable names to all the variables that don't have names.
"""
function fill_names!(d::AbstractNamedDecapode)
bulletcount = 1
for i in parts(d, :Var)
if !isassigned(d[:,:name],i) || isnothing(d[i, :name])
d[i,:name] = Symbol("$bulletcount")
bulletcount += 1
end
end
for e in incident(d, :∂ₜ, :op1)
s = d[e,:src]
t = d[e, :tgt]
d[t, :name] = append_dot(d[s,:name])
bulletcount = 1
for i in parts(d, :Var)
if !isassigned(d[:,:name],i) || isnothing(d[i, :name])
d[i,:name] = Symbol("$bulletcount")
bulletcount += 1
end
return d
end
for e in incident(d, :∂ₜ, :op1)
s = d[e,:src]
t = d[e, :tgt]
d[t, :name] = append_dot(d[s,:name])
end
return d
end

function make_sum_unique!(d::AbstractNamedDecapode)
num = 1
function make_sum_mult_unique!(d::AbstractNamedDecapode)
snum = 1
mnum = 1
for (i, name) in enumerate(d[:name])
if(name == :sum)
d[i, :name] = Symbol(join([String(name), string(num)] , "_"))
num += 1
d[i, :name] = Symbol("sum_$(snum)")
snum += 1
elseif(name == :mult)
d[i, :name] = Symbol("mult_$(mnum)")
mnum += 1
end
end
end

# Note: This hard-bakes in Form0 through Form2, and higher Forms are not
# allowed.
function recognize_types(d::AbstractNamedDecapode)
unrecognized_types = setdiff(d[:type], [:Form0, :Form1, :Form2, :DualForm0,
:DualForm1, :DualForm2, :Literal, :Parameter,
:Constant, :infer])
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized.")
end

function expand_operators(d::AbstractNamedDecapode)
e = SummationDecapode{Symbol, Symbol, Symbol}()
copy_parts!(e, d, (:Var, :TVar, :Op2))
Expand Down Expand Up @@ -476,4 +490,3 @@ end
resolve_overloads!(d::SummationDecapode) =
resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D)

+
2 changes: 0 additions & 2 deletions src/Decapodes2/decapodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using Catlab.Programs
using Catlab.CategoricalAlgebra
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using CombinatorialSpaces
using CombinatorialSpaces.ExteriorCalculus
using LinearAlgebra
using MLStyle
using Base.Iterators
Expand Down
124 changes: 95 additions & 29 deletions src/Decapodes2/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# - way to represent this in a Decapode ACSet.
@data Term begin
Var(Symbol)
Lit(Symbol)
Judge(Var, Symbol, Symbol) # Symbol 1: Form0 Symbol 2: X
AppCirc1(Vector{Symbol}, Term)
AppCirc2(Vector{Symbol}, Term, Term)
App1(Symbol, Term)
App2(Symbol, Term, Term)
Plus(Vector{Term})
Mult(Vector{Term})
Tan(Term)
end

Expand All @@ -26,17 +27,22 @@ struct DecaExpr
end

term(s::Symbol) = Var(normalize_unicode(s))
term(s::Number) = Lit(Symbol(s))

term(expr::Expr) = begin
@match expr begin
Expr(a) => Var(normalize_unicode(a))
Expr(:call, :∂ₜ, b) => Tan(term(b))
Expr(:call, Expr(:call, :, a...), b) => AppCirc1(a, Var(b))
#TODO: Would we want ∂ₜ to be used with general expressions or just Vars?
Expr(:call, :∂ₜ, b) => Tan(Var(b))

Expr(:call, Expr(:call, :, a...), b) => AppCirc1(a, term(b))
Expr(:call, a, b) => App1(a, term(b))
Expr(:call, Expr(:call, :, f...), x, y) => AppCirc1(f, Var(x), Var(y))

Expr(:call, :+, xs...) => Plus(term.(xs))
Expr(:call, f, x, y) => App2(f, term(x), term(y))
Expr(:call, :, a...) => (:AppCirc1, map(term, a))

# TODO: Will later be converted to Op2's or schema has to be changed to include multiplication
Expr(:call, :*, xs...) => Mult(term.(xs))

x => error("Cannot construct term from $x")
end
end
Expand All @@ -45,10 +51,16 @@ function parse_decapode(expr::Expr)
stmts = map(expr.args) do line
@match line begin
::LineNumberNode => missing
Expr(:(::), a::Symbol, b) => Judge(Var(a),b.args[1], b.args[2])
# TODO: If user doesn't provide space, this gives a temp space so we can continue to construction
# For now spaces don't matter so this is fine but if they do, this will need to change
Expr(:(::), a::Symbol, b::Symbol) => Judge(Var(a), b, :I)
Expr(:(::), a::Expr, b::Symbol) => map(sym -> Judge(Var(sym), b, :I), a.args)

Expr(:(::), a::Symbol, b) => Judge(Var(a), b.args[1], b.args[2])
Expr(:(::), a::Expr, b) => map(sym -> Judge(Var(sym), b.args[1], b.args[2]), a.args)

Expr(:call, :(==), lhs, rhs) => Eq(term(lhs), term(rhs))
x => x
_ => error("The line $line is malformed")
end
end |> skipmissing |> collect
judges = []
Expand All @@ -63,13 +75,27 @@ function parse_decapode(expr::Expr)
end
DecaExpr(judges, eqns)
end

# to_decapode helper functions
reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) =
let ! = reduce_term!
@match t begin
Var(x) => syms[x]
App1(f, t) => begin
Var(x) => begin
if haskey(syms, x)
syms[x]
else
res_var = add_part!(d, :Var, name = x, type=:infer)
syms[x] = res_var
end
end
Lit(x) => begin
if haskey(syms, x)
syms[x]
else
res_var = add_part!(d, :Var, name = x, type=:Literal)
syms[x] = res_var
end
end
App1(f, t) || AppCirc1(f, t) => begin
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op1, src=!(t,d,syms), tgt=res_var, op1=f)
return res_var
Expand All @@ -79,16 +105,6 @@ reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) =
add_part!(d, :Op2, proj1=!(t1,d,syms), proj2=!(t2,d,syms), res=res_var, op2=f)
return res_var
end
AppCirc1(fs, t) => begin
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op1, src=!(t,d,syms), tgt=res_var, op1=fs)
return res_var
end
AppCirc2(f, t1, t2) => begin
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op2, proj1=!(t1,d,syms), proj2=!(t2,d,syms), res=res_var, op2=fs)
return res_var
end
Plus(ts) => begin
summands = [!(t,d,syms) for t in ts]
res_var = add_part!(d, :Var, type=:infer, name=:sum)
Expand All @@ -98,8 +114,22 @@ reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) =
end
return res_var
end
# TODO: Just for now assuming we have 2 or more terms
Mult(ts) => begin
multiplicands = [!(t,d,syms) for t in ts]
res_var = add_part!(d, :Var, type=:infer, name=:mult)
m1,m2 = multiplicands[1:2]
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
for m in multiplicands[3:end]
m1 = res_var
m2 = m
res_var = add_part!(d, :Var, type=:infer, name=:mult)
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
end
return res_var
end
Tan(t) => begin
# TODO: this is creating a spurious variablbe with the same name
# TODO: this is creating a spurious variable with the same name
txv = add_part!(d, :Var, type=:infer)
tx = add_part!(d, :TVar, incl=txv)
tanop = add_part!(d, :Op1, src=!(t,d,syms), tgt=txv, op1=DerivOp)
Expand All @@ -109,13 +139,32 @@ reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) =
end
end

function eval_eq!(eq::Equation, d::AbstractDecapode, syms::Dict{Symbol, Int})
function eval_eq!(eq::Equation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
@match eq begin
Eq(t1, t2) => begin
lhs_ref = reduce_term!(t1,d,syms)
rhs_ref = reduce_term!(t2,d,syms)
deletions = []

# Always let the a named variable take precedence
# TODO: If we have variable to variable equality, we want
# some kind of way to check track of this equality
ref_pair = (t1, t2)
@match ref_pair begin
(Var(a), Var(b)) => return d
(t1, Var(b)) => begin
lhs_ref, rhs_ref = rhs_ref, lhs_ref
end
_ => nothing
end

# Make rhs_ref equal to lhs_ref and adjust all its incidents

# Case rhs_ref is a Tan
# WARNING: Don't push to deletion here because all TanVars should have a
# corresponding Op1. Pushing here would create a duplicate which breaks rem_parts!
for rhs in incident(d, rhs_ref, :incl)
d[rhs, :incl] = lhs_ref
end
# Case rhs_ref is a Op1
for rhs in incident(d, rhs_ref, :tgt)
d[rhs, :tgt] = lhs_ref
Expand All @@ -136,7 +185,7 @@ function eval_eq!(eq::Equation, d::AbstractDecapode, syms::Dict{Symbol, Int})
end
# TODO: delete unused vars. The only thing stopping me from doing
# this is I don't know if CSet deletion preserves incident relations
rem_parts!(d, :Var, sort(deletions))
#rem_parts!(d, :Var, sort(deletions))
end
end
return d
Expand All @@ -156,9 +205,12 @@ function Decapode(e::DecaExpr)
var_id = add_part!(d, :Var, type=(judgement._2, judgement._3))
symbol_table[judgement._1._1] = var_id
end
deletions = Vector{Int64}()
for eq in e.equations
eval_eq!(eq, d, symbol_table)
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))
recognize_types(d)
return d
end

Expand All @@ -169,26 +221,40 @@ function NamedDecapode(e::DecaExpr)
var_id = add_part!(d, :Var, name=judgement._1._1, type=judgement._2)
symbol_table[judgement._1._1] = var_id
end
deletions = Vector{Int64}()
for eq in e.equations
eval_eq!(eq, d, symbol_table)
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))
fill_names!(d)
d[:name] = map(normalize_unicode,d[:name])
recognize_types(d)
return d
end

function SummationDecapode(e::DecaExpr)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()

for judgement in e.judgements
var_id = add_part!(d, :Var, name=judgement._1._1, type=judgement._2)
symbol_table[judgement._1._1] = var_id
end

deletions = Vector{Int64}()
for eq in e.equations
eval_eq!(eq, d, symbol_table)
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))

recognize_types(d)

fill_names!(d)
d[:name] .= normalize_unicode.(d[:name])
make_sum_unique!(d)
make_sum_mult_unique!(d)
return d
end

macro SummationDecapode(e)
:(SummationDecapode(parse_decapode($(Meta.quot(e)))))
end
Loading

0 comments on commit 69d63b5

Please sign in to comment.