Skip to content

Commit

Permalink
adding a new way to flat params recursively to preserve names, and sa…
Browse files Browse the repository at this point in the history
…ving model names in runs
  • Loading branch information
pebeto committed Jul 2, 2023
1 parent f80cbf8 commit cba8599
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down
3 changes: 2 additions & 1 deletion ext/LoggersExt/LoggersExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module LoggersExt

using MLJBase: info, name, Model,
params, Machine, Measure
params, Machine, Measure,
flat_params

import MLJBase: save, evaluate!, MLFlowLogger

Expand Down
12 changes: 7 additions & 5 deletions ext/LoggersExt/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MLFlowLogger(base_uri::String, experiment_name::String,
MLFlowInstance(MLFlow(base_uri), experiment_name, artifact_location)

function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
model_params = params(model) |> pairs
model_params = params(model) |> flat_params |> collect
for (name, value) in model_params
logparam(client, run, name, value)
end
Expand Down Expand Up @@ -47,14 +47,16 @@ function evaluate!(mach::Machine, resampling, weights,
end

function save(logger::MLFlowInstance, mach::Machine)
model_name = name(mach.model)
fname = "$(model_name).jls"
save(fname, mach)

experiment = getorcreateexperiment(logger.client, logger.experiment_name,
artifact_location=logger.artifact_location)
run = createrun(logger.client, experiment)
run = createrun(logger.client, experiment;
run_name="$(model_name) run")

_logmodelparams(logger.client, run, mach.model)

fname = "$(name(mach.model)).jls"
save(fname, mach)
logartifact(logger.client, run, fname)
rm(fname)
end
32 changes: 32 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@ function flat_values(params::NamedTuple)
return Tuple(values)
end

"""
flat_params(t::NamedTuple)
View a nested named tuple `t` as a tree and return, as a Dict, the key subtrees
and the values at the leaves, in the order they appear in the original tuple.
```julia-repl
julia> t = (X = (x = 1, y = 2), Y = 3)
julia> flat_params(t)
LittleDict{...} with 3 entries:
"X_x" => 1
"X_y" => 2
"Y" => 3
```
"""
function flat_params(params::NamedTuple)
result = LittleDict{String, Any}()
for key in keys(params)
value = getproperty(params, key)
if value isa NamedTuple
sub_dict = flat_params(value)
for (sub_key, sub_value) in pairs(sub_dict)
new_key = string(key, "_", sub_key)
result[new_key] = sub_value
end
else
result[string(key)] = value
end
end
return result
end

## RECURSIVE VERSIONS OF getproperty and setproperty!

# applying the following to `:(a.b.c)` returns `(:(a.b), :c)`
Expand Down

0 comments on commit cba8599

Please sign in to comment.