Skip to content

Commit

Permalink
Make logging compatible with sqlalchemy 2.0. (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg authored Feb 7, 2023
1 parent b651439 commit 9acf723
Show file tree
Hide file tree
Showing 21 changed files with 634 additions and 669 deletions.
2 changes: 1 addition & 1 deletion .envs/testenv-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies:
- plotly # run, tests
- pybaum >= 0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy <2.0 # run, tests
- sqlalchemy # run, tests
- pip: # dev, tests, docs
- DFO-LS # dev, tests
- Py-BOBYQA # dev, tests
Expand Down
2 changes: 1 addition & 1 deletion .envs/testenv-others.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- plotly # run, tests
- pybaum >= 0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy <2.0 # run, tests
- sqlalchemy # run, tests
- pip: # dev, tests, docs
- DFO-LS # dev, tests
- Py-BOBYQA # dev, tests
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
- plotly # run, tests
- pybaum >= 0.1.2 # run, tests
- scipy>=1.2.1 # run, tests
- sqlalchemy <2.0 # run, tests
- sqlalchemy # run, tests
- pydata-sphinx-theme>=0.3.0 # docs
- myst-nb # docs
- sphinx # docs
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ filterwarnings = [
"ignore:Widget._widget_types is deprecated",
"ignore:Widget.widget_types is deprecated",
"ignore:Widget.widgets is deprecated",
"ignore:Deprecated API features detected",
]
addopts = ["--doctest-modules"]
markers = [
Expand Down
2 changes: 1 addition & 1 deletion src/estimagic/dashboard/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from estimagic.logging.database_utilities import read_new_rows, transpose_nested_list
from estimagic.logging.read_from_database import read_new_rows, transpose_nested_list


def reset_and_start_convergence(
Expand Down
5 changes: 3 additions & 2 deletions src/estimagic/dashboard/dashboard_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from estimagic.dashboard.callbacks import reset_and_start_convergence
from estimagic.dashboard.plot_functions import plot_time_series
from estimagic.logging.database_utilities import load_database, read_last_rows
from estimagic.logging.load_database import load_database
from estimagic.logging.read_from_database import read_last_rows
from estimagic.logging.read_log import read_start_params
from estimagic.parameters.parameter_groups import get_params_groups_and_short_names
from estimagic.parameters.tree_registry import get_registry
Expand Down Expand Up @@ -42,7 +43,7 @@ def dashboard_app(
doc.template = env.get_template("index.html")

# process inputs
database = load_database(path=session_data["database_path"])
database = load_database(path_or_database=session_data["database_path"])
start_point = _calculate_start_point(database, updating_options)
session_data["last_retrieved"] = start_point

Expand Down
104 changes: 104 additions & 0 deletions src/estimagic/logging/create_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import sqlalchemy as sql

from estimagic.exceptions import TableExistsError
from estimagic.logging.load_database import RobustPickler


def make_optimization_iteration_table(database, if_exists="extend"):
"""Generate a table for information that is generated with each function evaluation.
Args:
database (DataBase): DataBase object containing the engine and metadata.
if_exists (str): What to do if the table already exists. Can be "extend",
"replace" or "raise".
Returns:
database (sqlalchemy.MetaData):Bound metadata object with added table.
"""
table_name = "optimization_iterations"
_handle_existing_table(database, "optimization_iterations", if_exists)

columns = [
sql.Column("rowid", sql.Integer, primary_key=True),
sql.Column("params", sql.PickleType(pickler=RobustPickler)),
sql.Column("internal_derivative", sql.PickleType(pickler=RobustPickler)),
sql.Column("timestamp", sql.Float),
sql.Column("exceptions", sql.String),
sql.Column("valid", sql.Boolean),
sql.Column("hash", sql.String),
sql.Column("value", sql.Float),
sql.Column("step", sql.Integer),
sql.Column("criterion_eval", sql.PickleType(pickler=RobustPickler)),
]

sql.Table(
table_name,
database.metadata,
*columns,
sqlite_autoincrement=True,
extend_existing=True,
)

database.metadata.create_all(database.engine)


def make_steps_table(database, if_exists="extend"):
table_name = "steps"
_handle_existing_table(database, table_name, if_exists)
columns = [
sql.Column("rowid", sql.Integer, primary_key=True),
sql.Column("type", sql.String), # e.g. optimization
sql.Column("status", sql.String), # e.g. running
sql.Column("n_iterations", sql.Integer), # optional
sql.Column(
"name", sql.String
), # e.g. "optimization-1", "exploration", not unique
]
sql.Table(
table_name,
database.metadata,
*columns,
extend_existing=True,
sqlite_autoincrement=True,
)
database.metadata.create_all(database.engine)


def make_optimization_problem_table(database, if_exists="extend"):
table_name = "optimization_problem"
_handle_existing_table(database, table_name, if_exists)

columns = [
sql.Column("rowid", sql.Integer, primary_key=True),
sql.Column("direction", sql.String),
sql.Column("params", sql.PickleType(pickler=RobustPickler)),
sql.Column("algorithm", sql.PickleType(pickler=RobustPickler)),
sql.Column("algo_options", sql.PickleType(pickler=RobustPickler)),
sql.Column("numdiff_options", sql.PickleType(pickler=RobustPickler)),
sql.Column("log_options", sql.PickleType(pickler=RobustPickler)),
sql.Column("error_handling", sql.String),
sql.Column("error_penalty", sql.PickleType(pickler=RobustPickler)),
sql.Column("constraints", sql.PickleType(pickler=RobustPickler)),
sql.Column("free_mask", sql.PickleType(pickler=RobustPickler)),
]

sql.Table(
table_name,
database.metadata,
*columns,
extend_existing=True,
sqlite_autoincrement=True,
)

database.metadata.create_all(database.engine)


def _handle_existing_table(database, table_name, if_exists):
assert if_exists in ["replace", "extend", "raise"]

if table_name in database.metadata.tables:
if if_exists == "replace":
database.metadata.tables[table_name].drop(database.engine)
elif if_exists == "raise":
raise TableExistsError(f"The table {table_name} already exists.")
Loading

0 comments on commit 9acf723

Please sign in to comment.