From 5c3e28690696f4f84e59b7504c196239d88cffc0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Sep 2024 10:49:15 -0400 Subject: [PATCH] fix: static vector input to dense --- Project.toml | 2 +- src/utils.jl | 3 +++ test/Project.toml | 2 ++ test/layers/basic_tests.jl | 19 +++++++++++++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9a4bad210..971cf4a21 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.0.4" +version = "1.0.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index 13e442d08..1e2929dc7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,7 @@ using ForwardDiff: Dual using Functors: fmapstructure using Random: AbstractRNG using Static: Static, StaticBool, StaticInteger, StaticSymbol +using StaticArraysCore: SMatrix, SVector using LuxCore: LuxCore, AbstractLuxLayer using MLDataDevices: get_device @@ -199,10 +200,12 @@ function named_tuple_layers(layers::Vararg{AbstractLuxLayer, N}) where {N} end make_abstract_matrix(x::AbstractVector) = reshape(x, :, 1) +make_abstract_matrix(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) make_abstract_matrix(x::AbstractMatrix) = x make_abstract_matrix(x::AbstractArray{T, N}) where {T, N} = reshape(x, Base.size(x, 1), :) matrix_to_array(x::AbstractMatrix, ::AbstractVector) = vec(x) +matrix_to_array(x::SMatrix{L, 1, T}, ::AbstractVector) where {L, T} = SVector{L, T}(x) matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...) diff --git a/test/Project.toml b/test/Project.toml index 6b10e13e8..c7256fc0f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,6 +32,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -71,6 +72,7 @@ Setfield = "1.1.1" SimpleChains = "0.4.7" StableRNGs = "1.0.2" Static = "1" +StaticArrays = "1.9" Statistics = "1.11.1" Test = "1.10" Tracker = "0.2.34" diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index bbf4e3763..beee1536e 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -170,6 +170,25 @@ end end end +@testitem "Dense StaticArrays" setup=[SharedTestSetup] tags=[:core_layers] begin + using StaticArrays, Enzyme, ForwardDiff, ComponentArrays + + N = 8 + d = Lux.Dense(N => N) + ps = (; + weight=randn(SMatrix{N, N, Float64}), + bias=randn(SVector{N, Float64}) + ) + x = randn(SVector{N, Float64}) + + fun = let d = d, x = x + ps -> sum(d(x, ps, (;))[1]) + end + grad1 = ForwardDiff.gradient(fun, ComponentVector(ps)) + grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps) + @test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6 +end + @testitem "Scale" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345)