Skip to content

Commit

Permalink
reorganize ASP a little
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 8, 2024
1 parent 21a9f8b commit eb55759
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 158 deletions.
1 change: 1 addition & 0 deletions src/ACEfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("bayesianlinear.jl")
include("data.jl")
include("assemble.jl")
include("solvers.jl")
include("asp.jl")
include("fit.jl")

end
157 changes: 157 additions & 0 deletions src/asp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

@doc raw"""
`ASP` : Active Set Pursuit solver
Solves the lasso optimization problem.
```math
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
```
subject to
```math
\|A^T y\|_{\infty} \leq 1.
```
### Constructor Keyword arguments
```julia
ACEfit.ASP(; P = I, select = (:byerror, 1.0),
params...)
```
* `select` : Selection criterion for the final solution (required)
* `:final` : final solution (largest computed basis)
* `(:byerror, q)` : solution with error within `q` times the minimum error
along the path; if training error is used and `q == 1.0`, then this is
equivalent to to `:final`.
* `(:bysize, n)` : best solution with at most `n` non-zero features; if
training error is used, then it will be the solution with exactly `n`
non-zero features.
* `P = I` : prior / regularizer (optional)
The remaining kwarguments to `ASP` are parameters for the ASP homotopy solver.
* `actMax` : Maximum number of active constraints.
* `min_lambda` : Minimum value for `λ`. (defaults to 0)
* `loglevel` : Logging level.
* `itnMax` : Maximum number of iterations.
### Extended syntax for `solve`
```julia
solve(solver::ASP, A, y, Aval=A, yval=y)
```
* `A` : `m`-by-`n` design matrix. (required)
* `b` : `m`-vector. (required)
* `Aval = nothing` : `p`-by-`n` validation matrix (only for `validate` mode).
* `bval = nothing` : `p`- validation vector (only for `validate` mode).
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
then the solver will use this separate validation set instead of the training
set to select the best solution along the model path.
"""
struct ASP
P
select
params
end

function ASP(; P = I, select, mode=:train, params...)
return ASP(P, select, params)
end

function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params...)
q = length(tracer)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)

for i in 1:q
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
end

xs, in = select_solution(new_tracer, solver, Aval, yval)

# println("done.")
return Dict( "C" => xs,
"path" => new_tracer,
"nnzs" => length((new_tracer[in][:solution]).nzind) )
end


function select_solution(tracer, solver, A, y)
if solver.select == :final
criterion = :final
else
criterion, p = solver.select
end

if criterion == :final
return tracer[end][:solution], length(tracer)
end

if criterion == :byerror
maxind = length(tracer)
threshold = p
elseif criterion == :bysize
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
tracer) - 1
threshold = 1.0
else
error("Unknown selection criterion: $criterion")
end

errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
min_error = minimum(errors)
for (i, error) in enumerate(errors)
if error <= threshold * min_error
return tracer[i][:solution], i
end
end

error("selection failed for unknown reasons; please file an issue with a MWE to reproduce this error.")
end


#=
function select_smart(tracer, solver, Aval, yval)
best_metric = Inf
best_iteration = 0
validation_metric = 0
q = length(tracer)
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
nnzss = [(t[:solution]).nzind for t in tracer]
best_iteration = argmin(errors)
validation_metric = errors[best_iteration]
validation_end = norm(Aval * tracer[end][:solution] - yval)
if validation_end < validation_metric #make sure to check the last one too in case q<<100
best_iteration = q
end
criterion, threshold = solver.select
if criterion == :val
return tracer[best_iteration][:solution], best_iteration
elseif criterion == :byerror
for (i, error) in enumerate(errors)
if error <= threshold * validation_metric
return tracer[i][:solution], i
end
end
elseif criterion == :bysize
first_index = findfirst(sublist -> threshold in sublist, nnzss)
relevant_errors = errors[1:first_index - 1]
min_error = minimum(relevant_errors)
min_error_index = findfirst(==(min_error), relevant_errors)
return tracer[min_error_index][:solution], min_error_index
else
@error("Unknown selection criterion: $criterion")
end
end
=#
136 changes: 0 additions & 136 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,139 +196,3 @@ function solve(solver::TruncatedSVD, A, y)
return Dict{String, Any}("C" => solver.P \ θP)
end


@doc raw"""
`struct ASP` : Active Set Pursuit solver
solves the optimization problem using either the basic or smart homotopy approach.
```math
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
```
subject to
```math
\|A^T y\|_{\infty} \leq 1.
```
* Input
* `A` : `m`-by-`n` matrix.
* `b` : `m`-vector.
* `Aval` : `p`-by-`n` validation matrix (only for smart mode).
* `bval` : `p`- validation vector (only for smart mode).
* Solver parameters
* `min_lambda` : Minimum value for `λ`. Defaults to zero if not provided.
* `loglevel` : Logging level.
* `itnMax` : Maximum number of iterations.
* `actMax` : Maximum number of active constraints.
* `mode` : Either `:basic` or `:smart`.
Constructor
```julia
ACEfit.ASP(; P = I, select, mode, params)
```
where
- `P` : right-preconditioner / tychonov operator
- `select`: Selection mode for the final solution.
- `params`: The solver parameters, passed as named arguments.
"""
struct ASP
P::Any
select::Tuple
mode::Symbol
params::NamedTuple
end

function ASP(; P = I, select, mode=:basic, params...)
params_tuple = NamedTuple(params)
return ASP(P, select, mode, params_tuple)
end

function solve(solver::ASP, A, y, Aval=nothing, yval=nothing)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params[1]...)
q = length(tracer)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)

for i in 1:q
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
end

if solver.mode == :basic
xs, in = select_solution(new_tracer, solver, A, y)
elseif solver.mode == :smart
xs, in = select_smart(new_tracer, solver, Aval, yval)
else
@error("Unknown mode: $solver.mode")
end

println("done.")
return Dict("C" => xs, "path" => new_tracer, "nnzs" => length((new_tracer[in][:solution]).nzind) )
end

function select_solution(tracer, solver, A, y)
criterion, threshold = solver.select

if criterion == :final
return tracer[end][:solution], length(tracer)

elseif criterion == :byerror
errors = [norm(A * t[:solution] - y) for t in tracer]
min_error = minimum(errors)

for (i, error) in enumerate(errors)
if error <= threshold * min_error
return tracer[i][:solution], i
end
end
elseif criterion == :bysize
for i in 1:length(tracer)
if length((tracer[i][:solution]).nzind) == threshold
return tracer[i][:solution], i
end
end
else
@error("Unknown selection criterion: $criterion")
end
end


function select_smart(tracer, solver, Aval,yval)

best_metric = Inf
best_iteration = 0
validation_metric = 0
q = length(tracer)
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
nnzss = [(t[:solution]).nzind for t in tracer]
best_iteration = argmin(errors)
validation_metric = errors[best_iteration]
validation_end = norm(Aval * tracer[end][:solution] - yval)

if validation_end < validation_metric #make sure to check the last one too in case q<<100
best_iteration = q
end

criterion, threshold = solver.select

if criterion == :val
return tracer[best_iteration][:solution], best_iteration

elseif criterion == :byerror
for (i, error) in enumerate(errors)
if error <= threshold * validation_metric
return tracer[i][:solution], i
end
end

elseif criterion == :bysize
first_index = findfirst(sublist -> threshold in sublist, nnzss)
relevant_errors = errors[1:first_index - 1]
min_error = minimum(relevant_errors)
min_error_index = findfirst(==(min_error), relevant_errors)
return tracer[min_error_index][:solution], min_error_index

else
@error("Unknown selection criterion: $criterion")
end
end
Loading

0 comments on commit eb55759

Please sign in to comment.