Skip to content

Commit

Permalink
parameter function bug fix (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
pulsipher authored Jun 18, 2021
1 parent caf640a commit c772d3f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
61 changes: 53 additions & 8 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,29 +709,72 @@ end
################################################################################
const _BadOperators = (:in, :<=, :>=, :(==), :, :, )

## Make methods to find and replace a part of an expression
# Expr
function _expr_replace!(ex::Expr, old, new)
for (i, a) in enumerate(ex.args)
if a == old
ex.args[i] = new
elseif a isa Expr
_expr_replace!(a, old, new)
end
end
return ex
end

# Symbol
function _expr_replace!(ex::Symbol, old, new)
return ex == old ? new : ex
end

## Safely extract the parameter references (copy as needed)
# Symbol
function _extract_parameters(ex::Symbol)
return esc(ex)
end

# Expr
function _extract_parameters(ex::Expr)
return esc(copy(ex))
end

# Helper method to process parameter function expressions
function _process_func_expr(_error::Function, raw_expr)
if isexpr(raw_expr, :call)
# check that the call is not some operator
if raw_expr.args[1] in _BadOperators
_error("Invalid input syntax.")
end
func_expr = _esc_non_constant(raw_expr.args[1])
func_expr = esc(raw_expr.args[1])
is_anon = false
# check for keywords
if isexpr(raw_expr.args[2], :parameters) ||
any(isexpr(a, :kw) for a in raw_expr.args[2:end])
_error("Cannot specify keyword arguements directly, try using an ",
"anonymous function.")
end
# extract the parameter support inputs
# extract the parameter inputs
pref_expr = esc(Expr(:tuple, raw_expr.args[2:end]...))
elseif isexpr(raw_expr, :(->))
pref_expr = _esc_non_constant(raw_expr.args[1])
func_expr = _esc_non_constant(raw_expr)
# extract the parameter inputs
pref_expr = _extract_parameters(raw_expr.args[1]) # this will create a copy if needed
# fix the function if the parameter arguments were given as references
if isexpr(raw_expr.args[1], :ref)
raw_expr = _expr_replace!(raw_expr, raw_expr.args[1], gensym())
elseif isexpr(raw_expr.args[1], :tuple)
for a in raw_expr.args[1].args
if isexpr(a, :ref)
raw_expr = _expr_replace!(raw_expr, a, gensym())
end
end
end
# extract the function expression
func_expr = esc(raw_expr)
is_anon = true
else
_error("Unrecognized syntax.")
end
return func_expr, pref_expr
return func_expr, pref_expr, is_anon
end

"""
Expand Down Expand Up @@ -811,11 +854,11 @@ macro parameter_function(model, args...)
end
if isexpr(expr, :call) && expr.args[1] === :(==)
var = expr.args[2]
func, prefs = _process_func_expr(_error, expr.args[3])
func, prefs, is_anon_func = _process_func_expr(_error, expr.args[3])
is_anon = isexpr(var, (:vect, :vcat))
elseif isexpr(expr, (:call, :(->)))
var = gensym()
func, prefs = _process_func_expr(_error, expr)
func, prefs, is_anon_func = _process_func_expr(_error, expr)
is_anon = true
else
_error("Unrecognized syntax.")
Expand All @@ -833,8 +876,10 @@ macro parameter_function(model, args...)
_error("Index $(model) is the same symbol as the model. Use a ",
"different name for the index.")
end
if base_name === nothing && is_anon
if base_name === nothing && is_anon && !is_anon_func
name_code = :( string(nameof($func)) )
elseif base_name === nothing && is_anon
name_code = _name_call("", idxvars)
elseif base_name === nothing
name_code = _name_call(string(name), idxvars)
else
Expand Down
32 changes: 29 additions & 3 deletions test/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,33 @@ end
# setup
m = InfiniteModel()
@infinite_parameter(m, t in [0, 1])
@infinite_parameter(m, x[1:2] in [0, 1])
@infinite_parameter(m, x[1:2] in [0, 1], independent = true)
f5(t, x, a...; b...) = 42
# test _expr_replace!
@testset "_expr_replace!" begin
@test InfiniteOpt._expr_replace!(:(t -> f(t, x)), :t, :y) == :(y -> f(y, x))
@test InfiniteOpt._expr_replace!(:((t, x[1]) -> f(t, x[1])), :(x[1]), :y) == :((t, y) -> f(t, y))
@test InfiniteOpt._expr_replace!(:t, :t, :y) == :y
@test InfiniteOpt._expr_replace!(:t, :f, :y) == :t
end
# test _extract_parameters
@testset "_extract_parameters" begin
@test InfiniteOpt._extract_parameters(:y) == esc(:y)
@test InfiniteOpt._extract_parameters(:(x[1])) == esc(:(x[1]))
ex = :((t, x[1]))
@test InfiniteOpt._extract_parameters(ex).args[1] == ex
@test InfiniteOpt._extract_parameters(ex).args[1] !== ex
end
# test _process_func_expr
@testset "_process_func_expr" begin
# test normal
@test InfiniteOpt._process_func_expr(error, :(f(t, x))) == (esc(:f), esc(:(t, x)))
@test InfiniteOpt._process_func_expr(error, :(f(t, x))) == (esc(:f), esc(:(t, x)), false)
anon = :((t, x) -> f(t, x, 1, d = 1))
@test InfiniteOpt._process_func_expr(error, anon) == (esc(anon), esc(:(t,x)))
@test InfiniteOpt._process_func_expr(error, anon) == (esc(anon), esc(:(t,x)), true)
anon = :((t, x[1]) -> sin(t + x[1]))
@test eval(InfiniteOpt._process_func_expr(error, anon)[1].args[1])(0.5, 0.2) == sin(0.5 + 0.2)
anon = :(t[i] -> t[i] + 3)
@test eval(InfiniteOpt._process_func_expr(error, anon)[1].args[1])(2) == 5
# test errors
@test_throws ErrorException InfiniteOpt._process_func_expr(error, :(f(t, d = 2)))
@test_throws ErrorException InfiniteOpt._process_func_expr(error, :(f(x, t; d = 2)))
Expand Down Expand Up @@ -306,6 +325,13 @@ end
@test raw_function(refs[2]) != f5
@test name.(refs) == ["e[1]", "e[2]"]
idx += 2
# test infinite parameter with reference
refs = [GeneralVariableRef(m, idx + i, ParameterFunctionIndex) for i in 0:1]
@test @parameter_function(m, [i = 1:2] == (t, x[i]) -> sin(t + x[i])) == refs
@test parameter_refs.(refs) == [(t, x[1]), (t, x[2])]
@test call_function.(refs, 0.5, 0.2) == [sin(0.5 + 0.2), sin(0.5 + 0.2)]
@test name.(refs) == ["", ""]
idx += 2
end
end

Expand Down

0 comments on commit c772d3f

Please sign in to comment.