Skip to content

Commit

Permalink
Broadcast with closure
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 1, 2024
1 parent c20aec5 commit 6d12cbf
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 18 deletions.
18 changes: 13 additions & 5 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":julia: CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
- label: ":julia: Julia {{matrix.julia}} + {{matrix.testing_group}} + CUDA GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
Expand All @@ -17,12 +17,16 @@ steps:
cuda: "*"
env:
BACKEND_GROUP: "CUDA"
LUX_TEST_GROUP: "{{matrix.testing_group}}"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 240
timeout_in_minutes: 35
matrix:
setup:
julia:
- "1"
testing_group:
- "distributed"
- "!distributed"

- group: ":telescope: Downstream CUDA"
steps:
Expand All @@ -42,7 +46,7 @@ steps:
env:
RETESTITEMS_NWORKERS: 2
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 240
timeout_in_minutes: 60
matrix:
setup:
repo:
Expand All @@ -52,7 +56,7 @@ steps:

- group: ":julia: AMD GPU"
steps:
- label: ":julia: Julia: {{matrix.julia}} + AMD GPU"
- label: ":julia: Julia: {{matrix.julia}} + {{matrix.testing_group}} + AMD GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
Expand All @@ -68,16 +72,20 @@ steps:
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
BACKEND_GROUP: "AMDGPU"
LUX_TEST_GROUP: "{{matrix.testing_group}}"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 60
timeout_in_minutes: 35
matrix:
setup:
julia:
- "1"
testing_group:
- "distributed"
- "!distributed"

- group: ":telescope: Downstream AMD GPU"
steps:
Expand Down
9 changes: 6 additions & 3 deletions ext/LuxForwardDiffExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Low-Level functions
@inline function Lux.__partials(::Type{Tag}, x, i) where {Tag}
x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i)
x isa AbstractArray && return ForwardDiff.partials.(Tag, x, i)
if x isa AbstractArray
bfn(xᵢ, iᵢ) = ForwardDiff.partials(Tag, xᵢ, iᵢ)
return bfn.(x, i)

Check warning on line 6 in ext/LuxForwardDiffExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxForwardDiffExt/utils.jl#L4-L6

Added lines #L4 - L6 were not covered by tests
end
map_fn = @closure(xᵢ->Lux.__partials(Tag, xᵢ, i))
x isa Tuple && return map(map_fn, x)
x isa NamedTuple && return NamedTuple{keys(x)}(map(map_fn, values(x)))
Expand All @@ -12,8 +15,8 @@ end

@inline function Lux.__dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T}
if x isa AbstractArray
return ForwardDiff.Dual{
Tag, T, 1}.(x, ForwardDiff.Partials{1, T}.(tuple.(reshape(u, size(x)))))
bfn(xᵢ, uᵢ) = ForwardDiff.Dual{Tag, T, 1}(xᵢ, ForwardDiff.Partials{1, T}(uᵢ))
return bfn.(x, tuple.(reshape(u, size(x))))

Check warning on line 19 in ext/LuxForwardDiffExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxForwardDiffExt/utils.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end
x isa Tuple && return map((xᵢ, uᵢ) -> Lux.__dualify(Tag, T, xᵢ, uᵢ), x, u)
x isa NamedTuple &&
Expand Down
3 changes: 1 addition & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ function WrappedFunction(f::F) where {F}
# Not a depwarn but helpful to call this
Base.depwarn("The current default of `:direct_call` will be replaced with \
`:runtime_check` from v0.6). Please make sure that the assumptions of \
this function are correct or specific \
`WrappedFunction{:direct_call}(f)`",
this function are correct or specify `WrappedFunction{:direct_call}(f)`",
:WrappedFunction)
return WrappedFunction{:direct_call}(f)
end
Expand Down
10 changes: 6 additions & 4 deletions test/helpers/nestedad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@
(loss_function1, loss_function2, loss_function3, loss_function4)

for loss_fn in loss_fns
@test_nowarn loss_fn(model, X, ps, st)
@test loss_fn(model, X, ps, st) isa Number
l = loss_fn(model, X, ps, st)
@test l isa Number
@test isfinite(l) && !isnan(l)

_, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st)

Expand Down Expand Up @@ -130,8 +131,9 @@ end
(loss_function1, loss_function2, loss_function3, loss_function4)

for loss_fn in loss_fns
@test_nowarn loss_fn(model, X, ps, st)
@test loss_fn(model, X, ps, st) isa Number
l = loss_fn(model, X, ps, st)
@test l isa Number
@test isfinite(l) && !isnan(l)

_, ∂x, ∂ps, _ = Zygote.gradient(loss_fn, model, X, ps, st)

Expand Down
7 changes: 5 additions & 2 deletions test/layers/containers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "zero sum" begin
layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b)
layer = SkipConnection(
WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), .+)
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
x = randn(rng, 10, 10, 10, 10) |> aType
Expand Down Expand Up @@ -36,7 +37,9 @@ end

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "zero sum" begin
layer = Parallel(+, WrappedFunction(zero), NoOpLayer())
layer = Parallel(
+, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)),
NoOpLayer())
@test :layer_1 in keys(layer) && :layer_2 in keys(layer)
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
Expand Down
3 changes: 2 additions & 1 deletion test/layers/type_stability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@
model in model_list,
input in inputs

model = maybe_rewrite_to_crosscor(mode, model)
ps, st = Lux.setup(rng, model) |> dev
x = input |> dev

@inferred model(x, ps, st)
@inferred loss_function(model, x, ps, st)
if mode == "amdgpu" && model isa Conv
if mode == "amdgpu" && (model isa Conv || model isa CrossCor)
@test_broken @inferred Zygote.gradient(loss_function, model, x, ps, st)
else
@inferred Zygote.gradient(loss_function, model, x, ps, st)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ end
end

# Distributed Tests
if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) && BACKEND_GROUP != "amdgpu"
if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP)
using MPI

nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "")
Expand Down

0 comments on commit 6d12cbf

Please sign in to comment.