Skip to content

Commit

Permalink
Test ChainRules integration directly
Browse files Browse the repository at this point in the history
Update test/chainrules.jl

Update test/chainrules.jl
  • Loading branch information
oxinabox committed Oct 19, 2019
1 parent c89c5d0 commit 74396b4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
49 changes: 49 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using Zygote, Test, ChainRules

const cr_inner_demo_rrule_hitcount = Ref(0)
const cr_inner_demo_pullback_hitcount = Ref(0)
cr_inner_demo(x) = 5x
function ChainRules.rrule(::typeof(cr_inner_demo), x)
cr_inner_demo_rrule_hitcount[] += 1
function cr_inner_demo_pullback(Δx)
cr_inner_demo_pullback_hitcount[] += 1
return ChainRules.NO_FIELDS, 5.0*Δx
end
return cr_inner_demo(x), cr_inner_demo_pullback
end

function cr_outer_demo(x)
2 + 10cr_inner_demo(x)
end

@testset "ChainRules Integration" begin
@testset "gradient inner" begin
cr_inner_demo_rrule_hitcount[] = 0
cr_inner_demo_pullback_hitcount[] = 0
@test (5.0,) == gradient(cr_inner_demo, 11)
@test cr_inner_demo_rrule_hitcount[] == 1
@test cr_inner_demo_pullback_hitcount[] == 1
end

@testset "gradient outer" begin
cr_inner_demo_rrule_hitcount[] = 0
cr_inner_demo_pullback_hitcount[] = 0
@test (50.0,) == gradient(cr_outer_demo, 11)
@test cr_inner_demo_rrule_hitcount[] == 1
@test cr_inner_demo_pullback_hitcount[] == 1
end

@testset "pullback inner" begin
cr_inner_demo_rrule_hitcount[] = 0
cr_inner_demo_pullback_hitcount[] = 0
y, pb = pullback(cr_inner_demo, 11)
@test y == 55
@test cr_inner_demo_rrule_hitcount[] == 1
@test cr_inner_demo_pullback_hitcount[] == 0
@test pb(1)==(5.0,);
@test pb(2)==(10.0,);
@test pb(3)==(15.0,);
@test cr_inner_demo_pullback_hitcount[] == 3
@test cr_inner_demo_rrule_hitcount[] == 1
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ end
include("structures.jl")
end

@info "Testing ChainRules integration"

@testset "ChainRules" begin
include("chainrules.jl")
end

@info "Running Gradient Checks"

@testset "Gradients" begin
Expand Down

0 comments on commit 74396b4

Please sign in to comment.