From 9bf64f600d7b896bc8db660d9268eef3f1d4cc3a Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Tue, 10 Dec 2024 14:15:14 +0100 Subject: [PATCH] save / load for the surrogates using cpp --- smt/surrogate_models/idw.py | 18 ++++++ smt/surrogate_models/rbf.py | 18 ++++++ smt/surrogate_models/rmtb.py | 10 ++++ smt/surrogate_models/rmtc.py | 10 ++++ smt/surrogate_models/rmts.py | 8 +++ smt/surrogate_models/tests/test_save_load.py | 61 +++++++++++++++++++- 6 files changed, 124 insertions(+), 1 deletion(-) diff --git a/smt/surrogate_models/idw.py b/smt/surrogate_models/idw.py index aa08b3776..83c40ee67 100644 --- a/smt/surrogate_models/idw.py +++ b/smt/surrogate_models/idw.py @@ -9,6 +9,7 @@ from smt.surrogate_models.idwclib import PyIDW from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.utils.caching import cached_operation @@ -47,6 +48,16 @@ def _setup(self): self.idwc = PyIDW() self.idwc.setup(nx, nt, self.options["p"], xt.flatten()) + def __getstate__(self): + state = self.__dict__.copy() + state["idwc"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + self._setup() + ############################################################################ # Model functions ############################################################################ @@ -132,3 +143,10 @@ def _predict_output_derivatives(self, x): dy_dyt = {None: jac} return dy_dyt + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/surrogate_models/rbf.py b/smt/surrogate_models/rbf.py index 024299901..1a2e50065 100644 --- a/smt/surrogate_models/rbf.py +++ b/smt/surrogate_models/rbf.py @@ -9,6 +9,7 @@ from smt.surrogate_models.rbfclib import PyRBF from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.utils.caching import cached_operation from smt.utils.linear_solvers import get_solver @@ -95,6 +96,16 @@ def _setup(self): xt.flatten(), ) + def __getstate__(self): + state = self.__dict__.copy() + state["rbfc"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + self._setup() + def _new_train(self): num = self.num @@ -213,3 +224,10 @@ def _predict_output_derivatives(self, x): dy_dyt = (dytl_dyt.T.dot(dstates_dytl.T).dot(dy_dstates.T)).T dy_dyt = np.einsum("ij,k->ijk", dy_dyt, np.ones(ny)) return {None: dy_dyt} + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/surrogate_models/rmtb.py b/smt/surrogate_models/rmtb.py index 7ce62274e..d04944153 100644 --- a/smt/surrogate_models/rmtb.py +++ b/smt/surrogate_models/rmtb.py @@ -96,6 +96,16 @@ def _setup(self): np.array(num["ctrl_list"], np.int32), ) + def __getstate__(self): + state = self.__dict__.copy() + state["rmtsc"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + self._setup() + def _compute_jac_raw(self, ix1, ix2, x): xlimits = self.options["xlimits"] diff --git a/smt/surrogate_models/rmtc.py b/smt/surrogate_models/rmtc.py index 9ec69ec9e..db038771d 100644 --- a/smt/surrogate_models/rmtc.py +++ b/smt/surrogate_models/rmtc.py @@ -91,6 +91,16 @@ def _setup(self): np.array(num["term_list"], np.int32), ) + def __getstate__(self): + state = self.__dict__.copy() + state["rmtsc"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + self._setup() + def _compute_jac_raw(self, ix1, ix2, x): n = x.shape[0] nnz = n * self.num["term"] diff --git a/smt/surrogate_models/rmts.py b/smt/surrogate_models/rmts.py index 9dc903c44..7064614fc 100644 --- a/smt/surrogate_models/rmts.py +++ b/smt/surrogate_models/rmts.py @@ -10,6 +10,7 @@ import scipy.sparse from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.utils.caching import cached_operation from smt.utils.line_search import VALID_LINE_SEARCHES, LineSearch, get_line_search_class from smt.utils.linear_solvers import VALID_SOLVERS, LinearSolver, get_solver @@ -583,3 +584,10 @@ def _predict_output_derivatives(self, x): dy_dyt[kx - 1] = np.einsum("ij,jkl->ikl", dy_dw, dw_dyt) return dy_dyt + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/surrogate_models/tests/test_save_load.py b/smt/surrogate_models/tests/test_save_load.py index 81fe94ce9..f1a6d8e49 100644 --- a/smt/surrogate_models/tests/test_save_load.py +++ b/smt/surrogate_models/tests/test_save_load.py @@ -4,7 +4,21 @@ from smt.problems import Sphere from smt.sampling_methods import LHS -from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN +from smt.surrogate_models import ( + KRG, + LS, + KPLS, + GEKPLS, + KPLSK, + MGP, + QP, + SGP, + GENN, + RBF, + RMTB, + RMTC, + IDW, +) class TestSaveLoad(unittest.TestCase): @@ -82,6 +96,51 @@ def test_save_load_surrogates(self): os.remove(filename) + def test_save_load_surrogates_cpp(self): + surrogates_cpp = [RBF, RMTC, RMTB, IDW] + + num = 100 + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + x = np.linspace(0.0, 4.0, num) + xlimits = np.array([[0.0, 4.0]]) + + filename = "sm_save_test" + + for surrogate in surrogates_cpp: + if surrogate == RMTB: + sm = RMTB( + xlimits=xlimits, + order=4, + num_ctrl_pts=20, + energy_weight=1e-15, + regularization_weight=0.0, + ) + elif surrogate == RMTC: + sm = RMTC( + xlimits=xlimits, + num_elements=6, + energy_weight=1e-15, + regularization_weight=0.0, + ) + elif surrogate == RBF: + sm = RBF(d0=5) + else: + sm = IDW(p=2) + + sm.set_training_values(xt, yt) + sm.train() + + y1 = sm.predict_values(x) + sm.save(filename) + + sm2 = surrogate.load(filename) + y2 = sm2.predict_values(x) + + np.testing.assert_allclose(y1, y2) + + os.remove(filename) + if __name__ == "__main__": unittest.main()