Skip to content

Commit

Permalink
Merge pull request #737 from JuliaAI/fix-issue-377-test
Browse files Browse the repository at this point in the history
Fix the test that addresses #377
  • Loading branch information
ablaom authored Feb 3, 2022
2 parents b8c8b43 + 7606c7b commit 6914722
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
28 changes: 16 additions & 12 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,20 @@ MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =

# # CONSTRUCTING THE RETURN VALUE FOR A COMPOSITE FIT METHOD

logerr_identical_models(name, model) =
"The hyperparameters $name of "*
"$model have identical model "*
"instances as values. "
const ERR_IDENTICAL_MODELS = ArgumentError(
"Two distinct hyper-parameters of a "*
"composite model that are both "*
"associated with models in the underlying learning "*
"network (eg, any two components of a `@pipeline` model) "*
"cannot have identical values, although they can be `==` "*
"(corresponding nested properties are `==`). "*
"Consider constructing instances "*
"separately or use `deepcopy`. ")

# Identify which properties of `model` have, as values, a model in the
# learning network wrapped by `mach`, and check that no two such
# properties have have identical values (#377). Return the property name
Expand Down Expand Up @@ -299,20 +313,10 @@ function network_model_names(model::M,
if !no_duplicates
for (id, name) in name_given_id
if length(name) > 1
@error "The hyperparameters $name of "*
"$model have identical model "*
"instances as values. "
@error logerr_identical_models(name, model)
end
end
throw(ArgumentError(
"Two distinct hyper-parameters of a "*
"composite model that are both "*
"associated with models in the underlying learning "*
"network (eg, any two components of a `@pipeline` model) "*
"cannot have identical values, although they can be `==` "*
"(corresponding nested properties are `==`). "*
"Consider constructing instances "*
"separately or use `deepcopy`. "))
throw(ERR_IDENTICAL_MODELS)
end

return map(network_model_ids) do id
Expand Down
48 changes: 24 additions & 24 deletions test/composition/models/from_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ end

mach = fit!(machine(NoTraining()), verbosity=0)
@test transform(mach, X) == 2*X.age


## TESTINGS A STACK AND IN PARTICULAR FITTED_PARAMS

Expand Down Expand Up @@ -557,37 +557,37 @@ fp = fitted_params(mach)

## ISSUE #377

target_stand = Standardizer()
stand = Standardizer()
rgs = KNNRegressor()
stand1 = Standardizer()
stand2 = Standardizer()

X = source()
y = source()
Xraw = (x=[-2.0, 0.0, 2.0],)
X = source(Xraw)

mach1 = machine(target_stand, y)
z = transform(mach1, y)
mach2 = machine(stand, X)
W = transform(mach2, X)
mach3 = machine(rgs, W, z)
zhat = predict(mach3, W)
yhat = inverse_transform(mach1, zhat)
mach1 = machine(stand1, X)
X2 = transform(mach1, X)

@from_network machine(Deterministic(), X, y; predict=yhat) begin
mutable struct CompositeA
rgs=rgs
stand=stand
target=target_stand
mach2 = machine(stand2, X2)
X3 = transform(mach2, X2)

@from_network machine(Unsupervised(), X; transform=X3) begin
mutable struct CompositeZ
s1=stand1
s2=stand2
end
end

X, y = make_regression(20, 2);
model = CompositeA(stand=stand, target=stand)
mach = machine(model, X, y)
@test_logs((:error, r"The hyper"),
# check no problems with network:
fit!(X3)
@test X3().x [-1.0, 0.0, 1.0]

# instantiate with identical (===) models in two places:
model = CompositeZ(s1=stand1, s2=stand1)
mach = machine(model, Xraw)
@test_logs((:error, MLJBase.logerr_identical_models([:s1, :s2], model)),
(:error, r"Problem"),
(:info, r"Running"),
(:info, r"Type checks okay"),
@test_throws(ArgumentError,
@test_throws(MLJBase.ERR_IDENTICAL_MODELS,
fit!(mach, verbosity=-1)))


Expand All @@ -611,7 +611,7 @@ X = (x = Float64[1, 2, 3],)
mach = machine(AppleComposite(), X)
fit!(mach, verbosity=0, force=true)
@test transform(mach, X).x Float64[-1, 0, 1]
@test inverse_transform(mach, X) == X
@test inverse_transform(mach, X) == X

end

Expand Down

0 comments on commit 6914722

Please sign in to comment.