Skip to content

Commit

Permalink
docs: update docs from downstream changes (#790)
Browse files Browse the repository at this point in the history
* docs: make the interface specification sound mandatory

* docs: add list of supported rngs for init

* docs: remove packages no longer explicitly needed
  • Loading branch information
avik-pal authored Jul 26, 2024
1 parent 708c10c commit e1594bf
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
22 changes: 21 additions & 1 deletion docs/src/api/Building_Blocks/WeightInitializers.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ learning models.
Pages = ["WeightInitializers.md"]
```

## [Supported RNG Types](@id Supported-RNG-Types-WeightInit)

| **RNG Type / Package** | **Returned Array Type** | **Unsupported Functions** |
| --------------------------------- | ----------------------- | ------------------------------------------------ |
| `Random.jl` | `Array` | |
| `StableRNGs.jl` | `Array` | |
| `CUDA.default_rng()` | `CuArray` | |
| `GPUArrays.default_rng(CuArray)` | `CuArray` | |
| `AMDGPU.rocrand_rng()` | `ROCArray` | |
| `AMDGPU.gpuarrays_rng()` | `ROCArray` | |
| `GPUArrays.default_rng(ROCArray)` | `ROCArray` | |
| `Metal.gpuarrays_rng()` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) |
| `GPUArrays.default_rng(MtlArray)` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) |
| `oneAPI.gpuarrays_rng()` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) |
| `GPUArrays.default_rng(oneArray)` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) |

## API Reference

### Main Functions
Expand All @@ -24,7 +40,11 @@ truncated_normal
orthogonal
```

### Commonly Used Wrappers
### Other Convenience Functions

!!! warning "Beware"

Unlike the other functions these ones don't take a type argument.

```@docs
zeros16
Expand Down
73 changes: 31 additions & 42 deletions docs/src/manual/interface.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
# [Lux Interface](@id lux-interface)

!!! tip
!!! tip "Lux.jl vs LuxCore.jl"

If you just want to define compatibility with Lux without actually using any of the
other functionality provided by Lux (like layers), it is recommended to depend on
`LuxCore.jl` instead of `Lux.jl`. `LuxCore.jl` is a significantly lighter dependency.

First let's set the expectations straight.

- Do you **have to** follow the interface? *No*.
- **Should you** follow it? *Probably yes*.
- **Why?** It provides the ability for frameworks built on top of Lux to be cross
compatible. Additionally, any new functionality built into Lux, will just work for your
framework.
Following this interface provides the ability for frameworks built on top of Lux to be cross
compatible. Additionally, any new functionality built into Lux, will just work for your
framework.

!!! warning

The interface is optional for frameworks being developed independent of Lux. All
functionality in the core library (and officially supported ones) **must** adhere to
the interface

!!! tip
!!! tip "`@compact` macro"

While writing out a custom struct and defining dispatches manually is a good way to
understand the interface, it is not the most concise way. We recommend using the
Expand All @@ -46,17 +36,16 @@ architecture cannot change.
out [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding.

```@example layer_interface
using Lux, Random
using LuxCore, Random, WeightInitializers # Importing `Lux` also gives you access to `LuxCore`
struct Linear{F1, F2} <: Lux.AbstractExplicitLayer
struct Linear{F1, F2} <: LuxCore.AbstractExplicitLayer
in_dims::Int
out_dims::Int
init_weight::F1
init_bias::F2
end
function Linear(in_dims::Int, out_dims::Int; init_weight=Lux.glorot_uniform,
init_bias=Lux.zeros32)
function Linear(in_dims::Int, out_dims::Int; init_weight=glorot_uniform, init_bias=zeros32)
return Linear{typeof(init_weight), typeof(init_bias)}(in_dims, out_dims, init_weight,
init_bias)
end
Expand All @@ -71,31 +60,31 @@ etc. The recommended data structure for returning parameters is a NamedTuple, th
anything satisfying the [Parameter Interface](#parameter-interface) is valid.

```@example layer_interface
function Lux.initialparameters(rng::AbstractRNG, l::Linear)
function LuxCore.initialparameters(rng::AbstractRNG, l::Linear)
return (weight=l.init_weight(rng, l.out_dims, l.in_dims),
bias=l.init_bias(rng, l.out_dims, 1))
end
Lux.initialstates(::AbstractRNG, ::Linear) = NamedTuple()
LuxCore.initialstates(::AbstractRNG, ::Linear) = NamedTuple()
```

You could also implement `Lux.parameterlength` and `Lux.statelength` to prevent wasteful
reconstruction of the parameters and states.
You could also implement `LuxCore.parameterlength` and `LuxCore.statelength` to prevent
wasteful reconstruction of the parameters and states.

```@example layer_interface
# This works
println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ",
Lux.statelength(l))
println("Parameter Length: ", LuxCore.parameterlength(l), "; State Length: ",
LuxCore.statelength(l))
# But still recommended to define these
Lux.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
LuxCore.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
Lux.statelength(::Linear) = 0
LuxCore.statelength(::Linear) = 0
```

!!! tip
!!! tip "No RNG in `initialparameters` and `initialstates`"

You might notice that we don't pass in a `PRNG` for these functions. If your parameter
You might notice that we don't pass in a `RNG` for these functions. If your parameter
length and/or state length depend on a random number generator, you should think
**really hard** about what you are trying to do and why.

Expand All @@ -117,27 +106,27 @@ feel you need a refresher on that.
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, l)
ps, st = LuxCore.setup(rng, l)
println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ",
Lux.statelength(l))
println("Parameter Length: ", LuxCore.parameterlength(l), "; State Length: ",
LuxCore.statelength(l))
x = randn(rng, Float32, 2, 1)
Lux.apply(l, x, ps, st) # or `l(x, ps, st)`
LuxCore.apply(l, x, ps, st) # or `l(x, ps, st)`
```

### [Container Layer](@id Container-Layer)

If your layer comprises of other Lux layers, then it is a `Container Layer`. Note that you
could treat it as a [`Singular Layer`](#singular-layer), and it is still fine. FWIW, if you
cannot subtype your layer with `Lux.AbstractExplicitContainerLayer` then you
cannot subtype your layer with `LuxCore.AbstractExplicitContainerLayer` then you
should go down the [`Singular Layer`](#singular-layer) route. But subtyping allows us to
bypass some of these common definitions. Let us now define a layer, which is basically a
composition of two linear layers.

```@example layer_interface
struct ComposedLinear{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:linear_1, :linear_2)}
struct ComposedLinear{L1, L2} <: LuxCore.AbstractExplicitContainerLayer{(:linear_1, :linear_2)}
linear_1::L1
linear_2::L2
end
Expand All @@ -160,17 +149,17 @@ and we need to construct parameters and states for those. Let's construct these
model = ComposedLinear(Linear(2, 4), Linear(4, 2))
display(model)
ps, st = Lux.setup(rng, model)
ps, st = LuxCore.setup(rng, model)
println("Parameters: ", ps)
println("States: ", st)
println("Parameter Length: ", Lux.parameterlength(model), "; State Length: ",
Lux.statelength(model))
println("Parameter Length: ", LuxCore.parameterlength(model), "; State Length: ",
LuxCore.statelength(model))
x = randn(rng, Float32, 2, 1)
Lux.apply(model, x, ps, st) # or `model(x, ps, st)`
LuxCore.apply(model, x, ps, st) # or `model(x, ps, st)`
```

## Parameter Interface
Expand All @@ -180,7 +169,7 @@ We accept any parameter type as long as we can fetch the parameters using
and `ComponentArray`s. Let us go through a concrete example of what it means. Consider
[`Dense`](@ref) which expects two parameters named `weight` and `bias`.

!!! info
!!! note "Automatic Differentiation"

If you are defining your own parameter type, it is your responsibility to make sure that
it works with the AutoDiff System you are using.
Expand All @@ -192,7 +181,7 @@ d = Dense(2, 3)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps_default, st = Lux.setup(rng, d)
ps_default, st = LuxCore.setup(rng, d)
x = randn(rng, Float32, 2, 1)
Expand All @@ -202,7 +191,7 @@ println("Result with `NamedTuple` parameters: ", first(d(x, ps_default, st)))
Let, us define a custom parameter type with fields `myweight` and `mybias` but if we try to
access `weight` we get back `myweight`, similar for `bias`.

!!! warning
!!! warning "Beware!"

This is for demonstrative purposes, don't try this at home!

Expand Down
4 changes: 2 additions & 2 deletions docs/src/manual/weight_initializers.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ weights_cl = kaiming_normal(; gain=1.0)
weights = weights_cl(2, 5)
```

To generate weights directly on GPU, pass in a `CUDA.RNG`. (Note that this is currently
implemented only for NVIDIA GPUs)
To generate weights directly on GPU, pass in a `CUDA.RNG`. For a complete list of supported
RNG types, see [Supported RNG Types](@ref Supported-RNG-Types-WeightInit).

```@example weight-init
using LuxCUDA
Expand Down
10 changes: 5 additions & 5 deletions src/helpers/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ products efficiently using mixed-mode AD.
| Supported Backends | Packages Needed |
| :----------------- | :--------------- |
| `AutoForwardDiff` | `ForwardDiff.jl` |
| `AutoForwardDiff` | |
!!! warning
Expand Down Expand Up @@ -85,10 +85,10 @@ the following properties for `y = f(x)`:
## Backends & AD Packages
| Supported Backends | Packages Needed |
|:------------------ |:---------------- |
| `AutoForwardDiff` | `ForwardDiff.jl` |
| `AutoZygote` | `Zygote.jl` |
| Supported Backends | Packages Needed |
|:------------------ |:--------------- |
| `AutoForwardDiff` | |
| `AutoZygote` | `Zygote.jl` |
## Arguments
Expand Down
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Standard convolutional layer.
!!! tip "Conv2D"
!!! tip "Conv 2D"
Image data should be stored in WHCN order (width, height, channels, batch). In other
words, a `100 x 100` RGB image would be a `100 x 100 x 3 x 1` array, and a batch of 50
Expand Down

1 comment on commit e1594bf

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: e1594bf Previous: 708c10c Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3693.125 ns 3663.125 ns 1.01
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7382.428571428572 ns 7402.428571428572 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20759 ns 21370 ns 0.97
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9812.4 ns 9748.2 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9094.5 ns 8969.25 ns 1.01
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4493.375 ns 4458.375 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 4562.25 ns 4643.625 ns 0.98
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1109.9668874172185 ns 1171.4923076923078 ns 0.95
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1181.937062937063 ns 1197.5373134328358 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1793.1929824561403 ns 1840.7045454545455 ns 0.97
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.68289290681503 ns 179.47899159663865 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17342 ns 17282.5 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16932 ns 17142 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37520 ns 37821 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28584 ns 28674 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20278 ns 20158 ns 1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17392.5 ns 17573 ns 0.99
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25578 ns 25558 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3938.625 ns 3867.25 ns 1.02
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3991.25 ns 3951.125 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 5029.428571428572 ns 4960.714285714285 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1652.1 ns 1652.1 ns 1
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 39521781 ns 40993082 ns 0.96
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58585547 ns 58497971 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 78099068 ns 82819165 ns 0.94
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 90769018 ns 92585992 ns 0.98
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72690086 ns 78243436 ns 0.93
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12054967 ns 12365359.5 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 83400155.5 ns 90927109 ns 0.92
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7721292 ns 7696963 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7581495.5 ns 7605182 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9955016 ns 11852587 ns 0.84
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6388090.5 ns 6417475.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 725603932 ns 715800734 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2550320476 ns 2585785068 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 138974187 ns 145144322 ns 0.96
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 803301267 ns 857691023 ns 0.94
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2816236020 ns 3013396972 ns 0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 235878850.5 ns 271704927.5 ns 0.87
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 650791127 ns 796866243 ns 0.82
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2603380595.5 ns 2712566258 ns 0.96
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 135543456 ns 130670825 ns 1.04
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 175275149 ns 176972248 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 658937990.5 ns 660618124.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 35195315 ns 37524620 ns 0.94
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 166640478.5 ns 167203678 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 647602016 ns 650758175 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30477927 ns 30535622 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 186029467 ns 212791073.5 ns 0.87
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 718031730.5 ns 854087743 ns 0.84
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 38445179.5 ns 36391416 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1233631654 ns 1349278132.5 ns 0.91
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1887704692.5 ns 1887134256.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2231708206 ns 2450133270 ns 0.91
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2471164221 ns 2519902348 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1848957394.5 ns 1970487480.5 ns 0.94
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 1990760702 ns 2023950360 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 333383541 ns 347833228 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 330228531 ns 342952632 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 365192205 ns 404098087 ns 0.90
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11909407 ns 12035460 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18028312 ns 18156508.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19050490 ns 19338119 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23803362 ns 24010708 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17829963 ns 17942213 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1163076.5 ns 1209037 ns 0.96
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 22925657 ns 23232956 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2329735 ns 2293500 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2215453 ns 2221085 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2068279 ns 2088387 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 200525.5 ns 210493 ns 0.95
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293690.5 ns 297254.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 267291 ns 268511 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 369002 ns 372155 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 409669 ns 414683 ns 0.99
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275557 ns 275874.5 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 407905 ns 413892 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 397495 ns 399165 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81824 ns 81923 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82175 ns 82824 ns 0.99
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87213 ns 87643 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104706 ns 104415 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 193496448 ns 193631929 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 330361771 ns 331527065 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 425098445.5 ns 412143695 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 494104527 ns 469727799.5 ns 1.05
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 375779602 ns 384904410 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 317049031.5 ns 355709715 ns 0.89
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 471078717.5 ns 472893365 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 47539749 ns 47367307 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 46905764 ns 46829064.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 56265522 ns 57797619.5 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28878315 ns 28515793 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19552999 ns 19208519 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19606842 ns 19878820 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23577960 ns 23974472 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24212297 ns 24477689.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19666577 ns 19960697.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 21011715 ns 21327622 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6537708 ns 6617984 ns 0.99
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6511419 ns 6558899 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6527700 ns 6559310 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.