Source code for estimagic.logging.create_database

"""Functions to create new databases or load existing ones.

Note: Most functions in this module change their arguments in place since this is the
recommended way of doing things in sqlalchemy and makes sense for database code.

"""
from pathlib import Path

import numpy as np
import pandas as pd
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import Float
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import PickleType
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.dialects.sqlite import DATETIME

from estimagic.logging.update_database import append_rows


def load_database(path):
    """Return database metadata object for the database stored in ``path``.

    This is the default way of loading a database for read-only purposes in estimagic.

    Args:
        path (str or pathlib.Path): location of the database file. If the file does
            not exist, it will be created.

    Returns:
        database (sqlalchemy.MetaData). The engine that connects to the database can be
            accessed via ``database.bind``.

    """
    if isinstance(path, str):
        path = Path(path)

    if isinstance(path, Path):
        engine = create_engine(f"sqlite:///{path}")
        _make_engine_thread_safe(engine)
        database = MetaData()
        database.bind = engine
        database.reflect()
    elif isinstance(path, MetaData):
        database = path
    else:
        TypeError("'path' is neither a pathlib.Path nor a sqlalchemy.MetaData.")

    return database


def _make_engine_thread_safe(engine):
    """Make the engine even more thread safe than by default.

    The code is taken from the documentation: https://tinyurl.com/u9xea5z

    The main purpose is to emit the begin statement of any connection
    as late as possible in order to keep the time in which the database
    is locked as short as possible.

    """

    @event.listens_for(engine, "connect")
    def do_connect(dbapi_connection, connection_record):
        # disable pysqlite's emitting of the BEGIN statement entirely.
        # also stops it from emitting COMMIT before absolutely necessary.
        dbapi_connection.isolation_level = None

    @event.listens_for(engine, "begin")
    def do_begin(conn):
        # emit our own BEGIN
        conn.execute("BEGIN DEFERRED")


[docs]def prepare_database( path, params, comparison_plot_data=None, dash_options=None, constraints=None, optimization_status="scheduled", gradient_status=0, ): """Return database metadata object with all relevant tables for the optimization. This should always be used to create entirely new databases or to create the tables needed during optimization in an existing database. A new database is created if path does not exist yet. Otherwise the existing database is loaded and all tables needed to log the optimization are overwritten. Other tables remain unchanged. The resulting database has the following tables: - params_history: the complete history of parameters from the optimization. The index column is "iteration", the remaining columns are parameter names taken from params["name"]. - gradient_history: the complete history of gradient evaluations from the optimization. Same columns as params_history. - criterion_history: the complete history of criterion values from the optimization. The index column is "iteration", the second column is "value". - time_stamps: timestamps from the end of each criterion evaluation. Same columns as criterion_history. - convergence_history: the complete history of convergence criteria from the optimization. The index column is "iteration", the other columns are "ftol", "gtol" and "xtol". - start_params: copy of user provided ``params``. This is not just the first entry of params_history because it contains all columns and has a different index. - optimization_status: table with one row and one column called "value" which takes the values "scheduled", "running", "success" or "failure". Initialized to ``optimization_status``. - gradient_status: table with one row and one column called "value" which can be any float between 0 and 1 and indicates the progress of the gradient calculation. Initialized to ``gradient_status`` - dash_options: table with one row and one column called "value". It contains a dictionary with the dashboard options. Internally this is a PickleType, so the dictionary must be pickle serializable. Initialized to dash_options. - exceptions: table with one column called "value" with exceptions. - constraints: table with one row and one column called "value". It contains the list of constraints. Internally this is a PickleType, so the list must be pickle serializable. Args: path (str or pathlib.Path): location of the database file. If the file does not exist, it will be created. params (pd.DataFrame): see :ref:`params`. comparison_plot_data : (numpy.ndarray or pandas.Series or pandas.DataFrame): Contains the data for the comparison plot. Later updates will only deliver the value column where as this input has an index and other invariant information. dash_options (dict): Dictionary with the dashboard options. optimization_status (str): One of "scheduled", "running", "success", "failure". gradient_status (float): Progress of gradient calculation between 0 and 1. constraints (list): List of constraints. Returns: database (sqlalchemy.MetaData). The engine that connects to the database can be accessed via ``database.bind``. """ gradient_status = float(gradient_status) database = load_database(path) opt_tables = [ "params_history", "gradient_history", "criterion_history", "timestamps", "convergence_history", "start_params", "comparison_plot", "optimization_status", "gradient_status", "dash_options", "exceptions", "constraints", ] for table in opt_tables: if table in database.tables: database.tables[table].drop(database.bind) _define_table_formatted_with_params(database, params, "params_history") _define_table_formatted_with_params(database, params, "gradient_history") _define_fitness_history_table(database, "criterion_history") _define_time_stamps_table(database) _define_convergence_history_table(database) _define_start_params_table(database) _define_one_column_pickle_table(database, "comparison_plot") _define_optimization_status_table(database) _define_gradient_status_table(database) _define_scalar_pickle_table(database, "dash_options") _define_string_table(database, "exceptions") _define_scalar_pickle_table(database, "constraints") engine = database.bind database.create_all(engine) append_rows(database, "start_params", {"value": params}) append_rows(database, "optimization_status", {"value": optimization_status}) append_rows(database, "gradient_status", {"value": gradient_status}) append_rows(database, "dash_options", {"value": dash_options}) append_rows(database, "constraints", {"value": constraints}) if comparison_plot_data is None: comparison_plot_data = pd.DataFrame({"value": [np.nan]}) append_rows(database, "comparison_plot", {"value": comparison_plot_data}) return database
def _define_table_formatted_with_params(database, params, table_name): cols = [Column(name, Float) for name in params["name"]] values = Table( table_name, database, Column("iteration", Integer, primary_key=True), *cols, sqlite_autoincrement=True, extend_existing=True, ) return values def _define_fitness_history_table(database, table_name): critvals = Table( table_name, database, Column("iteration", Integer, primary_key=True), Column("value", Float), sqlite_autoincrement=True, extend_existing=True, ) return critvals def _define_time_stamps_table(database): tstamps = Table( "timestamps", database, Column("iteration", Integer, primary_key=True), Column("value", DATETIME), sqlite_autoincrement=True, extend_existing=True, ) return tstamps def _define_convergence_history_table(database): names = ["ftol", "gtol", "xtol"] cols = [Column(name, Float) for name in names] term = Table( "convergence_history", database, Column("iteration", Integer, primary_key=True), *cols, sqlite_autoincrement=True, extend_existing=True, ) return term def _define_start_params_table(database): start_params_table = Table( "start_params", database, Column("value", PickleType), extend_existing=True ) return start_params_table def _define_one_column_pickle_table(database, table): params_table = Table( table, database, Column("iteration", Integer, primary_key=True), Column("value", PickleType), sqlite_autoincrement=True, extend_existing=True, ) return params_table def _define_optimization_status_table(database): optstat = Table( "optimization_status", database, Column("value", String), extend_existing=True ) return optstat def _define_gradient_status_table(database): gradstat = Table( "gradient_status", database, Column("value", Float), extend_existing=True ) return gradstat def _define_scalar_pickle_table(database, name): dash_options = Table( name, database, Column("value", PickleType), extend_existing=True ) return dash_options def _define_string_table(database, name): exception_table = Table( name, database, Column("value", String), extend_existing=True ) return exception_table