Skip to content

Commit

Permalink
Refactor accessors on containers/arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 18, 2016
1 parent 3dd548e commit 6473008
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 20 deletions.
37 changes: 26 additions & 11 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ isdefined(Base, :__precompile__) && __precompile__()
module JuMP

importall Base.Operators
import Base.map

import MathProgBase

Expand Down Expand Up @@ -123,7 +124,7 @@ type Model <: AbstractModel
conDict::Dict{Symbol,Any} # dictionary from constraint names to constraint objects
varData::ObjectIdDict

getvalue_counter::Int # number of times we call getvalue on a JuMPContainer, so that we can print out a warning
map_counter::Int # number of times we call getvalue, getdual, getlowerbound and getupperbound on a JuMPContainer, so that we can print out a warning
operator_counter::Int # number of times we add large expressions

# Extension dictionary - e.g. for robust
Expand Down Expand Up @@ -177,7 +178,7 @@ function Model(;solver=UnsetSolver(), simplify_nonlinear_expressions::Bool=false
Dict{Symbol,Any}(), # varDict
Dict{Symbol,Any}(), # conDict
ObjectIdDict(), # varData
0, # getvalue_counter
0, # map_counter
0, # operator_counter
Dict{Symbol,Any}(), # ext
)
Expand Down Expand Up @@ -394,10 +395,12 @@ end
# internal method that doesn't print a warning if the value is NaN
_getValue(v::Variable) = v.m.colVal[v.col]

getvaluewarn(v) = Base.warn("Variable value not defined for $(getname(v)). Check that the model was properly solved.")

function getvalue(v::Variable)
ret = _getValue(v)
if isnan(ret)
Base.warn("Variable value not defined for $(getname(v)). Check that the model was properly solved.")
getvaluewarn(v)
end
ret
end
Expand Down Expand Up @@ -434,11 +437,17 @@ function getvalue(arr::Array{Variable})
end

# Dual value (reduced cost) getter

# internal method that doesn't print a warning if the value is NaN
_getDual(v::Variable) = v.m.redCosts[v.col]

getdualwarn(::Variable) = error("Variable bound duals (reduced costs) not available. Check that the model was properly solved and no integer variables are present.")

function getdual(v::Variable)
if length(v.m.redCosts) < MathProgBase.numvar(v.m)
error("Variable bound duals (reduced costs) not available. Check that the model was properly solved and no integer variables are present.")
getdualwarn(v)
end
return v.m.redCosts[v.col]
return _getDual(v)
end

const var_cats = [:Cont, :Int, :Bin, :SemiCont, :SemiInt]
Expand Down Expand Up @@ -538,11 +547,16 @@ LinearConstraint(ref::LinConstrRef) = ref.m.linconstr[ref.idx]::LinearConstraint

linearindex(x::ConstraintRef) = x.idx

# internal method that doesn't print a warning if the value is NaN
_getDual(c::LinConstrRef) = c.m.linconstrDuals[c.idx]

getdualwarn(::LinConstrRef) = error("Dual solution not available. Check that the model was properly solved and no integer variables are present.")

function getdual(c::LinConstrRef)
if length(c.m.linconstrDuals) != MathProgBase.numlinconstr(c.m)
error("Dual solution not available. Check that the model was properly solved and no integer variables are present.")
warnnan(_getDual)
end
return c.m.linconstrDuals[c.idx]
return _getDual(c)
end

# Returns the number of non-infinity and nonzero bounds on variables
Expand Down Expand Up @@ -796,13 +810,14 @@ function getconstraint(m::Model, conname::Symbol)
end

# usage warnings
function getvalue_warn(x::JuMPContainer{Variable})
function mapcontainer_warn(f, x::JuMPContainer{Variable})
isempty(x) && return
v = first(values(x))
m = v.m
m.getvalue_counter += 1
if m.getvalue_counter > 400
Base.warn_once("getvalue has been called on a collection of variables a large number of times. For performance reasons, this should be avoided. Instead of getvalue(x)[a,b,c], use getvalue(x[a,b,c]) to avoid temporary allocations.")
m.map_counter += 1
if m.map_counter > 400
# It might not be f that was called the 400 first times but most probably it is f
Base.warn_once("$f has been called on a collection of variables a large number of times. For performance reasons, this should be avoided. Instead of $f(x)[a,b,c], use $f(x[a,b,c]) to avoid temporary allocations.")
end
return
end
Expand Down
35 changes: 27 additions & 8 deletions src/JuMPContainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type JuMPContainerData
condition
end

# Needed by getvaluewarn when called by _mapInner
getname(data::JuMPContainerData) = data.name

#JuMPDict{T,N}(name::AbstractString) =
# JuMPDict{T,N}(Dict{NTuple{N},T}(), name)

Expand Down Expand Up @@ -98,25 +101,40 @@ pushmeta!(x::JuMPContainer, sym::Symbol, val) = (x.meta[sym] = val)
getmeta(x::JuMPContainer, sym::Symbol) = x.meta[sym]

# duck typing approach -- if eltype(innerArray) doesn't support accessor, will fail
for accessor in (:getdual, :getlowerbound, :getupperbound)
@eval $accessor(x::Union{JuMPContainer,Array}) = map($accessor,x)
for accessor in (:getdual, :getlowerbound, :getupperbound, :getvalue)
@eval $accessor(x::AbstractArray) = map($accessor,x)
end
# With JuMPContainer, we take care in _mapInner of the warning if NaN values are returned
# by the accessor so we use the inner accessor that does not generate warnings
for (accessor, inner) in ((:getdual, :_getDual), (:getlowerbound, :getlowerbound), (:getupperbound, :getupperbound), (:getvalue, :_getValue))
@eval $accessor(x::JuMPContainer) = _map($inner,x)
end


_similar(x::Array) = Array(Float64,size(x))
_similar{T}(x::Dict{T}) = Dict{T,Float64}()

_innercontainer(x::JuMPArray) = x.innerArray
_innercontainer(x::JuMPDict) = x.tupledict

function _getValueInner(x)
# Warning for getter returning NaN
function _warnnan(f, data)
if f === _getValue
getvaluewarn(data)
elseif f === _getDual
getdualwarn(data)
end
end

function _mapInner(f, x::JuMPContainer)
vars = _innercontainer(x)
vals = _similar(vars)
data = printdata(x)
warnedyet = false
for I in eachindex(vars)
tmp = _getValue(vars[I])
tmp = f(vars[I])
if isnan(tmp) && !warnedyet
warn("Variable value not defined for entry of $(data.name). Check that the model was properly solved.")
_warnnan(f, data.name)
warnedyet = true
end
vals[I] = tmp
Expand All @@ -127,9 +145,10 @@ end
JuMPContainer_from(x::JuMPDict,inner) = JuMPDict(inner)
JuMPContainer_from(x::JuMPArray,inner) = JuMPArray(inner, x.indexsets)

function getvalue(x::JuMPContainer)
getvalue_warn(x)
ret = JuMPContainer_from(x,_getValueInner(x))
# The name _map is used instead of map so that this function is called instead of map(::Function, ::JuMPArray)
function _map(f, x::JuMPContainer)
mapcontainer_warn(f, x)
ret = JuMPContainer_from(x, _mapInner(f, x))
# I guess copy!(::Dict, ::Dict) isn't defined, so...
for (key,val) in x.meta
ret.meta[key] = val
Expand Down
2 changes: 1 addition & 1 deletion src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function solve(m::Model; suppress_warnings=false,
isempty(kwargs) || error("Unrecognized keyword arguments: $(join([k[1] for k in kwargs], ", "))")

# Clear warning counters
m.getvalue_counter = 0
m.map_counter = 0
m.operator_counter = 0

# Remember if the solver was initially unset so we can restore
Expand Down
16 changes: 16 additions & 0 deletions test/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,19 @@ end
@test MathProgBase.numvar(m) == 4
end
end

@testset "[variable] getvalue on sparse array #889" begin
m = Model()
@variable(m, x)
@variable(m, y)
X = sparse([1, 3], [2, 3], [x, y])

@test typeof(X) == SparseMatrixCSC{Variable, Int}

setvalue(x, 1)
setvalue(y, 2)

Y = getvalue(X)
@test typeof(Y) == SparseMatrixCSC{Float64, Int}
@test Y == sparse([1, 3], [2, 3], [1, 2])
end

0 comments on commit 6473008

Please sign in to comment.