Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter Function Bug Fix #144

Merged
merged 1 commit into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,29 +709,72 @@ end
################################################################################
const _BadOperators = (:in, :<=, :>=, :(==), :≥, :≤, ∈)

## Make methods to find and replace a part of an expression
# Expr
function _expr_replace!(ex::Expr, old, new)
for (i, a) in enumerate(ex.args)
if a == old
ex.args[i] = new
elseif a isa Expr
_expr_replace!(a, old, new)
end
end
return ex
end

# Symbol
function _expr_replace!(ex::Symbol, old, new)
return ex == old ? new : ex
end

## Safely extract the parameter references (copy as needed)
# Symbol
function _extract_parameters(ex::Symbol)
return esc(ex)
end

# Expr
function _extract_parameters(ex::Expr)
return esc(copy(ex))
end

# Helper method to process parameter function expressions
function _process_func_expr(_error::Function, raw_expr)
if isexpr(raw_expr, :call)
# check that the call is not some operator
if raw_expr.args[1] in _BadOperators
_error("Invalid input syntax.")
end
func_expr = _esc_non_constant(raw_expr.args[1])
func_expr = esc(raw_expr.args[1])
is_anon = false
# check for keywords
if isexpr(raw_expr.args[2], :parameters) ||
any(isexpr(a, :kw) for a in raw_expr.args[2:end])
_error("Cannot specify keyword arguements directly, try using an ",
"anonymous function.")
end
# extract the parameter support inputs
# extract the parameter inputs
pref_expr = esc(Expr(:tuple, raw_expr.args[2:end]...))
elseif isexpr(raw_expr, :(->))
pref_expr = _esc_non_constant(raw_expr.args[1])
func_expr = _esc_non_constant(raw_expr)
# extract the parameter inputs
pref_expr = _extract_parameters(raw_expr.args[1]) # this will create a copy if needed
# fix the function if the parameter arguments were given as references
if isexpr(raw_expr.args[1], :ref)
raw_expr = _expr_replace!(raw_expr, raw_expr.args[1], gensym())
elseif isexpr(raw_expr.args[1], :tuple)
for a in raw_expr.args[1].args
if isexpr(a, :ref)
raw_expr = _expr_replace!(raw_expr, a, gensym())
end
end
end
# extract the function expression
func_expr = esc(raw_expr)
is_anon = true
else
_error("Unrecognized syntax.")
end
return func_expr, pref_expr
return func_expr, pref_expr, is_anon
end

"""
Expand Down Expand Up @@ -811,11 +854,11 @@ macro parameter_function(model, args...)
end
if isexpr(expr, :call) && expr.args[1] === :(==)
var = expr.args[2]
func, prefs = _process_func_expr(_error, expr.args[3])
func, prefs, is_anon_func = _process_func_expr(_error, expr.args[3])
is_anon = isexpr(var, (:vect, :vcat))
elseif isexpr(expr, (:call, :(->)))
var = gensym()
func, prefs = _process_func_expr(_error, expr)
func, prefs, is_anon_func = _process_func_expr(_error, expr)
is_anon = true
else
_error("Unrecognized syntax.")
Expand All @@ -833,8 +876,10 @@ macro parameter_function(model, args...)
_error("Index $(model) is the same symbol as the model. Use a ",
"different name for the index.")
end
if base_name === nothing && is_anon
if base_name === nothing && is_anon && !is_anon_func
name_code = :( string(nameof($func)) )
elseif base_name === nothing && is_anon
name_code = _name_call("", idxvars)
elseif base_name === nothing
name_code = _name_call(string(name), idxvars)
else
Expand Down
32 changes: 29 additions & 3 deletions test/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,33 @@ end
# setup
m = InfiniteModel()
@infinite_parameter(m, t in [0, 1])
@infinite_parameter(m, x[1:2] in [0, 1])
@infinite_parameter(m, x[1:2] in [0, 1], independent = true)
f5(t, x, a...; b...) = 42
# test _expr_replace!
@testset "_expr_replace!" begin
@test InfiniteOpt._expr_replace!(:(t -> f(t, x)), :t, :y) == :(y -> f(y, x))
@test InfiniteOpt._expr_replace!(:((t, x[1]) -> f(t, x[1])), :(x[1]), :y) == :((t, y) -> f(t, y))
@test InfiniteOpt._expr_replace!(:t, :t, :y) == :y
@test InfiniteOpt._expr_replace!(:t, :f, :y) == :t
end
# test _extract_parameters
@testset "_extract_parameters" begin
@test InfiniteOpt._extract_parameters(:y) == esc(:y)
@test InfiniteOpt._extract_parameters(:(x[1])) == esc(:(x[1]))
ex = :((t, x[1]))
@test InfiniteOpt._extract_parameters(ex).args[1] == ex
@test InfiniteOpt._extract_parameters(ex).args[1] !== ex
end
# test _process_func_expr
@testset "_process_func_expr" begin
# test normal
@test InfiniteOpt._process_func_expr(error, :(f(t, x))) == (esc(:f), esc(:(t, x)))
@test InfiniteOpt._process_func_expr(error, :(f(t, x))) == (esc(:f), esc(:(t, x)), false)
anon = :((t, x) -> f(t, x, 1, d = 1))
@test InfiniteOpt._process_func_expr(error, anon) == (esc(anon), esc(:(t,x)))
@test InfiniteOpt._process_func_expr(error, anon) == (esc(anon), esc(:(t,x)), true)
anon = :((t, x[1]) -> sin(t + x[1]))
@test eval(InfiniteOpt._process_func_expr(error, anon)[1].args[1])(0.5, 0.2) == sin(0.5 + 0.2)
anon = :(t[i] -> t[i] + 3)
@test eval(InfiniteOpt._process_func_expr(error, anon)[1].args[1])(2) == 5
# test errors
@test_throws ErrorException InfiniteOpt._process_func_expr(error, :(f(t, d = 2)))
@test_throws ErrorException InfiniteOpt._process_func_expr(error, :(f(x, t; d = 2)))
Expand Down Expand Up @@ -306,6 +325,13 @@ end
@test raw_function(refs[2]) != f5
@test name.(refs) == ["e[1]", "e[2]"]
idx += 2
# test infinite parameter with reference
refs = [GeneralVariableRef(m, idx + i, ParameterFunctionIndex) for i in 0:1]
@test @parameter_function(m, [i = 1:2] == (t, x[i]) -> sin(t + x[i])) == refs
@test parameter_refs.(refs) == [(t, x[1]), (t, x[2])]
@test call_function.(refs, 0.5, 0.2) == [sin(0.5 + 0.2), sin(0.5 + 0.2)]
@test name.(refs) == ["", ""]
idx += 2
end
end

Expand Down