Skip to content

Commit

Permalink
test: add tests for NNParamKolmogorov
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-b-b committed Jan 30, 2024
1 parent c1ba558 commit a158d47
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
63 changes: 63 additions & 0 deletions test/NNParamKolmogorov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test, Flux
using StochasticDiffEq
using LinearAlgebra
using HighDimPDE
using Random
Random.seed!(100)

d = 1
m = Chain(Dense(3, 16, tanh), Dense(16, 16, tanh), Dense(16, 5, tanh), Dense(5, 1))
ensemblealg = EnsembleThreads()
γ_mu_prototype = nothing
γ_sigma_prototype = zeros(d, d, 1)
γ_phi_prototype = nothing

sdealg = EM()
tspan = (0.00, 1.00)
trajectories = 10000
function phi(x, y_phi)
x .^ 2
end
sigma(dx, x, γ_sigma, t) = dx .= γ_sigma[:, :, 1]
mu(dx, x, γ_mu, t) = dx .= 0.00

xspan = (0.00, 3.00)

p_domain = (p_sigma = (0.00, 2.00), p_mu = nothing, p_phi = nothing)
p_prototype = (p_sigma = γ_sigma_prototype, p_mu = γ_mu_prototype, p_phi = γ_phi_prototype)
dps = (p_sigma = 0.01, p_mu = nothing, p_phi = nothing)

dt = 0.01
dx = 0.01
opt = Flux.ADAM(1e-2)

prob = PIDEProblem(phi,
mu,
sigma,
tspan,
xspan;
p_domain = p_domain,
p_prototype = p_prototype)

sol = solve(prob, HighDimPDE.NNParamKolmogorov(m, opt), sdealg, verbose = true, dt = 0.01,
abstol = 1e-10, dx = 0.01, trajectories = trajectories, maxiters = 1000,
use_gpu = false, dps = dps)

x_test = rand(xspan[1]:dx:xspan[2], d, 1, 1000)
t_test = rand(tspan[1]:dt:tspan[2], 1, 1000)
γ_sigma_test = rand(0.3:(dps.p_sigma):0.5, d, d, 1, 1000)

function analytical(x, t, y)
return x .^ 2 .+ t .* (y .* y)
end

preds = map((i) -> sol.ufuns(x_test[:, :, i],
t_test[:, i],
γ_sigma_test[:, :, :, i],
nothing,
nothing),
1:1000)
y_test = map((i) -> analytical(x_test[:, :, i], t_test[:, i], γ_sigma_test[:, :, :, i]),
1:1000)

@test Flux.mse(reduce(hcat, preds), reduce(hcat, y_test)) < 0.1
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ using SafeTestsets, Test
@time @safetestset "Deep Splitting" include("DeepSplitting.jl")
@time @safetestset "MC Sample" include("MCSample.jl")
@time @safetestset "NNKolmogorov" include("NNKolmogorov.jl")
@time @safetestset "NNKolmogorov" include("NNParamKolmogorov.jl")
end

0 comments on commit a158d47

Please sign in to comment.