Skip to content

Commit

Permalink
fix: missing module
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 26, 2024
1 parent a01768c commit eff3618
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using ADTypes: ADTypes, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote
using Compat: @compat
using ConcreteStructs: @concrete
using FastClosures: @closure
using LuxCore: AbstractExplicitLayer
using ..Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxDeviceUtils: AbstractLuxDevice, gpu_device
using Optimisers: Optimisers
using Random: AbstractRNG
Expand Down Expand Up @@ -43,10 +44,11 @@ Internal fields:
end

"""
TrainState(rng::Random.AbstractRNG, model::Lux.AbstractExplicitLayer,
TrainState(rng::Random.AbstractRNG, model::LuxCore.AbstractExplicitLayer,
optimizer::Optimisers.AbstractRule;
transform_variables::Union{Function, AbstractLuxDevice}=gpu_device())
TrainState(model::Lux.AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule)
TrainState(model::LuxCore.AbstractExplicitLayer, ps, st,
optimizer::Optimisers.AbstractRule)
Constructor for [`TrainState`](@ref).
Expand Down Expand Up @@ -74,7 +76,7 @@ function TrainState(
`TrainState(model::AbstractExplicitLayer, ps, st, \
optimizer::Optimisers.AbstractRule)`",
:TrainState)
ps, st = Lux.setup(rng, model) .|> transform_variables
ps, st = LuxCore.setup(rng, model) .|> transform_variables
return TrainState(model, ps, st, optimizer)
end

Expand All @@ -94,8 +96,8 @@ end
function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " # of parameters: ", Lux.parameterlength(ts.parameters))
println(io, " # of states: ", Lux.statelength(ts.states))
println(io, " # of parameters: ", LuxCore.parameterlength(ts.parameters))
println(io, " # of states: ", LuxCore.statelength(ts.states))
println(io, " optimizer: ", ts.optimizer)
print(io, " step: ", ts.step)
if ts.cache !== nothing
Expand Down

0 comments on commit eff3618

Please sign in to comment.