Skip to content

Commit

Permalink
uploading tests and including mlflow on CI
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Jul 17, 2023
1 parent 4c2f3a0 commit 18e60c7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ env:
TEST_MLJBASE: "true"
jobs:
test:
services:
mlflow:
image: adacotechjp/mlflow:2.3.1
ports:
- 5000:5000
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables", "MLFlowClient"]
5 changes: 3 additions & 2 deletions ext/LoggersExt/mlflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ function _logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
end
end

function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{Measure},
measurements::Vector{Float64})
function _logmachinemeasures(client::MLFlow, run::MLFlowRun, measures::Vector{T},
measurements::Vector{Float64}) where T<:Measure
measure_names = measures .|> info .|> x -> x.name
for (name, value) in zip(measure_names, measurements)
logmetric(client, run, name, value)
Expand Down Expand Up @@ -60,4 +60,5 @@ function save(logger::MLFlowLogger, mach::Machine)
_logmodelparams(logger.client, run, mach.model)
fname = "$(model_name).jls"
logartifact(logger.client, run, fname, io)
updaterun(logger.client, run, "FINISHED")
end
9 changes: 5 additions & 4 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,15 @@ export coerce, coerce!, autotype, schema, info
export UnivariateFiniteArray, UnivariateFiniteVector

# -----------------------------------------------------------------------
# abstract model types defined in MLJModelInterface.jl and extended here:
# re-export from MLJModelInterface.jl

#abstract model types defined in MLJModelInterface.jl and extended here:
for T in EXTENDED_ABSTRACT_MODEL_TYPES
@eval(export $T)
end

export params

# -------------------------------------------------------------------
# exports from this module, MLJBase

Expand All @@ -308,9 +312,6 @@ export default_resource
# one_dimensional_ranges.jl:
export ParamRange, NumericRange, NominalRange, iterator, scale

# parameter_inspection.jl:
export params # note this is *not* an extension of StatsBase.params

# data.jl:
export partition, unpack, complement, restrict, corestrict

Expand Down
60 changes: 60 additions & 0 deletions test/extensions/loggers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module TestLoggers

using Test
using MLJBase
using ..Models

@testset "mlflow logger" begin
artifact_directory = "mlj-test"
experiment_name = "mlflow logger tests"

@testset "outside extension tests" begin
@test_throws ErrorException mlflow_logger()

using MLFlowClient
logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)

@test logger.client isa MLFlow
@test logger.experiment_name == experiment_name
@test logger.artifact_location == artifact_directory
end # @testset "outside extension tests"

@testset "extension tests" begin
X = (x=rand(4),)
y = ["Chenta", "Missy", "Gala", "Wendy"] |> categorical

mach = machine(ConstantClassifier(), X, y)
fit!(mach, verbosity=0)

logger = mlflow_logger(; experiment_name=experiment_name, artifact_location=artifact_directory)

@testset "save" begin
run = MLJBase.save(logger, mach)
experiment = getexperiment(logger.client, run.info.experiment_id)
@test run isa MLFlowRun
@test experiment isa MLFlowExperiment

deleterun(logger.client, run)
deleteexperiment(logger.client, experiment)
end # @testset "save"

@testset "evaluate!" begin
evaluate!(mach, resampling=Holdout(), logger=logger)

experiments = searchexperiments(logger.client)
experiments_ids = experiments .|> (e -> e.experiment_id)
runs = searchruns(logger.client, experiments_ids)

# it's 2 because of the default experiment
@test length(experiments_ids) == 2
@test length(runs) == 1

deleterun(logger.client, runs[1])
deleteexperiment(logger.client, experiments[2])
end # @testset "evaluate!"
end # @testset "extension tests"
end # @testset "mlflow logger"

end # module

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ end
@test include("hyperparam/one_dimensional_ranges.jl")
@test include("hyperparam/one_dimensional_range_methods.jl")
end

@conditional_testset "extensions" begin
@test include("extensions/loggers.jl")
end

0 comments on commit 18e60c7

Please sign in to comment.