Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SacessOptimizer: tmpdir option #1115

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import time
from multiprocessing import Manager, Process
from multiprocessing.managers import SyncManager
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import uuid1
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(
max_walltime_s: float = np.inf,
sacess_loglevel: int = logging.INFO,
ess_loglevel: int = logging.WARNING,
tmpdir: Union[Path, str] = None,
):
"""Construct.

Expand All @@ -75,6 +78,11 @@ def __init__(
Loglevel for ESS runs.
sacess_loglevel:
Loglevel for SACESS runs.
tmpdir:
Directory for temporary files. Defaults to a directory in the current
working directory named ``SacessOptimizerTemp-{random suffix}``.
When setting this option, make sure any optimizers running in parallel
have unique tmpdirs.
"""
if (num_workers is None and ess_init_args is None) or (
num_workers is not None and ess_init_args is not None
Expand All @@ -96,6 +104,13 @@ def __init__(
self.sacess_loglevel = sacess_loglevel
logger.setLevel(self.sacess_loglevel)

self._tmpdir = tmpdir
if self._tmpdir is None:
while self._tmpdir is None or self._tmpdir.exists():
self._tmpdir = Path(f"SacessOptimizerTemp-{str(uuid1())[:8]}")
self._tmpdir = Path(self._tmpdir).absolute()
self._tmpdir.mkdir(parents=True, exist_ok=True)

def minimize(
self,
problem: Problem,
Expand Down Expand Up @@ -141,12 +156,15 @@ def minimize(
SacessWorker(
manager=sacess_manager,
ess_kwargs=ess_kwargs,
worker_idx=i,
worker_idx=worker_idx,
max_walltime_s=self.max_walltime_s,
loglevel=self.sacess_loglevel,
ess_loglevel=self.ess_loglevel,
tmp_result_file=SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
),
)
for i, ess_kwargs in enumerate(ess_init_args)
for worker_idx, ess_kwargs in enumerate(ess_init_args)
]
# launch worker processes
worker_processes = [
Expand Down Expand Up @@ -187,7 +205,7 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
result = None
for worker_idx in range(self.num_workers):
tmp_result_filename = SacessWorker.get_temp_result_filename(
worker_idx
worker_idx, self._tmpdir
)
try:
tmp_result = read_result(
Expand All @@ -210,10 +228,16 @@ def _create_result(self, problem: Problem) -> pypesto.Result:
)

# delete temporary files only after successful consolidation
for filename in map(
SacessWorker.get_temp_result_filename, range(self.num_workers)
):
for worker_idx in range(self.num_workers):
filename = SacessWorker.get_temp_result_filename(
worker_idx, self._tmpdir
)
os.remove(filename)
# delete tmpdir if empty
try:
self._tmpdir.rmdir()
except OSError:
pass

result.optimize_result.sort()

Expand Down Expand Up @@ -380,6 +404,7 @@ class SacessWorker:
_logger: A Logger instance.
_loglevel: Logging level for sacess
_ess_loglevel: Logging level for ESS runs
_tmp_result_file: Path of a temporary file to be created.
"""

def __init__(
Expand All @@ -390,6 +415,7 @@ def __init__(
max_walltime_s: float = np.inf,
loglevel: int = logging.INFO,
ess_loglevel: int = logging.WARNING,
tmp_result_file: str = None,
):
self._manager = manager
self._worker_idx = worker_idx
Expand All @@ -404,6 +430,7 @@ def __init__(
self._loglevel = loglevel
self._ess_loglevel = ess_loglevel
self._logger = None
self._tmp_result_file = tmp_result_file

def run(
self,
Expand Down Expand Up @@ -458,12 +485,13 @@ def run(
ess_results.optimize_result.list = (
ess_results.optimize_result.list[:50]
)
write_result(
ess_results,
self.get_temp_result_filename(self._worker_idx),
overwrite=True,
optimize=True,
)
if self._tmp_result_file:
write_result(
ess_results,
self._tmp_result_file,
overwrite=True,
optimize=True,
)
# check if the best solution of the last local ESS is sufficiently
# better than the sacess-wide best solution
self.maybe_update_best(ess.x_best, ess.fx_best)
Expand Down Expand Up @@ -608,8 +636,10 @@ def _keep_going(self):
return True

@staticmethod
def get_temp_result_filename(worker_idx: int) -> str:
return f"sacess-{worker_idx:02d}_tmp.h5"
def get_temp_result_filename(
worker_idx: int, tmpdir: Union[str, Path]
) -> str:
return str(Path(tmpdir, f"sacess-{worker_idx:02d}_tmp.h5").absolute())


def _run_worker(
Expand Down