diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 5910767a71..e6ced872e2 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -68,6 +68,7 @@ steps: JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" BACKEND_GROUP: "AMDGPU" + RETESTITEMS_NWORKERS: 2 agents: queue: "juliagpu" rocm: "*" diff --git a/ext/LuxForwardDiffExt/utils.jl b/ext/LuxForwardDiffExt/utils.jl index 1556a00c2d..f7afd78227 100644 --- a/ext/LuxForwardDiffExt/utils.jl +++ b/ext/LuxForwardDiffExt/utils.jl @@ -1,7 +1,10 @@ # Low-Level functions @inline function Lux.__partials(::Type{Tag}, x, i) where {Tag} x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i) - x isa AbstractArray && return ForwardDiff.partials.(Tag, x, i) + if x isa AbstractArray + bfn(xᵢ, iᵢ) = ForwardDiff.partials(Tag, xᵢ, iᵢ) + return bfn.(x, i) + end map_fn = @closure(xᵢ->Lux.__partials(Tag, xᵢ, i)) x isa Tuple && return map(map_fn, x) x isa NamedTuple && return NamedTuple{keys(x)}(map(map_fn, values(x))) @@ -12,8 +15,8 @@ end @inline function Lux.__dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T} if x isa AbstractArray - return ForwardDiff.Dual{ - Tag, T, 1}.(x, ForwardDiff.Partials{1, T}.(tuple.(reshape(u, size(x))))) + bfn(xᵢ, uᵢ) = ForwardDiff.Dual{Tag, T, 1}(xᵢ, ForwardDiff.Partials{1, T}(uᵢ)) + return bfn.(x, tuple.(reshape(u, size(x)))) end x isa Tuple && return map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u) x isa NamedTuple && diff --git a/test/helpers/nestedad_tests.jl b/test/helpers/nestedad_tests.jl index 05c42c1e12..75c497d9d5 100644 --- a/test/helpers/nestedad_tests.jl +++ b/test/helpers/nestedad_tests.jl @@ -50,8 +50,9 @@ (loss_function1, loss_function2, loss_function3, loss_function4) for loss_fn in loss_fns - @test_nowarn loss_fn(model, X, ps, st) - @test loss_fn(model, X, ps, st) isa Number + l = loss_fn(model, X, ps, st) + @test l isa Number + @test isfinite(l) && !isnan(l) _, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st) @@ -130,8 +131,9 @@ end (loss_function1, loss_function2, loss_function3, loss_function4) for loss_fn in loss_fns - @test_nowarn loss_fn(model, X, ps, st) - @test loss_fn(model, X, ps, st) isa Number + l = loss_fn(model, X, ps, st) + @test l isa Number + @test isfinite(l) && !isnan(l) _, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st) diff --git a/test/layers/type_stability_tests.jl b/test/layers/type_stability_tests.jl index 6e6fc184b6..2ef7a902c7 100644 --- a/test/layers/type_stability_tests.jl +++ b/test/layers/type_stability_tests.jl @@ -70,12 +70,13 @@ model in model_list, input in inputs + model = maybe_rewrite_to_crosscor(mode, model) ps, st = Lux.setup(rng, model) |> dev x = input |> dev @inferred model(x, ps, st) @inferred loss_function(model, x, ps, st) - if mode == "amdgpu" && model isa Conv + if mode == "amdgpu" && (model isa Conv || model isa CrossCor) @test_broken @inferred Zygote.gradient(loss_function, model, x, ps, st) else @inferred Zygote.gradient(loss_function, model, x, ps, st) diff --git a/test/runtests.jl b/test/runtests.jl index 247b3d1283..1b66ea2041 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,7 +84,7 @@ end end # Distributed Tests -if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) && BACKEND_GROUP != "amdgpu" +if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) using MPI nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "")