-
Notifications
You must be signed in to change notification settings - Fork 8
/
snopt_integration.patch
94 lines (79 loc) · 3.86 KB
/
snopt_integration.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
diff --git a/torchdiffeq/_impl/adjoint.py b/torchdiffeq/_impl/adjoint.py
index ca9ba7f..001a53d 100644
--- a/torchdiffeq/_impl/adjoint.py
+++ b/torchdiffeq/_impl/adjoint.py
@@ -5,6 +5,8 @@ from .odeint import SOLVERS, odeint
from .misc import _check_inputs, _flat_to_shape
from .misc import _mixed_norm
+from snopt import SNOptAdjointCollector
+
class OdeintAdjointMethod(torch.autograd.Function):
@@ -109,6 +111,8 @@ class OdeintAdjointMethod(torch.autograd.Function):
# Solve adjoint ODE #
##################################
+ snopt_collector = SNOptAdjointCollector(func) if adjoint_options['use_snopt'] else None
+
if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
@@ -123,9 +127,14 @@ class OdeintAdjointMethod(torch.autograd.Function):
time_vjps[i] = dLd_cur_t
# Run the augmented system backwards in time.
+ samp_t = t[i - 1:i + 1].flip(0)
+
+ if snopt_collector:
+ adjoint_options, samp_t = snopt_collector.check_inputs(adjoint_options, samp_t)
+
aug_state = odeint(
augmented_dynamics, tuple(aug_state),
- t[i - 1:i + 1].flip(0),
+ samp_t,
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py
index 6915f2b..4a79f59 100644
--- a/torchdiffeq/_impl/solvers.py
+++ b/torchdiffeq/_impl/solvers.py
@@ -5,7 +5,7 @@ from .misc import _handle_unused_kwargs
class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
- def __init__(self, dtype, y0, norm, **unused_kwargs):
+ def __init__(self, dtype, y0, norm, snopt_collector=None, **unused_kwargs):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs
@@ -13,6 +13,7 @@ class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
self.dtype = dtype
self.norm = norm
+ self.snopt_collector = snopt_collector
def _before_integrate(self, t):
pass
@@ -28,6 +29,8 @@ class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
self._before_integrate(t)
for i in range(1, len(t)):
solution[i] = self._advance(t[i])
+ if self.snopt_collector:
+ self.snopt_collector.call_invoke(self.func, t[i], solution[i])
return solution
@@ -48,7 +51,7 @@ class AdaptiveStepsizeEventODESolver(AdaptiveStepsizeODESolver, metaclass=abc.AB
class FixedGridODESolver(metaclass=abc.ABCMeta):
order: int
- def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs):
+ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, snopt_collector=None, **unused_kwargs):
self.atol = unused_kwargs.pop('atol')
unused_kwargs.pop('rtol', None)
unused_kwargs.pop('norm', None)
@@ -62,6 +65,7 @@ class FixedGridODESolver(metaclass=abc.ABCMeta):
self.step_size = step_size
self.interp = interp
self.perturb = perturb
+ self.snopt_collector = snopt_collector
if step_size is None:
if grid_constructor is None:
@@ -113,6 +117,8 @@ class FixedGridODESolver(metaclass=abc.ABCMeta):
solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j])
else:
raise ValueError(f"Unknown interpolation method {self.interp}")
+ if self.snopt_collector:
+ self.snopt_collector.call_invoke(self.func, t[j], solution[j])
j += 1
y0 = y1