Skip to content

Commit

Permalink
Merge pull request #82 from ACEsuit/asp
Browse files Browse the repository at this point in the history
Smart ASP
  • Loading branch information
cortner authored Sep 8, 2024
2 parents daff721 + eb55759 commit e0fea2e
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 123 deletions.
1 change: 1 addition & 0 deletions src/ACEfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("bayesianlinear.jl")
include("data.jl")
include("assemble.jl")
include("solvers.jl")
include("asp.jl")
include("fit.jl")

end
157 changes: 157 additions & 0 deletions src/asp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

@doc raw"""
`ASP` : Active Set Pursuit solver
Solves the lasso optimization problem.
```math
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
```
subject to
```math
\|A^T y\|_{\infty} \leq 1.
```
### Constructor Keyword arguments
```julia
ACEfit.ASP(; P = I, select = (:byerror, 1.0),
params...)
```
* `select` : Selection criterion for the final solution (required)
* `:final` : final solution (largest computed basis)
* `(:byerror, q)` : solution with error within `q` times the minimum error
along the path; if training error is used and `q == 1.0`, then this is
equivalent to to `:final`.
* `(:bysize, n)` : best solution with at most `n` non-zero features; if
training error is used, then it will be the solution with exactly `n`
non-zero features.
* `P = I` : prior / regularizer (optional)
The remaining kwarguments to `ASP` are parameters for the ASP homotopy solver.
* `actMax` : Maximum number of active constraints.
* `min_lambda` : Minimum value for `λ`. (defaults to 0)
* `loglevel` : Logging level.
* `itnMax` : Maximum number of iterations.
### Extended syntax for `solve`
```julia
solve(solver::ASP, A, y, Aval=A, yval=y)
```
* `A` : `m`-by-`n` design matrix. (required)
* `b` : `m`-vector. (required)
* `Aval = nothing` : `p`-by-`n` validation matrix (only for `validate` mode).
* `bval = nothing` : `p`- validation vector (only for `validate` mode).
If independent `Aval` and `yval` are provided (instead of detaults `A, y`),
then the solver will use this separate validation set instead of the training
set to select the best solution along the model path.
"""
struct ASP
P
select
params
end

function ASP(; P = I, select, mode=:train, params...)
return ASP(P, select, params)
end

function solve(solver::ASP, A, y, Aval=A, yval=y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params...)
q = length(tracer)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)

for i in 1:q
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
end

xs, in = select_solution(new_tracer, solver, Aval, yval)

# println("done.")
return Dict( "C" => xs,
"path" => new_tracer,
"nnzs" => length((new_tracer[in][:solution]).nzind) )
end


function select_solution(tracer, solver, A, y)
if solver.select == :final
criterion = :final
else
criterion, p = solver.select
end

if criterion == :final
return tracer[end][:solution], length(tracer)
end

if criterion == :byerror
maxind = length(tracer)
threshold = p
elseif criterion == :bysize
maxind = findfirst(t -> length((t[:solution]).nzind) > p,
tracer) - 1
threshold = 1.0
else
error("Unknown selection criterion: $criterion")
end

errors = [ norm(A * t[:solution] - y) for t in tracer[1:maxind] ]
min_error = minimum(errors)
for (i, error) in enumerate(errors)
if error <= threshold * min_error
return tracer[i][:solution], i
end
end

error("selection failed for unknown reasons; please file an issue with a MWE to reproduce this error.")
end


#=
function select_smart(tracer, solver, Aval, yval)
best_metric = Inf
best_iteration = 0
validation_metric = 0
q = length(tracer)
errors = [norm(Aval * t[:solution] - yval) for t in tracer]
nnzss = [(t[:solution]).nzind for t in tracer]
best_iteration = argmin(errors)
validation_metric = errors[best_iteration]
validation_end = norm(Aval * tracer[end][:solution] - yval)
if validation_end < validation_metric #make sure to check the last one too in case q<<100
best_iteration = q
end
criterion, threshold = solver.select
if criterion == :val
return tracer[best_iteration][:solution], best_iteration
elseif criterion == :byerror
for (i, error) in enumerate(errors)
if error <= threshold * validation_metric
return tracer[i][:solution], i
end
end
elseif criterion == :bysize
first_index = findfirst(sublist -> threshold in sublist, nnzss)
relevant_errors = errors[1:first_index - 1]
min_error = minimum(relevant_errors)
min_error_index = findfirst(==(min_error), relevant_errors)
return tracer[min_error_index][:solution], min_error_index
else
@error("Unknown selection criterion: $criterion")
end
end
=#
92 changes: 0 additions & 92 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,95 +196,3 @@ function solve(solver::TruncatedSVD, A, y)
return Dict{String, Any}("C" => solver.P \ θP)
end


@doc raw"""
`struct ASP` : Active Set Pursuit sparse solver
solves the following optimization problem using the homotopy approach:
```math
\max_{y} \left( b^T y - \frac{1}{2} λ y^T y \right)
```
subject to
```math
\|A^T y\|_{\infty} \leq 1.
```
* Input
* `A` : `m`-by-`n` explicit matrix or linear operator.
* `b` : `m`-vector.
* Solver parameters
* `min_lambda` : Minimum value for `λ`. Defaults to zero if not provided.
* `loglevel` : Logging level.
* `itnMax` : Maximum number of iterations.
* `actMax` : Maximum number of active constraints.
Constructor
```julia
ACEfit.ASP(; P = I, select, params)
```
where
- `P` : right-preconditioner / tychonov operator
- `select`: Selection mode for the final solution.
- `(:byerror, th)`: Selects the smallest active set fit within a factor `th` of the smallest fit error.
- `(:final, nothing)`: Returns the final iterate.
- `params`: The solver parameters, passed as named arguments.
"""
struct ASP
P::Any
select::Tuple
params::NamedTuple
end

function ASP(; P = I, select, params...)
params_tuple = NamedTuple(params)
return ASP(P, select, params_tuple)
end

function solve(solver::ASP, A, y)
# Apply preconditioning
AP = A / solver.P

tracer = asp_homotopy(AP, y; solver.params[1]...)

new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, length(tracer))

for i in 1:length(tracer)
new_tracer[i] = (solution = solver.P \ tracer[i][1], λ = tracer[i][2])
end

# Select the final solution based on the criterion
xs, in = select_solution(new_tracer, solver, A, y)

println("done.")
return Dict("C" => xs, "path" => new_tracer, "nnzs" => length((tracer[in][1]).nzind) )
end

function select_solution(tracer, solver, A, y)
criterion, threshold = solver.select

if criterion == :final
return tracer[end][1], length(tracer)

elseif criterion == :byerror
errors = [norm(A * t[1] - y) for t in tracer]
min_error = minimum(errors)

# Find the solution with the smallest error within the threshold
for (i, error) in enumerate(errors)
if error <= threshold * min_error
return tracer[i][1], i
end
end
elseif criterion == :bysize
for i in 1:length(tracer)
if length((tracer[i][1]).nzind) == threshold
return tracer[i][1], i
end
end
else
@error("Unknown selection criterion: $criterion")
end
end

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Loading

0 comments on commit e0fea2e

Please sign in to comment.