diff --git a/.envs/testenv-linux.yml b/.envs/testenv-linux.yml index 205b92b63..12d83b468 100644 --- a/.envs/testenv-linux.yml +++ b/.envs/testenv-linux.yml @@ -23,7 +23,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pip: # dev, tests, docs - DFO-LS # dev, tests - Py-BOBYQA # dev, tests diff --git a/.envs/testenv-others.yml b/.envs/testenv-others.yml index 995208d74..d35153d2e 100644 --- a/.envs/testenv-others.yml +++ b/.envs/testenv-others.yml @@ -22,7 +22,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pip: # dev, tests, docs - DFO-LS # dev, tests - Py-BOBYQA # dev, tests diff --git a/environment.yml b/environment.yml index 4915f898d..79e11d0ab 100644 --- a/environment.yml +++ b/environment.yml @@ -29,7 +29,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pydata-sphinx-theme>=0.3.0 # docs - myst-nb # docs - sphinx # docs diff --git a/pyproject.toml b/pyproject.toml index 35d6fffaf..3d4f4bb88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,6 @@ filterwarnings = [ "ignore:Widget._widget_types is deprecated", "ignore:Widget.widget_types is deprecated", "ignore:Widget.widgets is deprecated", - "ignore:Deprecated API features detected", ] addopts = ["--doctest-modules"] markers = [ diff --git a/src/estimagic/dashboard/callbacks.py b/src/estimagic/dashboard/callbacks.py index b461da25d..85e70a25a 100644 --- a/src/estimagic/dashboard/callbacks.py +++ b/src/estimagic/dashboard/callbacks.py @@ -2,7 +2,7 @@ import numpy as np -from estimagic.logging.database_utilities import read_new_rows, transpose_nested_list +from estimagic.logging.read_from_database import read_new_rows, transpose_nested_list def reset_and_start_convergence( diff --git a/src/estimagic/dashboard/dashboard_app.py b/src/estimagic/dashboard/dashboard_app.py index c5b9ec254..987933774 100644 --- a/src/estimagic/dashboard/dashboard_app.py +++ b/src/estimagic/dashboard/dashboard_app.py @@ -11,7 +11,8 @@ from estimagic.dashboard.callbacks import reset_and_start_convergence from estimagic.dashboard.plot_functions import plot_time_series -from estimagic.logging.database_utilities import load_database, read_last_rows +from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import read_last_rows from estimagic.logging.read_log import read_start_params from estimagic.parameters.parameter_groups import get_params_groups_and_short_names from estimagic.parameters.tree_registry import get_registry @@ -42,7 +43,7 @@ def dashboard_app( doc.template = env.get_template("index.html") # process inputs - database = load_database(path=session_data["database_path"]) + database = load_database(path_or_database=session_data["database_path"]) start_point = _calculate_start_point(database, updating_options) session_data["last_retrieved"] = start_point diff --git a/src/estimagic/logging/create_tables.py b/src/estimagic/logging/create_tables.py new file mode 100644 index 000000000..28f349ecc --- /dev/null +++ b/src/estimagic/logging/create_tables.py @@ -0,0 +1,104 @@ +import sqlalchemy as sql + +from estimagic.exceptions import TableExistsError +from estimagic.logging.load_database import RobustPickler + + +def make_optimization_iteration_table(database, if_exists="extend"): + """Generate a table for information that is generated with each function evaluation. + + Args: + database (DataBase): DataBase object containing the engine and metadata. + if_exists (str): What to do if the table already exists. Can be "extend", + "replace" or "raise". + + Returns: + database (sqlalchemy.MetaData):Bound metadata object with added table. + + """ + table_name = "optimization_iterations" + _handle_existing_table(database, "optimization_iterations", if_exists) + + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("internal_derivative", sql.PickleType(pickler=RobustPickler)), + sql.Column("timestamp", sql.Float), + sql.Column("exceptions", sql.String), + sql.Column("valid", sql.Boolean), + sql.Column("hash", sql.String), + sql.Column("value", sql.Float), + sql.Column("step", sql.Integer), + sql.Column("criterion_eval", sql.PickleType(pickler=RobustPickler)), + ] + + sql.Table( + table_name, + database.metadata, + *columns, + sqlite_autoincrement=True, + extend_existing=True, + ) + + database.metadata.create_all(database.engine) + + +def make_steps_table(database, if_exists="extend"): + table_name = "steps" + _handle_existing_table(database, table_name, if_exists) + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("type", sql.String), # e.g. optimization + sql.Column("status", sql.String), # e.g. running + sql.Column("n_iterations", sql.Integer), # optional + sql.Column( + "name", sql.String + ), # e.g. "optimization-1", "exploration", not unique + ] + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, + ) + database.metadata.create_all(database.engine) + + +def make_optimization_problem_table(database, if_exists="extend"): + table_name = "optimization_problem" + _handle_existing_table(database, table_name, if_exists) + + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("direction", sql.String), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("algorithm", sql.PickleType(pickler=RobustPickler)), + sql.Column("algo_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("numdiff_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("log_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("error_handling", sql.String), + sql.Column("error_penalty", sql.PickleType(pickler=RobustPickler)), + sql.Column("constraints", sql.PickleType(pickler=RobustPickler)), + sql.Column("free_mask", sql.PickleType(pickler=RobustPickler)), + ] + + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, + ) + + database.metadata.create_all(database.engine) + + +def _handle_existing_table(database, table_name, if_exists): + assert if_exists in ["replace", "extend", "raise"] + + if table_name in database.metadata.tables: + if if_exists == "replace": + database.metadata.tables[table_name].drop(database.engine) + elif if_exists == "raise": + raise TableExistsError(f"The table {table_name} already exists.") diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/database_utilities.py deleted file mode 100644 index d05525e1f..000000000 --- a/src/estimagic/logging/database_utilities.py +++ /dev/null @@ -1,554 +0,0 @@ -"""Functions to generate, load, write to and read from databases. - -The functions here are meant for internal use in estimagic, e.g. for logging during -the optimization and reading from the database in the dashboard. They do not require -detailed knowledge of databases in general but some knowledge of the schema -(e.g. table names) of the database we use for logging. - -Therefore, users who simply want to read the database should use the functions in -``read_log.py`` instead. - -""" -import io -import traceback -import warnings -from pathlib import Path - -import cloudpickle -import pandas as pd -from sqlalchemy import ( - BLOB, - Boolean, - Column, - Float, - Integer, - MetaData, - PickleType, - String, - Table, - and_, - create_engine, - event, - update, -) - -from estimagic.exceptions import TableExistsError, get_traceback - - -def load_database(metadata=None, path=None, fast_logging=False): - """Return a bound sqlalchemy.MetaData object for the database stored in ``path``. - - This is the only acceptable way of creating or loading databases in estimagic! - - If metadata is a bound MetaData object, it is just returned. If metadata is given - but not bound, we bind it to an engine that connects to the database stored under - ``path``. If only the path is provided, we generate an appropriate MetaData object - and bind it to the database. - - For speed reasons we do not make any checks that MetaData is compatible with the - database stored under path. - - Args: - metadata (sqlalchemy.MetaData): MetaData object that might or might not be - bound to the database under path. In any case it needs to be compatible - with the database stored under ``path``. For speed reasons, this is not - checked. - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. - - Returns: - metadata (sqlalchemy.MetaData). MetaData object that is bound to the database - under ``path``. - - """ - path = Path(path) if isinstance(path, str) else path - - if isinstance(metadata, MetaData): - if metadata.bind is None: - assert ( - path is not None - ), "If metadata is not bound, you need to provide a path." - engine = create_engine(f"sqlite:///{path}") - _configure_engine(engine, fast_logging) - metadata.bind = engine - elif metadata is None: - assert path is not None, "If metadata is None you need to provide a path." - path_existed = path.exists() - engine = create_engine(f"sqlite:///{path}") - _configure_engine(engine, fast_logging) - metadata = MetaData() - metadata.bind = engine - if path_existed: - _configure_reflect() - metadata.reflect() - else: - raise ValueError("metadata must be sqlalchemy.MetaData or None.") - - return metadata - - -def make_optimization_iteration_table(database, if_exists="extend"): - """Generate a table for information that is generated with each function evaluation. - - Args: - database (sqlalchemy.MetaData): Bound metadata object. - - Returns: - database (sqlalchemy.MetaData):Bound metadata object with added table. - - """ - table_name = "optimization_iterations" - _handle_existing_table(database, "optimization_iterations", if_exists) - - columns = [ - Column("rowid", Integer, primary_key=True), - Column("params", PickleType(pickler=RobustPickler)), - Column("internal_derivative", PickleType(pickler=RobustPickler)), - Column("timestamp", Float), - Column("exceptions", String), - Column("valid", Boolean), - Column("hash", String), - Column("value", Float), - Column("step", Integer), - Column("criterion_eval", PickleType(pickler=RobustPickler)), - ] - - Table( - table_name, database, *columns, sqlite_autoincrement=True, extend_existing=True - ) - - database.create_all(database.bind) - - -def make_steps_table(database, if_exists="extend"): - table_name = "steps" - _handle_existing_table(database, table_name, if_exists) - columns = [ - Column("rowid", Integer, primary_key=True), - Column("type", String), # e.g. optimization - Column("status", String), # e.g. running - Column("n_iterations", Integer), # optional - Column("name", String), # e.g. "optimization-1", "exploration", not unique - ] - Table( - table_name, database, *columns, extend_existing=True, sqlite_autoincrement=True - ) - database.create_all(database.bind) - - -def make_optimization_problem_table(database, if_exists="extend"): - table_name = "optimization_problem" - _handle_existing_table(database, table_name, if_exists) - - columns = [ - Column("rowid", Integer, primary_key=True), - Column("direction", String), - Column("params", PickleType(pickler=RobustPickler)), - Column("algorithm", PickleType(pickler=RobustPickler)), - Column("algo_options", PickleType(pickler=RobustPickler)), - Column("numdiff_options", PickleType(pickler=RobustPickler)), - Column("log_options", PickleType(pickler=RobustPickler)), - Column("error_handling", String), - Column("error_penalty", PickleType(pickler=RobustPickler)), - Column("constraints", PickleType(pickler=RobustPickler)), - Column("free_mask", PickleType(pickler=RobustPickler)), - ] - - Table( - table_name, database, *columns, extend_existing=True, sqlite_autoincrement=True - ) - - database.create_all(database.bind) - - -def _handle_existing_table(database, table_name, if_exists): - assert if_exists in ["replace", "extend", "raise"] - - if table_name in database.tables: - if if_exists == "replace": - database.tables[table_name].drop(database.bind) - elif if_exists == "raise": - raise TableExistsError(f"The table {table_name} already exists.") - - -def update_row(data, rowid, table_name, database, path, fast_logging): - database = load_database(database, path, fast_logging) - - table = database.tables[table_name] - stmt = update(table).where(table.c.rowid == rowid).values(**data) - - _execute_write_statement(stmt, database, path, table_name, data) - - -def append_row(data, table_name, database, path, fast_logging): - """ - - Args: - data (dict): The keys correspond to columns in the database table. - table_name (str): Name of the database table to which the row is added. - database (sqlalchemy.MetaData): Sqlachlemy metadata object of the database. - path (str or pathlib.Path): Path to the database file. Using a path is much - slower than a MetaData object and we advise to only use it as a fallback. - fast_logging (bool) - - """ - # this is necessary because database.bind gets lost when the database is pickled. - # it has no cost when database.bind is set. - database = load_database(database, path, fast_logging) - - stmt = database.tables[table_name].insert().values(**data) - - _execute_write_statement(stmt, database, path, table_name, data) - - -def _execute_write_statement( - statement, database, path, table_name, data # noqa: ARG001 -): - try: - # this will automatically roll back the transaction if any exception is raised - # and then raise the exception - with database.bind.begin() as connection: - connection.execute(statement) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - exception_info = traceback.format_exc() - warnings.warn( - f"Unable to write to database. The traceback was:\n\n{exception_info}" - ) - - -def read_new_rows( - database, - table_name, - last_retrieved, - return_type, - path=None, - fast_logging=False, - limit=None, - stride=1, - step=None, -): - """Read all iterations after last_retrieved up to a limit. - - Args: - database (sqlalchemy.MetaData) - table_name (str): name of the table to retrieve. - last_retrieved (int): The last iteration that was retrieved. - return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. Using a path is much slower than a - MetaData object and we advise to only use it as a fallback. - fast_logging (bool) - limit (int): maximum number of rows to extract from the table. - stride (int): Only return every n-th row. Default is every row (stride=1). - step (int): Only return iterations that belong to step. - - Returns: - result (return_type): up to limit rows after last_retrieved of the - `table_name` table as `return_type`. - int: The new last_retrieved value. - - """ - database = load_database(database, path, fast_logging) - last_retrieved = int(last_retrieved) - limit = int(limit) if limit is not None else limit - - table = database.tables[table_name] - stmt = table.select().where(table.c.rowid > last_retrieved).limit(limit) - conditions = [table.c.rowid > last_retrieved] - - if stride != 1: - conditions.append(table.c.rowid % stride == 0) - - if step is not None: - conditions.append(table.c.step == int(step)) - - stmt = table.select().where(and_(*conditions)).limit(limit) - - data = _execute_read_statement(database, table_name, stmt, return_type) - - if return_type == "list_of_dicts": - new_last = data[-1]["rowid"] if data else last_retrieved - else: - new_last = data["rowid"][-1] if data["rowid"] else last_retrieved - - return data, new_last - - -def read_last_rows( - database, - table_name, - n_rows, - return_type, - path=None, - fast_logging=False, - stride=1, - step=None, -): - """Read the last n_rows rows from a table. - - If a table has less than n_rows rows, the whole table is returned. - - Args: - database (sqlalchemy.MetaData) - table_name (str): name of the table to retrieve. - n_rows (int): number of rows to retrieve. - return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. Using a path is much slower than a - MetaData object and we advise to only use it as a fallback. - fast_logging (bool) - stride (int): Only return every n-th row. Default is every row (stride=1). - step (int): Only return rows that belong to step. - - Returns: - result (return_type): the last rows of the `table_name` table as `return_type`. - - """ - database = load_database(database, path, fast_logging) - n_rows = int(n_rows) - - table = database.tables[table_name] - - conditions = [] - - if stride != 1: - conditions.append(table.c.rowid % stride == 0) - - if step is not None: - conditions.append(table.c.step == int(step)) - - if conditions: - stmt = ( - table.select() - .order_by(table.c.rowid.desc()) - .where(and_(*conditions)) - .limit(n_rows) - ) - else: - stmt = table.select().order_by(table.c.rowid.desc()).limit(n_rows) - - reversed_ = _execute_read_statement(database, table_name, stmt, return_type) - if return_type == "list_of_dicts": - out = reversed_[::-1] - else: - out = {key: val[::-1] for key, val in reversed_.items()} - - return out - - -def read_specific_row( - database, table_name, rowid, return_type, path=None, fast_logging=False -): - """Read a specific row from a table. - - Args: - database (sqlalchemy.MetaData) - table_name (str): name of the table to retrieve. - n_rows (int): number of rows to retrieve. - return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. - Using a path is much slower than a MetaData object and we - advise to only use it as a fallback. - fast_logging (bool) - - Returns: - dict or list: The requested row from the database. - - """ - database = load_database(database, path, fast_logging) - rowid = int(rowid) - table = database.tables[table_name] - stmt = table.select().where(table.c.rowid == rowid) - data = _execute_read_statement(database, table_name, stmt, return_type) - return data - - -def read_table(database, table_name, return_type, path=None, fast_logging=False): - database = load_database(database, path, fast_logging) - table = database.tables[table_name] - stmt = table.select() - data = _execute_read_statement(database, table_name, stmt, return_type) - return data - - -def _execute_read_statement(database, table_name, statement, return_type): - try: - with database.bind.begin() as connection: - raw_result = list(connection.execute(statement)) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - exception_info = traceback.format_exc() - warnings.warn( - "Unable to read {table_name} from database. Try again later. The traceback " - f"was: \n\n{exception_info}" - ) - # if we only want to warn we must provide a raw_result to be processed below. - raw_result = [] - - columns = database.tables[table_name].columns.keys() - - if return_type == "list_of_dicts": - result = [dict(zip(columns, row)) for row in raw_result] - - elif return_type == "dict_of_lists": - raw_result = transpose_nested_list(raw_result) - result = dict(zip(columns, raw_result)) - if result == {}: - result = {col: [] for col in columns} - else: - raise NotImplementedError( - "The return_type must be 'list_of_dicts' or 'dict_of_lists', " - f"not {return_type}." - ) - - return result - - -def transpose_nested_list(nested_list): - """Transpose a list of lists. - - Args: - nested_list (list): Nested list where all sublists have the same length. - - Returns: - list - - Examples: - >>> transpose_nested_list([[1, 2], [3, 4]]) - [[1, 3], [2, 4]] - - """ - return list(map(list, zip(*nested_list))) - - -def list_of_dicts_to_dict_of_lists(list_of_dicts): - """Convert a list of dicts to a dict of lists. - - Args: - list_of_dicts (list): List of dictionaries. All dictionaries have the same keys. - - Returns: - dict - - Examples: - >>> list_of_dicts_to_dict_of_lists([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) - {'a': [1, 3], 'b': [2, 4]} - - """ - return {k: [dic[k] for dic in list_of_dicts] for k in list_of_dicts[0]} - - -def dict_of_lists_to_list_of_dicts(dict_of_lists): - """Convert a dict of lists to a list of dicts. - - Args: - dict_of_lists (dict): Dictionary of lists where all lists have the same length. - - Returns: - list - - Examples: - - >>> dict_of_lists_to_list_of_dicts({'a': [1, 3], 'b': [2, 4]}) - [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - - """ - return [dict(zip(dict_of_lists, t)) for t in zip(*dict_of_lists.values())] - - -def _configure_engine(engine, fast_logging): - """Configure the sqlite engine. - - The two functions that configure the emission of the begin statement are taken from - the sqlalchemy documentation the documentation: - https://tinyurl.com/u9xea5z - and are - the recommended way of working around a bug in the pysqlite driver. - - The other function speeds up the write process. If fast_logging is False, it does so - using only completely safe optimizations. Of fast_logging is True, it also uses - unsafe optimizations. - - """ - - @event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): # noqa: ARG001 - # disable pysqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before absolutely necessary. - dbapi_connection.isolation_level = None - - @event.listens_for(engine, "begin") - def do_begin(conn): - # emit our own BEGIN - conn.execute("BEGIN DEFERRED") - - @event.listens_for(engine, "connect") - def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA journal_mode = WAL") - if fast_logging: - cursor.execute("PRAGMA synchronous = OFF") - else: - cursor.execute("PRAGMA synchronous = NORMAL") - cursor.close() - - -def _configure_reflect(): - """Mark all BLOB dtypes as PickleType with our custom pickle reader. - - Code ist taken from the documentation: https://tinyurl.com/y7q287jr - - """ - - @event.listens_for(Table, "column_reflect") - def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 - if isinstance(column_info["type"], BLOB): - column_info["type"] = PickleType(pickler=RobustPickler) - - -class RobustPickler: - @staticmethod - def loads( - data, - fix_imports=True, # noqa: ARG004 - encoding="ASCII", # noqa: ARG004 - errors="strict", # noqa: ARG004 - buffers=None, # noqa: ARG004 - ): - """Robust pickle loading. - - We first try to unpickle the object with pd.read_pickle. This makes no - difference for non-pandas objects but makes the de-serialization - of pandas objects more robust across pandas versions. If that fails, we use - cloudpickle. If that fails, we return None but do not raise an error. - - See: https://github.com/pandas-dev/pandas/issues/16474 - - """ - try: - res = pd.read_pickle(io.BytesIO(data), compression=None) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - try: - res = cloudpickle.loads(data) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - res = None - tb = get_traceback() - warnings.warn( - f"Unable to read PickleType column from database:\n{tb}\n " - "The entry was replaced by None." - ) - - return res - - @staticmethod - def dumps( - obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 - ): - return cloudpickle.dumps(obj, protocol=protocol) diff --git a/src/estimagic/logging/load_database.py b/src/estimagic/logging/load_database.py new file mode 100644 index 000000000..b309f1196 --- /dev/null +++ b/src/estimagic/logging/load_database.py @@ -0,0 +1,161 @@ +import io +import warnings + +import cloudpickle +import pandas as pd +import sqlalchemy as sql + +from estimagic.exceptions import get_traceback + + +class DataBase: + """Class containing everything to work with a logging database. + + Importantly, the class is pickle-serializable which is important to share it across + multiple processes. Upon unpickling, it will automatically re-create an engine to + connect to the database. + + """ + + def __init__(self, metadata, path, fast_logging, engine=None): + self.metadata = metadata + self.path = path + self.fast_logging = fast_logging + if engine is None: + self.engine = _create_engine(path, fast_logging) + else: + self.engine = engine + + def __reduce__(self): + return (DataBase, (self.metadata, self.path, self.fast_logging)) + + +def load_database(path_or_database, fast_logging=False): + """Load or create a database from a path and configure it for our needs. + + This is the only acceptable way of loading or creating a database in estimagic! + + Args: + path_or_database (str or pathlib.Path): Path to the database or DataBase. + fast_logging (bool): If True, use unsafe optimizations to speed up the logging. + If False, only use ultra safe optimizations. + + Returns: + database (Database): Object containing everything to work with the + database. + + """ + if isinstance(path_or_database, DataBase): + out = path_or_database + else: + engine = _create_engine(path_or_database, fast_logging) + metadata = sql.MetaData() + _configure_reflect() + metadata.reflect(engine) + + out = DataBase( + metadata=metadata, + path=path_or_database, + fast_logging=fast_logging, + engine=engine, + ) + return out + + +def _create_engine(path, fast_logging): + engine = sql.create_engine(f"sqlite:///{path}") + _configure_engine(engine, fast_logging) + return engine + + +def _configure_engine(engine, fast_logging): + """Configure the sqlite engine. + + The two functions that configure the emission of the begin statement are taken from + the sqlalchemy documentation the documentation: https://tinyurl.com/u9xea5z and are + the recommended way of working around a bug in the pysqlite driver. + + The other function speeds up the write process. If fast_logging is False, it does so + using only completely safe optimizations. Of fast_logging is True, it also uses + unsafe optimizations. + + """ + + @sql.event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): # noqa: ARG001 + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before absolutely necessary. + dbapi_connection.isolation_level = None + + @sql.event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN DEFERRED") + + @sql.event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode = WAL") + if fast_logging: + cursor.execute("PRAGMA synchronous = OFF") + else: + cursor.execute("PRAGMA synchronous = NORMAL") + cursor.close() + + +def _configure_reflect(): + """Mark all BLOB dtypes as PickleType with our custom pickle reader. + + Code ist taken from the documentation: https://tinyurl.com/y7q287jr + + """ + + @sql.event.listens_for(sql.Table, "column_reflect") + def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 + if isinstance(column_info["type"], sql.BLOB): + column_info["type"] = sql.PickleType(pickler=RobustPickler) + + +class RobustPickler: + @staticmethod + def loads( + data, + fix_imports=True, # noqa: ARG004 + encoding="ASCII", # noqa: ARG004 + errors="strict", # noqa: ARG004 + buffers=None, # noqa: ARG004 + ): + """Robust pickle loading. + + We first try to unpickle the object with pd.read_pickle. This makes no + difference for non-pandas objects but makes the de-serialization + of pandas objects more robust across pandas versions. If that fails, we use + cloudpickle. If that fails, we return None but do not raise an error. + + See: https://github.com/pandas-dev/pandas/issues/16474 + + """ + try: + res = pd.read_pickle(io.BytesIO(data), compression=None) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + try: + res = cloudpickle.loads(data) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + res = None + tb = get_traceback() + warnings.warn( + f"Unable to read PickleType column from database:\n{tb}\n " + "The entry was replaced by None." + ) + + return res + + @staticmethod + def dumps( + obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 + ): + return cloudpickle.dumps(obj, protocol=protocol) diff --git a/src/estimagic/logging/read_from_database.py b/src/estimagic/logging/read_from_database.py new file mode 100644 index 000000000..af28ecf6e --- /dev/null +++ b/src/estimagic/logging/read_from_database.py @@ -0,0 +1,234 @@ +"""Functions to generate, load, write to and read from databases. + +The functions here are meant for internal use in estimagic, e.g. for logging during +the optimization and reading from the database in the dashboard. They do not require +detailed knowledge of databases in general but some knowledge of the schema +(e.g. table names) of the database we use for logging. + +Therefore, users who simply want to read the database should use the functions in +``read_log.py`` instead. + +""" +import traceback +import warnings + +import sqlalchemy as sql + + +def read_new_rows( + database, + table_name, + last_retrieved, + return_type, + limit=None, + stride=1, + step=None, +): + """Read all iterations after last_retrieved up to a limit. + + Args: + database (DataBase): Object containing everything to work with the database. + table_name (str): name of the table to retrieve. + last_retrieved (int): The last iteration that was retrieved. + return_type (str): either "list_of_dicts" or "dict_of_lists". + limit (int): maximum number of rows to extract from the table. + stride (int): Only return every n-th row. Default is every row (stride=1). + step (int): Only return iterations that belong to step. + + Returns: + result (return_type): up to limit rows after last_retrieved of the + `table_name` table as `return_type`. + int: The new last_retrieved value. + + """ + last_retrieved = int(last_retrieved) + limit = int(limit) if limit is not None else limit + + table = database.metadata.tables[table_name] + stmt = table.select().where(table.c.rowid > last_retrieved).limit(limit) + conditions = [table.c.rowid > last_retrieved] + + if stride != 1: + conditions.append(table.c.rowid % stride == 0) + + if step is not None: + conditions.append(table.c.step == int(step)) + + stmt = table.select().where(sql.and_(*conditions)).limit(limit) + + data = _execute_read_statement(database, table_name, stmt, return_type) + + if return_type == "list_of_dicts": + new_last = data[-1]["rowid"] if data else last_retrieved + else: + new_last = data["rowid"][-1] if data["rowid"] else last_retrieved + + return data, new_last + + +def read_last_rows( + database, + table_name, + n_rows, + return_type, + stride=1, + step=None, +): + """Read the last n_rows rows from a table. + + If a table has less than n_rows rows, the whole table is returned. + + Args: + database (DataBase) + table_name (str): name of the table to retrieve. + n_rows (int): number of rows to retrieve. + return_type (str): either "list_of_dicts" or "dict_of_lists". + stride (int): Only return every n-th row. Default is every row (stride=1). + step (int): Only return rows that belong to step. + + Returns: + result (return_type): the last rows of the `table_name` table as `return_type`. + + """ + n_rows = int(n_rows) + + table = database.metadata.tables[table_name] + + conditions = [] + + if stride != 1: + conditions.append(table.c.rowid % stride == 0) + + if step is not None: + conditions.append(table.c.step == int(step)) + + if conditions: + stmt = ( + table.select() + .order_by(table.c.rowid.desc()) + .where(sql.and_(*conditions)) + .limit(n_rows) + ) + else: + stmt = table.select().order_by(table.c.rowid.desc()).limit(n_rows) + + reversed_ = _execute_read_statement(database, table_name, stmt, return_type) + if return_type == "list_of_dicts": + out = reversed_[::-1] + else: + out = {key: val[::-1] for key, val in reversed_.items()} + + return out + + +def read_specific_row(database, table_name, rowid, return_type): + """Read a specific row from a table. + + Args: + database (sqlalchemy.MetaData) + table_name (str): name of the table to retrieve. + n_rows (int): number of rows to retrieve. + return_type (str): either "list_of_dicts" or "dict_of_lists". + + Returns: + dict or list: The requested row from the database. + + """ + rowid = int(rowid) + table = database.metadata.tables[table_name] + stmt = table.select().where(table.c.rowid == rowid) + data = _execute_read_statement(database, table_name, stmt, return_type) + return data + + +def read_table(database, table_name, return_type): + table = database.metadata.tables[table_name] + stmt = table.select() + data = _execute_read_statement(database, table_name, stmt, return_type) + return data + + +def _execute_read_statement(database, table_name, statement, return_type): + try: + with database.engine.begin() as connection: + raw_result = list(connection.execute(statement)) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + exception_info = traceback.format_exc() + warnings.warn( + "Unable to read {table_name} from database. Try again later. The traceback " + f"was: \n\n{exception_info}" + ) + # if we only want to warn we must provide a raw_result to be processed below. + raw_result = [] + + columns = database.metadata.tables[table_name].columns.keys() + + if return_type == "list_of_dicts": + result = [dict(zip(columns, row)) for row in raw_result] + + elif return_type == "dict_of_lists": + raw_result = transpose_nested_list(raw_result) + result = dict(zip(columns, raw_result)) + if result == {}: + result = {col: [] for col in columns} + else: + raise NotImplementedError( + "The return_type must be 'list_of_dicts' or 'dict_of_lists', " + f"not {return_type}." + ) + + return result + + +def transpose_nested_list(nested_list): + """Transpose a list of lists. + + Args: + nested_list (list): Nested list where all sublists have the same length. + + Returns: + list + + Examples: + >>> transpose_nested_list([[1, 2], [3, 4]]) + [[1, 3], [2, 4]] + + """ + return list(map(list, zip(*nested_list))) + + +def list_of_dicts_to_dict_of_lists(list_of_dicts): + """Convert a list of dicts to a dict of lists. + + Args: + list_of_dicts (list): List of dictionaries. All dictionaries have the same keys. + + Returns: + dict + + Examples: + >>> list_of_dicts_to_dict_of_lists([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + {'a': [1, 3], 'b': [2, 4]} + + """ + return {k: [dic[k] for dic in list_of_dicts] for k in list_of_dicts[0]} + + +def dict_of_lists_to_list_of_dicts(dict_of_lists): + """Convert a dict of lists to a list of dicts. + + Args: + dict_of_lists (dict): Dictionary of lists where all lists have the same length. + + Returns: + list + + Examples: + + >>> dict_of_lists_to_list_of_dicts({'a': [1, 3], 'b': [2, 4]}) + [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] + + """ + return [dict(zip(dict_of_lists, t)) for t in zip(*dict_of_lists.values())] diff --git a/src/estimagic/logging/read_log.py b/src/estimagic/logging/read_log.py index d8695bea3..5aafd1c09 100644 --- a/src/estimagic/logging/read_log.py +++ b/src/estimagic/logging/read_log.py @@ -14,10 +14,9 @@ import numpy as np import pandas as pd from pybaum import tree_flatten, tree_unflatten -from sqlalchemy import MetaData -from estimagic.logging.database_utilities import ( - load_database, +from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import ( read_last_rows, read_new_rows, read_specific_row, @@ -25,6 +24,14 @@ from estimagic.parameters.tree_registry import get_registry +def load_existing_database(path_or_database): + if isinstance(path_or_database, (Path, str)): + path = Path(path_or_database) + if not path.exists(): + raise FileNotFoundError(f"Database {path} does not exist.") + return load_database(path_or_database) + + def read_start_params(path_or_database): """Load the start parameters DataFrame. @@ -35,7 +42,7 @@ def read_start_params(path_or_database): params (pd.DataFrame): see :ref:`params`. """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) optimization_problem = read_last_rows( database=database, table_name="optimization_problem", @@ -46,21 +53,6 @@ def read_start_params(path_or_database): return start_params -def _load_database(path_or_database): - """Get an sqlalchemy.MetaDate object from path or database.""" - res = {"path": None, "metadata": None, "fast_logging": False} - if isinstance(path_or_database, MetaData): - res = path_or_database - elif isinstance(path_or_database, (Path, str)): - path = Path(path_or_database) - if not path.exists(): - raise FileNotFoundError(f"No such database file: {path}") - res = load_database(path=path) - else: - raise TypeError("path_or_database must be a path or sqlalchemy.MetaData object") - return res - - def read_steps_table(path_or_database): """Load the steps table. @@ -71,7 +63,7 @@ def read_steps_table(path_or_database): steps_df (pandas.DataFrame) """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) steps_table, _ = read_new_rows( database=database, table_name="steps", @@ -93,7 +85,7 @@ def read_optimization_problem_table(path_or_database): params (pd.DataFrame): see :ref:`params`. """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) steps_table, _ = read_new_rows( database=database, table_name="optimization_problem", @@ -112,7 +104,7 @@ class OptimizeLogReader: path: Union[str, Path] def __post_init__(self): - _database = _load_database(self.path) + _database = load_existing_database(self.path) _start_params = read_start_params(_database) _registry = get_registry(extended=True) _, _treedef = tree_flatten(_start_params, registry=_registry) diff --git a/src/estimagic/logging/write_to_database.py b/src/estimagic/logging/write_to_database.py new file mode 100644 index 000000000..b1e64c639 --- /dev/null +++ b/src/estimagic/logging/write_to_database.py @@ -0,0 +1,41 @@ +import traceback +import warnings + +import sqlalchemy as sql + + +def update_row(data, rowid, table_name, database): + table = database.metadata.tables[table_name] + stmt = sql.update(table).where(table.c.rowid == rowid).values(**data) + + _execute_write_statement(stmt, database) + + +def append_row(data, table_name, database): + """ + + Args: + data (dict): The keys correspond to columns in the database table. + table_name (str): Name of the database table to which the row is added. + database (DataBase): The database to which the row is added. + + """ + + stmt = database.metadata.tables[table_name].insert().values(**data) + + _execute_write_statement(stmt, database) + + +def _execute_write_statement(statement, database): + try: + # this will automatically roll back the transaction if any exception is raised + # and then raise the exception + with database.engine.begin() as connection: + connection.execute(statement) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + exception_info = traceback.format_exc() + warnings.warn( + f"Unable to write to database. The traceback was:\n\n{exception_info}" + ) diff --git a/src/estimagic/optimization/get_algorithm.py b/src/estimagic/optimization/get_algorithm.py index 4ec57c2fd..ba495a159 100644 --- a/src/estimagic/optimization/get_algorithm.py +++ b/src/estimagic/optimization/get_algorithm.py @@ -5,10 +5,10 @@ import numpy as np from estimagic.batch_evaluators import process_batch_evaluator -from estimagic.logging.database_utilities import ( +from estimagic.logging.read_from_database import ( list_of_dicts_to_dict_of_lists, - update_row, ) +from estimagic.logging.write_to_database import update_row from estimagic.optimization import ALL_ALGORITHMS from estimagic.utilities import propose_alternatives @@ -55,7 +55,7 @@ def get_final_algorithm( nonlinear_constraints, algo_options, logging, - db_kwargs, + database, collect_history, ): """Get algorithm-function with partialled options. @@ -74,7 +74,7 @@ def get_final_algorithm( algorithm. Entries that are not used by the algorithm are ignored with a warning. logging (bool): Whether the algorithm should do logging. - db_kwargs (dict): Dict with the entries "database", "path" and "fast_logging" + database (DataBase): Database to which the logging should be written. Returns: callable: The algorithm. @@ -96,7 +96,7 @@ def get_final_algorithm( algorithm = _add_logging( algorithm, logging=logging, - db_kwargs=db_kwargs, + database=database, ) is_parallel = internal_options.get("n_cores") not in (None, 1) @@ -110,7 +110,7 @@ def get_final_algorithm( return algorithm -def _add_logging(algorithm=None, *, logging=None, db_kwargs=None): +def _add_logging(algorithm=None, *, logging=None, database=None): """Add logging of status to the algorithm.""" def decorator_add_logging_to_algorithm(algorithm): @@ -125,7 +125,7 @@ def wrapper_add_logging_algorithm(**kwargs): data={"status": "running"}, rowid=step_id, table_name="steps", - **db_kwargs, + database=database, ) for task in ["criterion", "derivative", "criterion_and_derivative"]: @@ -141,7 +141,7 @@ def wrapper_add_logging_algorithm(**kwargs): data={"status": "complete"}, rowid=step_id, table_name="steps", - **db_kwargs, + database=database, ) return res diff --git a/src/estimagic/optimization/internal_criterion_template.py b/src/estimagic/optimization/internal_criterion_template.py index d8248f865..fa34bf4ae 100644 --- a/src/estimagic/optimization/internal_criterion_template.py +++ b/src/estimagic/optimization/internal_criterion_template.py @@ -3,7 +3,7 @@ from estimagic.differentiation.derivatives import first_derivative from estimagic.exceptions import UserFunctionRuntimeError, get_traceback -from estimagic.logging.database_utilities import append_row +from estimagic.logging.write_to_database import append_row from estimagic.parameters.conversion import aggregate_func_output_to_value @@ -19,7 +19,7 @@ def internal_criterion_and_derivative_template( criterion_and_derivative, numdiff_options, logging, - db_kwargs, + database, error_handling, error_penalty_func, fixed_log_data, @@ -68,7 +68,7 @@ def internal_criterion_and_derivative_template( derivatives. See :ref:`first_derivative` for details. Note that the default method is changed to "forward" for speed reasons. logging (bool): Whether logging is used. - db_kwargs (dict): Dictionary with entries "database", "path" and "fast_logging". + database (DataBase): Database to which the logs are written. error_handling (str): Either "raise" or "continue". Note that "continue" does not absolutely guarantee that no error is raised but we try to handle as many errors as possible in that case without aborting the optimization. @@ -228,7 +228,7 @@ def func(x): new_derivative=new_derivative, external_x=external_x, caught_exceptions=caught_exceptions, - db_kwargs=db_kwargs, + database=database, fixed_log_data=fixed_log_data, scalar_value=scalar_critval, now=now, @@ -307,7 +307,7 @@ def _log_new_evaluations( new_derivative, external_x, caught_exceptions, - db_kwargs, + database, fixed_log_data, scalar_value, now, @@ -338,7 +338,7 @@ def _log_new_evaluations( name = "optimization_iterations" - append_row(data, name, **db_kwargs) + append_row(data, name, database=database) def _get_output_for_optimizer( diff --git a/src/estimagic/optimization/optimization_logging.py b/src/estimagic/optimization/optimization_logging.py index b634dc307..50a65c600 100644 --- a/src/estimagic/optimization/optimization_logging.py +++ b/src/estimagic/optimization/optimization_logging.py @@ -1,7 +1,8 @@ -from estimagic.logging.database_utilities import append_row, read_last_rows, update_row +from estimagic.logging.read_from_database import read_last_rows +from estimagic.logging.write_to_database import append_row, update_row -def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): +def log_scheduled_steps_and_get_ids(steps, logging, database): """Add scheduled steps to the steps table of the database and get their ids. The ids are only determined once the steps are written to the database and the @@ -9,8 +10,8 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): Args: steps (list): List of dicts with entries for the steps table. - logging (bool): Whether to actually write to the databes. - db_kwargs (dict): Dict with the entries "database", "path" and "fast_logging" + logging (bool): Whether to actually write to the database. + database (DataBase): Returns: list: List of integers with the step ids. @@ -24,14 +25,14 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): append_row( data=data, table_name="steps", - **db_kwargs, + database=database, ) step_ids = read_last_rows( table_name="steps", n_rows=len(steps), return_type="dict_of_lists", - **db_kwargs, + database=database, )["rowid"] else: step_ids = list(range(len(steps))) @@ -39,7 +40,7 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): return step_ids -def update_step_status(step, new_status, db_kwargs): +def update_step_status(step, new_status, database): step = int(step) assert new_status in ["scheduled", "running", "complete", "skipped"] @@ -48,5 +49,5 @@ def update_step_status(step, new_status, db_kwargs): data={"status": new_status}, rowid=step, table_name="steps", - **db_kwargs, + database=database, ) diff --git a/src/estimagic/optimization/optimize.py b/src/estimagic/optimization/optimize.py index b64583924..6bde57eb6 100644 --- a/src/estimagic/optimization/optimize.py +++ b/src/estimagic/optimization/optimize.py @@ -4,13 +4,13 @@ from estimagic.batch_evaluators import process_batch_evaluator from estimagic.exceptions import InvalidFunctionError, InvalidKwargsError -from estimagic.logging.database_utilities import ( - append_row, - load_database, +from estimagic.logging.create_tables import ( make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, ) +from estimagic.logging.load_database import load_database +from estimagic.logging.write_to_database import append_row from estimagic.optimization.check_arguments import check_optimize_kwargs from estimagic.optimization.error_penalty import get_error_penalty_function from estimagic.optimization.get_algorithm import ( @@ -655,13 +655,8 @@ def _optimize( if logging: problem_data["free_mask"] = internal_params.free_mask database = _create_and_initialize_database(logging, log_options, problem_data) - db_kwargs = { - "database": database, - "path": logging, - "fast_logging": log_options.get("fast_logging", False), - } else: - db_kwargs = {"database": None, "path": None, "fast_logging": False} + database = None # ================================================================================== # Do some things that require internal parameters or bounds @@ -711,7 +706,7 @@ def _optimize( nonlinear_constraints=internal_constraints, algo_options=algo_options, logging=logging, - db_kwargs=db_kwargs, + database=database, collect_history=collect_history, ) # ================================================================================== @@ -725,7 +720,7 @@ def _optimize( "criterion_and_derivative": criterion_and_derivative, "numdiff_options": numdiff_options, "logging": logging, - "db_kwargs": db_kwargs, + "database": database, "algo_info": algo_info, "error_handling": error_handling, "error_penalty_func": error_penalty_func, @@ -753,7 +748,7 @@ def _optimize( step_ids = log_scheduled_steps_and_get_ids( steps=steps, logging=logging, - db_kwargs=db_kwargs, + database=database, ) raw_res = internal_algorithm(**problem_functions, x=x, step_id=step_ids[0]) @@ -774,7 +769,7 @@ def _optimize( upper_sampling_bounds=internal_params.soft_upper_bounds, options=multistart_options, logging=logging, - db_kwargs=db_kwargs, + database=database, error_handling=error_handling, ) @@ -825,7 +820,7 @@ def _create_and_initialize_database(logging, log_options, problem_data): elif if_database_exists == "replace": logging.unlink() - database = load_database(path=path, fast_logging=fast_logging) + database = load_database(path_or_database=path, fast_logging=fast_logging) # create the optimization_iterations table make_optimization_iteration_table( @@ -852,7 +847,7 @@ def _create_and_initialize_database(logging, log_options, problem_data): key: val for key, val in problem_data.items() if key not in not_saved } - append_row(problem_data, "optimization_problem", database, path, fast_logging) + append_row(problem_data, "optimization_problem", database=database) return database diff --git a/src/estimagic/optimization/simopt_optimizers.py b/src/estimagic/optimization/simopt_optimizers.py index 741e5610a..09268e877 100644 --- a/src/estimagic/optimization/simopt_optimizers.py +++ b/src/estimagic/optimization/simopt_optimizers.py @@ -8,7 +8,7 @@ from estimagic.config import IS_SIMOPT_INSTALLED from estimagic.decorators import mark_minimizer -from estimagic.logging.database_utilities import list_of_dicts_to_dict_of_lists +from estimagic.logging.read_from_database import list_of_dicts_to_dict_of_lists from estimagic.optimization.algo_options import ( STOPPING_MAX_CRITERION_EVALUATIONS_GLOBAL, ) diff --git a/src/estimagic/optimization/tiktak.py b/src/estimagic/optimization/tiktak.py index 97a442deb..69d0f862d 100644 --- a/src/estimagic/optimization/tiktak.py +++ b/src/estimagic/optimization/tiktak.py @@ -46,7 +46,7 @@ def run_multistart_optimization( upper_sampling_bounds, options, logging, - db_kwargs, + database, error_handling, ): steps = determine_steps(options["n_samples"], options["n_optimizations"]) @@ -54,7 +54,7 @@ def run_multistart_optimization( scheduled_steps = log_scheduled_steps_and_get_ids( steps=steps, logging=logging, - db_kwargs=db_kwargs, + database=database, ) if options["sample"] is not None: @@ -77,7 +77,7 @@ def run_multistart_optimization( update_step_status( step=scheduled_steps[0], new_status="running", - db_kwargs=db_kwargs, + database=database, ) if "criterion" in problem_functions: @@ -99,7 +99,7 @@ def run_multistart_optimization( update_step_status( step=scheduled_steps[0], new_status="complete", - db_kwargs=db_kwargs, + database=database, ) scheduled_steps = scheduled_steps[1:] @@ -124,7 +124,7 @@ def run_multistart_optimization( update_step_status( step=step, new_status="skipped", - db_kwargs=db_kwargs, + database=database, ) batched_sample = get_batched_optimization_sample( @@ -194,7 +194,7 @@ def run_multistart_optimization( update_step_status( step=step, new_status="skipped", - db_kwargs=db_kwargs, + database=database, ) break diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index 13878bd33..4651593e3 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -2,18 +2,18 @@ import numpy as np import pytest -import sqlalchemy -from estimagic.logging.database_utilities import ( - append_row, - load_database, +from estimagic.logging.create_tables import ( make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, +) +from estimagic.logging.load_database import DataBase, load_database +from estimagic.logging.read_from_database import ( read_last_rows, read_new_rows, read_table, - update_row, ) +from estimagic.logging.write_to_database import append_row, update_row from numpy.testing import assert_array_equal @@ -44,9 +44,10 @@ def problem_data(): def test_load_database_from_path(tmp_path): """Test that database is generated because it does not exist.""" path = tmp_path / "test.db" - database = load_database(path=path) - assert isinstance(database, sqlalchemy.MetaData) - assert database.bind is not None + database = load_database(path_or_database=path, fast_logging=False) + assert isinstance(database, DataBase) + assert database.path is not None + assert database.fast_logging is False def test_load_database_after_pickling(tmp_path): @@ -56,25 +57,16 @@ def test_load_database_after_pickling(tmp_path): """ path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path, fast_logging=False) database = pickle.loads(pickle.dumps(database)) - database = load_database(metadata=database, path=path) - assert database.bind is not None - - -def test_load_database_with_bound_metadata(tmp_path): - """Test that nothing happens when load_database is called with bound MetaData.""" - path = tmp_path / "test.db" - database = load_database(path=path) - new_database = load_database(metadata=database) - assert new_database is database + assert hasattr(database.engine, "connect") def test_optimization_iteration_table_scalar(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows(database, "optimization_iterations", 1, "list_of_dicts") assert isinstance(res, list) assert isinstance(res[0], dict) @@ -88,7 +80,7 @@ def test_optimization_iteration_table_scalar(tmp_path, iteration_data): def test_steps_table(tmp_path): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_steps_table(database) for status in ["scheduled", "running", "completed"]: append_row( @@ -100,8 +92,6 @@ def test_steps_table(tmp_path): }, "steps", database, - path, - False, ) res, _ = read_new_rows(database, "steps", 1, "dict_of_lists") @@ -118,9 +108,9 @@ def test_steps_table(tmp_path): def test_optimization_problem_table(tmp_path, problem_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_problem_table(database) - append_row(problem_data, "optimization_problem", database, path, False) + append_row(problem_data, "optimization_problem", database) res = read_last_rows(database, "optimization_problem", 1, "list_of_dicts")[0] assert res["rowid"] == 1 for key, expected in problem_data.items(): @@ -134,11 +124,11 @@ def test_optimization_problem_table(tmp_path, problem_data): def test_read_new_rows_stride(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_new_rows( database=database, @@ -154,13 +144,13 @@ def test_read_new_rows_stride(tmp_path, iteration_data): def test_update_row(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) - update_row({"value": 20}, 8, "optimization_iterations", database, path, False) + update_row({"value": 20}, 8, "optimization_iterations", database) res = read_new_rows( database=database, @@ -175,11 +165,11 @@ def test_update_row(tmp_path, iteration_data): def test_read_last_rows_stride(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows( database=database, @@ -195,12 +185,12 @@ def test_read_last_rows_stride(tmp_path, iteration_data): def test_read_new_rows_with_step(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res, _ = read_new_rows( database=database, @@ -216,12 +206,12 @@ def test_read_new_rows_with_step(tmp_path, iteration_data): def test_read_last_rows_with_step(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows( database=database, @@ -237,12 +227,12 @@ def test_read_last_rows_with_step(tmp_path, iteration_data): def test_read_table(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) table = read_table( database=database, diff --git a/tests/optimization/test_internal_criterion_and_derivative_template.py b/tests/optimization/test_internal_criterion_and_derivative_template.py index 7822a5751..ddb7a0866 100644 --- a/tests/optimization/test_internal_criterion_and_derivative_template.py +++ b/tests/optimization/test_internal_criterion_and_derivative_template.py @@ -47,7 +47,7 @@ def base_inputs(): "error_handling": "raise", "numdiff_options": {}, "logging": False, - "db_kwargs": {"database": False, "fast_logging": False, "path": "logging.db"}, + "database": None, "error_penalty_func": None, "fixed_log_data": {"stage": "optimization", "substage": 0}, } diff --git a/tests/optimization/test_multistart.py b/tests/optimization/test_multistart.py index d7ded9339..145bbff59 100644 --- a/tests/optimization/test_multistart.py +++ b/tests/optimization/test_multistart.py @@ -8,7 +8,8 @@ sos_dict_criterion, sos_scalar_criterion, ) -from estimagic.logging.database_utilities import load_database, read_new_rows +from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import read_new_rows from estimagic.logging.read_log import read_steps_table from estimagic.optimization.optimize import maximize, minimize from estimagic.optimization.optimize_result import OptimizeResult @@ -127,7 +128,7 @@ def test_all_steps_occur_in_optimization_iterations_if_no_convergence(params): logging="logging.db", ) - database = load_database(path="logging.db") + database = load_database(path_or_database="logging.db") iterations, _ = read_new_rows( database=database, table_name="optimization_iterations",