Skip to content

Commit

Permalink
update to new MLJ report APi
Browse files Browse the repository at this point in the history
  • Loading branch information
davnn committed Oct 18, 2022
1 parent 87c4aae commit 0398eba
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/mlj_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ from_categorical(categorical) = MLJ.unwrap.(categorical)
from_categorical(categorical::MLJ.Node) = MLJ.node(from_categorical, categorical)

# transform a fitresult (containing only the model) back to a Fit containing the model and training scores
to_fitresult(mach::MLJ.Machine{<:OD.Detector})::Fit = (mach.fitresult, mach.report.scores)
to_fitresult(mach::MLJ.Machine{<:OD.Detector})::Fit = (mach.fitresult, MLJ.report(mach).scores)

# this includes all composites defined in mlj_wrappers.jl
const DetectorComposites = Union{
Expand Down Expand Up @@ -142,16 +142,27 @@ function augmented_transform(mach::MLJ.Machine{<:OD.Detector}; rows=:)
return _augmented_transform(mach.model, to_fitresult(mach), selectrows(mach.model, rows, mach.data[1])...)
end

function get_scores_from_composite_report(mach)
# new #banana API
# fit_report = MLJ.report_given_method(mach)[:fit]
fit_report = mach.report
if haskey(fit_report, :additions) && haskey(fit_report.additions, :scores)
return fit_report.additions.scores
else
return fit_report.scores
end
end

function augmented_transform(mach::DetectorComposites; rows=:)
check_mach(mach)
scores_train = mach.report.scores
scores_train = get_scores_from_composite_report(mach)
scores_test = mach.fitresult.transform(selectrows(mach.model, rows, mach.data[1])...)
return scores_train, scores_test
end

function augmented_transform(mach::DetectorSurrogates; rows=:)
check_mach(mach)
scores_train = mach.report.scores
scores_train = get_scores_from_composite_report(mach)
scores_test = mach.fitresult.transform(rows=rows)
return scores_train, scores_test
end
Expand All @@ -164,7 +175,7 @@ end

function augmented_transform(mach::DetectorComposites, X)
check_mach(mach)
scores_train = mach.report.scores
scores_train = get_scores_from_composite_report(mach)
scores_test = mach.fitresult.transform(X)
return scores_train, scores_test
end
Expand Down

0 comments on commit 0398eba

Please sign in to comment.