From 6e96f8903693d24ac5c8e0b55eddf9ee43f36c82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:09:06 -0700 Subject: [PATCH 1/7] refactor: move training functions into `Training` module --- Project.toml | 2 +- src/Lux.jl | 8 ++++++-- src/contrib/contrib.jl | 29 +++------------------------- src/contrib/deprecated.jl | 13 ++++++------- src/{contrib => helpers}/training.jl | 28 +++++++++++++++++++++++---- 5 files changed, 40 insertions(+), 40 deletions(-) rename src/{contrib => helpers}/training.jl (93%) diff --git a/Project.toml b/Project.toml index cfd57f053..0114aac93 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.62" +version = "0.5.63-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/Lux.jl b/src/Lux.jl index 6677cf9cf..26d565753 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -28,8 +28,8 @@ using UnrolledUtilities: unrolled_map, unrolled_mapreduce @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates, parameterlength, statelength, inputsize, outputsize, - update_state, trainmode, testmode, setup, apply, display_name, replicate + initialstates, parameterlength, statelength, outputsize, apply, + display_name, replicate const CRC = ChainRulesCore @@ -44,6 +44,9 @@ include("preferences.jl") include("custom_errors.jl") include("utils.jl") +# Training Helpers +include("helpers/training.jl") + # Experimental include("contrib/contrib.jl") @@ -106,6 +109,7 @@ export SamePad, TimeLastIndex, BatchLastIndex export StatefulLuxLayer export CompactLuxLayer export @compact, @init_fn, @non_trainable +export Training export jacobian_vector_product, vector_jacobian_product export batched_jacobian diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index 15dd99c46..b55bb6254 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -2,7 +2,6 @@ module Experimental using ..Lux: Lux using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, apply -using LuxDeviceUtils: AbstractLuxDevice, gpu_device using ADTypes: ADTypes using ArgCheck: @argcheck @@ -15,46 +14,24 @@ using Markdown: @doc_str using Optimisers: Optimisers using Random: AbstractRNG, Random using Setfield: Setfield +using ..Training: Training const CRC = ChainRulesCore include("map.jl") -include("training.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") include("deprecated.jl") @compat public layer_map, @layer_map -@compat(public, - (TrainState, apply_gradients, apply_gradients!, - compute_gradients, single_train_step, single_train_step!)) @compat public FrozenLayer, freeze, unfreeze @compat public share_parameters @compat public DebugLayer, @debug_mode end -# Deprecations for v0.6 -module Training - -using ADTypes: AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote -using ..Experimental: Experimental - -for f in (:TrainState, :apply_gradients, :compute_gradients) - msg = lazy"`Lux.Training.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" - @eval begin - function $(f)(args...; kwargs...) - Base.depwarn($(msg), Symbol($(f))) - return Experimental.$(f)(args...; kwargs...) - end - end -end - -export AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote - -end - +# Deprecations for v1.0 macro layer_map(f, l, ps, st) Base.depwarn( "`Lux.@layer_map` has been deprecated in favor of `Lux.Experimental.@layer_map`", @@ -65,7 +42,7 @@ macro layer_map(f, l, ps, st) end for f in (:layer_map, :share_parameters, :FrozenLayer, :freeze, :unfreeze) - msg = lazy"`Lux.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" + msg = "`Lux.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" @eval begin $(f)(args...; kwargs...) = begin Base.depwarn($(msg), Symbol($(f))) diff --git a/src/contrib/deprecated.jl b/src/contrib/deprecated.jl index 6496b6ddd..4acc15a00 100644 --- a/src/contrib/deprecated.jl +++ b/src/contrib/deprecated.jl @@ -7,11 +7,10 @@ macro compact(exs...) return Lux.__compact_macro_impl(exs...) end -function StatefulLuxLayer(args...; kwargs...) - Base.depwarn( - "Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \ - and is now available in `Lux`. In other words this has been deprecated and will \ - be removed in v0.6. Use `Lux.StatefulLuxLayer` instead.", - :StatefulLuxLayer) - return Lux.StatefulLuxLayer(args...; kwargs...) +Base.@deprecate StatefulLuxLayer(args...; kwargs...) Lux.StatefulLuxLayer( + args...; kwargs...) false + +for f in (:TrainState, :TrainingBackendCache, :single_train_step, :single_train_step!, + :apply_gradients, :apply_gradients!, :compute_gradients) + @eval Base.@deprecate $f(args...; kwargs...) Training.$f(args...; kwargs...) false end diff --git a/src/contrib/training.jl b/src/helpers/training.jl similarity index 93% rename from src/contrib/training.jl rename to src/helpers/training.jl index 720100dc9..ebc19a31b 100644 --- a/src/contrib/training.jl +++ b/src/helpers/training.jl @@ -1,3 +1,14 @@ +module Training + +using ADTypes: ADTypes, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote +using Compat: @compat +using ConcreteStructs: @concrete +using FastClosures: @closure +using LuxCore: AbstractExplicitLayer +using LuxDeviceUtils: AbstractLuxDevice, gpu_device +using Optimisers: Optimisers +using Random: AbstractRNG + """ TrainState @@ -20,9 +31,9 @@ Internal fields: Constructing this object directly shouldn't be considered a stable API. Use the version with the Optimisers API. """ -@concrete struct TrainState{C, F} - cache::C - objective_function::F +@concrete struct TrainState + cache + objective_function model parameters states @@ -188,7 +199,8 @@ function compute_gradients(ad::ADTypes.AbstractADType, ::F, _, ::TrainState) whe end function __maybe_implemented_compute_gradients(::T) where {T <: ADTypes.AbstractADType} - throw(ArgumentError(lazy"Support for AD backend $(nameof(T)) has not been implemented yet!!!")) + throw(ArgumentError("Support for AD backend $(nameof(T)) has not been implemented \ + yet!")) end for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) @@ -301,3 +313,11 @@ function Optimisers.adjust(ts::TrainState; kwargs...) return TrainState(ts.cache, ts.objective_function, ts.model, ts.parameters, ts.states, optimizer, st_opt, ts.step) end + +@compat(public, + (TrainState, apply_gradients, apply_gradients!, + compute_gradients, single_train_step, single_train_step!)) + +export AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote + +end From 27cb993b1617459f1221f60af8148702099c9718 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:37:47 -0700 Subject: [PATCH 2/7] refactor: remove Experimental usage in extensions --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 2 +- ext/LuxEnzymeExt/training.jl | 12 ++++++------ ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 2 +- ext/LuxReverseDiffExt/training.jl | 4 ++-- ext/LuxTrackerExt.jl | 8 ++++---- ext/LuxZygoteExt/training.jl | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 36bb067e1..bf1c41c13 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -4,7 +4,7 @@ using ADTypes: AutoEnzyme using Enzyme: Enzyme, Active, Const, Duplicated using EnzymeCore: EnzymeCore using Lux: Lux -using Lux.Experimental: TrainingBackendCache, TrainState +using Lux.Training: TrainingBackendCache, TrainState include("training.jl") diff --git a/ext/LuxEnzymeExt/training.jl b/ext/LuxEnzymeExt/training.jl index 235628c50..01ce7be31 100644 --- a/ext/LuxEnzymeExt/training.jl +++ b/ext/LuxEnzymeExt/training.jl @@ -1,8 +1,8 @@ -function Lux.Experimental.compute_gradients( +function Lux.Training.compute_gradients( ::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F} dps = Lux.recursive_make_zero(ts.parameters) - obj_fn_wrap, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( + obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.__wrap_objective_function( obj_fn, ts.model, ts.parameters, ts.states, data, Val(true)) _, loss = Enzyme.autodiff( @@ -20,7 +20,7 @@ end const AUTODIFF_CACHE_TYPE = TrainingBackendCache{ :Enzyme, false, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS} -function Lux.Experimental.compute_gradients( +function Lux.Training.compute_gradients( ::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) @@ -35,7 +35,7 @@ function Lux.Experimental.compute_gradients( return dps, loss, ts.cache.extras.stats_wrap[], ts_new end -function Lux.Experimental.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, +function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Enzyme, false}}) where {F} @warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \ function that is changing across function calls. This can lead to the \ @@ -51,13 +51,13 @@ function Lux.Experimental.compute_gradients(ad::AutoEnzyme, obj_fn::F, data, ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new) + return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) end const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{ :Enzyme, false, PS, <:NamedTuple{(:forward, :reverse)}} where {PS} -function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data, +function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F} dps = Lux.recursive_make_zero!!(ts.cache.dparameters) params = Duplicated(ts.parameters, dps) diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 9669a013b..35eecb733 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -4,7 +4,7 @@ using ADTypes: ADTypes, AutoReverseDiff using ArrayInterface: ArrayInterface using FunctionWrappers: FunctionWrapper using Lux: Lux, LuxCPUDevice -using Lux.Experimental: TrainingBackendCache, TrainState +using Lux.Training: TrainingBackendCache, TrainState using LuxCore: LuxCore, AbstractExplicitLayer using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index baa871931..a2be5a58e 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,6 +1,6 @@ @static if pkgversion(ADTypes) < v"1.5" # older versions did not have `compile` type parameter. Use slower type-unstable code - function Lux.Experimental.compute_gradients( + function Lux.Training.compute_gradients( ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F} ad.compile && return __compiled_reverse_diff(obj_fn, data, ts) return __uncompiled_reverse_diff(obj_fn, data, ts) @@ -8,7 +8,7 @@ else for compiled in (false, true) fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff - @eval function Lux.Experimental.compute_gradients( + @eval function Lux.Training.compute_gradients( ::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F} return $(fname)(obj_fn, data, ts) end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index c65ed3644..8950b98ef 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -4,7 +4,7 @@ using ADTypes: AutoTracker using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore using Lux: Lux, LuxCPUDevice -using Lux.Experimental: TrainingBackendCache, TrainState +using Lux.Training: TrainingBackendCache, TrainState using LuxCore: LuxCore, AbstractExplicitLayer using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules @@ -23,7 +23,7 @@ function __construct_tracked_params(ps, dps) end # Lux.Training -function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, +function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) ps_tracked = __construct_tracked_params(ts.parameters, dparams) @@ -38,13 +38,13 @@ function Lux.Experimental.compute_gradients(::AutoTracker, obj_fn::F, data, return dparams, loss.data, stats, ts_new end -function Lux.Experimental.compute_gradients( +function Lux.Training.compute_gradients( ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) ts_new = TrainState( TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Experimental.compute_gradients(AutoTracker(), obj_fn, data, ts_new) + return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new) end # AoS to SoA conversion diff --git a/ext/LuxZygoteExt/training.jl b/ext/LuxZygoteExt/training.jl index 35c115572..3832800cd 100644 --- a/ext/LuxZygoteExt/training.jl +++ b/ext/LuxZygoteExt/training.jl @@ -1,5 +1,5 @@ -function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, data, - ts::Lux.Experimental.TrainState) where {F} +function Lux.Training.compute_gradients( + ::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F} (loss, st, stats), back = Zygote.pullback( objective_function, ts.model, ts.parameters, ts.states, data) grads = back((one(loss), nothing, nothing))[2] From 04dcd244d14e8a68934d5cb84ca3183199343cbc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:42:20 -0700 Subject: [PATCH 3/7] fix: drop very old ADTypes --- Project.toml | 2 +- bench/Project.toml | 2 -- bench/runbenchmarks.jl | 4 +--- ext/LuxReverseDiffExt/training.jl | 33 +++++++++---------------------- 4 files changed, 11 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index 0114aac93..52f436bbb 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" [compat] -ADTypes = "0.2, 1" +ADTypes = "1.5" Adapt = "4" Aqua = "0.8.4" ArgCheck = "2.1" diff --git a/bench/Project.toml b/bench/Project.toml index d930e322d..669a0cad2 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -19,5 +19,3 @@ Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -[compat] -ADTypes = "0.2" diff --git a/bench/runbenchmarks.jl b/bench/runbenchmarks.jl index 0a0af462f..ff0c1bfdb 100644 --- a/bench/runbenchmarks.jl +++ b/bench/runbenchmarks.jl @@ -1,4 +1,4 @@ -using ADTypes: ADTypes, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote +using ADTypes: ADTypes, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoTapir, AutoZygote using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @btime, @benchmarkable using ComponentArrays: ComponentArray using InteractiveUtils: versioninfo @@ -19,8 +19,6 @@ using Tapir: Tapir using Tracker: Tracker using Zygote: Zygote -struct AutoTapir <: ADTypes.AbstractReverseMode end - BLAS.set_num_threads(min(4, Threads.nthreads())) @info sprint(versioninfo) diff --git a/ext/LuxReverseDiffExt/training.jl b/ext/LuxReverseDiffExt/training.jl index a2be5a58e..024424901 100644 --- a/ext/LuxReverseDiffExt/training.jl +++ b/ext/LuxReverseDiffExt/training.jl @@ -1,30 +1,14 @@ -@static if pkgversion(ADTypes) < v"1.5" - # older versions did not have `compile` type parameter. Use slower type-unstable code - function Lux.Training.compute_gradients( - ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F} - ad.compile && return __compiled_reverse_diff(obj_fn, data, ts) - return __uncompiled_reverse_diff(obj_fn, data, ts) - end -else - for compiled in (false, true) - fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff - @eval function Lux.Training.compute_gradients( - ::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F} - return $(fname)(obj_fn, data, ts) - end - end -end - # Uncompiled ReverseDiff -@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} +@inline function Lux.Training.compute_gradients( + ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) ts_new = TrainState( TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return __uncompiled_reverse_diff(obj_fn, data, ts_new) + return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) end -@inline function __uncompiled_reverse_diff(obj_fn::F, data, +@inline function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) tape = ReverseDiff.InstructionTape() @@ -42,7 +26,8 @@ end end # Compiled ReverseDiff -@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F} +@inline function Lux.Training.compute_gradients( + ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F} grads = Lux.recursive_make_zero(ts.parameters) data_cache = deepcopy(data) ps_cache = deepcopy(ts.parameters) @@ -51,11 +36,11 @@ end ts_new = TrainState( TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model, ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return __compiled_reverse_diff(obj_fn, data, ts_new) + return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new) end ## Tape hasn't been compiled yet / Function mismatch so recompile -@inline function __compiled_reverse_diff(obj_fn::F, data, +@inline function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT} if LuxCore.statelength(ts.states) != 0 throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \ @@ -106,7 +91,7 @@ end return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new end -@inline function __compiled_reverse_diff(obj_fn::F, data, +@inline function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F} (; ps_cache, data_cache, output) = ts.cache.extras From 090fbbf3505bede12c9705990c9508222f0f9272 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:44:37 -0700 Subject: [PATCH 4/7] test: move the Training tests into a different group --- test/{contrib => helpers}/training_tests.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) rename test/{contrib => helpers}/training_tests.jl (97%) diff --git a/test/contrib/training_tests.jl b/test/helpers/training_tests.jl similarity index 97% rename from test/contrib/training_tests.jl rename to test/helpers/training_tests.jl index 4ab145529..51801cf3d 100644 --- a/test/contrib/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -1,4 +1,4 @@ -@testitem "TrainState" setup=[SharedTestSetup] tags=[:contrib] begin +@testitem "TrainState" setup=[SharedTestSetup] tags=[:helpers] begin using Optimisers rng = StableRNG(12345) @@ -22,7 +22,7 @@ end end -@testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:contrib] begin +@testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers, Enzyme function _loss_function(model, ps, st, data) @@ -53,7 +53,7 @@ end end end -@testitem "Training API" setup=[SharedTestSetup] tags=[:contrib] begin +@testitem "Training API" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers import Enzyme, Tracker, ReverseDiff, Zygote @@ -136,7 +136,7 @@ end end end -@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:contrib] begin +@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers using Enzyme @@ -179,7 +179,7 @@ end @test hasfield(typeof(tstate_new2.cache.extras), :reverse) end -@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:contrib] begin +@testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers, ReverseDiff mse1 = MSELoss() From 1243817df6a005e57dd6c182a522329bb071907a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:48:37 -0700 Subject: [PATCH 5/7] docs: update docs location for Training --- docs/src/api/Lux/contrib.md | 16 ---------------- docs/src/api/Lux/utilities.md | 18 +++++++++++++++++- examples/Basics/main.jl | 4 ++-- examples/ConvMixer/main.jl | 4 ++-- examples/DDIM/main.jl | 4 ++-- examples/HyperNet/main.jl | 5 ++--- examples/NeuralODE/main.jl | 4 ++-- examples/PolynomialFitting/main.jl | 11 +++++------ examples/SimpleChains/main.jl | 6 +++--- examples/SimpleRNN/main.jl | 5 ++--- 10 files changed, 37 insertions(+), 40 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index a5a4c3f15..a5143ae01 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -20,22 +20,6 @@ All features listed on this page are **experimental** which means: Pages = ["contrib.md"] ``` -## [Training](@id Training-API) - -Helper Functions making it easier to train `Lux.jl` models. - -Lux.Training is meant to be simple and provide extremely basic functionality. We provide -basic building blocks which can be seamlessly composed to create complex training pipelines. - -```@docs -Lux.Experimental.TrainState -Lux.Experimental.compute_gradients -Lux.Experimental.apply_gradients -Lux.Experimental.apply_gradients! -Lux.Experimental.single_train_step -Lux.Experimental.single_train_step! -``` - ## Parameter Freezing !!! info diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 67151939f..53faad239 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -6,6 +6,22 @@ Pages = ["utilities.md"] ``` +## [Training API](@id Training-API) + +Helper Functions making it easier to train `Lux.jl` models. + +Training is meant to be simple and provide extremely basic functionality. We provide +basic building blocks which can be seamlessly composed to create complex training pipelines. + +```@docs +Training.TrainState +Training.compute_gradients +Training.apply_gradients +Training.apply_gradients! +Training.single_train_step +Training.single_train_step! +``` + ## Loss Functions Loss Functions Objects take 2 forms of inputs: @@ -14,7 +30,7 @@ Loss Functions Objects take 2 forms of inputs: 2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters, `st` are the states and `(x, y)` are the input and target pair. Then it returns the loss, updated states, and an empty named tuple. This makes them compatible with the - [Experimental Training API](@ref Training-API). + [Training API](@ref Training-API). !!! warning diff --git a/examples/Basics/main.jl b/examples/Basics/main.jl index 15f03f741..ec2696365 100644 --- a/examples/Basics/main.jl +++ b/examples/Basics/main.jl @@ -303,9 +303,9 @@ println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y # We will train the model using our training API. function train_model!(model, ps, st, opt, nepochs::Int) - tstate = Lux.Experimental.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) for i in 1:nepochs - grads, loss, _, tstate = Lux.Experimental.single_train_step!( + grads, loss, _, tstate = Training.single_train_step!( AutoZygote(), lossfn, (x_samples, y_samples), tstate) if i % 1000 == 1 || i == nepochs @printf "Loss Value after %6d iterations: %.8f\n" i loss diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index a5d5a532c..372259d81 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -81,7 +81,7 @@ end opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - train_state = Lux.Experimental.TrainState( + train_state = Training.TrainState( rng, model, AdamW(; eta=lr_max, lambda=weight_decay); transform_variables=gdev) lr_schedule = linear_interpolation( @@ -97,7 +97,7 @@ end train_state = Optimisers.adjust!(train_state, lr) x = x |> gdev y = y |> gdev - (_, _, _, train_state) = Lux.Experimental.single_train_step!( + (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end ttime = time() - stime diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index 448bceaac..093438199 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -354,7 +354,7 @@ end @info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`" tb_logger = TBLogger(tb_dir) - tstate = Lux.Experimental.TrainState( + tstate = Training.TrainState( rng, model, AdamW(; eta=learning_rate_start, lambda=weight_decay); transform_variables=gdev) @@ -378,7 +378,7 @@ end for (i, data) in enumerate(data_loader) step += 1 data = data |> gdev - (_, _, stats, tstate) = Lux.Experimental.single_train_step!( + (_, _, stats, tstate) = Training.single_train_step!( AutoZygote(), loss_function, data, tstate) image_losses[i] = stats.image_loss noise_losses[i] = stats.noise_loss diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 3c6cdbd81..9045091c7 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -83,8 +83,7 @@ function train() rng = Xoshiro(0) - train_state = Lux.Experimental.TrainState( - rng, model, Adam(3.0f-4); transform_variables=dev) + train_state = Training.TrainState(rng, model, Adam(3.0f-4); transform_variables=dev) ### Lets train the model nepochs = 10 @@ -95,7 +94,7 @@ function train() for (x, y) in train_dataloader x = x |> dev y = y |> dev - (_, _, _, train_state) = Lux.Experimental.single_train_step!( + (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, ((data_idx, x), y), train_state) end ttime = time() - stime diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 039654491..55e0ae944 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -128,7 +128,7 @@ function train(model_function; cpu::Bool=false, kwargs...) ## Training train_dataloader, test_dataloader = loadmnist(128, 0.9) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) ### Lets train the model nepochs = 9 @@ -137,7 +137,7 @@ function train(model_function; cpu::Bool=false, kwargs...) for (x, y) in train_dataloader x = dev(x) y = dev(y) - _, _, _, tstate = Lux.Experimental.single_train_step!( + _, _, _, tstate = Training.single_train_step!( AutoZygote(), logitcrossentropy, (x, y), tstate) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index a3650f185..2db9b6591 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -49,7 +49,7 @@ opt = Adam(0.03f0) # ## Loss Function -# We will use the `Lux.Training` API so we need to ensure that our loss function takes 4 +# We will use the `Training` API so we need to ensure that our loss function takes 4 # inputs -- model, parameters, states and data. The function must return 3 values -- loss, # updated_state, and any computed statistics. This is already satisfied by the loss # functions provided by Lux. @@ -57,10 +57,10 @@ const loss_function = MSELoss() # ## Training -# First we will create a [`Lux.Experimental.TrainState`](@ref) which is essentially a +# First we will create a [`Training.TrainState`](@ref) which is essentially a # convenience wrapper over parameters, states and optimizer states. -tstate = Lux.Experimental.TrainState(rng, model, opt) +tstate = Training.TrainState(rng, model, opt) # Now we will use Zygote for our AD requirements. @@ -68,11 +68,10 @@ vjp_rule = AutoZygote() # Finally the training loop. -function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) +function main(tstate::Training.TrainState, vjp, data, epochs) data = data .|> gpu_device() for epoch in 1:epochs - _, loss, _, tstate = Lux.Experimental.single_train_step!( - vjp, loss_function, data, tstate) + _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 8bde951cc..480865cd2 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -62,20 +62,20 @@ end function train(model; rng=Xoshiro(0), kwargs...) train_dataloader, test_dataloader = loadmnist(128, 0.9) - train_state = Lux.Experimental.TrainState( + train_state = Training.TrainState( rng, model, Adam(3.0f-4); transform_variables=identity) ### Warmup the model x_proto = randn(rng, Float32, 28, 28, 1, 1) y_proto = onehotbatch([1], 0:9) - Lux.Experimental.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state) + Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state) ### Lets train the model nepochs = 10 for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - (gs, _, _, train_state) = Lux.Experimental.single_train_step!( + (gs, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end ttime = time() - stime diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index ffa3982ee..0cd861f1b 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -138,8 +138,7 @@ function main(model_type) rng = Xoshiro(0) dev = gpu_device() - train_state = Lux.Experimental.TrainState( - rng, model, Adam(0.01f0); transform_variables=dev) + train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev) for epoch in 1:25 ## Train the model @@ -147,7 +146,7 @@ function main(model_type) x = x |> dev y = y |> dev - (_, loss, _, train_state) = Lux.Experimental.single_train_step!( + (_, loss, _, train_state) = Training.single_train_step!( AutoZygote(), lossfn, (x, y), train_state) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss From a01768cd2ee2c660f358b5ec810f042b76723c88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 00:50:23 -0700 Subject: [PATCH 6/7] revert: re-add unused LuxCore imports --- src/Lux.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Lux.jl b/src/Lux.jl index 26d565753..73cddcbc9 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -28,8 +28,8 @@ using UnrolledUtilities: unrolled_map, unrolled_mapreduce @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates, parameterlength, statelength, outputsize, apply, - display_name, replicate + initialstates, parameterlength, statelength, inputsize, outputsize, + update_state, trainmode, testmode, setup, apply, display_name, replicate const CRC = ChainRulesCore From 9e262df4a4108fc4d058af9249e3ef727ceef552 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 08:41:11 -0700 Subject: [PATCH 7/7] fix: missing module --- src/helpers/training.jl | 14 ++++++++------ test/helpers/training_tests.jl | 10 +++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index ebc19a31b..670a8a870 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,7 +4,8 @@ using ADTypes: ADTypes, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using LuxCore: AbstractExplicitLayer +using ..Lux: Lux +using LuxCore: LuxCore, AbstractExplicitLayer using LuxDeviceUtils: AbstractLuxDevice, gpu_device using Optimisers: Optimisers using Random: AbstractRNG @@ -43,10 +44,11 @@ Internal fields: end """ - TrainState(rng::Random.AbstractRNG, model::Lux.AbstractExplicitLayer, + TrainState(rng::Random.AbstractRNG, model::LuxCore.AbstractExplicitLayer, optimizer::Optimisers.AbstractRule; transform_variables::Union{Function, AbstractLuxDevice}=gpu_device()) - TrainState(model::Lux.AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) + TrainState(model::LuxCore.AbstractExplicitLayer, ps, st, + optimizer::Optimisers.AbstractRule) Constructor for [`TrainState`](@ref). @@ -74,7 +76,7 @@ function TrainState( `TrainState(model::AbstractExplicitLayer, ps, st, \ optimizer::Optimisers.AbstractRule)`", :TrainState) - ps, st = Lux.setup(rng, model) .|> transform_variables + ps, st = LuxCore.setup(rng, model) .|> transform_variables return TrainState(model, ps, st, optimizer) end @@ -94,8 +96,8 @@ end function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) - println(io, " # of parameters: ", Lux.parameterlength(ts.parameters)) - println(io, " # of states: ", Lux.statelength(ts.states)) + println(io, " # of parameters: ", LuxCore.parameterlength(ts.parameters)) + println(io, " # of states: ", LuxCore.statelength(ts.states)) println(io, " optimizer: ", ts.optimizer) print(io, " step: ", ts.step) if ts.cache !== nothing diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 51801cf3d..a61d7888f 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -76,12 +76,12 @@ end AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue - @test_throws ArgumentError Lux.Experimental.__maybe_implemented_compute_gradients(ad) + @test_throws ArgumentError Lux.Training.__maybe_implemented_compute_gradients(ad) - @test_deprecated Lux.Training.TrainState( + @test_deprecated Lux.Experimental.TrainState( Lux.replicate(rng), model, opt; transform_variables=dev) - tstate = Lux.Experimental.TrainState( + tstate = Lux.Training.TrainState( Lux.replicate(rng), model, opt; transform_variables=dev) initial_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) @@ -93,10 +93,10 @@ end end (x, y) = first(dataset_) - @test_deprecated Lux.Training.compute_gradients(ad, mse, (x, y), tstate) + @test_deprecated Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) grads, loss, _, tstate = Lux.Experimental.compute_gradients( ad, mse, (x, y), tstate) - @test_deprecated Lux.Training.apply_gradients(tstate, grads) + @test_deprecated Lux.Experimental.apply_gradients(tstate, grads) for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = Lux.Experimental.single_train_step!(