Skip to content

Commit

Permalink
Merge branch 'complex_utils' of https://github.com/simeonschaub/Chain…
Browse files Browse the repository at this point in the history
…Rules.jl into complex_utils
  • Loading branch information
simeonschaub committed Aug 29, 2019
2 parents 31b9bf4 + b75dba4 commit 63bd4ea
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,19 @@

@scalar_rule(hypot(x, y), (x / Ω, y / Ω))
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx)
@scalar_rule(atan(x, y), @setup(u = x^2 + y^2), (y / u, -x / u))

@scalar_rule(atan(y, x), @setup(u = x^2 + y^2), (x / u, -y / u))
@scalar_rule(max(x, y), @setup(gt = x > y), (gt, !gt))
@scalar_rule(min(x, y), @setup(gt = x > y), (!gt, gt))
@scalar_rule(mod(x, y), @setup((u, nan) = promote(x / y, NaN16)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -floor(u))))
@scalar_rule(rem(x, y), @setup((u, nan) = promote(x / y, NaN16)),
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
@scalar_rule(real(x::Real), One())
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
@scalar_rule(imag(x::Real), Zero())

# product rule requires special care for arguments where `mul` is non-commutative

Expand Down
12 changes: 12 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@
end
end

@testset "Unary complex functions" begin
for x in rand.((Int, Float32, Float64, complex.((Float32, Float64))...))
test_scalar(real, x)
test_scalar(imag, x)
test_scalar(abs, x)
test_scalar(abs, x)
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
end
end

@testset "*(x, y)" begin
x, y = rand(3, 2), rand(2, 5)
z, (dx, dy) = rrule(*, x, y)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Test
using ChainRulesCore: add, cast, extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
DNE, Thunk, Casted, DNERule
DNE, Thunk, Casted, DNERule, AbstractDifferential

include("test_util.jl")

Expand Down

0 comments on commit 63bd4ea

Please sign in to comment.