diff --git a/Project.toml b/Project.toml index 06f320be..a5829d3c 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" diff --git a/ext/LoggersExt/LoggersExt.jl b/ext/LoggersExt/LoggersExt.jl index e571e824..7d722512 100644 --- a/ext/LoggersExt/LoggersExt.jl +++ b/ext/LoggersExt/LoggersExt.jl @@ -1,7 +1,8 @@ module LoggersExt using MLJBase: info, name, Model, - params, Machine, Measure + params, Machine, Measure, + flat_params import MLJBase: save, evaluate!, MLFlowLogger diff --git a/ext/LoggersExt/mlflow.jl b/ext/LoggersExt/mlflow.jl index 1de46b63..d066de18 100644 --- a/ext/LoggersExt/mlflow.jl +++ b/ext/LoggersExt/mlflow.jl @@ -12,7 +12,7 @@ MLFlowLogger(base_uri::String, experiment_name::String, MLFlowInstance(MLFlow(base_uri), experiment_name, artifact_location) function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model) - model_params = params(model) |> pairs + model_params = params(model) |> flat_params |> collect for (name, value) in model_params logparam(client, run, name, value) end @@ -47,14 +47,16 @@ function evaluate!(mach::Machine, resampling, weights, end function save(logger::MLFlowInstance, mach::Machine) + model_name = name(mach.model) + fname = "$(model_name).jls" + save(fname, mach) + experiment = getorcreateexperiment(logger.client, logger.experiment_name, artifact_location=logger.artifact_location) - run = createrun(logger.client, experiment) + run = createrun(logger.client, experiment; + run_name="$(model_name) run") _logmodelparams(logger.client, run, mach.model) - - fname = "$(name(mach.model)).jls" - save(fname, mach) logartifact(logger.client, run, fname) rm(fname) end diff --git a/src/utilities.jl b/src/utilities.jl index 66dd62b7..f19eceec 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -33,6 +33,38 @@ function flat_values(params::NamedTuple) return Tuple(values) end +""" + flat_params(t::NamedTuple) + +View a nested named tuple `t` as a tree and return, as a Dict, the key subtrees +and the values at the leaves, in the order they appear in the original tuple. + +```julia-repl +julia> t = (X = (x = 1, y = 2), Y = 3) +julia> flat_params(t) +LittleDict{...} with 3 entries: +"X_x" => 1 +"X_y" => 2 +"Y" => 3 +``` +""" +function flat_params(params::NamedTuple) + result = LittleDict{String, Any}() + for key in keys(params) + value = getproperty(params, key) + if value isa NamedTuple + sub_dict = flat_params(value) + for (sub_key, sub_value) in pairs(sub_dict) + new_key = string(key, "_", sub_key) + result[new_key] = sub_value + end + else + result[string(key)] = value + end + end + return result +end + ## RECURSIVE VERSIONS OF getproperty and setproperty! # applying the following to `:(a.b.c)` returns `(:(a.b), :c)`