Skip to content

Commit

Permalink
fix: static vector input to dense
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 18, 2024
1 parent 4b6fa05 commit 5c3e286
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.0.4"
version = "1.0.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ForwardDiff: Dual
using Functors: fmapstructure
using Random: AbstractRNG
using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: get_device
Expand Down Expand Up @@ -199,10 +200,12 @@ function named_tuple_layers(layers::Vararg{AbstractLuxLayer, N}) where {N}
end

make_abstract_matrix(x::AbstractVector) = reshape(x, :, 1)
make_abstract_matrix(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x)
make_abstract_matrix(x::AbstractMatrix) = x
make_abstract_matrix(x::AbstractArray{T, N}) where {T, N} = reshape(x, Base.size(x, 1), :)

matrix_to_array(x::AbstractMatrix, ::AbstractVector) = vec(x)
matrix_to_array(x::SMatrix{L, 1, T}, ::AbstractVector) where {L, T} = SVector{L, T}(x)
matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x
matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...)

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down Expand Up @@ -71,6 +72,7 @@ Setfield = "1.1.1"
SimpleChains = "0.4.7"
StableRNGs = "1.0.2"
Static = "1"
StaticArrays = "1.9"
Statistics = "1.11.1"
Test = "1.10"
Tracker = "0.2.34"
Expand Down
19 changes: 19 additions & 0 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ end
end
end

@testitem "Dense StaticArrays" setup=[SharedTestSetup] tags=[:core_layers] begin
using StaticArrays, Enzyme, ForwardDiff, ComponentArrays

N = 8
d = Lux.Dense(N => N)
ps = (;
weight=randn(SMatrix{N, N, Float64}),
bias=randn(SVector{N, Float64})
)
x = randn(SVector{N, Float64})

fun = let d = d, x = x
ps -> sum(d(x, ps, (;))[1])
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)
@test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6
end

@testitem "Scale" setup=[SharedTestSetup] tags=[:core_layers] begin
rng = StableRNG(12345)

Expand Down

0 comments on commit 5c3e286

Please sign in to comment.