diff --git a/Project.toml b/Project.toml index d6bdd624..35abedcc 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,8 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e" [compat] -ChainRulesCore = "0.9" -ChainRulesTestUtils = "0.6.3" +ChainRulesCore = "0.9.40" +ChainRulesTestUtils = "0.6.8" LogExpFunctions = "0.2" OpenSpecFun_jll = "0.5" julia = "1.3" diff --git a/src/chainrules.jl b/src/chainrules.jl index c678c15d..4e9ecc25 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,3 +1,8 @@ +const BESSEL_ORDER_INFO = """ +derivatives of Bessel functions with respect to the order are not implemented currently: +https://github.com/JuliaMath/SpecialFunctions.jl/issues/160 +""" + ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x)) ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x)) ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x)) @@ -31,49 +36,49 @@ ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x)) ChainRulesCore.@scalar_rule( besselj(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), (besselj(ν - 1, x) - besselj(ν + 1, x)) / 2 ), ) ChainRulesCore.@scalar_rule( besseli(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), (besseli(ν - 1, x) + besseli(ν + 1, x)) / 2, ), ) ChainRulesCore.@scalar_rule( bessely(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), (bessely(ν - 1, x) - bessely(ν + 1, x)) / 2, ), ) ChainRulesCore.@scalar_rule( besselk(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), -(besselk(ν - 1, x) + besselk(ν + 1, x)) / 2, ), ) ChainRulesCore.@scalar_rule( hankelh1(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), (hankelh1(ν - 1, x) - hankelh1(ν + 1, x)) / 2, ), ) ChainRulesCore.@scalar_rule( hankelh2(ν, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO), (hankelh2(ν - 1, x) - hankelh2(ν + 1, x)) / 2, ), ) ChainRulesCore.@scalar_rule( polygamma(m, x), ( - ChainRulesCore.@thunk(error("not implemented")), + ChainRulesCore.DoesNotExist(), polygamma(m + 1, x), ), ) diff --git a/test/chainrules.jl b/test/chainrules.jl index fedaf6c1..5dc23c4c 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,7 +1,7 @@ @testset "chainrules" begin Random.seed!(1) - @testset "general" begin + @testset "general: single input" begin for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im) test_scalar(erf, x) test_scalar(erfc, x) @@ -12,9 +12,6 @@ test_scalar(airybi, x) test_scalar(airybiprime, x) - test_scalar(besselj0, x) - test_scalar(besselj1, x) - test_scalar(erfcx, x) test_scalar(dawson, x) @@ -28,37 +25,74 @@ end if x isa Real && x > 0 || x isa Complex - test_scalar(bessely0, x) - test_scalar(bessely1, x) test_scalar(gamma, x) test_scalar(digamma, x) test_scalar(trigamma, x) end end + end + + @testset "Bessel functions" begin + for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) + test_scalar(besselj0, x) + test_scalar(besselj1, x) + + isreal(x) && x < 0 && continue + + test_scalar(bessely0, x) + test_scalar(bessely1, x) + + for nu in (-1.5, 2.2, 4.0) + test_frule(besseli, nu, x) + test_rrule(besseli, nu, x) - @testset "beta and logbeta" begin - test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) - for _x in test_points, _y in test_points - # ensure all complex if any complex for FiniteDifferences - x, y = promote(_x, _y) - test_frule(beta, x, y) - test_rrule(beta, x, y) + test_frule(besselj, nu, x) + test_rrule(besselj, nu, x) - test_frule(logbeta, x, y) - test_rrule(logbeta, x, y) + test_frule(besselk, nu, x) + test_rrule(besselk, nu, x) + + test_frule(bessely, nu, x) + test_rrule(bessely, nu, x) + + # use complex numbers in `rrule` for FiniteDifferences + test_frule(hankelh1, nu, x) + test_rrule(hankelh1, nu, complex(x)) + + # use complex numbers in `rrule` for FiniteDifferences + test_frule(hankelh2, nu, x) + test_rrule(hankelh2, nu, complex(x)) end end + end - @testset "log gamma and co" begin - # It is important that we have negative numbers with both odd and even integer parts - for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) - isreal(x) && x < 0 && continue - test_scalar(loggamma, x) + @testset "beta and logbeta" begin + test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im) + for _x in test_points, _y in test_points + # ensure all complex if any complex for FiniteDifferences + x, y = promote(_x, _y) + test_frule(beta, x, y) + test_rrule(beta, x, y) - isreal(x) || continue - test_frule(logabsgamma, x) - test_rrule(logabsgamma, x; output_tangent=(randn(), randn())) + test_frule(logbeta, x, y) + test_rrule(logbeta, x, y) + end + end + + @testset "log gamma and co" begin + # It is important that we have negative numbers with both odd and even integer parts + for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im) + for m in (0, 1, 2, 3) + test_frule(polygamma, m, x) + test_rrule(polygamma, m, x) end + + isreal(x) && x < 0 && continue + test_scalar(loggamma, x) + + isreal(x) || continue + test_frule(logabsgamma, x) + test_rrule(logabsgamma, x; output_tangent=(randn(), randn())) end end end