Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concrete_solve: How to detect if solution was succesful? Or how to get retcode? #234

Closed
scheidan opened this issue Apr 20, 2020 · 6 comments

Comments

@scheidan
Copy link

During optimization we may hit an unstable region in the parameter space. The classic workaround is to return -Inf as nicely documented here. Unfortunately concrete_solve doesn't give a retcode.
So I'm wondering:

  • How can we "catch" potential instability warnings in the loss functions?
  • Would retuning -Inf be a good idea at all? It seems to me that this would break gradient descent.

Thanks a lot!

@ChrisRackauckas
Copy link
Member

I think we need to find a nice safe way for concrete_solve to return the original solution object, but in a safe way so that if someone wants to interpolate it properly errors (since those gradients won't propagate). Basically, instead of going to DiffEqArray, it should scrub sol.k and change the interpolant to linear (since that's okay for derivatives) and return this linear interpolation only solution. IT's not too difficult to do, but it would take some work and we'd want to test it well to not leave users with tricky edge cases that give 0 gradients.

@ChrisRackauckas
Copy link
Member

SciML/DifferentialEquations.jl#610 is towards a real solution. I'll close this and we can track there.

@ArnoStrouwen
Copy link
Member

Did getting rid of concrete solve fix this issue entirely?

using DiffEqSensitivity
using OrdinaryDiffEq
using ReverseDiff
using Zygote

function fiip(du,u,p,t)
    du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
    du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end

function sum_of_solution(p)
    _prob = remake(prob,p=p)
    sol = solve(_prob,Tsit5(),saveat=0.1,sensealg=QuadratureAdjoint())
    if sol.retcode != :Success
        return p[1]
    end
    sum(sol)
end
p = [1.5,1.0,3.0,1.0]; u0 = [NaN;1.0]
prob = ODEProblem(fiip,u0,(0.0,10.0),p)
dp = ReverseDiff.gradient(sum_of_solution,p) # does not work
dp = Zygote.gradient(sum_of_solution,p) # works

@ChrisRackauckas
Copy link
Member

Yes it did.

@ChrisRackauckas
Copy link
Member

Oh, that last piece is an upstream issue: ReverseDiff.jl cannot allow derivative rules to return general structs, so we have to return an array. That's a ReverseDiff.jl issue though.

@ChrisRackauckas
Copy link
Member

concrete_solve is gone, so the retcode exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants