Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improvements to the Training API #794

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.62"
version = "0.5.63-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -63,7 +63,7 @@ LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"

[compat]
ADTypes = "0.2, 1"
ADTypes = "1.5"
Adapt = "4"
Aqua = "0.8.4"
ArgCheck = "2.1"
Expand Down
2 changes: 0 additions & 2 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,3 @@ Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2"
4 changes: 1 addition & 3 deletions bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ADTypes: ADTypes, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote
using ADTypes: ADTypes, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoTapir, AutoZygote
using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @btime, @benchmarkable
using ComponentArrays: ComponentArray
using InteractiveUtils: versioninfo
Expand All @@ -19,8 +19,6 @@ using Tapir: Tapir
using Tracker: Tracker
using Zygote: Zygote

struct AutoTapir <: ADTypes.AbstractReverseMode end

BLAS.set_num_threads(min(4, Threads.nthreads()))

@info sprint(versioninfo)
Expand Down
16 changes: 0 additions & 16 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@ All features listed on this page are **experimental** which means:
Pages = ["contrib.md"]
```

## [Training](@id Training-API)

Helper Functions making it easier to train `Lux.jl` models.

Lux.Training is meant to be simple and provide extremely basic functionality. We provide
basic building blocks which can be seamlessly composed to create complex training pipelines.

```@docs
Lux.Experimental.TrainState
Lux.Experimental.compute_gradients
Lux.Experimental.apply_gradients
Lux.Experimental.apply_gradients!
Lux.Experimental.single_train_step
Lux.Experimental.single_train_step!
```

## Parameter Freezing

!!! info
Expand Down
18 changes: 17 additions & 1 deletion docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@
Pages = ["utilities.md"]
```

## [Training API](@id Training-API)

Helper Functions making it easier to train `Lux.jl` models.

Training is meant to be simple and provide extremely basic functionality. We provide
basic building blocks which can be seamlessly composed to create complex training pipelines.

```@docs
Training.TrainState
Training.compute_gradients
Training.apply_gradients
Training.apply_gradients!
Training.single_train_step
Training.single_train_step!
```

## Loss Functions

Loss Functions Objects take 2 forms of inputs:
Expand All @@ -14,7 +30,7 @@ Loss Functions Objects take 2 forms of inputs:
2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters,
`st` are the states and `(x, y)` are the input and target pair. Then it returns the
loss, updated states, and an empty named tuple. This makes them compatible with the
[Experimental Training API](@ref Training-API).
[Training API](@ref Training-API).

!!! warning

Expand Down
4 changes: 2 additions & 2 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y

# We will train the model using our training API.
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Lux.Experimental.TrainState(model, ps, st, opt)
tstate = Training.TrainState(model, ps, st, opt)
for i in 1:nepochs
grads, loss, _, tstate = Lux.Experimental.single_train_step!(
grads, loss, _, tstate = Training.single_train_step!(
AutoZygote(), lossfn, (x_samples, y_samples), tstate)
if i % 1000 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
Expand Down
4 changes: 2 additions & 2 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

train_state = Lux.Experimental.TrainState(
train_state = Training.TrainState(
rng, model, AdamW(; eta=lr_max, lambda=weight_decay); transform_variables=gdev)

lr_schedule = linear_interpolation(
Expand All @@ -97,7 +97,7 @@ end
train_state = Optimisers.adjust!(train_state, lr)
x = x |> gdev
y = y |> gdev
(_, _, _, train_state) = Lux.Experimental.single_train_step!(
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
end
ttime = time() - stime
Expand Down
4 changes: 2 additions & 2 deletions examples/DDIM/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ end
@info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`"
tb_logger = TBLogger(tb_dir)

tstate = Lux.Experimental.TrainState(
tstate = Training.TrainState(
rng, model, AdamW(; eta=learning_rate_start, lambda=weight_decay);
transform_variables=gdev)

Expand All @@ -378,7 +378,7 @@ end
for (i, data) in enumerate(data_loader)
step += 1
data = data |> gdev
(_, _, stats, tstate) = Lux.Experimental.single_train_step!(
(_, _, stats, tstate) = Training.single_train_step!(
AutoZygote(), loss_function, data, tstate)
image_losses[i] = stats.image_loss
noise_losses[i] = stats.noise_loss
Expand Down
5 changes: 2 additions & 3 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ function train()

rng = Xoshiro(0)

train_state = Lux.Experimental.TrainState(
rng, model, Adam(3.0f-4); transform_variables=dev)
train_state = Training.TrainState(rng, model, Adam(3.0f-4); transform_variables=dev)

### Lets train the model
nepochs = 10
Expand All @@ -95,7 +94,7 @@ function train()
for (x, y) in train_dataloader
x = x |> dev
y = y |> dev
(_, _, _, train_state) = Lux.Experimental.single_train_step!(
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
Expand Down
4 changes: 2 additions & 2 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function train(model_function; cpu::Bool=false, kwargs...)
## Training
train_dataloader, test_dataloader = loadmnist(128, 0.9)

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0))
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

### Lets train the model
nepochs = 9
Expand All @@ -137,7 +137,7 @@ function train(model_function; cpu::Bool=false, kwargs...)
for (x, y) in train_dataloader
x = dev(x)
y = dev(y)
_, _, _, tstate = Lux.Experimental.single_train_step!(
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate)
end
ttime = time() - stime
Expand Down
11 changes: 5 additions & 6 deletions examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,29 @@ opt = Adam(0.03f0)

# ## Loss Function

# We will use the `Lux.Training` API so we need to ensure that our loss function takes 4
# We will use the `Training` API so we need to ensure that our loss function takes 4
# inputs -- model, parameters, states and data. The function must return 3 values -- loss,
# updated_state, and any computed statistics. This is already satisfied by the loss
# functions provided by Lux.
const loss_function = MSELoss()

# ## Training

# First we will create a [`Lux.Experimental.TrainState`](@ref) which is essentially a
# First we will create a [`Training.TrainState`](@ref) which is essentially a
# convenience wrapper over parameters, states and optimizer states.

tstate = Lux.Experimental.TrainState(rng, model, opt)
tstate = Training.TrainState(rng, model, opt)

# Now we will use Zygote for our AD requirements.

vjp_rule = AutoZygote()

# Finally the training loop.

function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
_, loss, _, tstate = Lux.Experimental.single_train_step!(
vjp, loss_function, data, tstate)
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
Expand Down
6 changes: 3 additions & 3 deletions examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ end
function train(model; rng=Xoshiro(0), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9)

train_state = Lux.Experimental.TrainState(
train_state = Training.TrainState(
rng, model, Adam(3.0f-4); transform_variables=identity)

### Warmup the model
x_proto = randn(rng, Float32, 28, 28, 1, 1)
y_proto = onehotbatch([1], 0:9)
Lux.Experimental.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)
Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

### Lets train the model
nepochs = 10
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
(gs, _, _, train_state) = Lux.Experimental.single_train_step!(
(gs, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
end
ttime = time() - stime
Expand Down
5 changes: 2 additions & 3 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,15 @@ function main(model_type)
rng = Xoshiro(0)

dev = gpu_device()
train_state = Lux.Experimental.TrainState(
rng, model, Adam(0.01f0); transform_variables=dev)
train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev)

for epoch in 1:25
## Train the model
for (x, y) in train_loader
x = x |> dev
y = y |> dev

(_, loss, _, train_state) = Lux.Experimental.single_train_step!(
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)

@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using EnzymeCore: EnzymeCore
using Lux: Lux
using Lux.Experimental: TrainingBackendCache, TrainState
using Lux.Training: TrainingBackendCache, TrainState

include("training.jl")

Expand Down
12 changes: 6 additions & 6 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function Lux.Experimental.compute_gradients(
function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

obj_fn_wrap, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(
obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.__wrap_objective_function(
obj_fn, ts.model, ts.parameters, ts.states, data, Val(true))

_, loss = Enzyme.autodiff(
Expand All @@ -20,7 +20,7 @@ end
const AUTODIFF_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}

function Lux.Experimental.compute_gradients(
function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)

Expand All @@ -35,7 +35,7 @@ function Lux.Experimental.compute_gradients(
return dps, loss, ts.cache.extras.stats_wrap[], ts_new
end

function Lux.Experimental.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Enzyme, false}}) where {F}
@warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \
function that is changing across function calls. This can lead to the \
Expand All @@ -51,13 +51,13 @@ function Lux.Experimental.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)

return Lux.Experimental.compute_gradients(ad, obj_fn, data, ts_new)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
end

const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}

function Lux.Experimental.compute_gradients(::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
params = Duplicated(ts.parameters, dps)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ADTypes: ADTypes, AutoReverseDiff
using ArrayInterface: ArrayInterface
using FunctionWrappers: FunctionWrapper
using Lux: Lux, LuxCPUDevice
using Lux.Experimental: TrainingBackendCache, TrainState
using Lux.Training: TrainingBackendCache, TrainState
using LuxCore: LuxCore, AbstractExplicitLayer
using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal,
@grad_from_chainrules
Expand Down
33 changes: 9 additions & 24 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
@static if pkgversion(ADTypes) < v"1.5"
# older versions did not have `compile` type parameter. Use slower type-unstable code
function Lux.Experimental.compute_gradients(
ad::AutoReverseDiff, obj_fn::F, data, ts::TrainState) where {F}
ad.compile && return __compiled_reverse_diff(obj_fn, data, ts)
return __uncompiled_reverse_diff(obj_fn, data, ts)
end
else
for compiled in (false, true)
fname = compiled ? :__compiled_reverse_diff : :__uncompiled_reverse_diff
@eval function Lux.Experimental.compute_gradients(
::AutoReverseDiff{$(compiled)}, obj_fn::F, data, ts::TrainState) where {F}
return $(fname)(obj_fn, data, ts)
end
end
end

# Uncompiled ReverseDiff
@inline function __uncompiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
@inline function Lux.Training.compute_gradients(
ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return __uncompiled_reverse_diff(obj_fn, data, ts_new)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
end

@inline function __uncompiled_reverse_diff(obj_fn::F, data,
@inline function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
tape = ReverseDiff.InstructionTape()
Expand All @@ -42,7 +26,8 @@ end
end

# Compiled ReverseDiff
@inline function __compiled_reverse_diff(obj_fn::F, data, ts::TrainState) where {F}
@inline function Lux.Training.compute_gradients(
ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
data_cache = deepcopy(data)
ps_cache = deepcopy(ts.parameters)
Expand All @@ -51,11 +36,11 @@ end
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return __compiled_reverse_diff(obj_fn, data, ts_new)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
end

## Tape hasn't been compiled yet / Function mismatch so recompile
@inline function __compiled_reverse_diff(obj_fn::F, data,
@inline function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
if LuxCore.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
Expand Down Expand Up @@ -106,7 +91,7 @@ end
return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new
end

@inline function __compiled_reverse_diff(obj_fn::F, data,
@inline function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

Expand Down
Loading
Loading