Skip to content

Commit

Permalink
Add compression of data at the save to reduce the size of the file (#694
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Antoine-Averland authored Dec 12, 2024
1 parent 047f722 commit 5722846
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
7 changes: 4 additions & 3 deletions smt/surrogate_models/tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
class TestSaveLoad(unittest.TestCase):
def test_save_load_GEKPLS(self):
filename = "sm_save_test"
fun = Sphere(ndim=2)
ndim = 2
fun = Sphere(ndim=ndim)

sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(20)
yt = fun(xt)

for i in range(2):
for i in range(ndim):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

Expand All @@ -43,7 +44,7 @@ def test_save_load_GEKPLS(self):

sm = GEKPLS(print_global=False)
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
for i in range(ndim):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()
for i in range(X.shape[0]):
Expand Down
12 changes: 9 additions & 3 deletions smt/utils/persistence.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pickle
import zlib


def save(self, filename):
serialized_data = pickle.dumps(self, protocol=5)
compressed_data = zlib.compress(serialized_data)
with open(filename, "wb") as file:
pickle.dump(self, file)
file.write(compressed_data)


def load(filename):
sm = None
with open(filename, "rb") as file:
sm = pickle.load(file)
compressed_data = file.read()

serialized_data = zlib.decompress(compressed_data)
sm = pickle.loads(serialized_data)

return sm

0 comments on commit 5722846

Please sign in to comment.