From c4f915f23384d8ed6d890afd8b60f9c2e723a6cb Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Mon, 6 Feb 2023 19:27:03 +0100 Subject: [PATCH 1/7] Make logging compatible with sqlalchemy 2.0. --- .envs/testenv-linux.yml | 2 +- .envs/testenv-others.yml | 2 +- environment.yml | 2 +- src/estimagic/dashboard/dashboard_app.py | 2 +- src/estimagic/logging/database_utilities.py | 470 +++++++++--------- src/estimagic/logging/read_log.py | 33 +- src/estimagic/optimization/get_algorithm.py | 12 +- .../internal_criterion_template.py | 10 +- .../optimization/optimization_logging.py | 14 +- src/estimagic/optimization/optimize.py | 19 +- src/estimagic/optimization/tiktak.py | 12 +- tests/logging/test_database_utilities.py | 61 +-- ...ernal_criterion_and_derivative_template.py | 2 +- tests/optimization/test_multistart.py | 2 +- 14 files changed, 303 insertions(+), 340 deletions(-) diff --git a/.envs/testenv-linux.yml b/.envs/testenv-linux.yml index 29e481681..dfafb1fa2 100644 --- a/.envs/testenv-linux.yml +++ b/.envs/testenv-linux.yml @@ -26,7 +26,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pip: # dev, tests, docs diff --git a/.envs/testenv-others.yml b/.envs/testenv-others.yml index 66efcff60..96b5d7c96 100644 --- a/.envs/testenv-others.yml +++ b/.envs/testenv-others.yml @@ -25,7 +25,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pip: # dev, tests, docs diff --git a/environment.yml b/environment.yml index 0c1c3bf52..09c02e14d 100644 --- a/environment.yml +++ b/environment.yml @@ -32,7 +32,7 @@ dependencies: - plotly # run, tests - pybaum >= 0.1.2 # run, tests - scipy>=1.2.1 # run, tests - - sqlalchemy <2.0 # run, tests + - sqlalchemy # run, tests - pydata-sphinx-theme>=0.3.0 # docs - myst-parser # docs diff --git a/src/estimagic/dashboard/dashboard_app.py b/src/estimagic/dashboard/dashboard_app.py index c5b9ec254..c43dec3db 100644 --- a/src/estimagic/dashboard/dashboard_app.py +++ b/src/estimagic/dashboard/dashboard_app.py @@ -42,7 +42,7 @@ def dashboard_app( doc.template = env.get_template("index.html") # process inputs - database = load_database(path=session_data["database_path"]) + database = load_database(path_or_database=session_data["database_path"]) start_point = _calculate_start_point(database, updating_options) session_data["last_retrieved"] = start_point diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/database_utilities.py index eccb394cc..bebc44a47 100644 --- a/src/estimagic/logging/database_utilities.py +++ b/src/estimagic/logging/database_utilities.py @@ -12,86 +12,115 @@ import io import traceback import warnings -from pathlib import Path import cloudpickle import pandas as pd -from sqlalchemy import ( - BLOB, - Boolean, - Column, - Float, - Integer, - MetaData, - PickleType, - String, - Table, - and_, - create_engine, - event, - update, -) +import sqlalchemy as sql from estimagic.exceptions import TableExistsError, get_traceback -def load_database(metadata=None, path=None, fast_logging=False): - """Return a bound sqlalchemy.MetaData object for the database stored in ``path``. +class DataBase: + """Class containing everything to work with a logging database. - This is the only acceptable way of creating or loading databases in estimagic! + Importantly, the class is pickle-serializable which is important to share it across + multiple processes. - If metadata is a bound MetaData object, it is just returned. If metadata is given - but not bound, we bind it to an engine that connects to the database stored under - ``path``. If only the path is provided, we generate an appropriate MetaData object - and bind it to the database. + """ + + def __init__(self, metadata, path, fast_logging, engine=None): + self.metadata = metadata + self.path = path + self.fast_logging = fast_logging + if isinstance(engine, sql.Engine): + self.engine = engine + else: + self.engine = _create_engine(path, fast_logging) + + def __reduce__(self): + return (DataBase, (self.metadata, self.path, self.fast_logging)) - For speed reasons we do not make any checks that MetaData is compatible with the - database stored under path. + +def load_database(path_or_database, fast_logging=False): + """Load or create a database from a path and configure it for our needs. + + This is the only acceptable way of loading or creating a database in estimagic! Args: - metadata (sqlalchemy.MetaData): MetaData object that might or might not be - bound to the database under path. In any case it needs to be compatible - with the database stored under ``path``. For speed reasons, this is not - checked. - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. + path (str or pathlib.Path): Path to the database. + fast_logging (bool): If True, use unsafe optimizations to speed up the logging. + If False, only use ultra safe optimizations. Returns: - metadata (sqlalchemy.MetaData). MetaData object that is bound to the database - under ``path``. + database (Database): Object containing everything to work with the + database. """ - path = Path(path) if isinstance(path, str) else path - - if isinstance(metadata, MetaData): - if metadata.bind is None: - assert ( - path is not None - ), "If metadata is not bound, you need to provide a path." - engine = create_engine(f"sqlite:///{path}") - _configure_engine(engine, fast_logging) - metadata.bind = engine - elif metadata is None: - assert path is not None, "If metadata is None you need to provide a path." - path_existed = path.exists() - engine = create_engine(f"sqlite:///{path}") - _configure_engine(engine, fast_logging) - metadata = MetaData() - metadata.bind = engine - if path_existed: - _configure_reflect() - metadata.reflect() + if isinstance(path_or_database, DataBase): + out = path_or_database else: - raise ValueError("metadata must be sqlalchemy.MetaData or None.") + engine = _create_engine(path_or_database, fast_logging) + metadata = sql.MetaData() + _configure_reflect() + metadata.reflect(engine) + + out = DataBase( + metadata=metadata, + path=path_or_database, + fast_logging=fast_logging, + engine=engine, + ) + return out - return metadata + +def _create_engine(path, fast_logging): + engine = sql.create_engine(f"sqlite:///{path}") + _configure_engine(engine, fast_logging) + return engine + + +def _configure_engine(engine, fast_logging): + """Configure the sqlite engine. + + The two functions that configure the emission of the begin statement are taken from + the sqlalchemy documentation the documentation: https://tinyurl.com/u9xea5z and are + the recommended way of working around a bug in the pysqlite driver. + + The other function speeds up the write process. If fast_logging is False, it does so + using only completely safe optimizations. Of fast_logging is True, it also uses + unsafe optimizations. + + """ + + @sql.event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): # noqa: ARG001 + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before absolutely necessary. + dbapi_connection.isolation_level = None + + @sql.event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN DEFERRED") + + @sql.event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode = WAL") + if fast_logging: + cursor.execute("PRAGMA synchronous = OFF") + else: + cursor.execute("PRAGMA synchronous = NORMAL") + cursor.close() def make_optimization_iteration_table(database, if_exists="extend"): """Generate a table for information that is generated with each function evaluation. Args: - database (sqlalchemy.MetaData): Bound metadata object. + database (DataBase): Bound metadata object. + if_exists (str): What to do if the table already exists. Can be "extend", + "replace" or "raise". Returns: database (sqlalchemy.MetaData):Bound metadata object with added table. @@ -101,39 +130,117 @@ def make_optimization_iteration_table(database, if_exists="extend"): _handle_existing_table(database, "optimization_iterations", if_exists) columns = [ - Column("rowid", Integer, primary_key=True), - Column("params", PickleType(pickler=RobustPickler)), - Column("internal_derivative", PickleType(pickler=RobustPickler)), - Column("timestamp", Float), - Column("exceptions", String), - Column("valid", Boolean), - Column("hash", String), - Column("value", Float), - Column("step", Integer), - Column("criterion_eval", PickleType(pickler=RobustPickler)), + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("internal_derivative", sql.PickleType(pickler=RobustPickler)), + sql.Column("timestamp", sql.Float), + sql.Column("exceptions", sql.String), + sql.Column("valid", sql.Boolean), + sql.Column("hash", sql.String), + sql.Column("value", sql.Float), + sql.Column("step", sql.Integer), + sql.Column("criterion_eval", sql.PickleType(pickler=RobustPickler)), ] - Table( - table_name, database, *columns, sqlite_autoincrement=True, extend_existing=True + sql.Table( + table_name, + database.metadata, + *columns, + sqlite_autoincrement=True, + extend_existing=True, ) - database.create_all(database.bind) + database.metadata.create_all(database.engine) + + +def _handle_existing_table(database, table_name, if_exists): + assert if_exists in ["replace", "extend", "raise"] + + if table_name in database.metadata.tables: + if if_exists == "replace": + database.metadata.tables[table_name].drop(database.engine) + elif if_exists == "raise": + raise TableExistsError(f"The table {table_name} already exists.") + + +def _configure_reflect(): + """Mark all BLOB dtypes as PickleType with our custom pickle reader. + + Code ist taken from the documentation: https://tinyurl.com/y7q287jr + + """ + + @sql.event.listens_for(sql.Table, "column_reflect") + def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 + if isinstance(column_info["type"], sql.BLOB): + column_info["type"] = sql.PickleType(pickler=RobustPickler) + + +class RobustPickler: + @staticmethod + def loads( + data, + fix_imports=True, # noqa: ARG004 + encoding="ASCII", # noqa: ARG004 + errors="strict", # noqa: ARG004 + buffers=None, # noqa: ARG004 + ): + """Robust pickle loading. + + We first try to unpickle the object with pd.read_pickle. This makes no + difference for non-pandas objects but makes the de-serialization + of pandas objects more robust across pandas versions. If that fails, we use + cloudpickle. If that fails, we return None but do not raise an error. + + See: https://github.com/pandas-dev/pandas/issues/16474 + + """ + try: + res = pd.read_pickle(io.BytesIO(data), compression=None) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + try: + res = cloudpickle.loads(data) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + res = None + tb = get_traceback() + warnings.warn( + f"Unable to read PickleType column from database:\n{tb}\n " + "The entry was replaced by None." + ) + + return res + + @staticmethod + def dumps( + obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 + ): + return cloudpickle.dumps(obj, protocol=protocol) def make_steps_table(database, if_exists="extend"): table_name = "steps" _handle_existing_table(database, table_name, if_exists) columns = [ - Column("rowid", Integer, primary_key=True), - Column("type", String), # e.g. optimization - Column("status", String), # e.g. running - Column("n_iterations", Integer), # optional - Column("name", String), # e.g. "optimization-1", "exploration", not unique + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("type", sql.String), # e.g. optimization + sql.Column("status", sql.String), # e.g. running + sql.Column("n_iterations", sql.Integer), # optional + sql.Column( + "name", sql.String + ), # e.g. "optimization-1", "exploration", not unique ] - Table( - table_name, database, *columns, extend_existing=True, sqlite_autoincrement=True + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, ) - database.create_all(database.bind) + database.metadata.create_all(database.engine) def make_optimization_problem_table(database, if_exists="extend"): @@ -141,73 +248,60 @@ def make_optimization_problem_table(database, if_exists="extend"): _handle_existing_table(database, table_name, if_exists) columns = [ - Column("rowid", Integer, primary_key=True), - Column("direction", String), - Column("params", PickleType(pickler=RobustPickler)), - Column("algorithm", PickleType(pickler=RobustPickler)), - Column("algo_options", PickleType(pickler=RobustPickler)), - Column("numdiff_options", PickleType(pickler=RobustPickler)), - Column("log_options", PickleType(pickler=RobustPickler)), - Column("error_handling", String), - Column("error_penalty", PickleType(pickler=RobustPickler)), - Column("constraints", PickleType(pickler=RobustPickler)), - Column("free_mask", PickleType(pickler=RobustPickler)), + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("direction", sql.String), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("algorithm", sql.PickleType(pickler=RobustPickler)), + sql.Column("algo_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("numdiff_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("log_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("error_handling", sql.String), + sql.Column("error_penalty", sql.PickleType(pickler=RobustPickler)), + sql.Column("constraints", sql.PickleType(pickler=RobustPickler)), + sql.Column("free_mask", sql.PickleType(pickler=RobustPickler)), ] - Table( - table_name, database, *columns, extend_existing=True, sqlite_autoincrement=True + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, ) - database.create_all(database.bind) + database.metadata.create_all(database.engine) -def _handle_existing_table(database, table_name, if_exists): - assert if_exists in ["replace", "extend", "raise"] +# ====================================================================================== - if table_name in database.tables: - if if_exists == "replace": - database.tables[table_name].drop(database.bind) - elif if_exists == "raise": - raise TableExistsError(f"The table {table_name} already exists.") +def update_row(data, rowid, table_name, database): + table = database.metadata.tables[table_name] + stmt = sql.update(table).where(table.c.rowid == rowid).values(**data) -def update_row(data, rowid, table_name, database, path, fast_logging): - database = load_database(database, path, fast_logging) + _execute_write_statement(stmt, database) - table = database.tables[table_name] - stmt = update(table).where(table.c.rowid == rowid).values(**data) - _execute_write_statement(stmt, database, path, table_name, data) - - -def append_row(data, table_name, database, path, fast_logging): +def append_row(data, table_name, database): """ Args: data (dict): The keys correspond to columns in the database table. table_name (str): Name of the database table to which the row is added. - database (sqlalchemy.MetaData): Sqlachlemy metadata object of the database. - path (str or pathlib.Path): Path to the database file. Using a path is much - slower than a MetaData object and we advise to only use it as a fallback. - fast_logging (bool) + database (DataBase): The database to which the row is added. """ - # this is necessary because database.bind gets lost when the database is pickled. - # it has no cost when database.bind is set. - database = load_database(database, path, fast_logging) - stmt = database.tables[table_name].insert().values(**data) + stmt = database.metadata.tables[table_name].insert().values(**data) - _execute_write_statement(stmt, database, path, table_name, data) + _execute_write_statement(stmt, database) -def _execute_write_statement( - statement, database, path, table_name, data # noqa: ARG001 -): +def _execute_write_statement(statement, database): try: # this will automatically roll back the transaction if any exception is raised # and then raise the exception - with database.bind.begin() as connection: + with database.engine.begin() as connection: connection.execute(statement) except (KeyboardInterrupt, SystemExit): raise @@ -223,8 +317,6 @@ def read_new_rows( table_name, last_retrieved, return_type, - path=None, - fast_logging=False, limit=None, stride=1, step=None, @@ -232,7 +324,7 @@ def read_new_rows( """Read all iterations after last_retrieved up to a limit. Args: - database (sqlalchemy.MetaData) + database (DataBase) table_name (str): name of the table to retrieve. last_retrieved (int): The last iteration that was retrieved. return_type (str): either "list_of_dicts" or "dict_of_lists". @@ -250,11 +342,10 @@ def read_new_rows( int: The new last_retrieved value. """ - database = load_database(database, path, fast_logging) last_retrieved = int(last_retrieved) limit = int(limit) if limit is not None else limit - table = database.tables[table_name] + table = database.metadata.tables[table_name] stmt = table.select().where(table.c.rowid > last_retrieved).limit(limit) conditions = [table.c.rowid > last_retrieved] @@ -264,7 +355,7 @@ def read_new_rows( if step is not None: conditions.append(table.c.step == int(step)) - stmt = table.select().where(and_(*conditions)).limit(limit) + stmt = table.select().where(sql.and_(*conditions)).limit(limit) data = _execute_read_statement(database, table_name, stmt, return_type) @@ -281,8 +372,6 @@ def read_last_rows( table_name, n_rows, return_type, - path=None, - fast_logging=False, stride=1, step=None, ): @@ -291,14 +380,10 @@ def read_last_rows( If a table has less than n_rows rows, the whole table is returned. Args: - database (sqlalchemy.MetaData) + database (DataBase) table_name (str): name of the table to retrieve. n_rows (int): number of rows to retrieve. return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. Using a path is much slower than a - MetaData object and we advise to only use it as a fallback. - fast_logging (bool) stride (int): Only return every n-th row. Default is every row (stride=1). step (int): Only return rows that belong to step. @@ -306,10 +391,9 @@ def read_last_rows( result (return_type): the last rows of the `table_name` table as `return_type`. """ - database = load_database(database, path, fast_logging) n_rows = int(n_rows) - table = database.tables[table_name] + table = database.metadata.tables[table_name] conditions = [] @@ -323,7 +407,7 @@ def read_last_rows( stmt = ( table.select() .order_by(table.c.rowid.desc()) - .where(and_(*conditions)) + .where(sql.and_(*conditions)) .limit(n_rows) ) else: @@ -338,9 +422,7 @@ def read_last_rows( return out -def read_specific_row( - database, table_name, rowid, return_type, path=None, fast_logging=False -): +def read_specific_row(database, table_name, rowid, return_type): """Read a specific row from a table. Args: @@ -348,26 +430,20 @@ def read_specific_row( table_name (str): name of the table to retrieve. n_rows (int): number of rows to retrieve. return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. - Using a path is much slower than a MetaData object and we - advise to only use it as a fallback. - fast_logging (bool) Returns: dict or list: The requested row from the database. """ - database = load_database(database, path, fast_logging) rowid = int(rowid) - table = database.tables[table_name] + table = database.metadata.tables[table_name] stmt = table.select().where(table.c.rowid == rowid) data = _execute_read_statement(database, table_name, stmt, return_type) return data -def read_table(database, table_name, return_type, path=None, fast_logging=False): - database = load_database(database, path, fast_logging) - table = database.tables[table_name] +def read_table(database, table_name, return_type): + table = database.metadata.tables[table_name] stmt = table.select() data = _execute_read_statement(database, table_name, stmt, return_type) return data @@ -375,7 +451,7 @@ def read_table(database, table_name, return_type, path=None, fast_logging=False) def _execute_read_statement(database, table_name, statement, return_type): try: - with database.bind.begin() as connection: + with database.engine.begin() as connection: raw_result = list(connection.execute(statement)) except (KeyboardInterrupt, SystemExit): raise @@ -388,7 +464,7 @@ def _execute_read_statement(database, table_name, statement, return_type): # if we only want to warn we must provide a raw_result to be processed below. raw_result = [] - columns = database.tables[table_name].columns.keys() + columns = database.metadata.tables[table_name].columns.keys() if return_type == "list_of_dicts": result = [dict(zip(columns, row)) for row in raw_result] @@ -407,6 +483,9 @@ def _execute_read_statement(database, table_name, statement, return_type): return result +# ====================================================================================== + + def transpose_nested_list(nested_list): """Transpose a list of lists. @@ -457,96 +536,3 @@ def dict_of_lists_to_list_of_dicts(dict_of_lists): """ return [dict(zip(dict_of_lists, t)) for t in zip(*dict_of_lists.values())] - - -def _configure_engine(engine, fast_logging): - """Configure the sqlite engine. - - The two functions that configure the emission of the begin statement are taken from - the sqlalchemy documentation the documentation: https://tinyurl.com/u9xea5z and are - the recommended way of working around a bug in the pysqlite driver. - - The other function speeds up the write process. If fast_logging is False, it does so - using only completely safe optimizations. Of fast_logging is True, it also uses - unsafe optimizations. - - """ - - @event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): # noqa: ARG001 - # disable pysqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before absolutely necessary. - dbapi_connection.isolation_level = None - - @event.listens_for(engine, "begin") - def do_begin(conn): - # emit our own BEGIN - conn.execute("BEGIN DEFERRED") - - @event.listens_for(engine, "connect") - def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA journal_mode = WAL") - if fast_logging: - cursor.execute("PRAGMA synchronous = OFF") - else: - cursor.execute("PRAGMA synchronous = NORMAL") - cursor.close() - - -def _configure_reflect(): - """Mark all BLOB dtypes as PickleType with our custom pickle reader. - - Code ist taken from the documentation: https://tinyurl.com/y7q287jr - - """ - - @event.listens_for(Table, "column_reflect") - def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 - if isinstance(column_info["type"], BLOB): - column_info["type"] = PickleType(pickler=RobustPickler) - - -class RobustPickler: - @staticmethod - def loads( - data, - fix_imports=True, # noqa: ARG004 - encoding="ASCII", # noqa: ARG004 - errors="strict", # noqa: ARG004 - buffers=None, # noqa: ARG004 - ): - """Robust pickle loading. - - We first try to unpickle the object with pd.read_pickle. This makes no - difference for non-pandas objects but makes the de-serialization - of pandas objects more robust across pandas versions. If that fails, we use - cloudpickle. If that fails, we return None but do not raise an error. - - See: https://github.com/pandas-dev/pandas/issues/16474 - - """ - try: - res = pd.read_pickle(io.BytesIO(data), compression=None) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - try: - res = cloudpickle.loads(data) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - res = None - tb = get_traceback() - warnings.warn( - f"Unable to read PickleType column from database:\n{tb}\n " - "The entry was replaced by None." - ) - - return res - - @staticmethod - def dumps( - obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 - ): - return cloudpickle.dumps(obj, protocol=protocol) diff --git a/src/estimagic/logging/read_log.py b/src/estimagic/logging/read_log.py index 8b22e5366..9297ec6d7 100644 --- a/src/estimagic/logging/read_log.py +++ b/src/estimagic/logging/read_log.py @@ -14,7 +14,6 @@ import numpy as np import pandas as pd from pybaum import tree_flatten, tree_unflatten -from sqlalchemy import MetaData from estimagic.logging.database_utilities import ( load_database, @@ -25,6 +24,14 @@ from estimagic.parameters.tree_registry import get_registry +def load_existing_database(path_or_database): + if isinstance(path_or_database, (Path, str)): + path = Path(path_or_database) + if not path.exists(): + raise FileNotFoundError(f"Database {path} does not exist.") + return load_database(path_or_database) + + def read_start_params(path_or_database): """Load the start parameters DataFrame. @@ -35,7 +42,7 @@ def read_start_params(path_or_database): params (pd.DataFrame): see :ref:`params`. """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) optimization_problem = read_last_rows( database=database, table_name="optimization_problem", @@ -46,22 +53,6 @@ def read_start_params(path_or_database): return start_params -def _load_database(path_or_database): - """Get an sqlalchemy.MetaDate object from path or database.""" - - res = {"path": None, "metadata": None, "fast_logging": False} - if isinstance(path_or_database, MetaData): - res = path_or_database - elif isinstance(path_or_database, (Path, str)): - path = Path(path_or_database) - if not path.exists(): - raise FileNotFoundError(f"No such database file: {path}") - res = load_database(path=path) - else: - raise TypeError("path_or_database must be a path or sqlalchemy.MetaData object") - return res - - def read_steps_table(path_or_database): """Load the steps table. @@ -72,7 +63,7 @@ def read_steps_table(path_or_database): steps_df (pandas.DataFrame) """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) steps_table, _ = read_new_rows( database=database, table_name="steps", @@ -94,7 +85,7 @@ def read_optimization_problem_table(path_or_database): params (pd.DataFrame): see :ref:`params`. """ - database = _load_database(path_or_database) + database = load_existing_database(path_or_database) steps_table, _ = read_new_rows( database=database, table_name="optimization_problem", @@ -113,7 +104,7 @@ class OptimizeLogReader: path: Union[str, Path] def __post_init__(self): - _database = _load_database(self.path) + _database = load_existing_database(self.path) _start_params = read_start_params(_database) _registry = get_registry(extended=True) _, _treedef = tree_flatten(_start_params, registry=_registry) diff --git a/src/estimagic/optimization/get_algorithm.py b/src/estimagic/optimization/get_algorithm.py index 395957e02..ffb2df516 100644 --- a/src/estimagic/optimization/get_algorithm.py +++ b/src/estimagic/optimization/get_algorithm.py @@ -55,7 +55,7 @@ def get_final_algorithm( nonlinear_constraints, algo_options, logging, - db_kwargs, + database, collect_history, ): """Get algorithm-function with partialled options. @@ -74,7 +74,7 @@ def get_final_algorithm( algorithm. Entries that are not used by the algorithm are ignored with a warning. logging (bool): Whether the algorithm should do logging. - db_kwargs (dict): Dict with the entries "database", "path" and "fast_logging" + database (DataBase): Database to which the logging should be written. Returns: callable: The algorithm. @@ -96,7 +96,7 @@ def get_final_algorithm( algorithm = _add_logging( algorithm, logging=logging, - db_kwargs=db_kwargs, + database=database, ) is_parallel = internal_options.get("n_cores") not in (None, 1) @@ -110,7 +110,7 @@ def get_final_algorithm( return algorithm -def _add_logging(algorithm=None, *, logging=None, db_kwargs=None): +def _add_logging(algorithm=None, *, logging=None, database=None): """Add logging of status to the algorithm.""" def decorator_add_logging_to_algorithm(algorithm): @@ -125,7 +125,7 @@ def wrapper_add_logging_algorithm(**kwargs): data={"status": "running"}, rowid=step_id, table_name="steps", - **db_kwargs, + database=database, ) for task in ["criterion", "derivative", "criterion_and_derivative"]: @@ -141,7 +141,7 @@ def wrapper_add_logging_algorithm(**kwargs): data={"status": "complete"}, rowid=step_id, table_name="steps", - **db_kwargs, + database=database, ) return res diff --git a/src/estimagic/optimization/internal_criterion_template.py b/src/estimagic/optimization/internal_criterion_template.py index 7d87c2354..754d23a3d 100644 --- a/src/estimagic/optimization/internal_criterion_template.py +++ b/src/estimagic/optimization/internal_criterion_template.py @@ -19,7 +19,7 @@ def internal_criterion_and_derivative_template( criterion_and_derivative, numdiff_options, logging, - db_kwargs, + database, error_handling, error_penalty_func, fixed_log_data, @@ -68,7 +68,7 @@ def internal_criterion_and_derivative_template( derivatives. See :ref:`first_derivative` for details. Note that the default method is changed to "forward" for speed reasons. logging (bool): Whether logging is used. - db_kwargs (dict): Dictionary with entries "database", "path" and "fast_logging". + database (DataBase): Database to which the logs are written. error_handling (str): Either "raise" or "continue". Note that "continue" does not absolutely guarantee that no error is raised but we try to handle as many errors as possible in that case without aborting the optimization. @@ -233,7 +233,7 @@ def func(x): new_derivative=new_derivative, external_x=external_x, caught_exceptions=caught_exceptions, - db_kwargs=db_kwargs, + database=database, fixed_log_data=fixed_log_data, scalar_value=scalar_critval, now=now, @@ -312,7 +312,7 @@ def _log_new_evaluations( new_derivative, external_x, caught_exceptions, - db_kwargs, + database, fixed_log_data, scalar_value, now, @@ -343,7 +343,7 @@ def _log_new_evaluations( name = "optimization_iterations" - append_row(data, name, **db_kwargs) + append_row(data, name, database=database) def _get_output_for_optimizer( diff --git a/src/estimagic/optimization/optimization_logging.py b/src/estimagic/optimization/optimization_logging.py index b634dc307..d5a123cb5 100644 --- a/src/estimagic/optimization/optimization_logging.py +++ b/src/estimagic/optimization/optimization_logging.py @@ -1,7 +1,7 @@ from estimagic.logging.database_utilities import append_row, read_last_rows, update_row -def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): +def log_scheduled_steps_and_get_ids(steps, logging, database): """Add scheduled steps to the steps table of the database and get their ids. The ids are only determined once the steps are written to the database and the @@ -9,8 +9,8 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): Args: steps (list): List of dicts with entries for the steps table. - logging (bool): Whether to actually write to the databes. - db_kwargs (dict): Dict with the entries "database", "path" and "fast_logging" + logging (bool): Whether to actually write to the database. + database (DataBase): Returns: list: List of integers with the step ids. @@ -24,14 +24,14 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): append_row( data=data, table_name="steps", - **db_kwargs, + database=database, ) step_ids = read_last_rows( table_name="steps", n_rows=len(steps), return_type="dict_of_lists", - **db_kwargs, + database=database, )["rowid"] else: step_ids = list(range(len(steps))) @@ -39,7 +39,7 @@ def log_scheduled_steps_and_get_ids(steps, logging, db_kwargs): return step_ids -def update_step_status(step, new_status, db_kwargs): +def update_step_status(step, new_status, database): step = int(step) assert new_status in ["scheduled", "running", "complete", "skipped"] @@ -48,5 +48,5 @@ def update_step_status(step, new_status, db_kwargs): data={"status": new_status}, rowid=step, table_name="steps", - **db_kwargs, + database=database, ) diff --git a/src/estimagic/optimization/optimize.py b/src/estimagic/optimization/optimize.py index 74a378e61..6244e25ee 100644 --- a/src/estimagic/optimization/optimize.py +++ b/src/estimagic/optimization/optimize.py @@ -655,13 +655,8 @@ def _optimize( if logging: problem_data["free_mask"] = internal_params.free_mask database = _create_and_initialize_database(logging, log_options, problem_data) - db_kwargs = { - "database": database, - "path": logging, - "fast_logging": log_options.get("fast_logging", False), - } else: - db_kwargs = {"database": None, "path": None, "fast_logging": False} + database = None # ================================================================================== # Do some things that require internal parameters or bounds @@ -711,7 +706,7 @@ def _optimize( nonlinear_constraints=internal_constraints, algo_options=algo_options, logging=logging, - db_kwargs=db_kwargs, + database=database, collect_history=collect_history, ) # ================================================================================== @@ -725,7 +720,7 @@ def _optimize( "criterion_and_derivative": criterion_and_derivative, "numdiff_options": numdiff_options, "logging": logging, - "db_kwargs": db_kwargs, + "database": database, "algo_info": algo_info, "error_handling": error_handling, "error_penalty_func": error_penalty_func, @@ -754,7 +749,7 @@ def _optimize( step_ids = log_scheduled_steps_and_get_ids( steps=steps, logging=logging, - db_kwargs=db_kwargs, + database=database, ) raw_res = internal_algorithm(**problem_functions, x=x, step_id=step_ids[0]) @@ -776,7 +771,7 @@ def _optimize( upper_sampling_bounds=internal_params.soft_upper_bounds, options=multistart_options, logging=logging, - db_kwargs=db_kwargs, + database=database, error_handling=error_handling, ) @@ -827,7 +822,7 @@ def _create_and_initialize_database(logging, log_options, problem_data): elif if_database_exists == "replace": logging.unlink() - database = load_database(path=path, fast_logging=fast_logging) + database = load_database(path_or_database=path, fast_logging=fast_logging) # create the optimization_iterations table make_optimization_iteration_table( @@ -854,7 +849,7 @@ def _create_and_initialize_database(logging, log_options, problem_data): key: val for key, val in problem_data.items() if key not in not_saved } - append_row(problem_data, "optimization_problem", database, path, fast_logging) + append_row(problem_data, "optimization_problem", database=database) return database diff --git a/src/estimagic/optimization/tiktak.py b/src/estimagic/optimization/tiktak.py index 143d9e6de..e1b290228 100644 --- a/src/estimagic/optimization/tiktak.py +++ b/src/estimagic/optimization/tiktak.py @@ -34,7 +34,7 @@ def run_multistart_optimization( upper_sampling_bounds, options, logging, - db_kwargs, + database, error_handling, ): steps = determine_steps(options["n_samples"], options["n_optimizations"]) @@ -42,7 +42,7 @@ def run_multistart_optimization( scheduled_steps = log_scheduled_steps_and_get_ids( steps=steps, logging=logging, - db_kwargs=db_kwargs, + database=database, ) if options["sample"] is not None: @@ -65,7 +65,7 @@ def run_multistart_optimization( update_step_status( step=scheduled_steps[0], new_status="running", - db_kwargs=db_kwargs, + database=database, ) if "criterion" in problem_functions: @@ -87,7 +87,7 @@ def run_multistart_optimization( update_step_status( step=scheduled_steps[0], new_status="complete", - db_kwargs=db_kwargs, + database=database, ) scheduled_steps = scheduled_steps[1:] @@ -112,7 +112,7 @@ def run_multistart_optimization( update_step_status( step=step, new_status="skipped", - db_kwargs=db_kwargs, + database=database, ) batched_sample = get_batched_optimization_sample( @@ -183,7 +183,7 @@ def run_multistart_optimization( update_step_status( step=step, new_status="skipped", - db_kwargs=db_kwargs, + database=database, ) break diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index 13878bd33..7d79d33b5 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -2,8 +2,9 @@ import numpy as np import pytest -import sqlalchemy +import sqlalchemy as sql from estimagic.logging.database_utilities import ( + DataBase, append_row, load_database, make_optimization_iteration_table, @@ -44,9 +45,10 @@ def problem_data(): def test_load_database_from_path(tmp_path): """Test that database is generated because it does not exist.""" path = tmp_path / "test.db" - database = load_database(path=path) - assert isinstance(database, sqlalchemy.MetaData) - assert database.bind is not None + database = load_database(path_or_database=path, fast_logging=False) + assert isinstance(database, DataBase) + assert database.path is not None + assert database.fast_logging is False def test_load_database_after_pickling(tmp_path): @@ -56,25 +58,16 @@ def test_load_database_after_pickling(tmp_path): """ path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path, fast_logging=False) database = pickle.loads(pickle.dumps(database)) - database = load_database(metadata=database, path=path) - assert database.bind is not None - - -def test_load_database_with_bound_metadata(tmp_path): - """Test that nothing happens when load_database is called with bound MetaData.""" - path = tmp_path / "test.db" - database = load_database(path=path) - new_database = load_database(metadata=database) - assert new_database is database + assert isinstance(database.engine, sql.Engine) def test_optimization_iteration_table_scalar(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows(database, "optimization_iterations", 1, "list_of_dicts") assert isinstance(res, list) assert isinstance(res[0], dict) @@ -88,7 +81,7 @@ def test_optimization_iteration_table_scalar(tmp_path, iteration_data): def test_steps_table(tmp_path): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_steps_table(database) for status in ["scheduled", "running", "completed"]: append_row( @@ -100,8 +93,6 @@ def test_steps_table(tmp_path): }, "steps", database, - path, - False, ) res, _ = read_new_rows(database, "steps", 1, "dict_of_lists") @@ -118,9 +109,9 @@ def test_steps_table(tmp_path): def test_optimization_problem_table(tmp_path, problem_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_problem_table(database) - append_row(problem_data, "optimization_problem", database, path, False) + append_row(problem_data, "optimization_problem", database) res = read_last_rows(database, "optimization_problem", 1, "list_of_dicts")[0] assert res["rowid"] == 1 for key, expected in problem_data.items(): @@ -134,11 +125,11 @@ def test_optimization_problem_table(tmp_path, problem_data): def test_read_new_rows_stride(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_new_rows( database=database, @@ -154,13 +145,13 @@ def test_read_new_rows_stride(tmp_path, iteration_data): def test_update_row(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) - update_row({"value": 20}, 8, "optimization_iterations", database, path, False) + update_row({"value": 20}, 8, "optimization_iterations", database) res = read_new_rows( database=database, @@ -175,11 +166,11 @@ def test_update_row(tmp_path, iteration_data): def test_read_last_rows_stride(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows( database=database, @@ -195,12 +186,12 @@ def test_read_last_rows_stride(tmp_path, iteration_data): def test_read_new_rows_with_step(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res, _ = read_new_rows( database=database, @@ -216,12 +207,12 @@ def test_read_new_rows_with_step(tmp_path, iteration_data): def test_read_last_rows_with_step(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) res = read_last_rows( database=database, @@ -237,12 +228,12 @@ def test_read_last_rows_with_step(tmp_path, iteration_data): def test_read_table(tmp_path, iteration_data): path = tmp_path / "test.db" - database = load_database(path=path) + database = load_database(path_or_database=path) make_optimization_iteration_table(database) for i in range(1, 11): # sqlalchemy starts counting at 1 iteration_data["value"] = i iteration_data["step"] = i % 2 - append_row(iteration_data, "optimization_iterations", database, path, False) + append_row(iteration_data, "optimization_iterations", database) table = read_table( database=database, diff --git a/tests/optimization/test_internal_criterion_and_derivative_template.py b/tests/optimization/test_internal_criterion_and_derivative_template.py index 7822a5751..ddb7a0866 100644 --- a/tests/optimization/test_internal_criterion_and_derivative_template.py +++ b/tests/optimization/test_internal_criterion_and_derivative_template.py @@ -47,7 +47,7 @@ def base_inputs(): "error_handling": "raise", "numdiff_options": {}, "logging": False, - "db_kwargs": {"database": False, "fast_logging": False, "path": "logging.db"}, + "database": None, "error_penalty_func": None, "fixed_log_data": {"stage": "optimization", "substage": 0}, } diff --git a/tests/optimization/test_multistart.py b/tests/optimization/test_multistart.py index d7ded9339..cd9328bdc 100644 --- a/tests/optimization/test_multistart.py +++ b/tests/optimization/test_multistart.py @@ -127,7 +127,7 @@ def test_all_steps_occur_in_optimization_iterations_if_no_convergence(params): logging="logging.db", ) - database = load_database(path="logging.db") + database = load_database(path_or_database="logging.db") iterations, _ = read_new_rows( database=database, table_name="optimization_iterations", From 0efbfd2f23df8df1b7b3f9986f9866350df57e8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 18:34:56 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../inference/bootstrap_montecarlo_comparison.ipynb | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/source/explanations/inference/bootstrap_montecarlo_comparison.ipynb b/docs/source/explanations/inference/bootstrap_montecarlo_comparison.ipynb index 60656e8bc..889fb09fa 100644 --- a/docs/source/explanations/inference/bootstrap_montecarlo_comparison.ipynb +++ b/docs/source/explanations/inference/bootstrap_montecarlo_comparison.ipynb @@ -81,9 +81,7 @@ " cluster = []\n", "\n", " for g in range(nclusters):\n", - "\n", " for i in range(nobs_per_cluster):\n", - "\n", " key = (i + 1) * (g + 1) - 1\n", "\n", " arg = (\n", @@ -136,7 +134,6 @@ " \"\"\"\n", "\n", " def logit_wrap(df):\n", - "\n", " y = df[\"y\"]\n", " x = df[\"x\"]\n", "\n", @@ -145,12 +142,10 @@ " return pd.Series(result, index=[\"constant\", \"x\"])\n", "\n", " if cluster is False:\n", - "\n", " result = em.bootstrap(data=data, outcome=logit_wrap, n_draws=sample_size)\n", " estimates = pd.DataFrame(result.outcomes)[\"x\"]\n", "\n", " else:\n", - "\n", " result = em.bootstrap(\n", " data=data,\n", " outcome=logit_wrap,\n", @@ -189,7 +184,6 @@ " np.zeros(nsim)\n", "\n", " def loop():\n", - "\n", " df = create_clustered_data(nclusters, nobs_per_cluster, true_beta)\n", "\n", " return [get_t_values(df), get_t_values(df, cluster=True)]\n", @@ -231,7 +225,6 @@ "results_list = []\n", "\n", "for g, k in [[20, 50], [100, 10], [500, 2]]:\n", - "\n", " results_list.append(monte_carlo(nsim=100, nclusters=g, nobs_per_cluster=k))" ] }, From 6ded77bc99b2cd51ef22729d1dff6c775873a7ea Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 7 Feb 2023 12:50:18 +0100 Subject: [PATCH 3/7] Start to split database_utilities. --- src/estimagic/dashboard/dashboard_app.py | 3 +- src/estimagic/logging/database_utilities.py | 158 +------------------ src/estimagic/logging/load_database.py | 160 ++++++++++++++++++++ src/estimagic/logging/read_log.py | 2 +- src/estimagic/optimization/optimize.py | 2 +- tests/logging/test_database_utilities.py | 3 +- tests/optimization/test_multistart.py | 3 +- 7 files changed, 169 insertions(+), 162 deletions(-) create mode 100644 src/estimagic/logging/load_database.py diff --git a/src/estimagic/dashboard/dashboard_app.py b/src/estimagic/dashboard/dashboard_app.py index c43dec3db..ca3458a53 100644 --- a/src/estimagic/dashboard/dashboard_app.py +++ b/src/estimagic/dashboard/dashboard_app.py @@ -11,7 +11,8 @@ from estimagic.dashboard.callbacks import reset_and_start_convergence from estimagic.dashboard.plot_functions import plot_time_series -from estimagic.logging.database_utilities import load_database, read_last_rows +from estimagic.logging.database_utilities import read_last_rows +from estimagic.logging.load_database import load_database from estimagic.logging.read_log import read_start_params from estimagic.parameters.parameter_groups import get_params_groups_and_short_names from estimagic.parameters.tree_registry import get_registry diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/database_utilities.py index 9b89a464f..390236903 100644 --- a/src/estimagic/logging/database_utilities.py +++ b/src/estimagic/logging/database_utilities.py @@ -9,109 +9,13 @@ ``read_log.py`` instead. """ -import io import traceback import warnings -import cloudpickle -import pandas as pd import sqlalchemy as sql -from estimagic.exceptions import TableExistsError, get_traceback - - -class DataBase: - """Class containing everything to work with a logging database. - - Importantly, the class is pickle-serializable which is important to share it across - multiple processes. - - """ - - def __init__(self, metadata, path, fast_logging, engine=None): - self.metadata = metadata - self.path = path - self.fast_logging = fast_logging - if engine is None: - self.engine = _create_engine(path, fast_logging) - else: - self.engine = engine - - def __reduce__(self): - return (DataBase, (self.metadata, self.path, self.fast_logging)) - - -def load_database(path_or_database, fast_logging=False): - """Load or create a database from a path and configure it for our needs. - - This is the only acceptable way of loading or creating a database in estimagic! - - Args: - path (str or pathlib.Path): Path to the database. - fast_logging (bool): If True, use unsafe optimizations to speed up the logging. - If False, only use ultra safe optimizations. - - Returns: - database (Database): Object containing everything to work with the - database. - - """ - if isinstance(path_or_database, DataBase): - out = path_or_database - else: - engine = _create_engine(path_or_database, fast_logging) - metadata = sql.MetaData() - _configure_reflect() - metadata.reflect(engine) - - out = DataBase( - metadata=metadata, - path=path_or_database, - fast_logging=fast_logging, - engine=engine, - ) - return out - - -def _create_engine(path, fast_logging): - engine = sql.create_engine(f"sqlite:///{path}") - _configure_engine(engine, fast_logging) - return engine - - -def _configure_engine(engine, fast_logging): - """Configure the sqlite engine. - - The two functions that configure the emission of the begin statement are taken from - the sqlalchemy documentation the documentation: https://tinyurl.com/u9xea5z and are - the recommended way of working around a bug in the pysqlite driver. - - The other function speeds up the write process. If fast_logging is False, it does so - using only completely safe optimizations. Of fast_logging is True, it also uses - unsafe optimizations. - - """ - - @sql.event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): # noqa: ARG001 - # disable pysqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before absolutely necessary. - dbapi_connection.isolation_level = None - - @sql.event.listens_for(engine, "begin") - def do_begin(conn): - # emit our own BEGIN - conn.exec_driver_sql("BEGIN DEFERRED") - - @sql.event.listens_for(engine, "connect") - def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA journal_mode = WAL") - if fast_logging: - cursor.execute("PRAGMA synchronous = OFF") - else: - cursor.execute("PRAGMA synchronous = NORMAL") - cursor.close() +from estimagic.exceptions import TableExistsError +from estimagic.logging.load_database import RobustPickler def make_optimization_iteration_table(database, if_exists="extend"): @@ -163,64 +67,6 @@ def _handle_existing_table(database, table_name, if_exists): raise TableExistsError(f"The table {table_name} already exists.") -def _configure_reflect(): - """Mark all BLOB dtypes as PickleType with our custom pickle reader. - - Code ist taken from the documentation: https://tinyurl.com/y7q287jr - - """ - - @sql.event.listens_for(sql.Table, "column_reflect") - def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 - if isinstance(column_info["type"], sql.BLOB): - column_info["type"] = sql.PickleType(pickler=RobustPickler) - - -class RobustPickler: - @staticmethod - def loads( - data, - fix_imports=True, # noqa: ARG004 - encoding="ASCII", # noqa: ARG004 - errors="strict", # noqa: ARG004 - buffers=None, # noqa: ARG004 - ): - """Robust pickle loading. - - We first try to unpickle the object with pd.read_pickle. This makes no - difference for non-pandas objects but makes the de-serialization - of pandas objects more robust across pandas versions. If that fails, we use - cloudpickle. If that fails, we return None but do not raise an error. - - See: https://github.com/pandas-dev/pandas/issues/16474 - - """ - try: - res = pd.read_pickle(io.BytesIO(data), compression=None) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - try: - res = cloudpickle.loads(data) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - res = None - tb = get_traceback() - warnings.warn( - f"Unable to read PickleType column from database:\n{tb}\n " - "The entry was replaced by None." - ) - - return res - - @staticmethod - def dumps( - obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 - ): - return cloudpickle.dumps(obj, protocol=protocol) - - def make_steps_table(database, if_exists="extend"): table_name = "steps" _handle_existing_table(database, table_name, if_exists) diff --git a/src/estimagic/logging/load_database.py b/src/estimagic/logging/load_database.py new file mode 100644 index 000000000..2d77e90aa --- /dev/null +++ b/src/estimagic/logging/load_database.py @@ -0,0 +1,160 @@ +import io +import warnings + +import cloudpickle +import pandas as pd +import sqlalchemy as sql + +from estimagic.exceptions import get_traceback + + +class DataBase: + """Class containing everything to work with a logging database. + + Importantly, the class is pickle-serializable which is important to share it across + multiple processes. + + """ + + def __init__(self, metadata, path, fast_logging, engine=None): + self.metadata = metadata + self.path = path + self.fast_logging = fast_logging + if engine is None: + self.engine = _create_engine(path, fast_logging) + else: + self.engine = engine + + def __reduce__(self): + return (DataBase, (self.metadata, self.path, self.fast_logging)) + + +def load_database(path_or_database, fast_logging=False): + """Load or create a database from a path and configure it for our needs. + + This is the only acceptable way of loading or creating a database in estimagic! + + Args: + path (str or pathlib.Path): Path to the database. + fast_logging (bool): If True, use unsafe optimizations to speed up the logging. + If False, only use ultra safe optimizations. + + Returns: + database (Database): Object containing everything to work with the + database. + + """ + if isinstance(path_or_database, DataBase): + out = path_or_database + else: + engine = _create_engine(path_or_database, fast_logging) + metadata = sql.MetaData() + _configure_reflect() + metadata.reflect(engine) + + out = DataBase( + metadata=metadata, + path=path_or_database, + fast_logging=fast_logging, + engine=engine, + ) + return out + + +def _create_engine(path, fast_logging): + engine = sql.create_engine(f"sqlite:///{path}") + _configure_engine(engine, fast_logging) + return engine + + +def _configure_engine(engine, fast_logging): + """Configure the sqlite engine. + + The two functions that configure the emission of the begin statement are taken from + the sqlalchemy documentation the documentation: https://tinyurl.com/u9xea5z and are + the recommended way of working around a bug in the pysqlite driver. + + The other function speeds up the write process. If fast_logging is False, it does so + using only completely safe optimizations. Of fast_logging is True, it also uses + unsafe optimizations. + + """ + + @sql.event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): # noqa: ARG001 + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before absolutely necessary. + dbapi_connection.isolation_level = None + + @sql.event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN DEFERRED") + + @sql.event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): # noqa: ARG001 + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode = WAL") + if fast_logging: + cursor.execute("PRAGMA synchronous = OFF") + else: + cursor.execute("PRAGMA synchronous = NORMAL") + cursor.close() + + +def _configure_reflect(): + """Mark all BLOB dtypes as PickleType with our custom pickle reader. + + Code ist taken from the documentation: https://tinyurl.com/y7q287jr + + """ + + @sql.event.listens_for(sql.Table, "column_reflect") + def _setup_pickletype(inspector, table, column_info): # noqa: ARG001 + if isinstance(column_info["type"], sql.BLOB): + column_info["type"] = sql.PickleType(pickler=RobustPickler) + + +class RobustPickler: + @staticmethod + def loads( + data, + fix_imports=True, # noqa: ARG004 + encoding="ASCII", # noqa: ARG004 + errors="strict", # noqa: ARG004 + buffers=None, # noqa: ARG004 + ): + """Robust pickle loading. + + We first try to unpickle the object with pd.read_pickle. This makes no + difference for non-pandas objects but makes the de-serialization + of pandas objects more robust across pandas versions. If that fails, we use + cloudpickle. If that fails, we return None but do not raise an error. + + See: https://github.com/pandas-dev/pandas/issues/16474 + + """ + try: + res = pd.read_pickle(io.BytesIO(data), compression=None) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + try: + res = cloudpickle.loads(data) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + res = None + tb = get_traceback() + warnings.warn( + f"Unable to read PickleType column from database:\n{tb}\n " + "The entry was replaced by None." + ) + + return res + + @staticmethod + def dumps( + obj, protocol=None, *, fix_imports=True, buffer_callback=None # noqa: ARG004 + ): + return cloudpickle.dumps(obj, protocol=protocol) diff --git a/src/estimagic/logging/read_log.py b/src/estimagic/logging/read_log.py index 339c4ed23..8a1140320 100644 --- a/src/estimagic/logging/read_log.py +++ b/src/estimagic/logging/read_log.py @@ -16,11 +16,11 @@ from pybaum import tree_flatten, tree_unflatten from estimagic.logging.database_utilities import ( - load_database, read_last_rows, read_new_rows, read_specific_row, ) +from estimagic.logging.load_database import load_database from estimagic.parameters.tree_registry import get_registry diff --git a/src/estimagic/optimization/optimize.py b/src/estimagic/optimization/optimize.py index e2a48f0ec..a64224975 100644 --- a/src/estimagic/optimization/optimize.py +++ b/src/estimagic/optimization/optimize.py @@ -6,11 +6,11 @@ from estimagic.exceptions import InvalidFunctionError, InvalidKwargsError from estimagic.logging.database_utilities import ( append_row, - load_database, make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, ) +from estimagic.logging.load_database import load_database from estimagic.optimization.check_arguments import check_optimize_kwargs from estimagic.optimization.error_penalty import get_error_penalty_function from estimagic.optimization.get_algorithm import ( diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index e38ff02a3..fd468ec00 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -3,9 +3,7 @@ import numpy as np import pytest from estimagic.logging.database_utilities import ( - DataBase, append_row, - load_database, make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, @@ -14,6 +12,7 @@ read_table, update_row, ) +from estimagic.logging.load_database import DataBase, load_database from numpy.testing import assert_array_equal diff --git a/tests/optimization/test_multistart.py b/tests/optimization/test_multistart.py index cd9328bdc..d4068e499 100644 --- a/tests/optimization/test_multistart.py +++ b/tests/optimization/test_multistart.py @@ -8,7 +8,8 @@ sos_dict_criterion, sos_scalar_criterion, ) -from estimagic.logging.database_utilities import load_database, read_new_rows +from estimagic.logging.database_utilities import read_new_rows +from estimagic.logging.load_database import load_database from estimagic.logging.read_log import read_steps_table from estimagic.optimization.optimize import maximize, minimize from estimagic.optimization.optimize_result import OptimizeResult From 4271f535c983e1a0af96300332a71a765c966e79 Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 7 Feb 2023 12:55:17 +0100 Subject: [PATCH 4/7] Continue splitting. --- src/estimagic/logging/create_tables.py | 104 ++++++++++++++++++++ src/estimagic/logging/database_utilities.py | 104 -------------------- src/estimagic/optimization/optimize.py | 6 +- tests/logging/test_database_utilities.py | 6 +- 4 files changed, 112 insertions(+), 108 deletions(-) create mode 100644 src/estimagic/logging/create_tables.py diff --git a/src/estimagic/logging/create_tables.py b/src/estimagic/logging/create_tables.py new file mode 100644 index 000000000..bf82e1bea --- /dev/null +++ b/src/estimagic/logging/create_tables.py @@ -0,0 +1,104 @@ +import sqlalchemy as sql + +from estimagic.exceptions import TableExistsError +from estimagic.logging.load_database import RobustPickler + + +def make_optimization_iteration_table(database, if_exists="extend"): + """Generate a table for information that is generated with each function evaluation. + + Args: + database (DataBase): Bound metadata object. + if_exists (str): What to do if the table already exists. Can be "extend", + "replace" or "raise". + + Returns: + database (sqlalchemy.MetaData):Bound metadata object with added table. + + """ + table_name = "optimization_iterations" + _handle_existing_table(database, "optimization_iterations", if_exists) + + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("internal_derivative", sql.PickleType(pickler=RobustPickler)), + sql.Column("timestamp", sql.Float), + sql.Column("exceptions", sql.String), + sql.Column("valid", sql.Boolean), + sql.Column("hash", sql.String), + sql.Column("value", sql.Float), + sql.Column("step", sql.Integer), + sql.Column("criterion_eval", sql.PickleType(pickler=RobustPickler)), + ] + + sql.Table( + table_name, + database.metadata, + *columns, + sqlite_autoincrement=True, + extend_existing=True, + ) + + database.metadata.create_all(database.engine) + + +def _handle_existing_table(database, table_name, if_exists): + assert if_exists in ["replace", "extend", "raise"] + + if table_name in database.metadata.tables: + if if_exists == "replace": + database.metadata.tables[table_name].drop(database.engine) + elif if_exists == "raise": + raise TableExistsError(f"The table {table_name} already exists.") + + +def make_steps_table(database, if_exists="extend"): + table_name = "steps" + _handle_existing_table(database, table_name, if_exists) + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("type", sql.String), # e.g. optimization + sql.Column("status", sql.String), # e.g. running + sql.Column("n_iterations", sql.Integer), # optional + sql.Column( + "name", sql.String + ), # e.g. "optimization-1", "exploration", not unique + ] + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, + ) + database.metadata.create_all(database.engine) + + +def make_optimization_problem_table(database, if_exists="extend"): + table_name = "optimization_problem" + _handle_existing_table(database, table_name, if_exists) + + columns = [ + sql.Column("rowid", sql.Integer, primary_key=True), + sql.Column("direction", sql.String), + sql.Column("params", sql.PickleType(pickler=RobustPickler)), + sql.Column("algorithm", sql.PickleType(pickler=RobustPickler)), + sql.Column("algo_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("numdiff_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("log_options", sql.PickleType(pickler=RobustPickler)), + sql.Column("error_handling", sql.String), + sql.Column("error_penalty", sql.PickleType(pickler=RobustPickler)), + sql.Column("constraints", sql.PickleType(pickler=RobustPickler)), + sql.Column("free_mask", sql.PickleType(pickler=RobustPickler)), + ] + + sql.Table( + table_name, + database.metadata, + *columns, + extend_existing=True, + sqlite_autoincrement=True, + ) + + database.metadata.create_all(database.engine) diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/database_utilities.py index 390236903..d5876a908 100644 --- a/src/estimagic/logging/database_utilities.py +++ b/src/estimagic/logging/database_utilities.py @@ -14,110 +14,6 @@ import sqlalchemy as sql -from estimagic.exceptions import TableExistsError -from estimagic.logging.load_database import RobustPickler - - -def make_optimization_iteration_table(database, if_exists="extend"): - """Generate a table for information that is generated with each function evaluation. - - Args: - database (DataBase): Bound metadata object. - if_exists (str): What to do if the table already exists. Can be "extend", - "replace" or "raise". - - Returns: - database (sqlalchemy.MetaData):Bound metadata object with added table. - - """ - table_name = "optimization_iterations" - _handle_existing_table(database, "optimization_iterations", if_exists) - - columns = [ - sql.Column("rowid", sql.Integer, primary_key=True), - sql.Column("params", sql.PickleType(pickler=RobustPickler)), - sql.Column("internal_derivative", sql.PickleType(pickler=RobustPickler)), - sql.Column("timestamp", sql.Float), - sql.Column("exceptions", sql.String), - sql.Column("valid", sql.Boolean), - sql.Column("hash", sql.String), - sql.Column("value", sql.Float), - sql.Column("step", sql.Integer), - sql.Column("criterion_eval", sql.PickleType(pickler=RobustPickler)), - ] - - sql.Table( - table_name, - database.metadata, - *columns, - sqlite_autoincrement=True, - extend_existing=True, - ) - - database.metadata.create_all(database.engine) - - -def _handle_existing_table(database, table_name, if_exists): - assert if_exists in ["replace", "extend", "raise"] - - if table_name in database.metadata.tables: - if if_exists == "replace": - database.metadata.tables[table_name].drop(database.engine) - elif if_exists == "raise": - raise TableExistsError(f"The table {table_name} already exists.") - - -def make_steps_table(database, if_exists="extend"): - table_name = "steps" - _handle_existing_table(database, table_name, if_exists) - columns = [ - sql.Column("rowid", sql.Integer, primary_key=True), - sql.Column("type", sql.String), # e.g. optimization - sql.Column("status", sql.String), # e.g. running - sql.Column("n_iterations", sql.Integer), # optional - sql.Column( - "name", sql.String - ), # e.g. "optimization-1", "exploration", not unique - ] - sql.Table( - table_name, - database.metadata, - *columns, - extend_existing=True, - sqlite_autoincrement=True, - ) - database.metadata.create_all(database.engine) - - -def make_optimization_problem_table(database, if_exists="extend"): - table_name = "optimization_problem" - _handle_existing_table(database, table_name, if_exists) - - columns = [ - sql.Column("rowid", sql.Integer, primary_key=True), - sql.Column("direction", sql.String), - sql.Column("params", sql.PickleType(pickler=RobustPickler)), - sql.Column("algorithm", sql.PickleType(pickler=RobustPickler)), - sql.Column("algo_options", sql.PickleType(pickler=RobustPickler)), - sql.Column("numdiff_options", sql.PickleType(pickler=RobustPickler)), - sql.Column("log_options", sql.PickleType(pickler=RobustPickler)), - sql.Column("error_handling", sql.String), - sql.Column("error_penalty", sql.PickleType(pickler=RobustPickler)), - sql.Column("constraints", sql.PickleType(pickler=RobustPickler)), - sql.Column("free_mask", sql.PickleType(pickler=RobustPickler)), - ] - - sql.Table( - table_name, - database.metadata, - *columns, - extend_existing=True, - sqlite_autoincrement=True, - ) - - database.metadata.create_all(database.engine) - - # ====================================================================================== diff --git a/src/estimagic/optimization/optimize.py b/src/estimagic/optimization/optimize.py index a64224975..429c9ec98 100644 --- a/src/estimagic/optimization/optimize.py +++ b/src/estimagic/optimization/optimize.py @@ -4,12 +4,14 @@ from estimagic.batch_evaluators import process_batch_evaluator from estimagic.exceptions import InvalidFunctionError, InvalidKwargsError -from estimagic.logging.database_utilities import ( - append_row, +from estimagic.logging.create_tables import ( make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, ) +from estimagic.logging.database_utilities import ( + append_row, +) from estimagic.logging.load_database import load_database from estimagic.optimization.check_arguments import check_optimize_kwargs from estimagic.optimization.error_penalty import get_error_penalty_function diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index fd468ec00..35631bcb1 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -2,11 +2,13 @@ import numpy as np import pytest -from estimagic.logging.database_utilities import ( - append_row, +from estimagic.logging.create_tables import ( make_optimization_iteration_table, make_optimization_problem_table, make_steps_table, +) +from estimagic.logging.database_utilities import ( + append_row, read_last_rows, read_new_rows, read_table, From 0618c9db2a4f07b5854760670cc81fb7ace23b6a Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 7 Feb 2023 12:58:48 +0100 Subject: [PATCH 5/7] Continue splitting. --- src/estimagic/logging/database_utilities.py | 39 ------------------ src/estimagic/logging/write_to_database.py | 41 +++++++++++++++++++ src/estimagic/optimization/get_algorithm.py | 2 +- .../internal_criterion_template.py | 2 +- .../optimization/optimization_logging.py | 3 +- src/estimagic/optimization/optimize.py | 4 +- tests/logging/test_database_utilities.py | 3 +- 7 files changed, 47 insertions(+), 47 deletions(-) create mode 100644 src/estimagic/logging/write_to_database.py diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/database_utilities.py index d5876a908..36d535b21 100644 --- a/src/estimagic/logging/database_utilities.py +++ b/src/estimagic/logging/database_utilities.py @@ -14,45 +14,6 @@ import sqlalchemy as sql -# ====================================================================================== - - -def update_row(data, rowid, table_name, database): - table = database.metadata.tables[table_name] - stmt = sql.update(table).where(table.c.rowid == rowid).values(**data) - - _execute_write_statement(stmt, database) - - -def append_row(data, table_name, database): - """ - - Args: - data (dict): The keys correspond to columns in the database table. - table_name (str): Name of the database table to which the row is added. - database (DataBase): The database to which the row is added. - - """ - - stmt = database.metadata.tables[table_name].insert().values(**data) - - _execute_write_statement(stmt, database) - - -def _execute_write_statement(statement, database): - try: - # this will automatically roll back the transaction if any exception is raised - # and then raise the exception - with database.engine.begin() as connection: - connection.execute(statement) - except (KeyboardInterrupt, SystemExit): - raise - except Exception: - exception_info = traceback.format_exc() - warnings.warn( - f"Unable to write to database. The traceback was:\n\n{exception_info}" - ) - def read_new_rows( database, diff --git a/src/estimagic/logging/write_to_database.py b/src/estimagic/logging/write_to_database.py new file mode 100644 index 000000000..b1e64c639 --- /dev/null +++ b/src/estimagic/logging/write_to_database.py @@ -0,0 +1,41 @@ +import traceback +import warnings + +import sqlalchemy as sql + + +def update_row(data, rowid, table_name, database): + table = database.metadata.tables[table_name] + stmt = sql.update(table).where(table.c.rowid == rowid).values(**data) + + _execute_write_statement(stmt, database) + + +def append_row(data, table_name, database): + """ + + Args: + data (dict): The keys correspond to columns in the database table. + table_name (str): Name of the database table to which the row is added. + database (DataBase): The database to which the row is added. + + """ + + stmt = database.metadata.tables[table_name].insert().values(**data) + + _execute_write_statement(stmt, database) + + +def _execute_write_statement(statement, database): + try: + # this will automatically roll back the transaction if any exception is raised + # and then raise the exception + with database.engine.begin() as connection: + connection.execute(statement) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + exception_info = traceback.format_exc() + warnings.warn( + f"Unable to write to database. The traceback was:\n\n{exception_info}" + ) diff --git a/src/estimagic/optimization/get_algorithm.py b/src/estimagic/optimization/get_algorithm.py index f4f340bc5..455f5f62e 100644 --- a/src/estimagic/optimization/get_algorithm.py +++ b/src/estimagic/optimization/get_algorithm.py @@ -7,8 +7,8 @@ from estimagic.batch_evaluators import process_batch_evaluator from estimagic.logging.database_utilities import ( list_of_dicts_to_dict_of_lists, - update_row, ) +from estimagic.logging.write_to_database import update_row from estimagic.optimization import ALL_ALGORITHMS from estimagic.utilities import propose_alternatives diff --git a/src/estimagic/optimization/internal_criterion_template.py b/src/estimagic/optimization/internal_criterion_template.py index b15c2b6d2..fa34bf4ae 100644 --- a/src/estimagic/optimization/internal_criterion_template.py +++ b/src/estimagic/optimization/internal_criterion_template.py @@ -3,7 +3,7 @@ from estimagic.differentiation.derivatives import first_derivative from estimagic.exceptions import UserFunctionRuntimeError, get_traceback -from estimagic.logging.database_utilities import append_row +from estimagic.logging.write_to_database import append_row from estimagic.parameters.conversion import aggregate_func_output_to_value diff --git a/src/estimagic/optimization/optimization_logging.py b/src/estimagic/optimization/optimization_logging.py index d5a123cb5..cb2549dbb 100644 --- a/src/estimagic/optimization/optimization_logging.py +++ b/src/estimagic/optimization/optimization_logging.py @@ -1,4 +1,5 @@ -from estimagic.logging.database_utilities import append_row, read_last_rows, update_row +from estimagic.logging.database_utilities import read_last_rows +from estimagic.logging.write_to_database import append_row, update_row def log_scheduled_steps_and_get_ids(steps, logging, database): diff --git a/src/estimagic/optimization/optimize.py b/src/estimagic/optimization/optimize.py index 429c9ec98..6bde57eb6 100644 --- a/src/estimagic/optimization/optimize.py +++ b/src/estimagic/optimization/optimize.py @@ -9,10 +9,8 @@ make_optimization_problem_table, make_steps_table, ) -from estimagic.logging.database_utilities import ( - append_row, -) from estimagic.logging.load_database import load_database +from estimagic.logging.write_to_database import append_row from estimagic.optimization.check_arguments import check_optimize_kwargs from estimagic.optimization.error_penalty import get_error_penalty_function from estimagic.optimization.get_algorithm import ( diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index 35631bcb1..815320fb9 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -8,13 +8,12 @@ make_steps_table, ) from estimagic.logging.database_utilities import ( - append_row, read_last_rows, read_new_rows, read_table, - update_row, ) from estimagic.logging.load_database import DataBase, load_database +from estimagic.logging.write_to_database import append_row, update_row from numpy.testing import assert_array_equal From 7da21082a778ffcf498c158081c746c6fe6f9ea5 Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 7 Feb 2023 13:02:04 +0100 Subject: [PATCH 6/7] Finish splitting. --- src/estimagic/dashboard/callbacks.py | 2 +- src/estimagic/dashboard/dashboard_app.py | 2 +- .../logging/{database_utilities.py => read_from_database.py} | 0 src/estimagic/logging/read_log.py | 4 ++-- src/estimagic/optimization/get_algorithm.py | 2 +- src/estimagic/optimization/optimization_logging.py | 2 +- src/estimagic/optimization/simopt_optimizers.py | 2 +- tests/logging/test_database_utilities.py | 4 ++-- tests/optimization/test_multistart.py | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) rename src/estimagic/logging/{database_utilities.py => read_from_database.py} (100%) diff --git a/src/estimagic/dashboard/callbacks.py b/src/estimagic/dashboard/callbacks.py index b461da25d..85e70a25a 100644 --- a/src/estimagic/dashboard/callbacks.py +++ b/src/estimagic/dashboard/callbacks.py @@ -2,7 +2,7 @@ import numpy as np -from estimagic.logging.database_utilities import read_new_rows, transpose_nested_list +from estimagic.logging.read_from_database import read_new_rows, transpose_nested_list def reset_and_start_convergence( diff --git a/src/estimagic/dashboard/dashboard_app.py b/src/estimagic/dashboard/dashboard_app.py index ca3458a53..987933774 100644 --- a/src/estimagic/dashboard/dashboard_app.py +++ b/src/estimagic/dashboard/dashboard_app.py @@ -11,8 +11,8 @@ from estimagic.dashboard.callbacks import reset_and_start_convergence from estimagic.dashboard.plot_functions import plot_time_series -from estimagic.logging.database_utilities import read_last_rows from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import read_last_rows from estimagic.logging.read_log import read_start_params from estimagic.parameters.parameter_groups import get_params_groups_and_short_names from estimagic.parameters.tree_registry import get_registry diff --git a/src/estimagic/logging/database_utilities.py b/src/estimagic/logging/read_from_database.py similarity index 100% rename from src/estimagic/logging/database_utilities.py rename to src/estimagic/logging/read_from_database.py diff --git a/src/estimagic/logging/read_log.py b/src/estimagic/logging/read_log.py index 8a1140320..5aafd1c09 100644 --- a/src/estimagic/logging/read_log.py +++ b/src/estimagic/logging/read_log.py @@ -15,12 +15,12 @@ import pandas as pd from pybaum import tree_flatten, tree_unflatten -from estimagic.logging.database_utilities import ( +from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import ( read_last_rows, read_new_rows, read_specific_row, ) -from estimagic.logging.load_database import load_database from estimagic.parameters.tree_registry import get_registry diff --git a/src/estimagic/optimization/get_algorithm.py b/src/estimagic/optimization/get_algorithm.py index 455f5f62e..ba495a159 100644 --- a/src/estimagic/optimization/get_algorithm.py +++ b/src/estimagic/optimization/get_algorithm.py @@ -5,7 +5,7 @@ import numpy as np from estimagic.batch_evaluators import process_batch_evaluator -from estimagic.logging.database_utilities import ( +from estimagic.logging.read_from_database import ( list_of_dicts_to_dict_of_lists, ) from estimagic.logging.write_to_database import update_row diff --git a/src/estimagic/optimization/optimization_logging.py b/src/estimagic/optimization/optimization_logging.py index cb2549dbb..50a65c600 100644 --- a/src/estimagic/optimization/optimization_logging.py +++ b/src/estimagic/optimization/optimization_logging.py @@ -1,4 +1,4 @@ -from estimagic.logging.database_utilities import read_last_rows +from estimagic.logging.read_from_database import read_last_rows from estimagic.logging.write_to_database import append_row, update_row diff --git a/src/estimagic/optimization/simopt_optimizers.py b/src/estimagic/optimization/simopt_optimizers.py index 741e5610a..09268e877 100644 --- a/src/estimagic/optimization/simopt_optimizers.py +++ b/src/estimagic/optimization/simopt_optimizers.py @@ -8,7 +8,7 @@ from estimagic.config import IS_SIMOPT_INSTALLED from estimagic.decorators import mark_minimizer -from estimagic.logging.database_utilities import list_of_dicts_to_dict_of_lists +from estimagic.logging.read_from_database import list_of_dicts_to_dict_of_lists from estimagic.optimization.algo_options import ( STOPPING_MAX_CRITERION_EVALUATIONS_GLOBAL, ) diff --git a/tests/logging/test_database_utilities.py b/tests/logging/test_database_utilities.py index 815320fb9..4651593e3 100644 --- a/tests/logging/test_database_utilities.py +++ b/tests/logging/test_database_utilities.py @@ -7,12 +7,12 @@ make_optimization_problem_table, make_steps_table, ) -from estimagic.logging.database_utilities import ( +from estimagic.logging.load_database import DataBase, load_database +from estimagic.logging.read_from_database import ( read_last_rows, read_new_rows, read_table, ) -from estimagic.logging.load_database import DataBase, load_database from estimagic.logging.write_to_database import append_row, update_row from numpy.testing import assert_array_equal diff --git a/tests/optimization/test_multistart.py b/tests/optimization/test_multistart.py index d4068e499..145bbff59 100644 --- a/tests/optimization/test_multistart.py +++ b/tests/optimization/test_multistart.py @@ -8,8 +8,8 @@ sos_dict_criterion, sos_scalar_criterion, ) -from estimagic.logging.database_utilities import read_new_rows from estimagic.logging.load_database import load_database +from estimagic.logging.read_from_database import read_new_rows from estimagic.logging.read_log import read_steps_table from estimagic.optimization.optimize import maximize, minimize from estimagic.optimization.optimize_result import OptimizeResult From 4d11b7145b40dbb0cc103cac4ea27fd5c0bd3d6e Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 7 Feb 2023 13:25:18 +0100 Subject: [PATCH 7/7] Polishing. --- src/estimagic/logging/create_tables.py | 22 ++++++++++----------- src/estimagic/logging/load_database.py | 5 +++-- src/estimagic/logging/read_from_database.py | 9 +-------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/estimagic/logging/create_tables.py b/src/estimagic/logging/create_tables.py index bf82e1bea..28f349ecc 100644 --- a/src/estimagic/logging/create_tables.py +++ b/src/estimagic/logging/create_tables.py @@ -8,7 +8,7 @@ def make_optimization_iteration_table(database, if_exists="extend"): """Generate a table for information that is generated with each function evaluation. Args: - database (DataBase): Bound metadata object. + database (DataBase): DataBase object containing the engine and metadata. if_exists (str): What to do if the table already exists. Can be "extend", "replace" or "raise". @@ -43,16 +43,6 @@ def make_optimization_iteration_table(database, if_exists="extend"): database.metadata.create_all(database.engine) -def _handle_existing_table(database, table_name, if_exists): - assert if_exists in ["replace", "extend", "raise"] - - if table_name in database.metadata.tables: - if if_exists == "replace": - database.metadata.tables[table_name].drop(database.engine) - elif if_exists == "raise": - raise TableExistsError(f"The table {table_name} already exists.") - - def make_steps_table(database, if_exists="extend"): table_name = "steps" _handle_existing_table(database, table_name, if_exists) @@ -102,3 +92,13 @@ def make_optimization_problem_table(database, if_exists="extend"): ) database.metadata.create_all(database.engine) + + +def _handle_existing_table(database, table_name, if_exists): + assert if_exists in ["replace", "extend", "raise"] + + if table_name in database.metadata.tables: + if if_exists == "replace": + database.metadata.tables[table_name].drop(database.engine) + elif if_exists == "raise": + raise TableExistsError(f"The table {table_name} already exists.") diff --git a/src/estimagic/logging/load_database.py b/src/estimagic/logging/load_database.py index 2d77e90aa..b309f1196 100644 --- a/src/estimagic/logging/load_database.py +++ b/src/estimagic/logging/load_database.py @@ -12,7 +12,8 @@ class DataBase: """Class containing everything to work with a logging database. Importantly, the class is pickle-serializable which is important to share it across - multiple processes. + multiple processes. Upon unpickling, it will automatically re-create an engine to + connect to the database. """ @@ -35,7 +36,7 @@ def load_database(path_or_database, fast_logging=False): This is the only acceptable way of loading or creating a database in estimagic! Args: - path (str or pathlib.Path): Path to the database. + path_or_database (str or pathlib.Path): Path to the database or DataBase. fast_logging (bool): If True, use unsafe optimizations to speed up the logging. If False, only use ultra safe optimizations. diff --git a/src/estimagic/logging/read_from_database.py b/src/estimagic/logging/read_from_database.py index 36d535b21..af28ecf6e 100644 --- a/src/estimagic/logging/read_from_database.py +++ b/src/estimagic/logging/read_from_database.py @@ -27,14 +27,10 @@ def read_new_rows( """Read all iterations after last_retrieved up to a limit. Args: - database (DataBase) + database (DataBase): Object containing everything to work with the database. table_name (str): name of the table to retrieve. last_retrieved (int): The last iteration that was retrieved. return_type (str): either "list_of_dicts" or "dict_of_lists". - path (str or pathlib.Path): location of the database file. If the file does - not exist, it will be created. Using a path is much slower than a - MetaData object and we advise to only use it as a fallback. - fast_logging (bool) limit (int): maximum number of rows to extract from the table. stride (int): Only return every n-th row. Default is every row (stride=1). step (int): Only return iterations that belong to step. @@ -186,9 +182,6 @@ def _execute_read_statement(database, table_name, statement, return_type): return result -# ====================================================================================== - - def transpose_nested_list(nested_list): """Transpose a list of lists.