Skip to content

Commit

Permalink
fix LIBSVM model type mapping
Browse files Browse the repository at this point in the history
didn't quite understand why it didn't work before but works now
  • Loading branch information
ValdarT committed May 27, 2019
1 parent b65fbfd commit cceb87f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/LIBSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,32 @@ function EpsilonSVR(
end


const SVM = Union{LinearSVC, SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM} # all SVM models defined here


"""
map_model_type(model::SVM)
Helper function to map the model to the correct LIBSVM model type needed for function dispatch.
"""
function map_model_type(model::SVM)
if isa(model, LinearSVC)
return LIBSVM.LinearSVC
elseif isa(model, SVC)
return LIBSVM.SVC
elseif isa(model, NuSVC)
return LIBSVM.NuSVC
elseif isa(model, NuSVR)
return LIBSVM.NuSVR
elseif isa(model, EpsilonSVR)
return LIBSVM.EpsilonSVR
elseif isa(model, OneClassSVM)
return LIBSVM.OneClassSVM
else
error("Got unsupported model type: $(typeof(model))")
end
end

"""
get_svm_parameters(model::Union{SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM})
Expand All @@ -287,7 +313,7 @@ Helper function to get the parameters from the SVM model struct.
function get_svm_parameters(model::Union{SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM})
#Build arguments for calling svmtrain
params = Tuple{Symbol, Any}[]
push!(params, (:svmtype, eval(Meta.parse("LIBSVM.$(typeof(model))")))) # LIBSVM model type
push!(params, (:svmtype, map_model_type(model))) # get LIBSVM model type
for fn in fieldnames(typeof(model))
push!(params, (fn, getfield(model, fn)))
end
Expand Down Expand Up @@ -402,7 +428,6 @@ MLJBase.load_path(::Type{<:NuSVR}) = "MLJModels.LIBSVM_.NuSVR"
MLJBase.load_path(::Type{<:EpsilonSVR}) = "MLJModels.LIBSVM_.EpsilonSVR"
MLJBase.load_path(::Type{<:OneClassSVM}) = "MLJModels.LIBSVM_.OneClassSVM"

const SVM = Union{LinearSVC, SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM} # all SVM models defined here
MLJBase.package_name(::Type{<:SVM}) = "LIBSVM"
MLJBase.package_uuid(::Type{<:SVM}) = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MLJBase.is_pure_julia(::Type{<:SVM}) = false
Expand Down
1 change: 1 addition & 0 deletions test/LIBSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using MLJBase
using Test
using LinearAlgebra

import MLJModels
import LIBSVM
using MLJModels.LIBSVM_
using CategoricalArrays
Expand Down

0 comments on commit cceb87f

Please sign in to comment.