Skip to content

Commit

Permalink
further tweaks to make replace() work
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Dec 20, 2021
1 parent db7c05e commit 6b0f2cf
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 58 deletions.
99 changes: 53 additions & 46 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

const DOC_SIGNATURES =
"""
A learning network *signature* is a named tuple that we define when
constructing a learning network machine, `mach`. Examples are
A learning network *signature* is an intermediate object defined when
a user constructs a learning network machine, `mach`. They are named
tuples whose values are the nodes consitituting interface points
between the network and the machine. Examples are
(predict=yhat, )
(transform=Xsmall,)
Expand All @@ -25,35 +27,31 @@ whose keys are arbitrary, and whose values are nodes of the
network. For each such key-value pair `k=n`, the value returned by
`n()` is included in the named tuple `report(mach)`, with
corresponding key `k`. So, in the third example above,
`report(mach).loss` then returns the value of `loss_node()`.
`report(mach).loss` will return the value of `loss_node()`.
"""

"""
call_report_nodes(signature)
Return `NamedTuple()` if `:report` is not a key of
`signature`. Othewise, expecting `signature.report` to be some named
tuple `(k1=n1, k2=n2, ..., kn=nn)`, return the named tuple `(k1=n1(),
k2=n2(), ..., kn=nn())`
For example, if
signature = (predict=yhat, report=(x=n, a=m))
function _operation_part(signature)
ops = filter(in(OPERATIONS), keys(signature))
return NamedTuple{ops}(map(op->getproperty(signature, op), ops))
end
function _report_part(signature)
:report in keys(signature) || return NamedTuple()
return signature.report
end

then return `(x=n(), a=m())`.
_operations(signature) = keys(_operation_part(signature))

$DOC_SIGNATURES
function _nodes(signature)
return (values(_operation_part(signature))...,
values(_report_part(signature))...)
end

"""
function call_report_nodes(signature)
function _call(nt::NamedTuple)
_call(n) = deepcopy(n())
_keys = keys(signature)
_values = values(signature)
:report in _keys || return NamedTuple()
nt = signature.report
new_keys, new_values = keys(nt), values(nt)
return NamedTuple{new_keys}(_call.(new_values))
_keys = keys(nt)
_values = values(nt)
return NamedTuple{_keys}(_call.(_values))
end

"""
Expand All @@ -76,7 +74,7 @@ will not error but it will not a give meaningful return value either.
"""
function model_supertype(signature)

operations = keys(signature)
operations = _operations(signature)

length(intersect(operations, (:predict_mean, :predict_median))) == 1 &&
return Deterministic
Expand All @@ -97,8 +95,6 @@ function model_supertype(signature)

end

operations(s::NamedTuple) = filter(in(OPERATIONS), keys(s))


# # FITRESULTS FOR COMPOSITE MODELS

Expand All @@ -107,12 +103,7 @@ mutable struct CompositeFitresult
glb
report_additions
function CompositeFitresult(signature)
ops = operations(signature)
vals = tuple([getproperty(signature, op) for op in ops]...)
if :report in keys(signature)
vals = tuple(vals..., values(signature.report)...)
end
glb = MLJBase.glb(vals...)
glb = MLJBase.glb(_nodes(signature)...)
new(signature, glb)
end
end
Expand All @@ -121,7 +112,7 @@ glb(c::CompositeFitresult) = getfield(c, :glb)
report_additions(c::CompositeFitresult) = getfield(c, :report_additions)

update!(c::CompositeFitresult) =
setfield!(c, :report_additions, call_report_nodes(signature(c)))
setfield!(c, :report_additions, _call(_report_part(signature(c))))

# To accommodate pre-existing design (operations.jl) arrange
# that `fitresult.predict` returns the predict node, etc:
Expand Down Expand Up @@ -156,23 +147,23 @@ const ERR_EXPECTED_NODE_IN_SIGNATURE = ArgumentError(
"Did not enounter `Node` in place one was expected. ")

function check_surrogate_machine(::Surrogate, signature, _sources)
isempty(operations(signature)) && throw(ERR_MUST_OPERATE)
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
isempty(_sources) && throw(ERR_MUST_SPECIFY_SOURCES)
return nothing
end

function check_surrogate_machine(::Union{Supervised,SupervisedAnnotator},
signature,
_sources)
isempty(operations(signature)) && throw(ERR_MUST_PREDICT)
isempty(_operations(signature)) && throw(ERR_MUST_PREDICT)
length(_sources) > 1 || throw(err_supervised_nargs())
return nothing
end

function check_surrogate_machine(::Union{Unsupervised},
signature,
_sources)
isempty(operations(signature)) && throw(ERR_MUST_TRANSFORM)
isempty(_operations(signature)) && throw(ERR_MUST_TRANSFORM)
length(_sources) < 2 || throw(err_unsupervised_nargs())
return nothing
end
Expand All @@ -183,7 +174,7 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)
signature = (; pair_itr...)

# signature checks:
isempty(operations(signature)) && throw(ERR_MUST_OPERATE)
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
for k in keys(signature)
if k in OPERATIONS
getproperty(signature, k) isa AbstractNode ||
Expand Down Expand Up @@ -441,9 +432,10 @@ function Base.replace(mach::Machine{<:Surrogate},
pairs::Pair...; empty_unspecified_sources=false)

signature = MLJBase.signature(mach.fitresult)
interface_nodes = values(signature)
operation_nodes = values(_operation_part(signature))
report_nodes = values(_report_part(signature))

W = glb(interface_nodes...)
W = glb(operation_nodes..., report_nodes...)

# Note: We construct nodes of the new network as values of a
# dictionary keyed on the nodes of the old network. Additionally,
Expand Down Expand Up @@ -492,7 +484,9 @@ function Base.replace(mach::Machine{<:Surrogate},
newnode_given_old =
IdDict{AbstractNode,AbstractNode}(all_source_pairs)
newsources = [newnode_given_old[s] for s in sources_]
newinterface_node_given_old =
newoperation_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newreport_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newmach_given_old = IdDict{Machine,Machine}()

Expand All @@ -512,14 +506,27 @@ function Base.replace(mach::Machine{<:Surrogate},
end
newnode_given_old[N] = N.operation(m, args...)
end
if N in interface_nodes
newinterface_node_given_old[N] = newnode_given_old[N]
if N in operation_nodes
newoperation_node_given_old[N] = newnode_given_old[N]
elseif N in report_nodes
newreport_node_given_old[N] = newnode_given_old[N]
end
end

newinterface_nodes = Tuple(newinterface_node_given_old[N] for N in
interface_nodes)
newsignature = NamedTuple{keys(signature)}(newinterface_nodes)
newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in
operation_nodes)
newreport_nodes = Tuple(newreport_node_given_old[N] for N in
report_nodes)
report_tuple =
NamedTuple{keys(_report_part(signature))}(newreport_nodes)
operation_tuple =
NamedTuple{keys(_operation_part(signature))}(newoperation_nodes)

newsignature = if isempty(report_tuple)
operation_tuple
else
merge(operation_tuple, (report=report_tuple,))
end

return machine(mach.model, newsources...; newsignature...)

Expand Down
2 changes: 1 addition & 1 deletion src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function update(model::M,
network_model_names = cache.network_model_names
old_model = cache.old_model

glb_node = glb(values(signature(fitresult))...) # greatest lower bound
glb_node = glb(fitresult) # greatest lower bound

if fallback(model, old_model, network_model_names, glb_node)
return fit(model, verbosity, args...)
Expand Down
27 changes: 20 additions & 7 deletions test/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@ MLJBase.predict(model::DummyClusterer, fitresult, Xnew) =
N = 20
X = (a = rand(N), b = categorical(rand("FM", N)))

@testset "call_report_nodes" begin
@test MLJBase.call_report_nodes((predict=source(:yhat), )) == NamedTuple()
s = (transform=source(:transform),
report=(a=source(:a), b=source(:b)),
predict=source(:yhat))
R = MLJBase.call_report_nodes(s)
@testset "signature helpers" begin
@test MLJBase._call(NamedTuple()) == NamedTuple()
a = source(:a)
b = source(:b)
W = source(:W)
yhat = source(:yhat)
s = (transform=W,
report=(a=a, b=b),
predict=yhat)
@test MLJBase._report_part(s) == (a=a, b=b)
@test MLJBase._operation_part(s) == (transform=W, predict=yhat)
@test MLJBase._nodes(s) == (W, yhat, a, b)
@test MLJBase._operations(s) == (:transform, :predict)
R = MLJBase._call(MLJBase._report_part(s))
@test R.a == :a
@test R.b == :b
end
Expand Down Expand Up @@ -144,6 +152,7 @@ oakM = machine(oak, W, u)
uhat = 0.5*(predict(knnM, W) + predict(oakM, W))
zhat = inverse_transform(standM, uhat)
yhat = exp(zhat)
enode = @node mae(ys, yhat)

@testset "replace method for learning network machines" begin

Expand All @@ -160,7 +169,9 @@ yhat = exp(zhat)
knn2 = deepcopy(knn)

# duplicate a learning network machine:
mach = machine(Deterministic(), Xs, ys; predict=yhat)
mach = machine(Deterministic(), Xs, ys;
predict=yhat,
report=(mae=enode,))
mach2 = replace(mach, hot=>hot2, knn=>knn2,
ys=>source(ys.data);
empty_unspecified_sources=true)
Expand All @@ -173,6 +184,8 @@ yhat = exp(zhat)
fit!(mach, verbosity=0)
fit!(mach2, verbosity=0)
@test predict(mach, X) predict(mach2, X)
@test MLJBase.report_additions(mach.fitresult).mae
MLJBase.report_additions(mach2.fitresult).mae

@test mach2.args[1]() == Xs()
@test mach2.args[2]() == ys()
Expand Down
5 changes: 4 additions & 1 deletion test/composition/models/from_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,10 @@ m = machine(clust, W)
yhat = predict(m, W)
Wout = transform(m, W)
foo = first(yhat)
mach = machine(Unsupervised(), Xs; predict=yhat, transform=Wout, foo=foo)
mach = machine(Unsupervised(), Xs;
predict=yhat,
transform=Wout,
report=(foo=foo,))

@from_network mach begin
mutable struct WrappedClusterer
Expand Down
6 changes: 3 additions & 3 deletions test/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ WrappedDummyClusterer(; model=DummyClusterer()) =
Xs;
predict=yhat,
transform=Wout,
foo=foo)
report=(foo=foo,))
return!(mach, model, verbosity)
end
X, _ = make_regression(10, 5);
Expand Down Expand Up @@ -310,8 +310,8 @@ function MLJBase.fit(m::TwoStages, verbosity, X, y)
Xs,
ys;
predict=ypred3,
μpred=μpred,
σpred=σpred)
report=(μpred=μpred,
σpred=σpred))
return!(mach, m, verbosity)
end

Expand Down

0 comments on commit 6b0f2cf

Please sign in to comment.