From eff36185cebd5c81feba9e3d1447ab7b38aa3709 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 08:41:11 -0700 Subject: [PATCH] fix: missing module --- src/helpers/training.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index ebc19a31ba..670a8a870c 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,7 +4,8 @@ using ADTypes: ADTypes, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using LuxCore: AbstractExplicitLayer +using ..Lux: Lux +using LuxCore: LuxCore, AbstractExplicitLayer using LuxDeviceUtils: AbstractLuxDevice, gpu_device using Optimisers: Optimisers using Random: AbstractRNG @@ -43,10 +44,11 @@ Internal fields: end """ - TrainState(rng::Random.AbstractRNG, model::Lux.AbstractExplicitLayer, + TrainState(rng::Random.AbstractRNG, model::LuxCore.AbstractExplicitLayer, optimizer::Optimisers.AbstractRule; transform_variables::Union{Function, AbstractLuxDevice}=gpu_device()) - TrainState(model::Lux.AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) + TrainState(model::LuxCore.AbstractExplicitLayer, ps, st, + optimizer::Optimisers.AbstractRule) Constructor for [`TrainState`](@ref). @@ -74,7 +76,7 @@ function TrainState( `TrainState(model::AbstractExplicitLayer, ps, st, \ optimizer::Optimisers.AbstractRule)`", :TrainState) - ps, st = Lux.setup(rng, model) .|> transform_variables + ps, st = LuxCore.setup(rng, model) .|> transform_variables return TrainState(model, ps, st, optimizer) end @@ -94,8 +96,8 @@ end function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) - println(io, " # of parameters: ", Lux.parameterlength(ts.parameters)) - println(io, " # of states: ", Lux.statelength(ts.states)) + println(io, " # of parameters: ", LuxCore.parameterlength(ts.parameters)) + println(io, " # of states: ", LuxCore.statelength(ts.states)) println(io, " optimizer: ", ts.optimizer) print(io, " step: ", ts.step) if ts.cache !== nothing