Skip to content

Commit

Permalink
remove PINO code
Browse files Browse the repository at this point in the history
  • Loading branch information
YichengDWu committed Aug 21, 2023
1 parent bb8af16 commit 8d254c4
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 176 deletions.
100 changes: 0 additions & 100 deletions docs/src/tutorials/burgers.jl

This file was deleted.

61 changes: 0 additions & 61 deletions src/pde/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,6 @@ function build_loss_function(pde_system::PDESystem, pinn::PINN,
return pde_and_bcs_loss_function
end

#=
function build_loss_function(pde_system::ParametricPDESystem, pinn::PINN,
strategy::AbstractTrainingAlg, coord_branch_net;
derivative=finitediff)
(; eqs, bcs, ivs, dvs, pvs) = pde_system
(; phi, init_params) = pinn
depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(ivs, dvs)
_, _, _, dict_pmdepvars, dict_pmdepvar_input = get_vars(ivs, pvs)
multioutput = false
pinnrep = (; eqs, bcs, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input,
dict_pmdepvars, dict_pmdepvar_input, multioutput, pvs, init_params, pinn,
derivative, strategy, fdtype, coord_branch_net)
datafree_pde_loss_functions = Tuple(build_loss_function(pinnrep, first(eq), i)
for (i, eq) in enumerate(eqs))
datafree_bc_loss_functions = Tuple(build_loss_function(pinnrep, first(bc),
i +
length(datafree_pde_loss_functions))
for (i, bc) in enumerate(bcs))
pde_and_bcs_loss_function = scalarize(strategy, phi, datafree_pde_loss_functions,
datafree_bc_loss_functions)
return pde_and_bcs_loss_function
end
=#

"""
discretize(pde_system::PDESystem, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg; derivative=finitediff,
Expand Down Expand Up @@ -122,37 +92,6 @@ function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
return Optimization.OptimizationProblem(f, init_params, datasets)
end

# ParametricPDESystem no long supported
#=function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg, functionsampler::FunctionSampler,
coord_branch_net::AbstractArray;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
fdtype=Float64,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler)
init_params = Lux.fmap(Base.Fix1(broadcast, fdtype), pinn.init_params)
init_params = _ComponentArray(init_params)
datasets = map(Base.Fix1(broadcast, fdtype), datasets)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, CuArray), datasets) : datasets
pfs = sample(functionsampler)
coord_branch_net = coord_branch_net isa Union{AbstractVector, StepRangeLen} ?
[coord_branch_net] : coord_branch_net
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
coord_branch_net, derivative,
fdtype)
function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(pinn.phi, θ)
end
f = OptimizationFunction(full_loss_function, adtype)
p = PINOParameterHandler(datasets, pfs)
return Optimization.OptimizationProblem(f, init_params, p)
end
=#

function symbolic_discretize(pde_system, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
Expand Down
15 changes: 0 additions & 15 deletions src/pde/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,3 @@ function Base.show(io::IO, ::MIME"text/plain", sys::ParametricPDESystem)
println(io, "Parametric Variables: ", sys.pvs)
return nothing
end

mutable struct PINOParameterHandler
coords::Any
fs::Any
end

get_local_ps(p::PINOParameterHandler) = p.coords
get_global_ps(p::PINOParameterHandler) = p.fs
Base.getindex(p::PINOParameterHandler, i) = getindex(p.coords, i)

@inline get_local_ps(p::Vector{<:AbstractMatrix}) = p
@inline get_global_ps(::Vector{<:AbstractMatrix}) = nothing

ChainRulesCore.@non_differentiable get_local_ps(::Any...)
ChainRulesCore.@non_differentiable get_global_ps(::Any...)

0 comments on commit 8d254c4

Please sign in to comment.