Skip to content

Commit

Permalink
Adding a flat_params test, preparing resampling to receive loggers an…
Browse files Browse the repository at this point in the history
…d making _logmachinemeasures more generic
  • Loading branch information
pebeto committed Jul 23, 2023
1 parent 73d9bfb commit 314987f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
3 changes: 1 addition & 2 deletions ext/LoggersExt/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
end
end

function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{T},
measurements::Vector{Float64}) where T<:Measure
function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures, measurements)
measure_names = measures .|> info .|> x -> x.name
for (name, value) in zip(measure_names, measurements)
logmetric(client, run, name, value)
Expand Down
12 changes: 8 additions & 4 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,8 @@ end
operation=predict,
repeats = 1,
acceleration=default_resource(),
check_measure=true
check_measure=true,
logger=nothing
)
Resampling model wrapper, used internally by the `fit` method of
Expand Down Expand Up @@ -1374,7 +1375,7 @@ are not to be confused with any weights bound to a `Resampler` instance
in a machine, used for training the wrapped `model` when supported.
"""
mutable struct Resampler{S} <: Model
mutable struct Resampler{S, L} <: Model
model
resampling::S # resampling strategy
measure
Expand All @@ -1385,6 +1386,7 @@ mutable struct Resampler{S} <: Model
check_measure::Bool
repeats::Int
cache::Bool
logger::L
end

# Some traits are markded as `missing` because we cannot determine
Expand Down Expand Up @@ -1423,7 +1425,8 @@ function Resampler(;
acceleration=default_resource(),
check_measure=true,
repeats=1,
cache=true
cache=true,
logger=nothing
)
resampler = Resampler(
model,
Expand All @@ -1435,7 +1438,8 @@ function Resampler(;
acceleration,
check_measure,
repeats,
cache
cache,
logger
)
message = MLJModelInterface.clean!(resampler)
isempty(message) || @warn message
Expand Down
5 changes: 4 additions & 1 deletion test/extensions/loggers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ using ..Models
experiment_name = "mlflow logger tests"

@testset "outside extension tests" begin
@test_throws ErrorException mlflow_logger("http://localhost:5000")
# This test is only relevant for Julia 1.9 and above
if VERSION >= v"1.9"
@test_throws ErrorException mlflow_logger("http://localhost:5000")
end

using MLFlowClient
logger = mlflow_logger("http://localhost:5000";
Expand Down
11 changes: 11 additions & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,16 @@ end
"sin, cos, tan, ..."
end

@testset "flattening parameters" begin
t = (a = (ax = (ax1 = 1, ax2 = 2), ay = 3), b = 4)
dict_t = Dict(
"a__ax__ax1" => 1,
"a__ax__ax2" => 2,
"a__ay" => 3,
"b" => 4,
)
@test MLJBase.flat_params(t) == dict_t
end

end # module
true

0 comments on commit 314987f

Please sign in to comment.