Skip to content

Commit

Permalink
Adapt read tests for logging to new SQLiteLogger
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Aug 9, 2024
1 parent 07ba628 commit 8c88dfa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
37 changes: 20 additions & 17 deletions tests/optimagic/logging/test_read_log.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import asdict

import numpy as np
import pandas as pd
import pytest
from optimagic.logging.read_log import (
OptimizeLogReader,
read_optimization_problem_table,
read_start_params,
read_steps_table,
)

from optimagic.logging.logger import SQLiteLogger

from optimagic.optimization.optimize import minimize
from optimagic.parameters.tree_registry import get_registry
from pybaum import tree_equal, tree_just_flatten
Expand All @@ -30,18 +29,18 @@ def _crit(params):


def test_read_start_params(example_db):
res = read_start_params(example_db)
res = SQLiteLogger(example_db).read_start_params()
assert res == {"a": 1, "b": 2, "c": 3}


def test_log_reader_read_start_params(example_db):
reader = OptimizeLogReader(example_db)
reader = SQLiteLogger(example_db)
res = reader.read_start_params()
assert res == {"a": 1, "b": 2, "c": 3}


def test_log_reader_read_iteration(example_db):
reader = OptimizeLogReader(example_db)
reader = SQLiteLogger(example_db)
first_row = reader.read_iteration(0)
assert first_row["params"] == {"a": 1, "b": 2, "c": 3}
assert first_row["rowid"] == 1
Expand All @@ -53,15 +52,15 @@ def test_log_reader_read_iteration(example_db):


def test_log_reader_read_history(example_db):
reader = OptimizeLogReader(example_db)
reader = SQLiteLogger(example_db)
res = reader.read_history()
assert res["runtime"][0] == 0
assert res["criterion"][0] == 14
assert res["params"][0] == {"a": 1, "b": 2, "c": 3}


def test_log_reader_read_multistart_history(example_db):
reader = OptimizeLogReader(example_db)
reader = SQLiteLogger(example_db)
history, local_history, exploration = reader.read_multistart_history(
direction="minimize"
)
Expand All @@ -70,24 +69,28 @@ def test_log_reader_read_multistart_history(example_db):

registry = get_registry(extended=True)
assert tree_equal(
tree_just_flatten(history, registry=registry),
tree_just_flatten(reader.read_history(), registry=registry),
tree_just_flatten(asdict(history), registry=registry),
tree_just_flatten(asdict(reader.read_history()), registry=registry),
)


def test_read_steps_table(example_db):
res = read_steps_table(example_db)
res = SQLiteLogger(example_db).step_store.to_df()
assert isinstance(res, pd.DataFrame)
assert res.loc[0, "rowid"] == 1
assert res.loc[0, "type"] == "optimization"
assert res.loc[0, "status"] == "complete"


def test_read_optimization_problem_table(example_db):
res = read_optimization_problem_table(example_db)
res = SQLiteLogger(example_db).problem_store.to_df()
assert isinstance(res, pd.DataFrame)


def test_non_existing_database_raises_error():
# TODO: db file is created at instantiation of the logger, decide how to handle
# empty tables. By now, the logger methods may raise unspecific errors
# (like IndexError)
@pytest.mark.skip
def test_non_existing_database_raises_error(tmp_path):
with pytest.raises(FileNotFoundError):
read_start_params("i_do_not_exist.db")
SQLiteLogger(tmp_path / "i_do_not_exist.db").read_start_params()
3 changes: 2 additions & 1 deletion tests/optimagic/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def _crit(params):


def test_estimagic_log_reader_is_deprecated(example_db):
msg = "estimagic.OptimizeLogReader has been deprecated"
msg = "OptimizeLogReader is deprecated and will be removed in a future "
"version. Please use optimagic.logging.SQLiteLogger instead."
with pytest.warns(FutureWarning, match=msg):
OptimizeLogReader(example_db)

Expand Down

0 comments on commit 8c88dfa

Please sign in to comment.