diff --git a/Project.toml b/Project.toml index ec3b1115a..358f1ef83 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.21.1" [deps] CRlibm = "96374032-68de-5a5b-8d9e-752f78720389" +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FastRounding = "fa42c844-2597-5d31-933b-ebd51ab2693f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -13,6 +14,7 @@ SetRounding = "3cc68bcd-71a2-5612-b932-767ffbe40ab0" [compat] CRlibm = "0.7, 0.8, 1" +DiffRules = "1" EnumX = "1" FastRounding = "0.2, 0.3" RoundingEmulator = "0.2" diff --git a/src/IntervalArithmetic.jl b/src/IntervalArithmetic.jl index 9f116991e..2b29739b8 100644 --- a/src/IntervalArithmetic.jl +++ b/src/IntervalArithmetic.jl @@ -63,6 +63,8 @@ export bisect include("decorations/decorations.jl") export decoration, DecoratedInterval, com, dac, def, trv, ill +include("ad.jl") + include("rand.jl") include("parsing.jl") diff --git a/src/ad.jl b/src/ad.jl new file mode 100644 index 000000000..6b80868b0 --- /dev/null +++ b/src/ad.jl @@ -0,0 +1,11 @@ +function _abs_deriv(x::Interval{T}) where T + if inf(x) == 0 + return Interval{T}(1) + elseif sup(x) == 0 + return Interval{T}(-1, 1) + else + return sign(x) + end +end + +_abs_deriv(x::DecoratedInterval) = DecoratedInterval(_abs_deriv(interval(x)), decoration(x)) \ No newline at end of file diff --git a/test/interval_tests/ad.jl b/test/interval_tests/ad.jl new file mode 100644 index 000000000..85d427038 --- /dev/null +++ b/test/interval_tests/ad.jl @@ -0,0 +1,16 @@ +using IntervalArithmetic, ForwardDiff +using Test + +@testset "AD" begin + for F in (Interval, DecoratedInterval) + @test ForwardDiff.derivative(abs, F(-2.0 .. 2.0)) == F(-1.0 .. 1.0) + @test ForwardDiff.derivative(abs, F(1.0 .. 2.0)) == F(1.0 .. 1.0) + @test ForwardDiff.derivative(abs, F(-2.0 .. -1.0)) == F(-1.0 .. -1.0) + @test ForwardDiff.derivative(abs, F(0.0)) == F(1.0) + @test ForwardDiff.derivative(abs, F(-2.0 .. 0.0)) == F(-1.0 .. 1.0) + @test ForwardDiff.derivative(abs, F(0.0 .. 2.0)) == F(1.0) + + # Test proper handeling at abs(0) + @test ForwardDiff.hessian(t -> abs(t[1])^2, [F(0.0)])[1] == F(2.0) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index eb312a2d8..4864c9b8c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,8 @@ end include_test("interval_tests/intervals.jl") include_test("decoration_tests/decoration_tests.jl") +include_test("interval_tests/ad.jl") + include_test("rand.jl") # Display tests: @@ -24,3 +26,12 @@ include_test("multidim_tests/multidim.jl") # ITF1788 tests include_test("test_ITF1788/run_ITF1788.jl") # TODO fix these tests + +# Display tests: +include_test("display_tests/display.jl") + +# Multidim tests +include_test("multidim_tests/multidim.jl") + +# ITF1788 tests +include_test("test_ITF1788/run_ITF1788.jl") # TODO fix these tests