From e976fa27f0c33b15af207eb236895af8639ff2dd Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 02:42:29 +0530 Subject: [PATCH 01/22] feat: add a dispatch on `PIDEProblem` where `f` (non linear function) is `nothing` --- src/HighDimPDE.jl | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 5eb3568..f70b847 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -229,6 +229,57 @@ function ParabolicPDEProblem(μ, kwargs) end +""" +Defines a Kolmogorov Backward PDE : +## Arguments : +* `g` : terminal condition, of the form `g(x)`. +* `μ` : drift function, of the form `μ(x, p, t)`. +* `σ` : diffusion function `σ(x, p, t)`. +* `tspan`: timespan of the problem. +* `xspan`: the domain of system state +* `xdims`: the numner of the state variables + +## Keyword Arguments: +* `noise_rate_prototype` : Incase of a non diagonal noise, the prototype of `dx` in `σ` +""" +function PIDEProblem(g, + μ, + σ, + tspan, + xspan, + xdims; + p = nothing, + x0_sample = NoSampling(), + noise_rate_prototype = nothing, + kwargs...) + x = first(xspan) + x = fill(x, xdims, 1) + kwargs = merge(NamedTuple(kwargs), + (xspan = xspan, xdims = xdims, noise_rate_prototype = noise_rate_prototype)) + + PIDEProblem{typeof(g(x)), + typeof(g), + Nothing, + typeof(μ), + typeof(σ), + typeof(x), + eltype(tspan), + typeof(p), + typeof(x0_sample), + Nothing, + typeof(kwargs)}(g(x), + g, + nothing, + μ, + σ, + x, + tspan, + p, + x0_sample, + nothing, + kwargs) +end + struct PIDESolution{X0, Ts, L, Us, NNs, Ls} x0::X0 ts::Ts @@ -267,6 +318,7 @@ include("DeepBSDE.jl") include("DeepBSDE_Han.jl") include("MLP.jl") include("NNStopping.jl") +include("NNKolmogorov.jl") export PIDEProblem, ParabolicPDEProblem, PIDESolution, DeepSplitting, DeepBSDE, MLP, NNStopping From 5169183ce876bd2d9237e9dee0641938fb62ac53 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 02:43:23 +0530 Subject: [PATCH 02/22] feat: add `NNKolmogorov` --- src/NNKolmogorov.jl | 108 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 src/NNKolmogorov.jl diff --git a/src/NNKolmogorov.jl b/src/NNKolmogorov.jl new file mode 100644 index 0000000..27a160e --- /dev/null +++ b/src/NNKolmogorov.jl @@ -0,0 +1,108 @@ +""" +Algorithm for solving Backward Kolmogorov Equations. + +```julia +NeuralPDE.NNKolmogorov(chain, opt, sdealg, ensemblealg ) +``` +Arguments: +- `chain`: A Chain neural network with a d-dimensional output. +- `opt`: The optimizer to train the neural network. Defaults to `ADAM(0.1)`. +- `sdealg`: The algorithm used to solve the discretized SDE according to the process that X follows. Defaults to `EM()`. +- `ensemblealg`: The algorithm used to solve the Ensemble Problem that performs Ensemble simulations for the SDE. Defaults to `EnsembleThreads()`. See + the [Ensemble Algorithms](https://diffeq.sciml.ai/stable/features/ensemble/#EnsembleAlgorithms-1) + documentation for more details. +- - `kwargs`: Additional arguments splatted to the SDE solver. See the + [Common Solver Arguments](https://diffeq.sciml.ai/dev/basics/common_solver_opts/) + documentation for more details. +[1]Beck, Christian, et al. "Solving stochastic differential equations and Kolmogorov equations by means of deep learning." arXiv preprint arXiv:1806.00421 (2018). +""" +struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm + chain::C + opt::O +end +NNKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNKolmogorov(chain, opt) + +function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, + pdealg::HighDimPDE.NNKolmogorov, + sdealg; + ensemblealg = EnsembleThreads(), + abstol = 1.0f-6, + verbose = false, + maxiters = 300, + trajectories = 1000, + save_everystep = false, + use_gpu = false, + dt, + dx, + kwargs...) + tspan = prob.tspan + sigma = prob.σ + μ = prob.μ + noise_rate_prototype = prob.kwargs.noise_rate_prototype + phi = prob.g + + xspan = prob.kwargs.xspan + d = prob.kwargs.xdims + + ts = tspan[1]:dt:tspan[2] + xs = xspan[1]:dx:xspan[2] + N = size(ts) + T = tspan[2] + + #hidden layer + chain = pdealg.chain + opt = pdealg.opt + ps = Flux.params(chain) + xi = rand(xs, d, trajectories) + #Finding Solution to the SDE having initial condition xi. Y = Phi(S(X , T)) + sdeproblem = SDEProblem(μ, + sigma, + xi, + tspan, + noise_rate_prototype = noise_rate_prototype) + function prob_func(prob, i, repeat) + SDEProblem(prob.f, + xi[:, i], + prob.tspan, + noise_rate_prototype = prob.noise_rate_prototype) + end + output_func(sol, i) = (sol.u[end], false) + ensembleprob = EnsembleProblem(sdeproblem, + prob_func = prob_func, + output_func = output_func) + sim = solve(ensembleprob, + sdealg, + ensemblealg, + dt = dt, + trajectories = trajectories, + adaptive = false) + + x_sde = Array(sim) + + y = reduce(hcat, phi.(eachcol(x_sde))) + + if use_gpu == true + y = y |> gpu + xi = xi |> gpu + end + data = Iterators.repeated((xi, y), maxiters) + if use_gpu == true + data = data |> gpu + end + + #MSE Loss Function + loss(x, y) = Flux.mse(chain(x), y) + + losses = AbstractFloat[] + callback = function () + l = loss(xi, y) + verbose && println("Current loss is: $l") + push!(losses, l) + l < abstol && Flux.stop() + end + + Flux.train!(loss, ps, data, opt; cb = callback) + chainout = chain(xi) + xi, chainout + return PIDESolution(xi, ts, losses, chainout, chain, nothing) +end From 2eaf6c0db038d487fa9896706c8f540de9dcdef5 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 02:44:00 +0530 Subject: [PATCH 03/22] test: add tests for NNKolmogorov --- test/NNKolmogorov.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 92 insertions(+) create mode 100644 test/NNKolmogorov.jl diff --git a/test/NNKolmogorov.jl b/test/NNKolmogorov.jl new file mode 100644 index 0000000..d1164c5 --- /dev/null +++ b/test/NNKolmogorov.jl @@ -0,0 +1,91 @@ +using Test, Flux, StochasticDiffEq +using HighDimPDE +using Distributions + +using Random +Random.seed!(100) + +# For a diract delta take u0 = Normal(0 , sigma) where sigma --> 0 +u0 = Normal(1.00, 1.00) +xspan = (-2.0, 6.0) +tspan = (0.0, 1.0) +σ(u, p, t) = 2.00 +μ(u, p, t) = -2.00 + +d = 1 +sdealg = EM() +g(x) = pdf(u0, x) +prob = PIDEProblem(g, μ, σ, tspan, xspan, d) +opt = Flux.ADAM(0.01) +m = Chain(Dense(1, 5, elu), Dense(5, 5, elu), Dense(5, 5, elu), Dense(5, 1)) +ensemblealg = EnsembleThreads() +sol = solve(prob, NNKolmogorov(m, opt), sdealg; ensemblealg = ensemblealg, verbose = true, + dt = 0.01, + abstol = 1e-10, dx = 0.0001, trajectories = 100000, maxiters = 500) + +## The solution is obtained taking the Fourier Transform. +analytical(xi) = pdf.(Normal(3, sqrt(1.0 + 5.00)), xi) +##Validation +xs = -5:0.00001:5 +x_1 = rand(xs, 1, 1000) +err_l2 = Flux.mse(analytical(x_1), sol.ufuns(x_1)) +@test err_l2 < 0.01 + +xspan = (-6.0, 6.0) +tspan = (0.0, 1.0) +σ(u, p, t) = 0.5 * u +μ(u, p, t) = 0.5 * 0.25 * u +d = 1 +function g(x) + 1.77 .* x .- 0.015 .* x .^ 3 +end + +sdealg = EM() +prob = PIDEProblem(g, μ, σ, tspan, xspan, d) +opt = Flux.ADAM(0.01) +m = Chain(Dense(1, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) +sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, + dx = 0.0001, trajectories = 1000, abstol = 1e-6, maxiters = 300) + +function analytical(xi) + y = Float64[] + a = 1.77 * exp(0.5 * (0.5)^2 * 1.0) + b = -0.015 * exp(0.5 * (0.5 * 3)^2 * 1.0) + for x in xi + y = push!(y, a * x + b * x^3) + end + y = reshape(y, size(xi)[1], size(xi)[2]) + return y +end +xs = -5.00:0.01:5.00 +x_val = rand(xs, d, 50) +errorl2 = Flux.mse(analytical(x_val), sol.ufuns(x_val)) +println("error_l2 = ", errorl2, "\n") +@test errorl2 < 0.4 + +##Non-Diagonal Test +μ_noise = (du, u, p, t) -> du .= 1.01u +σ_noise = function (du, u, p, t) + du[1, 1] = 0.3u[1] + du[1, 2] = 0.6u[1] + du[1, 3] = 0.9u[1] + du[1, 4] = 0.12u[2] + du[2, 1] = 1.2u[1] + du[2, 2] = 0.2u[2] + du[2, 3] = 0.3u[2] + du[2, 4] = 1.8u[2] +end +Σ = [1.0 0.3; 0.3 1.0] +uo3 = MvNormal([0.0; 0.0], Σ) +g(x) = pdf(uo3, x) + +sdealg = EM() +xspan = (-10.0, 10.0) +tspan = (0.0, 1.0) +d = 2 +prob = PIDEProblem(g, μ_noise, σ_noise, tspan, xspan, d; noise_rate_prototype = zeros(2, 4)) +opt = Flux.ADAM(0.01) +m = Chain(Dense(d, 32, elu), Dense(32, 64, elu), Dense(64, 1)) +sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.001, + abstol = 1e-6, dx = 0.001, trajectories = 1000, maxiters = 200) +println("Non-Diagonal test working.") diff --git a/test/runtests.jl b/test/runtests.jl index 016c8fc..3f5bfe6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,5 @@ using SafeTestsets, Test @time @safetestset "Deep Splitting" include("DeepSplitting.jl") @time @safetestset "MC Sample" include("MCSample.jl") @time @safetestset "NNStopping" include("NNStopping.jl") + @time @safetestset "NNKolmogorov" include("NNKolmogorov.jl") end From 6ff67d17fa6014615234d924d27d70000e33b525 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 02:44:56 +0530 Subject: [PATCH 04/22] fix: add `Distributions` to test deps --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 98e8793..9b9e32a 100644 --- a/Project.toml +++ b/Project.toml @@ -44,8 +44,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "SafeTestsets"] +test = ["Aqua", "Distributions", "Test", "SafeTestsets"] From ecbe769f42d96fee3a9b14b0e58397fb7e9de150 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 10:50:21 +0530 Subject: [PATCH 05/22] fix: remove need for xdims. allow multiple domains for `x` --- src/HighDimPDE.jl | 11 ++++------- src/NNKolmogorov.jl | 10 +++++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index f70b847..45e17c6 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -236,8 +236,7 @@ Defines a Kolmogorov Backward PDE : * `μ` : drift function, of the form `μ(x, p, t)`. * `σ` : diffusion function `σ(x, p, t)`. * `tspan`: timespan of the problem. -* `xspan`: the domain of system state -* `xdims`: the numner of the state variables +* `xspan`: the domain of system state. This can be a tuple of floats for single dimension, and a vector of tuples for multiple dimensions. Where each tuple corresponds to a dimension of state vector. ## Keyword Arguments: * `noise_rate_prototype` : Incase of a non diagonal noise, the prototype of `dx` in `σ` @@ -246,16 +245,14 @@ function PIDEProblem(g, μ, σ, tspan, - xspan, - xdims; + xspan; p = nothing, x0_sample = NoSampling(), noise_rate_prototype = nothing, kwargs...) - x = first(xspan) - x = fill(x, xdims, 1) + x = isa(xspan, Vector) ? first.(xspan) : first(xspan) kwargs = merge(NamedTuple(kwargs), - (xspan = xspan, xdims = xdims, noise_rate_prototype = noise_rate_prototype)) + (xspan = xspan, noise_rate_prototype = noise_rate_prototype)) PIDEProblem{typeof(g(x)), typeof(g), diff --git a/src/NNKolmogorov.jl b/src/NNKolmogorov.jl index 27a160e..395312f 100644 --- a/src/NNKolmogorov.jl +++ b/src/NNKolmogorov.jl @@ -42,10 +42,14 @@ function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, phi = prob.g xspan = prob.kwargs.xspan - d = prob.kwargs.xdims + xspans = isa(xspan, Tuple) ? [xspan] : xspan + + d = length(xspans) ts = tspan[1]:dt:tspan[2] - xs = xspan[1]:dx:xspan[2] + xs = map(xspans) do xspan + xspan[1]:dx:xspan[2] + end N = size(ts) T = tspan[2] @@ -53,7 +57,7 @@ function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, chain = pdealg.chain opt = pdealg.opt ps = Flux.params(chain) - xi = rand(xs, d, trajectories) + xi = mapreduce(x -> rand(x, 1, trajectories), vcat, xs) #Finding Solution to the SDE having initial condition xi. Y = Phi(S(X , T)) sdeproblem = SDEProblem(μ, sigma, From 4426a9505f64d954da81427e10c273ff9a81424c Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 28 Jan 2024 10:51:05 +0530 Subject: [PATCH 06/22] test: update tests --- test/NNKolmogorov.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/NNKolmogorov.jl b/test/NNKolmogorov.jl index d1164c5..b121e51 100644 --- a/test/NNKolmogorov.jl +++ b/test/NNKolmogorov.jl @@ -15,7 +15,7 @@ tspan = (0.0, 1.0) d = 1 sdealg = EM() g(x) = pdf(u0, x) -prob = PIDEProblem(g, μ, σ, tspan, xspan, d) +prob = PIDEProblem(g, μ, σ, tspan, xspan) opt = Flux.ADAM(0.01) m = Chain(Dense(1, 5, elu), Dense(5, 5, elu), Dense(5, 5, elu), Dense(5, 1)) ensemblealg = EnsembleThreads() @@ -35,13 +35,13 @@ xspan = (-6.0, 6.0) tspan = (0.0, 1.0) σ(u, p, t) = 0.5 * u μ(u, p, t) = 0.5 * 0.25 * u -d = 1 + function g(x) 1.77 .* x .- 0.015 .* x .^ 3 end sdealg = EM() -prob = PIDEProblem(g, μ, σ, tspan, xspan, d) +prob = PIDEProblem(g, μ, σ, tspan, xspan) opt = Flux.ADAM(0.01) m = Chain(Dense(1, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, @@ -80,10 +80,10 @@ uo3 = MvNormal([0.0; 0.0], Σ) g(x) = pdf(uo3, x) sdealg = EM() -xspan = (-10.0, 10.0) +xspan = [(-10.0, 10.0), (-10.0, 10.0)] tspan = (0.0, 1.0) +prob = PIDEProblem(g, μ_noise, σ_noise, tspan, xspan; noise_rate_prototype = zeros(2, 4)) d = 2 -prob = PIDEProblem(g, μ_noise, σ_noise, tspan, xspan, d; noise_rate_prototype = zeros(2, 4)) opt = Flux.ADAM(0.01) m = Chain(Dense(d, 32, elu), Dense(32, 64, elu), Dense(64, 1)) sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.001, From 7c41b3a7aaf39d7c54029e6d1ccb58e159e368e4 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 30 Jan 2024 11:29:45 +0530 Subject: [PATCH 07/22] feat: add NNParamKolmogorov alg --- src/NNParamKolmogorov.jl | 138 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 src/NNParamKolmogorov.jl diff --git a/src/NNParamKolmogorov.jl b/src/NNParamKolmogorov.jl new file mode 100644 index 0000000..ab063bb --- /dev/null +++ b/src/NNParamKolmogorov.jl @@ -0,0 +1,138 @@ + +struct NNParamKolmogorov{C, O} <: HighDimPDEAlgorithm + chain::C + opt::O +end + +NNParamKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNParamKolmogorov(chain, opt) + +function DiffEqBase.solve(prob::PIDEProblem, + pdealg::NNParamKolmogorov, + sdealg = EM(); + ensemblealg = EnsembleThreads(), + abstol = 1.0f-6, + verbose = false, + maxiters = 300, + trajectories = 1000, + save_everystep = false, + use_gpu = false, + dps = (0.01,), + dt, + dx, + kwargs...) + tspan = prob.tspan + sigma = prob.σ + mu = prob.μ + noise_rate_prototype = get(prob.kwargs, :noise_rate_prototype, nothing) + + p_domain = prob.kwargs.p_domain + xspan = prob.kwargs.xspan + + xspans = isa(xspan, Tuple) ? [xspan] : xspan + d = length(xspans) + + phi = prob.g + + ts = tspan[1]:dt:tspan[2] + xs = map(xspans) do xspan + xspan[1]:dx:xspan[2] + end + + p_domain = prob.kwargs.p_domain + p_prototype = prob.kwargs.p_prototype + + chain = pdealg.chain + ps = Flux.params(chain) + opt = pdealg.opt + + xi = mapreduce(x -> rand(x, 1, trajectories), vcat, xs) + ti = rand(ts, 1, trajectories) + + ps_sigma, ps_mu, ps_phi = map(zip(p_domain, + p_prototype, + dps)) do (domain, prototype, dp) + # domain , prototype, dp = p_domain[key], p_prototype[key], dps[key] + isnothing(domain) && return + return rand(domain[1]:dp:domain[2], size(prototype)..., trajectories) + end + + total_dims = mapreduce(*, (ti, xi, ps_mu, ps_sigma, ps_phi)) do y + isnothing(y) && return 1 + *(size(y)[1:(end - 1)]...) + end + + train_data = mapreduce(vcat, (ti, xi, ps_mu, ps_sigma, ps_phi)) do y + isnothing(y) && return rand(0, trajectories) # empty matrix + reshape(y, :, trajectories) + end + + ps_sigma_iterator = !isnothing(ps_sigma) ? + eachslice(ps_sigma, dims = length(size(ps_sigma))) : + collect(Iterators.repeated(nothing, trajectories)) + ps_mu_iterator = !isnothing(ps_mu) ? eachslice(ps_mu, dims = length(size(ps_mu))) : + collect(Iterators.repeated(nothing, trajectories)) + ps_phi_iterator = !isnothing(ps_phi) ? eachslice(ps_phi, dims = length(size(ps_phi))) : + collect(Iterators.repeated(nothing, trajectories)) + # return xi, ti, ps_sigma_iterator[1] + prob_func = (prob, i, repeat) -> begin + sigma_(dx, x, p, t) = sigma(dx, x, ps_sigma_iterator[i], t) + mu_(dx, x, p, t) = mu(dx, x, ps_mu_iterator[i], t) + SDEProblem(mu_, + sigma_, + xi[:, i], + (tspan[1], ti[:, 1][1]), + noise_rate_prototype = noise_rate_prototype) + end + + output_func = (sol, i) -> (sol[end], false) + + sdeprob = SDEProblem(mu, + sigma, + xi[:, 1], + tspan; + noise_rate_prototype = noise_rate_prototype) + + ensembleprob = EnsembleProblem(sdeprob, + prob_func = prob_func, + output_func = output_func) + # return train_data , ps_sigma_iterator, ps_mu_iterator + sol = solve(ensembleprob, sdealg, ensemblealg; trajectories = trajectories, dt = dt) + # return train_data, sol + # Y = reduce(hcat, phi.(eachcol(x_sde))) + Y = reduce(hcat, phi.(eachcol(Array(sol)), ps_phi_iterator)) + if use_gpu == true + Y = Y |> gpu + train_data = train_data |> gpu + end + + data = Iterators.repeated((train_data, Y), maxiters) + if use_gpu == true + data = data |> gpu + end + + #MSE Loss Function + loss(x, y) = Flux.mse(chain(x), y) + + losses = AbstractFloat[] + callback = function () + l = loss(train_data, Y) + verbose && println("Current loss is: $l") + push!(losses, l) + l < abstol && Flux.stop() + end + + Flux.train!(loss, ps, data, opt; cb = callback) + + sol_func = (x0, t, _p_sigma, _p_mu, _p_phi) -> begin + ps = map(zip(p_prototype, (_p_sigma, _p_mu, _p_phi))) do (prototype, p) + @assert typeof(prototype) == typeof(p) + !isnothing(prototype) && return reshape(p, :, 1) + return nothing + end + ps = filter(x -> !isnothing(x), ps) + chain(vcat(reshape(t, :, 1), reshape(x0, :, 1), ps...)) + end + + train_out = chain(train_data) + PIDESolution(xi, ts, losses, train_out, sol_func, nothing) +end #solve From cde9de875734d66f43481276f0c0c544146651f2 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 30 Jan 2024 11:30:08 +0530 Subject: [PATCH 08/22] fix: update PIDEProblem dispatch --- src/HighDimPDE.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 45e17c6..6828b8d 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -254,7 +254,16 @@ function PIDEProblem(g, kwargs = merge(NamedTuple(kwargs), (xspan = xspan, noise_rate_prototype = noise_rate_prototype)) - PIDEProblem{typeof(g(x)), + g_ = try + g(x) + catch e + if e isa MethodError + g(x, kwargs[:p_domain].p_phi) + else + throw(e) + end + end + PIDEProblem{typeof(g_), typeof(g), Nothing, typeof(μ), @@ -264,7 +273,7 @@ function PIDEProblem(g, typeof(p), typeof(x0_sample), Nothing, - typeof(kwargs)}(g(x), + typeof(kwargs)}(g_, g, nothing, μ, @@ -316,6 +325,7 @@ include("DeepBSDE_Han.jl") include("MLP.jl") include("NNStopping.jl") include("NNKolmogorov.jl") +include("NNParamKolmogorov.jl") export PIDEProblem, ParabolicPDEProblem, PIDESolution, DeepSplitting, DeepBSDE, MLP, NNStopping From 192b3f5dd88d57b34b583e722c01009d7a19e0eb Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 30 Jan 2024 11:30:52 +0530 Subject: [PATCH 09/22] test: add tests for NNParamKolmogorov --- test/NNParamKolmogorov.jl | 63 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 64 insertions(+) create mode 100644 test/NNParamKolmogorov.jl diff --git a/test/NNParamKolmogorov.jl b/test/NNParamKolmogorov.jl new file mode 100644 index 0000000..0751326 --- /dev/null +++ b/test/NNParamKolmogorov.jl @@ -0,0 +1,63 @@ +using Test, Flux +using StochasticDiffEq +using LinearAlgebra +using HighDimPDE +using Random +Random.seed!(100) + +d = 1 +m = Chain(Dense(3, 16, tanh), Dense(16, 16, tanh), Dense(16, 5, tanh), Dense(5, 1)) +ensemblealg = EnsembleThreads() +γ_mu_prototype = nothing +γ_sigma_prototype = zeros(d, d, 1) +γ_phi_prototype = nothing + +sdealg = EM() +tspan = (0.00, 1.00) +trajectories = 10000 +function phi(x, y_phi) + x .^ 2 +end +sigma(dx, x, γ_sigma, t) = dx .= γ_sigma[:, :, 1] +mu(dx, x, γ_mu, t) = dx .= 0.00 + +xspan = (0.00, 3.00) + +p_domain = (p_sigma = (0.00, 2.00), p_mu = nothing, p_phi = nothing) +p_prototype = (p_sigma = γ_sigma_prototype, p_mu = γ_mu_prototype, p_phi = γ_phi_prototype) +dps = (p_sigma = 0.01, p_mu = nothing, p_phi = nothing) + +dt = 0.01 +dx = 0.01 +opt = Flux.ADAM(1e-2) + +prob = PIDEProblem(phi, + mu, + sigma, + tspan, + xspan; + p_domain = p_domain, + p_prototype = p_prototype) + +sol = solve(prob, HighDimPDE.NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, + abstol = 1e-10, dx = 0.01, trajectories = trajectories, maxiters = 1000, + use_gpu = false, dps = dps) + +x_test = rand(xspan[1]:dx:xspan[2], d, 1, 1000) +t_test = rand(tspan[1]:dt:tspan[2], 1, 1000) +γ_sigma_test = rand(0.3:(dps.p_sigma):0.5, d, d, 1, 1000) + +function analytical(x, t, y) + return x .^ 2 .+ t .* (y .* y) +end + +preds = map((i) -> sol.ufuns(x_test[:, :, i], + t_test[:, i], + γ_sigma_test[:, :, :, i], + nothing, + nothing), + 1:1000) +y_test = map((i) -> analytical(x_test[:, :, i], t_test[:, i], γ_sigma_test[:, :, :, i]), + 1:1000) + +@test Flux.mse(reduce(hcat, preds), reduce(hcat, y_test)) < 0.1 diff --git a/test/runtests.jl b/test/runtests.jl index 3f5bfe6..35af2ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,5 @@ using SafeTestsets, Test @time @safetestset "MC Sample" include("MCSample.jl") @time @safetestset "NNStopping" include("NNStopping.jl") @time @safetestset "NNKolmogorov" include("NNKolmogorov.jl") + @time @safetestset "NNKolmogorov" include("NNParamKolmogorov.jl") end From beadd2a6c65445a4c3696183f33bef1cbe53ae7c Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Thu, 1 Feb 2024 23:24:44 +0530 Subject: [PATCH 10/22] fix: remove dispatch on PIDEProblem --- src/HighDimPDE.jl | 57 ----------------------------------------------- 1 file changed, 57 deletions(-) diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 6828b8d..515c317 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -229,63 +229,6 @@ function ParabolicPDEProblem(μ, kwargs) end -""" -Defines a Kolmogorov Backward PDE : -## Arguments : -* `g` : terminal condition, of the form `g(x)`. -* `μ` : drift function, of the form `μ(x, p, t)`. -* `σ` : diffusion function `σ(x, p, t)`. -* `tspan`: timespan of the problem. -* `xspan`: the domain of system state. This can be a tuple of floats for single dimension, and a vector of tuples for multiple dimensions. Where each tuple corresponds to a dimension of state vector. - -## Keyword Arguments: -* `noise_rate_prototype` : Incase of a non diagonal noise, the prototype of `dx` in `σ` -""" -function PIDEProblem(g, - μ, - σ, - tspan, - xspan; - p = nothing, - x0_sample = NoSampling(), - noise_rate_prototype = nothing, - kwargs...) - x = isa(xspan, Vector) ? first.(xspan) : first(xspan) - kwargs = merge(NamedTuple(kwargs), - (xspan = xspan, noise_rate_prototype = noise_rate_prototype)) - - g_ = try - g(x) - catch e - if e isa MethodError - g(x, kwargs[:p_domain].p_phi) - else - throw(e) - end - end - PIDEProblem{typeof(g_), - typeof(g), - Nothing, - typeof(μ), - typeof(σ), - typeof(x), - eltype(tspan), - typeof(p), - typeof(x0_sample), - Nothing, - typeof(kwargs)}(g_, - g, - nothing, - μ, - σ, - x, - tspan, - p, - x0_sample, - nothing, - kwargs) -end - struct PIDESolution{X0, Ts, L, Us, NNs, Ls} x0::X0 ts::Ts From 541a11c838c0bd213995dbc7d8c14fcba27af331 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Thu, 1 Feb 2024 23:54:53 +0530 Subject: [PATCH 11/22] test: change Flux.ADAM to Flux.Optimisers.Adam --- test/NNKolmogorov.jl | 6 +++--- test/NNParamKolmogorov.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/NNKolmogorov.jl b/test/NNKolmogorov.jl index b121e51..9bd58b7 100644 --- a/test/NNKolmogorov.jl +++ b/test/NNKolmogorov.jl @@ -16,7 +16,7 @@ d = 1 sdealg = EM() g(x) = pdf(u0, x) prob = PIDEProblem(g, μ, σ, tspan, xspan) -opt = Flux.ADAM(0.01) +opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(1, 5, elu), Dense(5, 5, elu), Dense(5, 5, elu), Dense(5, 1)) ensemblealg = EnsembleThreads() sol = solve(prob, NNKolmogorov(m, opt), sdealg; ensemblealg = ensemblealg, verbose = true, @@ -42,7 +42,7 @@ end sdealg = EM() prob = PIDEProblem(g, μ, σ, tspan, xspan) -opt = Flux.ADAM(0.01) +opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(1, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, dx = 0.0001, trajectories = 1000, abstol = 1e-6, maxiters = 300) @@ -84,7 +84,7 @@ xspan = [(-10.0, 10.0), (-10.0, 10.0)] tspan = (0.0, 1.0) prob = PIDEProblem(g, μ_noise, σ_noise, tspan, xspan; noise_rate_prototype = zeros(2, 4)) d = 2 -opt = Flux.ADAM(0.01) +opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(d, 32, elu), Dense(32, 64, elu), Dense(64, 1)) sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.001, abstol = 1e-6, dx = 0.001, trajectories = 1000, maxiters = 200) diff --git a/test/NNParamKolmogorov.jl b/test/NNParamKolmogorov.jl index 0751326..8aadcd9 100644 --- a/test/NNParamKolmogorov.jl +++ b/test/NNParamKolmogorov.jl @@ -29,7 +29,7 @@ dps = (p_sigma = 0.01, p_mu = nothing, p_phi = nothing) dt = 0.01 dx = 0.01 -opt = Flux.ADAM(1e-2) +opt = Flux.Optimisers.Adam(1e-2) prob = PIDEProblem(phi, mu, From ac6bde7c4b95a88fd6873ebf1e1ad337f565c4b1 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 4 Feb 2024 21:51:59 +0530 Subject: [PATCH 12/22] fix: update NNKolmogorov and NNParamKolmogorov to ParabolicPDEProblem --- src/NNKolmogorov.jl | 36 ++++++++++++++++++------------------ src/NNParamKolmogorov.jl | 30 +++++++++++++----------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/NNKolmogorov.jl b/src/NNKolmogorov.jl index 395312f..92d73c1 100644 --- a/src/NNKolmogorov.jl +++ b/src/NNKolmogorov.jl @@ -22,7 +22,7 @@ struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm end NNKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNKolmogorov(chain, opt) -function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, +function DiffEqBase.solve(prob::ParabolicPDEProblem, pdealg::HighDimPDE.NNKolmogorov, sdealg; ensemblealg = EnsembleThreads(), @@ -38,7 +38,8 @@ function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, tspan = prob.tspan sigma = prob.σ μ = prob.μ - noise_rate_prototype = prob.kwargs.noise_rate_prototype + + noise_rate_prototype = get(prob.kwargs, :noise_rate_prototype, nothing) phi = prob.g xspan = prob.kwargs.xspan @@ -61,9 +62,10 @@ function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, #Finding Solution to the SDE having initial condition xi. Y = Phi(S(X , T)) sdeproblem = SDEProblem(μ, sigma, - xi, + xi[:, 1], tspan, noise_rate_prototype = noise_rate_prototype) + function prob_func(prob, i, repeat) SDEProblem(prob.f, xi[:, i], @@ -85,27 +87,25 @@ function DiffEqBase.solve(prob::Union{PIDEProblem, SDEProblem}, y = reduce(hcat, phi.(eachcol(x_sde))) - if use_gpu == true - y = y |> gpu - xi = xi |> gpu - end - data = Iterators.repeated((xi, y), maxiters) - if use_gpu == true - data = data |> gpu - end + y = use_gpu ? y |> gpu : y + xi = use_gpu ? xi |> gpu : xi #MSE Loss Function - loss(x, y) = Flux.mse(chain(x), y) + loss(m, x, y) = Flux.mse(m(x), y) losses = AbstractFloat[] - callback = function () - l = loss(xi, y) - verbose && println("Current loss is: $l") + + opt_state = Flux.setup(opt, chain) + for epoch in 1:maxiters + gs = Flux.gradient(chain) do model + loss(model, xi, y) + end + Flux.update!(opt_state, chain, gs[1]) + l = loss(chain, xi, y) + @info "Current Epoch: $epoch Current Loss: $l" push!(losses, l) - l < abstol && Flux.stop() end - - Flux.train!(loss, ps, data, opt; cb = callback) + # Flux.train!(loss, chain, data, opt; cb = callback) chainout = chain(xi) xi, chainout return PIDESolution(xi, ts, losses, chainout, chain, nothing) diff --git a/src/NNParamKolmogorov.jl b/src/NNParamKolmogorov.jl index ab063bb..b55aae3 100644 --- a/src/NNParamKolmogorov.jl +++ b/src/NNParamKolmogorov.jl @@ -6,7 +6,7 @@ end NNParamKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNParamKolmogorov(chain, opt) -function DiffEqBase.solve(prob::PIDEProblem, +function DiffEqBase.solve(prob::ParabolicPDEProblem, pdealg::NNParamKolmogorov, sdealg = EM(); ensemblealg = EnsembleThreads(), @@ -100,29 +100,25 @@ function DiffEqBase.solve(prob::PIDEProblem, # return train_data, sol # Y = reduce(hcat, phi.(eachcol(x_sde))) Y = reduce(hcat, phi.(eachcol(Array(sol)), ps_phi_iterator)) - if use_gpu == true - Y = Y |> gpu - train_data = train_data |> gpu - end - - data = Iterators.repeated((train_data, Y), maxiters) - if use_gpu == true - data = data |> gpu - end + Y = use_gpu ? Y |> gpu : Y + train_data = use_gpu ? train_data |> gpu : train_data #MSE Loss Function - loss(x, y) = Flux.mse(chain(x), y) + loss(m, x, y) = Flux.mse(m(x), y) losses = AbstractFloat[] - callback = function () - l = loss(train_data, Y) - verbose && println("Current loss is: $l") + + opt_state = Flux.setup(opt, chain) + for epoch in 1:maxiters + gs = Flux.gradient(chain) do model + loss(model, train_data, Y) + end + Flux.update!(opt_state, chain, gs[1]) + l = loss(chain, train_data, Y) + @info "Current Epoch: $epoch Current Loss: $l" push!(losses, l) - l < abstol && Flux.stop() end - Flux.train!(loss, ps, data, opt; cb = callback) - sol_func = (x0, t, _p_sigma, _p_mu, _p_phi) -> begin ps = map(zip(p_prototype, (_p_sigma, _p_mu, _p_phi))) do (prototype, p) @assert typeof(prototype) == typeof(p) From e1185aa51364a3bec396fa8a84050822b39d0a9c Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Sun, 4 Feb 2024 21:52:56 +0530 Subject: [PATCH 13/22] test: update tests for NNKolmogorov and NNParamKolmogorov --- test/NNKolmogorov.jl | 12 +++++++++--- test/NNParamKolmogorov.jl | 9 +++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test/NNKolmogorov.jl b/test/NNKolmogorov.jl index 9bd58b7..0be5266 100644 --- a/test/NNKolmogorov.jl +++ b/test/NNKolmogorov.jl @@ -15,7 +15,7 @@ tspan = (0.0, 1.0) d = 1 sdealg = EM() g(x) = pdf(u0, x) -prob = PIDEProblem(g, μ, σ, tspan, xspan) +prob = ParabolicPDEProblem(μ, σ, nothing, tspan; g, xspan) opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(1, 5, elu), Dense(5, 5, elu), Dense(5, 5, elu), Dense(5, 1)) ensemblealg = EnsembleThreads() @@ -41,7 +41,7 @@ function g(x) end sdealg = EM() -prob = PIDEProblem(g, μ, σ, tspan, xspan) +prob = ParabolicPDEProblem(μ, σ, nothing, tspan; g, xspan) opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(1, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) sol = solve(prob, NNKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, @@ -82,7 +82,13 @@ g(x) = pdf(uo3, x) sdealg = EM() xspan = [(-10.0, 10.0), (-10.0, 10.0)] tspan = (0.0, 1.0) -prob = PIDEProblem(g, μ_noise, σ_noise, tspan, xspan; noise_rate_prototype = zeros(2, 4)) +prob = ParabolicPDEProblem(μ_noise, + σ_noise, + nothing, + tspan; + g, + xspan, + noise_rate_prototype = zeros(2, 4)) d = 2 opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(d, 32, elu), Dense(32, 64, elu), Dense(64, 1)) diff --git a/test/NNParamKolmogorov.jl b/test/NNParamKolmogorov.jl index 8aadcd9..6944198 100644 --- a/test/NNParamKolmogorov.jl +++ b/test/NNParamKolmogorov.jl @@ -31,11 +31,12 @@ dt = 0.01 dx = 0.01 opt = Flux.Optimisers.Adam(1e-2) -prob = PIDEProblem(phi, - mu, +prob = ParabolicPDEProblem(mu, sigma, - tspan, - xspan; + nothing, + tspan; + g = phi, + xspan, p_domain = p_domain, p_prototype = p_prototype) From 6999b39e4cea5f328439d43a2c1a409e8a3d82f7 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Wed, 7 Feb 2024 23:47:58 +0530 Subject: [PATCH 14/22] chore: export NNKolmogorov and NNParamKolmogorov --- src/HighDimPDE.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 515c317..4f3a525 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -271,6 +271,6 @@ include("NNKolmogorov.jl") include("NNParamKolmogorov.jl") export PIDEProblem, ParabolicPDEProblem, PIDESolution, DeepSplitting, DeepBSDE, MLP, NNStopping - +export NNKolmogorov, NNParamKolmogorov export NormalSampling, UniformSampling, NoSampling, solve end From b8f2f455ee88bff1718b48c89d06d0cefacf45e6 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Wed, 7 Feb 2024 23:48:23 +0530 Subject: [PATCH 15/22] test: update tests --- test/NNParamKolmogorov.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/NNParamKolmogorov.jl b/test/NNParamKolmogorov.jl index 6944198..06377bc 100644 --- a/test/NNParamKolmogorov.jl +++ b/test/NNParamKolmogorov.jl @@ -40,7 +40,7 @@ prob = ParabolicPDEProblem(mu, p_domain = p_domain, p_prototype = p_prototype) -sol = solve(prob, HighDimPDE.NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, +sol = solve(prob, NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, abstol = 1e-10, dx = 0.01, trajectories = trajectories, maxiters = 1000, use_gpu = false, dps = dps) From 934cc85d347ad6083082c4e98556c87d1360d895 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Wed, 7 Feb 2024 23:49:04 +0530 Subject: [PATCH 16/22] test: fix typo --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 35af2ad..e8f7c0d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,5 +9,5 @@ using SafeTestsets, Test @time @safetestset "MC Sample" include("MCSample.jl") @time @safetestset "NNStopping" include("NNStopping.jl") @time @safetestset "NNKolmogorov" include("NNKolmogorov.jl") - @time @safetestset "NNKolmogorov" include("NNParamKolmogorov.jl") + @time @safetestset "NNParamKolmogorov" include("NNParamKolmogorov.jl") end From d7da58e9397dc30499caf5215f3ca7500007a8ff Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Fri, 9 Feb 2024 14:45:10 +0530 Subject: [PATCH 17/22] build: add compat for test dependency `Distributions` --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 9b9e32a..345acd1 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" Aqua = "0.8" CUDA = "4.4, 5" DiffEqBase = "6.137" +Distributions = "v0.25.107" DocStringExtensions = "0.9" Flux = "0.13.12, 0.14" Functors = "0.4" From 691304dc663d8e8fe47452e8cf76746ac6ff0b9b Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 12 Feb 2024 22:34:12 +0530 Subject: [PATCH 18/22] docs: add docstrings and docs for NNKolmogorov and NNParamKolmogorov --- docs/pages.jl | 6 ++- docs/src/NNKolmogorov.md | 24 ++++++++++ docs/src/NNParamKolmogorov.md | 24 ++++++++++ docs/src/tutorials/nnkolmogorov.md | 33 +++++++++++++ docs/src/tutorials/nnparamkolmogorov.md | 61 +++++++++++++++++++++++++ src/NNKolmogorov.jl | 26 +++++++---- src/NNParamKolmogorov.jl | 37 +++++++++++++-- 7 files changed, 199 insertions(+), 12 deletions(-) create mode 100644 docs/src/NNKolmogorov.md create mode 100644 docs/src/NNParamKolmogorov.md create mode 100644 docs/src/tutorials/nnkolmogorov.md create mode 100644 docs/src/tutorials/nnparamkolmogorov.md diff --git a/docs/pages.jl b/docs/pages.jl index 5a0650e..e733586 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -5,12 +5,16 @@ pages = [ "Solver Algorithms" => ["MLP.md", "DeepSplitting.md", "DeepBSDE.md", - "NNStopping.md"], + "NNStopping.md", + "NNKolmogorov.md", + "NNParamKolmogorov.md"], "Tutorials" => [ "tutorials/deepsplitting.md", "tutorials/deepbsde.md", "tutorials/mlp.md", "tutorials/nnstopping.md", + "tutorials/nnkolmogorov.md", + "tutorials/nnparamkolmogorov.md", ], "Feynman Kac formula" => "Feynman_Kac.md", ] diff --git a/docs/src/NNKolmogorov.md b/docs/src/NNKolmogorov.md new file mode 100644 index 0000000..61ce630 --- /dev/null +++ b/docs/src/NNKolmogorov.md @@ -0,0 +1,24 @@ +# [The `NNKolmogorov` algorithm](@id nn_komogorov) + +### Problems Supported: +1. [`ParabolicPDEProblem`](@ref) + +```@autodocs +Modules = [HighDimPDE] +Pages = ["NNKolmogorov.jl"] +``` + +`NNKolmogorov` obtains a terminal solution for Backward Kolmogorov Equations of the form: +```math +\partial_t u(t,x) = \mu(t, x) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) +``` +with initial condition given by `g(x)` + +We can use the Feynman-Kac formula : +```math +S_t^x = \int_{0}^{t}\mu(S_s^x)ds + \int_{0}^{t}\sigma(S_s^x)dB_s +``` +And the solution is given by: +```math +f(T, x) = \mathbb{E}[g(S_T^x)] +``` \ No newline at end of file diff --git a/docs/src/NNParamKolmogorov.md b/docs/src/NNParamKolmogorov.md new file mode 100644 index 0000000..4b865ca --- /dev/null +++ b/docs/src/NNParamKolmogorov.md @@ -0,0 +1,24 @@ +# [The `NNParamKolmogorov` algorithm](@id nn_komogorov) + +### Problems Supported: +1. [`ParabolicPDEProblem`](@ref) + +```@autodocs +Modules = [HighDimPDE] +Pages = ["NNParamKolmogorov.jl"] +``` + +`NNParamKolmogorov` obtains a terminal solution for parametric families of Backward Kolmogorov Equations of the form: +```math +\partial_t u(t,x) = \mu(t, x, γ_mu) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x, γ_sigma) \Delta_x u(t,x) +``` +with initial condition given by `g(x, γ_phi)` + +We can use the Feynman-Kac formula : +```math +S_t^x = \int_{0}^{t}\mu(S_s^x)ds + \int_{0}^{t}\sigma(S_s^x)dB_s +``` +And the solution is given by: +```math +f(T, x) = \mathbb{E}[g(S_T^x, γ_phi)] +``` \ No newline at end of file diff --git a/docs/src/tutorials/nnkolmogorov.md b/docs/src/tutorials/nnkolmogorov.md new file mode 100644 index 0000000..9f35f95 --- /dev/null +++ b/docs/src/tutorials/nnkolmogorov.md @@ -0,0 +1,33 @@ +# `NNKolmogorov` + +## Solving high dimensional Rainbow European Options for a range of initial stock prices: + +```julia +d = 10 # dims +T = 1/12 +sigma = 0.01 .+ 0.03.*Matrix(Diagonal(ones(d))) # volatility +mu = 0.06 # interest rate +K = 100.0 # strike price +function μ_func(du, u, p, t) + du .= mu*u +end + +function σ_func(du, u, p, t) + du .= sigma * u +end + +tspan = (0.0, T) +# The range for initial stock price +xspan = [(98.00, 102.00) for i in 1:d] + +g(x) = max(maximum(x) -K, 0) + +sdealg = EM() +# provide `x0` as nothing to the problem since we are provinding a range for `x0`. +prob = ParabolicPDEProblem(μ_func, σ_func, nothing, tspan, g = g, xspan = xspan) +opt = Flux.Optimisers.Adam(0.01) +alg = NNKolmogorov(m, opt) +m = Chain(Dense(d, 16, elu), Dense(16, 32, elu), Dense(32, 16, elu), Dense(16, 1)) +sol = solve(prob, alg, sdealg, verbose = true, dt = 0.01, + dx = 0.0001, trajectories = 1000, abstol = 1e-6, maxiters = 300) +``` diff --git a/docs/src/tutorials/nnparamkolmogorov.md b/docs/src/tutorials/nnparamkolmogorov.md new file mode 100644 index 0000000..735cf2d --- /dev/null +++ b/docs/src/tutorials/nnparamkolmogorov.md @@ -0,0 +1,61 @@ +# `NNParamKolmogorov` + +## Solving Parametric Family of High Dimensional Heat Equation. + +In this example we will solve the high dimensional heat equation over a range of initial values, and also over a range of thermal diffusivity. +```julia +d = 10 +# models input is `d` for initial values, `d` for thermal diffusivity, and last dimension is for stopping time. +m = Chain(Dense(d + 1 + 1, 32, relu), Dense(32, 16, relu), Dense(16, 8, relu), Dense(8, 1)) +ensemblealg = EnsembleThreads() +γ_mu_prototype = nothing +γ_sigma_prototype = zeros(1, 1) +γ_phi_prototype = nothing + +sdealg = EM() +tspan = (0.00, 1.00) +trajectories = 100000 +function phi(x, y_phi) + sum(x .^ 2) +end +function sigma_(dx, x, γ_sigma, t) + dx .= γ_sigma[:, :, 1] +end +mu_(dx, x, γ_mu, t) = dx .= 0.00 + +xspan = [(0.00, 3.00) for i in 1:d] + +p_domain = (p_sigma = (0.00, 2.00), p_mu = nothing, p_phi = nothing) +p_prototype = (p_sigma = γ_sigma_prototype, p_mu = γ_mu_prototype, p_phi = γ_phi_prototype) +dps = (p_sigma = 0.1, p_mu = nothing, p_phi = nothing) + +dt = 0.01 +dx = 0.01 +opt = Flux.Optimisers.Adam(5e-2) + +prob = ParabolicPDEProblem(mu_, + sigma_, + nothing, + tspan; + g = phi, + xspan, + p_domain = p_domain, + p_prototype = p_prototype) + +sol = solve(prob, NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, + abstol = 1e-10, dx = 0.1, trajectories = trajectories, maxiters = 1000, + use_gpu = false, dps = dps) +``` +Similarly we can parametrize the drift function `mu_` and the initial function `g`, and obtain a solution over all parameters and initial values. + +# Inferring on the solution from `NNParamKolmogorov`: +```julia +x_test = rand(xspan[1][1]:0.1:xspan[1][2], d) +p_sigma_test = rand(p_domain.p_sigma[1]:dps.p_sigma:p_domain.p_sigma[2], 1, 1) +t_test = rand(tspan[1]:dt:tspan[2], 1, 1) +p_mu_test = nothing +p_phi_test = nothing +``` +```julia +sol.ufuns(x_test, t_test, p_sigma_test, p_mu_test, p_phi_test) +``` \ No newline at end of file diff --git a/src/NNKolmogorov.jl b/src/NNKolmogorov.jl index 92d73c1..e8599ca 100644 --- a/src/NNKolmogorov.jl +++ b/src/NNKolmogorov.jl @@ -2,18 +2,11 @@ Algorithm for solving Backward Kolmogorov Equations. ```julia -NeuralPDE.NNKolmogorov(chain, opt, sdealg, ensemblealg ) +HighDimPDE.NNKolmogorov(chain, opt) ``` Arguments: - `chain`: A Chain neural network with a d-dimensional output. - `opt`: The optimizer to train the neural network. Defaults to `ADAM(0.1)`. -- `sdealg`: The algorithm used to solve the discretized SDE according to the process that X follows. Defaults to `EM()`. -- `ensemblealg`: The algorithm used to solve the Ensemble Problem that performs Ensemble simulations for the SDE. Defaults to `EnsembleThreads()`. See - the [Ensemble Algorithms](https://diffeq.sciml.ai/stable/features/ensemble/#EnsembleAlgorithms-1) - documentation for more details. -- - `kwargs`: Additional arguments splatted to the SDE solver. See the - [Common Solver Arguments](https://diffeq.sciml.ai/dev/basics/common_solver_opts/) - documentation for more details. [1]Beck, Christian, et al. "Solving stochastic differential equations and Kolmogorov equations by means of deep learning." arXiv preprint arXiv:1806.00421 (2018). """ struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm @@ -22,6 +15,23 @@ struct NNKolmogorov{C, O} <: HighDimPDEAlgorithm end NNKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNKolmogorov(chain, opt) +""" +$(TYPEDSIGNATURES) + +Returns a `PIDESolution` object. + +# Arguments + +- `sdealg`: a SDE solver from [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/solvers/sde_solve/). + If not provided, the plain vanilla [DeepBSDE](https://arxiv.org/abs/1707.02568) method will be applied. + If provided, the SDE associated with the PDE problem will be solved relying on + methods from DifferentialEquations.jl, using [Ensemble solves](https://diffeq.sciml.ai/stable/features/ensemble/) + via `sdealg`. Check the available `sdealg` on the + [DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/). +- `maxiters`: The number of training epochs. Defaults to `300` +- `trajectories`: The number of trajectories simulated for training. Defaults to `100` +- Extra keyword arguments passed to `solve` will be further passed to the SDE solver. +""" function DiffEqBase.solve(prob::ParabolicPDEProblem, pdealg::HighDimPDE.NNKolmogorov, sdealg; diff --git a/src/NNParamKolmogorov.jl b/src/NNParamKolmogorov.jl index b55aae3..f134eef 100644 --- a/src/NNParamKolmogorov.jl +++ b/src/NNParamKolmogorov.jl @@ -1,4 +1,14 @@ - +""" +Algorithm for solving Backward Kolmogorov Equations. + +```julia +HighDimPDE.NNKolmogorov(chain, opt) +``` +Arguments: +- `chain`: A Chain neural network with a d-dimensional output. +- `opt`: The optimizer to train the neural network. Defaults to `ADAM(0.1)`. +[1] Berner Julius et al. "Numerically solving parametric families of high-dimensional Kolmogorov partial differential equations via deep learning." +""" struct NNParamKolmogorov{C, O} <: HighDimPDEAlgorithm chain::C opt::O @@ -6,6 +16,24 @@ end NNParamKolmogorov(chain; opt = Flux.ADAM(0.1)) = NNParamKolmogorov(chain, opt) +""" +$(TYPEDSIGNATURES) + +Returns a `PIDESolution` object. + +# Arguments + +- `sdealg`: a SDE solver from [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/solvers/sde_solve/). + If not provided, the plain vanilla [DeepBSDE](https://arxiv.org/abs/1707.02568) method will be applied. + If provided, the SDE associated with the PDE problem will be solved relying on + methods from DifferentialEquations.jl, using [Ensemble solves](https://diffeq.sciml.ai/stable/features/ensemble/) + via `sdealg`. Check the available `sdealg` on the + [DifferentialEquations.jl doc](https://diffeq.sciml.ai/stable/solvers/sde_solve/). +- `maxiters`: The number of training epochs. Defaults to `300` +- `trajectories`: The number of trajectories simulated for training. Defaults to `100` +- `dps::NamedTuple`: The sampling interval for ranges of parameters. Should have keys : `p_sigma`, 'p_mu` and `p_phi` +- Extra keyword arguments passed to `solve` will be further passed to the SDE solver. +""" function DiffEqBase.solve(prob::ParabolicPDEProblem, pdealg::NNParamKolmogorov, sdealg = EM(); @@ -38,8 +66,11 @@ function DiffEqBase.solve(prob::ParabolicPDEProblem, xspan[1]:dx:xspan[2] end - p_domain = prob.kwargs.p_domain - p_prototype = prob.kwargs.p_prototype + p_defaults = (p_sigma = nothing, p_mu = nothing, p_phi = nothing) + + p_domain = merge(p_defaults, prob.kwargs.p_domain) + p_prototype = merge(p_defaults, prob.kwargs.p_prototype) + dps = merge(p_defaults, dps) chain = pdealg.chain ps = Flux.params(chain) From e55a1b47886b9d361042e7871c6f40bb9d80a306 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 12 Feb 2024 22:35:33 +0530 Subject: [PATCH 19/22] fix: fix bug in indexing solutions in NNParamKolmogorov --- src/NNParamKolmogorov.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NNParamKolmogorov.jl b/src/NNParamKolmogorov.jl index f134eef..401e5dc 100644 --- a/src/NNParamKolmogorov.jl +++ b/src/NNParamKolmogorov.jl @@ -115,7 +115,7 @@ function DiffEqBase.solve(prob::ParabolicPDEProblem, noise_rate_prototype = noise_rate_prototype) end - output_func = (sol, i) -> (sol[end], false) + output_func = (sol, i) -> (sol.u[end], false) sdeprob = SDEProblem(mu, sigma, From f0bfef04763d8ef77ddfbb5952e9a4ca6f81db95 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 12 Feb 2024 22:37:21 +0530 Subject: [PATCH 20/22] test: replace `pdf(d,x)` with `map(Base.Fix1(pdf, d), x)` --- test/NNKolmogorov.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/NNKolmogorov.jl b/test/NNKolmogorov.jl index 0be5266..f30483f 100644 --- a/test/NNKolmogorov.jl +++ b/test/NNKolmogorov.jl @@ -14,7 +14,7 @@ tspan = (0.0, 1.0) d = 1 sdealg = EM() -g(x) = pdf(u0, x) +g(x) = map(Base.Fix1(pdf, u0), x) prob = ParabolicPDEProblem(μ, σ, nothing, tspan; g, xspan) opt = Flux.Optimisers.Adam(0.01) m = Chain(Dense(1, 5, elu), Dense(5, 5, elu), Dense(5, 5, elu), Dense(5, 1)) @@ -24,8 +24,7 @@ sol = solve(prob, NNKolmogorov(m, opt), sdealg; ensemblealg = ensemblealg, verbo abstol = 1e-10, dx = 0.0001, trajectories = 100000, maxiters = 500) ## The solution is obtained taking the Fourier Transform. -analytical(xi) = pdf.(Normal(3, sqrt(1.0 + 5.00)), xi) -##Validation +analytical(xi) = map(Base.Fix1(pdf, Normal(3, sqrt(1.0 + 5.00))), xi)##Validation xs = -5:0.00001:5 x_1 = rand(xs, 1, 1000) err_l2 = Flux.mse(analytical(x_1), sol.ufuns(x_1)) From 8fb87fe946ed8d1ba821299bf34f705a69ba7017 Mon Sep 17 00:00:00 2001 From: Ashutosh Bharambe Date: Mon, 12 Feb 2024 22:45:04 +0530 Subject: [PATCH 21/22] docs: bump HighDimPDE version to 2 --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index c092516..c6d1296 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ HighDimPDE = "57c578d5-59d4-4db8-a490-a9fc372d19d2" [compat] Documenter = "1" Flux = "0.13, 0.14" -HighDimPDE = "1.2" +HighDimPDE = "2" From aaa400d6d359ca05fa4186a13ba3c2a73a9696d6 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Thu, 15 Feb 2024 16:09:14 +0530 Subject: [PATCH 22/22] docs: update docs --- docs/src/NNKolmogorov.md | 8 +++++++- docs/src/NNParamKolmogorov.md | 8 +++++++- src/NNKolmogorov.jl | 2 +- src/NNParamKolmogorov.jl | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/src/NNKolmogorov.md b/docs/src/NNKolmogorov.md index 61ce630..73d554b 100644 --- a/docs/src/NNKolmogorov.md +++ b/docs/src/NNKolmogorov.md @@ -8,11 +8,17 @@ Modules = [HighDimPDE] Pages = ["NNKolmogorov.jl"] ``` -`NNKolmogorov` obtains a terminal solution for Backward Kolmogorov Equations of the form: +`NNKolmogorov` obtains a +- terminal solution for Forward Kolmogorov Equations of the form: ```math \partial_t u(t,x) = \mu(t, x) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) ``` with initial condition given by `g(x)` +- or an initial condition for Backward Kolmogorov Equations of the form: +```math +\partial_t u(t,x) = - \mu(t, x) \nabla_x u(t,x) - \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) +``` +with terminal condition given by `g(x)` We can use the Feynman-Kac formula : ```math diff --git a/docs/src/NNParamKolmogorov.md b/docs/src/NNParamKolmogorov.md index 4b865ca..5635f8c 100644 --- a/docs/src/NNParamKolmogorov.md +++ b/docs/src/NNParamKolmogorov.md @@ -8,11 +8,17 @@ Modules = [HighDimPDE] Pages = ["NNParamKolmogorov.jl"] ``` -`NNParamKolmogorov` obtains a terminal solution for parametric families of Backward Kolmogorov Equations of the form: +`NNParamKolmogorov` obtains a +- terminal solution for parametric families of Forward Kolmogorov Equations of the form: ```math \partial_t u(t,x) = \mu(t, x, γ_mu) \nabla_x u(t,x) + \frac{1}{2} \sigma^2(t, x, γ_sigma) \Delta_x u(t,x) ``` with initial condition given by `g(x, γ_phi)` +- or an initial condition for parametric families of Backward Kolmogorov Equations of the form: +```math +\partial_t u(t,x) = - \mu(t, x) \nabla_x u(t,x) - \frac{1}{2} \sigma^2(t, x) \Delta_x u(t,x) +``` +with terminal condition given by `g(x, γ_phi)` We can use the Feynman-Kac formula : ```math diff --git a/src/NNKolmogorov.jl b/src/NNKolmogorov.jl index e8599ca..6eedd1a 100644 --- a/src/NNKolmogorov.jl +++ b/src/NNKolmogorov.jl @@ -1,5 +1,5 @@ """ -Algorithm for solving Backward Kolmogorov Equations. +Algorithm for solving Kolmogorov Equations. ```julia HighDimPDE.NNKolmogorov(chain, opt) diff --git a/src/NNParamKolmogorov.jl b/src/NNParamKolmogorov.jl index 401e5dc..71fd368 100644 --- a/src/NNParamKolmogorov.jl +++ b/src/NNParamKolmogorov.jl @@ -1,5 +1,5 @@ """ -Algorithm for solving Backward Kolmogorov Equations. +Algorithm for solving paramateric families of Kolmogorov Equations. ```julia HighDimPDE.NNKolmogorov(chain, opt)