diff --git a/pypesto/optimize/ess/sacess.py b/pypesto/optimize/ess/sacess.py index 7b8cde97d..34b5941ec 100644 --- a/pypesto/optimize/ess/sacess.py +++ b/pypesto/optimize/ess/sacess.py @@ -7,7 +7,9 @@ import time from multiprocessing import Manager, Process from multiprocessing.managers import SyncManager -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from uuid import uuid1 from warnings import warn import numpy as np @@ -49,6 +51,7 @@ def __init__( max_walltime_s: float = np.inf, sacess_loglevel: int = logging.INFO, ess_loglevel: int = logging.WARNING, + tmpdir: Union[Path, str] = None, ): """Construct. @@ -75,6 +78,11 @@ def __init__( Loglevel for ESS runs. sacess_loglevel: Loglevel for SACESS runs. + tmpdir: + Directory for temporary files. Defaults to a directory in the current + working directory named ``SacessOptimizerTemp-{random suffix}``. + When setting this option, make sure any optimizers running in parallel + have unique tmpdirs. """ if (num_workers is None and ess_init_args is None) or ( num_workers is not None and ess_init_args is not None @@ -96,6 +104,13 @@ def __init__( self.sacess_loglevel = sacess_loglevel logger.setLevel(self.sacess_loglevel) + self._tmpdir = tmpdir + if self._tmpdir is None: + while self._tmpdir is None or self._tmpdir.exists(): + self._tmpdir = Path(f"SacessOptimizerTemp-{str(uuid1())[:8]}") + self._tmpdir = Path(self._tmpdir).absolute() + self._tmpdir.mkdir(parents=True, exist_ok=True) + def minimize( self, problem: Problem, @@ -141,12 +156,15 @@ def minimize( SacessWorker( manager=sacess_manager, ess_kwargs=ess_kwargs, - worker_idx=i, + worker_idx=worker_idx, max_walltime_s=self.max_walltime_s, loglevel=self.sacess_loglevel, ess_loglevel=self.ess_loglevel, + tmp_result_file=SacessWorker.get_temp_result_filename( + worker_idx, self._tmpdir + ), ) - for i, ess_kwargs in enumerate(ess_init_args) + for worker_idx, ess_kwargs in enumerate(ess_init_args) ] # launch worker processes worker_processes = [ @@ -187,7 +205,7 @@ def _create_result(self, problem: Problem) -> pypesto.Result: result = None for worker_idx in range(self.num_workers): tmp_result_filename = SacessWorker.get_temp_result_filename( - worker_idx + worker_idx, self._tmpdir ) try: tmp_result = read_result( @@ -210,10 +228,16 @@ def _create_result(self, problem: Problem) -> pypesto.Result: ) # delete temporary files only after successful consolidation - for filename in map( - SacessWorker.get_temp_result_filename, range(self.num_workers) - ): + for worker_idx in range(self.num_workers): + filename = SacessWorker.get_temp_result_filename( + worker_idx, self._tmpdir + ) os.remove(filename) + # delete tmpdir if empty + try: + self._tmpdir.rmdir() + except OSError: + pass result.optimize_result.sort() @@ -380,6 +404,7 @@ class SacessWorker: _logger: A Logger instance. _loglevel: Logging level for sacess _ess_loglevel: Logging level for ESS runs + _tmp_result_file: Path of a temporary file to be created. """ def __init__( @@ -390,6 +415,7 @@ def __init__( max_walltime_s: float = np.inf, loglevel: int = logging.INFO, ess_loglevel: int = logging.WARNING, + tmp_result_file: str = None, ): self._manager = manager self._worker_idx = worker_idx @@ -404,6 +430,7 @@ def __init__( self._loglevel = loglevel self._ess_loglevel = ess_loglevel self._logger = None + self._tmp_result_file = tmp_result_file def run( self, @@ -458,12 +485,13 @@ def run( ess_results.optimize_result.list = ( ess_results.optimize_result.list[:50] ) - write_result( - ess_results, - self.get_temp_result_filename(self._worker_idx), - overwrite=True, - optimize=True, - ) + if self._tmp_result_file: + write_result( + ess_results, + self._tmp_result_file, + overwrite=True, + optimize=True, + ) # check if the best solution of the last local ESS is sufficiently # better than the sacess-wide best solution self.maybe_update_best(ess.x_best, ess.fx_best) @@ -608,8 +636,10 @@ def _keep_going(self): return True @staticmethod - def get_temp_result_filename(worker_idx: int) -> str: - return f"sacess-{worker_idx:02d}_tmp.h5" + def get_temp_result_filename( + worker_idx: int, tmpdir: Union[str, Path] + ) -> str: + return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute()) def _run_worker(