Skip to content

Commit

Permalink
Refactoring, IO implementation for saving, tests and mlflow running on
Browse files Browse the repository at this point in the history
CI
  • Loading branch information
pebeto committed Jul 17, 2023
1 parent cba8599 commit 7977fd2
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 38 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ env:
TEST_MLJBASE: "true"
jobs:
test:
services:
mlflow:
image: adacotechjp/mlflow:2.3.1
ports:
- 5000:5000
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables", "MLFlowClient"]
7 changes: 4 additions & 3 deletions ext/LoggersExt/LoggersExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module LoggersExt

using MLJBase: info, name, Model,
params, Machine, Measure,
flat_params
Machine, Measure, flat_params

import MLJBase: save, evaluate!, MLFlowLogger
import MLJBase: save, evaluate!, mlflow_logger

include("utils.jl")

include("mlflow.jl")

Expand Down
24 changes: 13 additions & 11 deletions ext/LoggersExt/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ using MLFlowClient: MLFlow, logparam, logmetric,
createrun, MLFlowRun, updaterun,
logartifact, getorcreateexperiment

struct MLFlowInstance
struct MLFlowLogger
client::MLFlow
experiment_name::String
artifact_location::Union{String, Missing}
end
MLFlowLogger(base_uri::String, experiment_name::String,
mlflow_logger(base_uri::String, experiment_name::String,
artifact_location::Union{String, Missing}) =
MLFlowInstance(MLFlow(base_uri), experiment_name, artifact_location)
MLFlowLogger(MLFlow(base_uri), experiment_name, artifact_location)

function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
model_params = params(model) |> flat_params |> collect
Expand All @@ -18,8 +18,8 @@ function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
end
end

function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{Measure},
measurements::Vector{Float64})
function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{T},
measurements::Vector{Float64}) where T<:Measure
measure_names = measures .|> info .|> x -> x.name
for (name, value) in zip(measure_names, measurements)
logmetric(client, run, name, value)
Expand All @@ -29,7 +29,7 @@ end
function evaluate!(mach::Machine, resampling, weights,
class_weights, rows, verbosity,
repeats, measures, operations,
acceleration, force, logger::MLFlowInstance)
acceleration, force, logger::MLFlowLogger)
performance_evaluation = evaluate!(mach, resampling, weights,
class_weights, rows, verbosity,
repeats, measures, operations,
Expand All @@ -46,17 +46,19 @@ function evaluate!(mach::Machine, resampling, weights,
return performance_evaluation
end

function save(logger::MLFlowInstance, mach::Machine)
function save(logger::MLFlowLogger, mach::Machine)
io = IOBuffer()
save(io, mach)

model_name = name(mach.model)
fname = "$(model_name).jls"
save(fname, mach)

experiment = getorcreateexperiment(logger.client, logger.experiment_name,
artifact_location=logger.artifact_location)
run = createrun(logger.client, experiment;
run_name="$(model_name) run")

_logmodelparams(logger.client, run, mach.model)
logartifact(logger.client, run, fname)
rm(fname)
fname = "$(model_name).jls"
logartifact(logger.client, run, fname, io)
updaterun(logger.client, run, "FINISHED")
end
14 changes: 7 additions & 7 deletions src/parameter_inspection.jl → ext/LoggersExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
istransparent(::Any) = false
istransparent(::MLJType) = true
isamodel(::Any) = false
isamodel(::Model) = true

"""
params(m::MLJType)
params(m::Model)
Recursively convert any transparent object `m` into a named tuple,
keyed on the property names of `m`. An object is *transparent* if
`MLJBase.istransparent(m) == true`. The named tuple is possibly nested
`isamodel(m) == true`. The named tuple is possibly nested
because `params` is recursively applied to the property values, which
themselves might be transparent.
For most `MLJType` objects, properties are synonymous with fields, but
For most `Model` objects, properties are synonymous with fields, but
this is not a hard requirement.
Most objects of type `MLJType` are transparent.
Most objects of type `Model` are transparent.
julia> params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
Expand All @@ -24,7 +24,7 @@ Most objects of type `MLJType` are transparent.
parallel = true,)
"""
params(m) = params(m, Val(istransparent(m)))
params(m) = params(m, Val(isamodel(m)))
params(m, ::Val{false}) = m
function params(m, ::Val{true})
fields = propertynames(m)
Expand Down
11 changes: 6 additions & 5 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,15 @@ export coerce, coerce!, autotype, schema, info
export UnivariateFiniteArray, UnivariateFiniteVector

# -----------------------------------------------------------------------
# abstract model types defined in MLJModelInterface.jl and extended here:
# re-export from MLJModelInterface.jl

#abstract model types defined in MLJModelInterface.jl and extended here:
for T in EXTENDED_ABSTRACT_MODEL_TYPES
@eval(export $T)
end

export params

# -------------------------------------------------------------------
# exports from this module, MLJBase

Expand All @@ -308,9 +312,6 @@ export default_resource
# one_dimensional_ranges.jl:
export ParamRange, NumericRange, NominalRange, iterator, scale

# parameter_inspection.jl:
export params # note this is *not* an extension of StatsBase.params

# data.jl:
export partition, unpack, complement, restrict, corestrict

Expand Down Expand Up @@ -381,7 +382,7 @@ export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
levels, levels!, std, Not, support, logpdf, LittleDict

# loggers.jl
export MLFlowLogger
export mlflow_logger

if !isdefined(Base, :get_extension)
include("../ext/LoggersExt/LoggersExt.jl")
Expand Down
15 changes: 8 additions & 7 deletions src/loggers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
MLFlowLogger(; base_uri="localhost:5000", experiment_name=missing)
mlflow_logger(; base_uri="localhost:5000", experiment_name=missing)
Base type for MLFlow logger. Creates an instance of MLFlow, as defined in
Constructor for the base type for MLFlow logger. Creates an instance of MLFlow,
as defined in
[`MLFlowClient.jl`](https://juliaai.github.io/MLFlowClient.jl/dev/), and logs
to an experiment.
Expand All @@ -13,14 +14,14 @@ If `experiment_name` is not provided, a new experiment with the name
"MLJ.jl experiments" will be created.
### Return value
A `MLFlowInstance` object, containing a
A `MLFlowLogger` object, containing a
[`MLFlow`](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
object and the experiment name
"""
MLFlowLogger(; base_uri="http://localhost:5000",
mlflow_logger(; base_uri="http://localhost:5000",
experiment_name="MLJ experiments",
artifact_location=missing) =
MLFlowLogger(base_uri, experiment_name, artifact_location)
MLFlowLogger(_, _, _) =
error("Please run `import MLFlowClient` to use MLFlowLogger.")
mlflow_logger(base_uri, experiment_name, artifact_location)
mlflow_logger(_, _, _) =
error("Please run `import MLFlowClient` to use mlflow_logger.")
8 changes: 4 additions & 4 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ LittleDict{...} with 3 entries:
"Y" => 3
```
"""
function flat_params(params::NamedTuple)
function flat_params(parameters::NamedTuple)
result = LittleDict{String, Any}()
for key in keys(params)
value = getproperty(params, key)
for key in keys(parameters)
value = params(getproperty(parameters, key))
if value isa NamedTuple
sub_dict = flat_params(value)
for (sub_key, sub_value) in pairs(sub_dict)
new_key = string(key, "_", sub_key)
new_key = string(key, "__", sub_key)
result[new_key] = sub_value
end
else
Expand Down
60 changes: 60 additions & 0 deletions test/extensions/loggers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module TestLoggers

using Test
using MLJBase
using ..Models

@testset "mlflow logger" begin
artifact_directory = "mlj-test"
experiment_name = "mlflow logger tests"

@testset "outside extension tests" begin
@test_throws ErrorException mlflow_logger()

using MLFlowClient
logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)

@test logger.client isa MLFlow
@test logger.experiment_name == experiment_name
@test logger.artifact_location == artifact_directory
end # @testset "outside extension tests"

@testset "extension tests" begin
X = (x=rand(4),)
y = ["Chenta", "Missy", "Gala", "Wendy"] |> categorical

mach = machine(ConstantClassifier(), X, y)
fit!(mach, verbosity=0)

logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)

@testset "save" begin
run = MLJBase.save(logger, mach)
experiment = getexperiment(logger.client, run.info.experiment_id)
@test run isa MLFlowRun
@test experiment isa MLFlowExperiment

deleterun(logger.client, run)
deleteexperiment(logger.client, experiment)
end # @testset "save"

@testset "evaluate!" begin
evaluate!(mach, resampling=Holdout(), logger=logger)

experiments = searchexperiments(logger.client)
experiments_ids = experiments .|> (e -> e.experiment_id)
runs = searchruns(logger.client, experiments_ids)

# it's 2 because of the default experiment
@test length(experiments_ids) == 2
@test length(runs) == 1

deleterun(logger.client, runs[1])
deleteexperiment(logger.client, experiments[2])
end # @testset "evaluate!"
end # @testset "extension tests"
end # @testset "mlflow logger"

end # module

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ end
@test include("hyperparam/one_dimensional_ranges.jl")
@test include("hyperparam/one_dimensional_range_methods.jl")
end

@conditional_testset "extensions" begin
@test include("extensions/loggers.jl")
end

0 comments on commit 7977fd2

Please sign in to comment.