diff --git a/src/mlj_helpers.jl b/src/mlj_helpers.jl index c6b05bf..7311011 100644 --- a/src/mlj_helpers.jl +++ b/src/mlj_helpers.jl @@ -84,7 +84,7 @@ from_categorical(categorical) = MLJ.unwrap.(categorical) from_categorical(categorical::MLJ.Node) = MLJ.node(from_categorical, categorical) # transform a fitresult (containing only the model) back to a Fit containing the model and training scores -to_fitresult(mach::MLJ.Machine{<:OD.Detector})::Fit = (mach.fitresult, mach.report.scores) +to_fitresult(mach::MLJ.Machine{<:OD.Detector})::Fit = (mach.fitresult, MLJ.report(mach).scores) # this includes all composites defined in mlj_wrappers.jl const DetectorComposites = Union{ @@ -142,16 +142,27 @@ function augmented_transform(mach::MLJ.Machine{<:OD.Detector}; rows=:) return _augmented_transform(mach.model, to_fitresult(mach), selectrows(mach.model, rows, mach.data[1])...) end +function get_scores_from_composite_report(mach) + # new #banana API + # fit_report = MLJ.report_given_method(mach)[:fit] + fit_report = mach.report + if haskey(fit_report, :additions) && haskey(fit_report.additions, :scores) + return fit_report.additions.scores + else + return fit_report.scores + end +end + function augmented_transform(mach::DetectorComposites; rows=:) check_mach(mach) - scores_train = mach.report.scores + scores_train = get_scores_from_composite_report(mach) scores_test = mach.fitresult.transform(selectrows(mach.model, rows, mach.data[1])...) return scores_train, scores_test end function augmented_transform(mach::DetectorSurrogates; rows=:) check_mach(mach) - scores_train = mach.report.scores + scores_train = get_scores_from_composite_report(mach) scores_test = mach.fitresult.transform(rows=rows) return scores_train, scores_test end @@ -164,7 +175,7 @@ end function augmented_transform(mach::DetectorComposites, X) check_mach(mach) - scores_train = mach.report.scores + scores_train = get_scores_from_composite_report(mach) scores_test = mach.fitresult.transform(X) return scores_train, scores_test end