Skip to content

Commit

Permalink
Merge pull request #741 from LuxDL/ap/clean
Browse files Browse the repository at this point in the history
Use shorthand syntax of @concrete
  • Loading branch information
avik-pal authored Jun 28, 2024
2 parents 76eba2b + 747ec9a commit 1a61165
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
4 changes: 2 additions & 2 deletions ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ function Optimisers.adjust(ts::Lux.Experimental.TrainState; kwargs...)
end

# DistributedUtils
@concrete struct DistributedOptimizer{B <: AbstractLuxDistributedBackend} <: AbstractRule
backend::B
@concrete struct DistributedOptimizer <: AbstractRule
backend <: AbstractLuxDistributedBackend
opt
end

Expand Down
13 changes: 6 additions & 7 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ julia> logitbce_ls(y_model, y_bin) > logitbce(y_model, y_bin)
true
```
"""
@concrete struct BinaryCrossEntropyLoss{logits, L <: Union{Nothing, Real}} <:
AbstractLossFunction
label_smoothing::L
@concrete struct BinaryCrossEntropyLoss{logits} <: AbstractLossFunction
label_smoothing <: Union{Nothing, Real}
agg
epsilon
end
Expand Down Expand Up @@ -192,8 +191,8 @@ julia> CrossEntropyLoss(label_smoothing=0.15)(y_model, y) ≈ 1.5776052f0
true
```
"""
@concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction
label_smoothing::L
@concrete struct CrossEntropyLoss{logits} <: AbstractLossFunction
label_smoothing <: Union{Nothing, Real}
dims
agg
epsilon
Expand Down Expand Up @@ -412,10 +411,10 @@ julia> KLDivergenceLoss(; epsilon=0)(p1, p2)
Inf
```
"""
@concrete struct KLDivergenceLoss{C <: CrossEntropyLoss} <: AbstractLossFunction
@concrete struct KLDivergenceLoss <: AbstractLossFunction
agg
dims
celoss::C
celoss <: CrossEntropyLoss
end

function KLDivergenceLoss(; dims=1, agg=mean, epsilon=nothing, label_smoothing=nothing)
Expand Down
13 changes: 6 additions & 7 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ julia> size.(first(model((x1, x2), ps, st)))
((1,), (1,))
```
"""
@concrete struct Parallel{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
@concrete struct Parallel <: AbstractExplicitContainerLayer{(:layers,)}
connection
layers::T
layers <: NamedTuple
name
end

Expand Down Expand Up @@ -346,10 +346,9 @@ end
- States of each `layer` wrapped in a NamedTuple with
`fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API)
"""
@concrete struct PairwiseFusion{T <: NamedTuple} <:
AbstractExplicitContainerLayer{(:layers,)}
@concrete struct PairwiseFusion <: AbstractExplicitContainerLayer{(:layers,)}
connection
layers::T
layers <: NamedTuple
name
end

Expand Down Expand Up @@ -457,8 +456,8 @@ Chain(
# plus 7 states.
```
"""
@concrete struct Chain{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
layers::T
@concrete struct Chain <: AbstractExplicitContainerLayer{(:layers,)}
layers <: NamedTuple
name
end

Expand Down

1 comment on commit 1a61165

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 1a61165 Previous: 76eba2b Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3658.125 ns 3729.375 ns 0.98
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7185.4 ns 7106.666666666667 ns 1.01
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21610 ns 21039 ns 1.03
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9954.4 ns 9638 ns 1.03
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8986.625 ns 8896.666666666666 ns 1.01
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4523.375 ns 4458.375 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1162.9044117647059 ns 1153.442857142857 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1112.9741935483871 ns 1112.3375796178343 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1163.123287671233 ns 1152.2266666666667 ns 1.01
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1774.3508771929824 ns 1770.2068965517242 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.07943262411348 ns 179.34615384615384 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17313 ns 17413 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16952 ns 17172 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 39563 ns 37320 ns 1.06
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29234 ns 29255 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 21470 ns 21420 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17242 ns 17332 ns 0.99
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4325.285714285715 ns 4320.857142857143 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3880.875 ns 3852.25 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3936 ns 3963.625 ns 0.99
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4889 ns 4873.428571428572 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1663.1 ns 1664.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 41122113 ns 46289048 ns 0.89
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57716330 ns 57928426 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 81840420 ns 109544567 ns 0.75
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 91437945.5 ns 106428149 ns 0.86
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 75926448 ns 105745221 ns 0.72
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11662872.5 ns 11590862.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17787447 ns 17707811 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7032740 ns 7000562 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6996461 ns 6959465 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12114935 ns 18180821 ns 0.67
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6387272 ns 6376496 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 767844159 ns 736289591 ns 1.04
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2602811621 ns 2518006964 ns 1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 144091616 ns 132959679 ns 1.08
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 893102323 ns 903910474.5 ns 0.99
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3246592469 ns 3242525967 ns 1.00
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 222301050.5 ns 236061029 ns 0.94
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 723583721 ns 729785075 ns 0.99
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2632526936 ns 2997367730 ns 0.88
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 126310974.5 ns 127309525 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 176488869 ns 179367196 ns 0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 649092634 ns 649655820.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34416017 ns 45338820 ns 0.76
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 163630557 ns 164121883 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 640318006 ns 636120217 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29797077.5 ns 30331214.5 ns 0.98
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 225565619 ns 202050007 ns 1.12
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 851688514 ns 902440023 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 38079734.5 ns 37215471 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1261029479 ns 1226513913.5 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1854091416 ns 1870479613 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2553348643 ns 2368702024 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2528691188 ns 2555200453 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1925582416 ns 1948553539.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 562461809 ns 555581777 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 320567710 ns 318311145 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 318656712 ns 318132255 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 379455021 ns 395277295 ns 0.96
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12014151 ns 11904025 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17777207 ns 17973921.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 18982505.5 ns 19198939 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23814355 ns 23864657.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17776145 ns 17949860 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1163571 ns 1152816 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5835756.5 ns 5788207 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2049324.5 ns 2056086 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2030939 ns 2035808 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2074821 ns 2082205 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 202707 ns 195797 ns 1.04
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 291222 ns 296394 ns 0.98
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265629.5 ns 266483.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 364529 ns 368559 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 407123 ns 411520 ns 0.99
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273649.5 ns 275104.5 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 407198 ns 409701.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83466 ns 83085 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81592 ns 81983 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82003 ns 82910.5 ns 0.99
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86571 ns 86782 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104455 ns 104706 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 200999670 ns 204785709 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 324420843 ns 325733461 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 417216292 ns 449856941 ns 0.93
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 459212467 ns 482834181 ns 0.95
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 391131363 ns 421978168.5 ns 0.93
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 320280603 ns 320008908 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 102161858 ns 100911992.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43990307.5 ns 43729818 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43757610.5 ns 43491960 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 57183117.5 ns 57142792 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28324912 ns 28372687.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18822434 ns 18881824 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19540731 ns 19470234 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23356974 ns 23433086 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24094714 ns 24127103 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19581858.5 ns 19572785.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6520727 ns 6512344 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6499891.5 ns 6519102 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6481714 ns 6469715 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6495020 ns 6473077 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.