Skip to content

Commit

Permalink
Add differentiation rules from ChainRules (#238)
Browse files Browse the repository at this point in the history
* Add differentiation rules from ChainRules

* Allow test failures on Julia nightly

* Allow failures (correctly?)

* Try to avoid spurious test failures by setting seed

* Throw error instead of returning NaN

* Fix test errors
  • Loading branch information
devmotion authored Dec 4, 2020
1 parent e8a1e5c commit 9f230e6
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ julia:
- 1.3
- 1
- nightly
matrix:
allow_failures:
- julia: nightly
notifications:
email: false

Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e"

[compat]
ChainRulesCore = "0.9"
OpenSpecFun_jll = "0.5.3"
julia = "1.3"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["ChainRulesTestUtils", "Random", "Test"]
3 changes: 3 additions & 0 deletions src/SpecialFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module SpecialFunctions

import ChainRulesCore

using OpenSpecFun_jll

export
Expand Down Expand Up @@ -71,6 +73,7 @@ include("gamma.jl")
include("gamma_inc.jl")
include("betanc.jl")
include("beta_inc.jl")
include("chainrules.jl")
include("deprecated.jl")

for f in (:digamma, :erf, :erfc, :erfcinv, :erfcx, :erfi, :erfinv, :logerfc, :logerfcx,
Expand Down
95 changes: 95 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x))
ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x))
ChainRulesCore.@scalar_rule(
besselj1(x),
(besselj0(x) - besselj(2, x)) / 2,
)
ChainRulesCore.@scalar_rule(bessely0(x), -bessely1(x))
ChainRulesCore.@scalar_rule(
bessely1(x),
(bessely0(x) - bessely(2, x)) / 2,
)
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp^2))
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))
ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
ChainRulesCore.@scalar_rule(
invdigamma(x),
inv(trigamma(invdigamma(x))),
)
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))

# binary
ChainRulesCore.@scalar_rule(
besselj(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
(besselj- 1, x) - besselj+ 1, x)) / 2
),
)
ChainRulesCore.@scalar_rule(
besseli(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
(besseli- 1, x) + besseli+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
bessely(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
(bessely- 1, x) - bessely+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
besselk(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
-(besselk- 1, x) + besselk+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh1(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh2(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
polygamma(m, x),
(
ChainRulesCore.@thunk(error("not implemented")),
polygamma(m + 1, x),
),
)
# todo: setup for common expr
ChainRulesCore.@scalar_rule(
beta(a, b),
*(digamma(a) - digamma(a + b)),
Ω*(digamma(b) - digamma(a + b)),)
)
ChainRulesCore.@scalar_rule(
logbeta(a, b),
(digamma(a) - digamma(a + b),
digamma(b) - digamma(a + b),)
)

# actually is the absolute value of the logorithm of gamma paired with sign gamma
ChainRulesCore.@scalar_rule(logabsgamma(x), digamma(x), ChainRulesCore.Zero())

ChainRulesCore.@scalar_rule(loggamma(x), digamma(x))
74 changes: 74 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
@testset "chainrules" begin
Random.seed!(1)

@testset "general" begin
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
test_scalar(erf, x)
test_scalar(erfc, x)
test_scalar(erfi, x)

test_scalar(airyai, x)
test_scalar(airyaiprime, x)
test_scalar(airybi, x)
test_scalar(airybiprime, x)

test_scalar(besselj0, x)
test_scalar(besselj1, x)

test_scalar(erfcx, x)
test_scalar(dawson, x)

if x isa Real
test_scalar(invdigamma, x)
end

if x isa Real && 0 < x < 1
test_scalar(erfinv, x)
test_scalar(erfcinv, x)
end

if x isa Real && x > 0 || x isa Complex
test_scalar(bessely0, x)
test_scalar(bessely1, x)
test_scalar(gamma, x)
test_scalar(digamma, x)
test_scalar(trigamma, x)
end
end

@testset "beta and logbeta" begin
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
for _x in test_points, _y in test_points
# ensure all complex if any complex for FiniteDifferences
x, y = promote(_x, _y)
T = typeof(x)

Δx, x̄ = randn(T, 2)
Δy, ȳ = randn(T, 2)
Δz = randn(T)

frule_test(beta, (x, Δx), (y, Δy))
rrule_test(beta, Δz, (x, x̄), (y, ȳ))

frule_test(logbeta, (x, Δx), (y, Δy))
rrule_test(logbeta, Δz, (x, x̄), (y, ȳ))
end
end

@testset "log gamma and co" begin
# It is important that we have negative numbers with both odd and even integer parts
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
isreal(x) && x < 0 && continue
test_scalar(loggamma, x)

isreal(x) || continue

Δx, x̄ = randn(2)
Δz = (randn(), randn())

frule_test(logabsgamma, (x, Δx))
rrule_test(logabsgamma, Δz, (x, x̄))
end
end
end
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This file contains code that was formerly a part of Julia. License is MIT: http://julialang.org/license

using SpecialFunctions
using ChainRulesTestUtils
using Random
using Test
using Base.MathConstants: γ

Expand Down Expand Up @@ -28,7 +30,8 @@ tests = [
"gamma_inc",
"gamma",
"sincosint",
"other_tests"
"other_tests",
"chainrules"
]

const testdir = dirname(@__FILE__)
Expand Down

0 comments on commit 9f230e6

Please sign in to comment.