Skip to content

Commit

Permalink
Broadcast with closure
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 30, 2024
1 parent c20aec5 commit 1189ae3
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ steps:
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
BACKEND_GROUP: "AMDGPU"
RETESTITEMS_NWORKERS: 2
agents:
queue: "juliagpu"
rocm: "*"
Expand Down
9 changes: 6 additions & 3 deletions ext/LuxForwardDiffExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Low-Level functions
@inline function Lux.__partials(::Type{Tag}, x, i) where {Tag}
x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i)
x isa AbstractArray && return ForwardDiff.partials.(Tag, x, i)
if x isa AbstractArray
bfn(xᵢ, iᵢ) = ForwardDiff.partials(Tag, xᵢ, iᵢ)
return bfn.(x, i)

Check warning on line 6 in ext/LuxForwardDiffExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxForwardDiffExt/utils.jl#L4-L6

Added lines #L4 - L6 were not covered by tests
end
map_fn = @closure(xᵢ->Lux.__partials(Tag, xᵢ, i))
x isa Tuple && return map(map_fn, x)
x isa NamedTuple && return NamedTuple{keys(x)}(map(map_fn, values(x)))
Expand All @@ -12,8 +15,8 @@ end

@inline function Lux.__dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T}
if x isa AbstractArray
return ForwardDiff.Dual{
Tag, T, 1}.(x, ForwardDiff.Partials{1, T}.(tuple.(reshape(u, size(x)))))
bfn(xᵢ, uᵢ) = ForwardDiff.Dual{Tag, T, 1}(xᵢ, ForwardDiff.Partials{1, T}(uᵢ))
return bfn.(x, tuple.(reshape(u, size(x))))

Check warning on line 19 in ext/LuxForwardDiffExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxForwardDiffExt/utils.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end
x isa Tuple && return map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u)
x isa NamedTuple &&
Expand Down
10 changes: 6 additions & 4 deletions test/helpers/nestedad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@
(loss_function1, loss_function2, loss_function3, loss_function4)

for loss_fn in loss_fns
@test_nowarn loss_fn(model, X, ps, st)
@test loss_fn(model, X, ps, st) isa Number
l = loss_fn(model, X, ps, st)
@test l isa Number
@test isfinite(l) && !isnan(l)

_, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st)

Expand Down Expand Up @@ -130,8 +131,9 @@ end
(loss_function1, loss_function2, loss_function3, loss_function4)

for loss_fn in loss_fns
@test_nowarn loss_fn(model, X, ps, st)
@test loss_fn(model, X, ps, st) isa Number
l = loss_fn(model, X, ps, st)
@test l isa Number
@test isfinite(l) && !isnan(l)

_, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st)

Expand Down
3 changes: 2 additions & 1 deletion test/layers/type_stability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
model in model_list,
input in inputs

model = maybe_rewrite_to_crosscor(mode, model)
ps, st = Lux.setup(rng, model) |> dev
x = input |> dev

@inferred model(x, ps, st)
@inferred loss_function(model, x, ps, st)
if mode == "amdgpu" && model isa Conv
if mode == "amdgpu" && (model isa Conv || model isa CrossCor)
@test_broken @inferred Zygote.gradient(loss_function, model, x, ps, st)
else
@inferred Zygote.gradient(loss_function, model, x, ps, st)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
end

# Distributed Tests
if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) && BACKEND_GROUP != "amdgpu"
if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP)
using MPI

nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "")
Expand Down

0 comments on commit 1189ae3

Please sign in to comment.