From a158d47edc2a6dc2b9af99f6a38f875ffd838622 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 30 Jan 2024 11:30:52 +0530 Subject: [PATCH] test: add tests for NNParamKolmogorov --- test/NNParamKolmogorov.jl | 63 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 64 insertions(+) create mode 100644 test/NNParamKolmogorov.jl diff --git a/test/NNParamKolmogorov.jl b/test/NNParamKolmogorov.jl new file mode 100644 index 0000000..0751326 --- /dev/null +++ b/test/NNParamKolmogorov.jl @@ -0,0 +1,63 @@ +using Test, Flux +using StochasticDiffEq +using LinearAlgebra +using HighDimPDE +using Random +Random.seed!(100) + +d = 1 +m = Chain(Dense(3, 16, tanh), Dense(16, 16, tanh), Dense(16, 5, tanh), Dense(5, 1)) +ensemblealg = EnsembleThreads() +γ_mu_prototype = nothing +γ_sigma_prototype = zeros(d, d, 1) +γ_phi_prototype = nothing + +sdealg = EM() +tspan = (0.00, 1.00) +trajectories = 10000 +function phi(x, y_phi) + x .^ 2 +end +sigma(dx, x, γ_sigma, t) = dx .= γ_sigma[:, :, 1] +mu(dx, x, γ_mu, t) = dx .= 0.00 + +xspan = (0.00, 3.00) + +p_domain = (p_sigma = (0.00, 2.00), p_mu = nothing, p_phi = nothing) +p_prototype = (p_sigma = γ_sigma_prototype, p_mu = γ_mu_prototype, p_phi = γ_phi_prototype) +dps = (p_sigma = 0.01, p_mu = nothing, p_phi = nothing) + +dt = 0.01 +dx = 0.01 +opt = Flux.ADAM(1e-2) + +prob = PIDEProblem(phi, + mu, + sigma, + tspan, + xspan; + p_domain = p_domain, + p_prototype = p_prototype) + +sol = solve(prob, HighDimPDE.NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01, + abstol = 1e-10, dx = 0.01, trajectories = trajectories, maxiters = 1000, + use_gpu = false, dps = dps) + +x_test = rand(xspan[1]:dx:xspan[2], d, 1, 1000) +t_test = rand(tspan[1]:dt:tspan[2], 1, 1000) +γ_sigma_test = rand(0.3:(dps.p_sigma):0.5, d, d, 1, 1000) + +function analytical(x, t, y) + return x .^ 2 .+ t .* (y .* y) +end + +preds = map((i) -> sol.ufuns(x_test[:, :, i], + t_test[:, i], + γ_sigma_test[:, :, :, i], + nothing, + nothing), + 1:1000) +y_test = map((i) -> analytical(x_test[:, :, i], t_test[:, i], γ_sigma_test[:, :, :, i]), + 1:1000) + +@test Flux.mse(reduce(hcat, preds), reduce(hcat, y_test)) < 0.1 diff --git a/test/runtests.jl b/test/runtests.jl index 8f4caa3..00fab0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ using SafeTestsets, Test @time @safetestset "Deep Splitting" include("DeepSplitting.jl") @time @safetestset "MC Sample" include("MCSample.jl") @time @safetestset "NNKolmogorov" include("NNKolmogorov.jl") + @time @safetestset "NNKolmogorov" include("NNParamKolmogorov.jl") end