Source code for estimagic.logging.read_log

"""Functions to read data from the database used for logging.

The functions in the module are meant for end users of estimagic.
They do not require any knowledge of databases.

When using them internally (e.g. in the dashboard), make sure to supply a database to
path_or_database. Otherwise, the functions may be very slow.

"""
from dataclasses import dataclass
from pathlib import Path
from typing import Union

import numpy as np
import pandas as pd
from estimagic.logging.database_utilities import load_database
from estimagic.logging.database_utilities import read_last_rows
from estimagic.logging.database_utilities import read_new_rows
from estimagic.logging.database_utilities import read_specific_row
from estimagic.parameters.tree_registry import get_registry
from pybaum import tree_flatten
from pybaum import tree_unflatten
from sqlalchemy import MetaData


def read_start_params(path_or_database):
    """Load the start parameters DataFrame.

    Args:
        path_or_database (pathlib.Path, str or sqlalchemy.MetaData)

    Returns:
        params (pd.DataFrame): see :ref:`params`.

    """
    database = _load_database(path_or_database)
    optimization_problem = read_last_rows(
        database=database,
        table_name="optimization_problem",
        n_rows=1,
        return_type="dict_of_lists",
    )
    start_params = optimization_problem["params"][0]
    return start_params


def _load_database(path_or_database):
    """Get an sqlalchemy.MetaDate object from path or database."""

    res = {"path": None, "metadata": None, "fast_logging": False}
    if isinstance(path_or_database, MetaData):
        res = path_or_database
    elif isinstance(path_or_database, (Path, str)):
        path = Path(path_or_database)
        if not path.exists():
            raise FileNotFoundError(f"No such database file: {path}")
        res = load_database(path=path)
    else:
        raise ValueError(
            "path_or_database must be a path or sqlalchemy.MetaData object"
        )
    return res


def read_steps_table(path_or_database):
    """Load the steps table.

    Args:
        path_or_database (pathlib.Path, str or sqlalchemy.MetaData)

    Returns:
        steps_df (pandas.DataFrame)

    """
    database = _load_database(path_or_database)
    steps_table, _ = read_new_rows(
        database=database,
        table_name="steps",
        last_retrieved=0,
        return_type="list_of_dicts",
    )
    steps_df = pd.DataFrame(steps_table)

    return steps_df


def read_optimization_problem_table(path_or_database):
    """Load the start parameters DataFrame.

    Args:
        path_or_database (pathlib.Path, str or sqlalchemy.MetaData)

    Returns:
        params (pd.DataFrame): see :ref:`params`.

    """
    database = _load_database(path_or_database)
    steps_table, _ = read_new_rows(
        database=database,
        table_name="optimization_problem",
        last_retrieved=0,
        return_type="list_of_dicts",
    )
    steps_df = pd.DataFrame(steps_table)

    return steps_df


[docs]@dataclass class OptimizeLogReader: """Read information about an optimization from a sqlite database.""" path: Union[str, Path] def __post_init__(self): _database = _load_database(self.path) _start_params = read_start_params(_database) _registry = get_registry(extended=True) _, _treedef = tree_flatten(_start_params, registry=_registry) self._database = _database self._registry = _registry self._treedef = _treedef self._start_params = _start_params def read_iteration(self, iteration): out = _read_optimization_iteration( database=self._database, iteration=iteration, params_treedef=self._treedef, registry=self._registry, ) return out def read_history(self): out = _read_optimization_history( database=self._database, params_treedef=self._treedef, registry=self._registry, ) return out def read_start_params(self): return self._start_params
def _read_optimization_iteration(database, iteration, params_treedef, registry): """Get information about an optimization iteration.""" if iteration >= 0: rowid = iteration + 1 else: last_iteration = read_last_rows( database=database, table_name="optimization_iterations", n_rows=1, return_type="list_of_dicts", ) highest_rowid = last_iteration[0]["rowid"] # iteration is negative here! rowid = highest_rowid + iteration + 1 data = read_specific_row( database, table_name="optimization_iterations", rowid=rowid, return_type="list_of_dicts", ) if len(data) == 0: raise IndexError(f"Invalid iteration requested: {iteration}") else: data = data[0] params = tree_unflatten(params_treedef, data["params"], registry=registry) data["params"] = params return data def _read_optimization_history(database, params_treedef, registry): """Read a histories out values, parameters and other information.""" raw_res, _ = read_new_rows( database=database, table_name="optimization_iterations", last_retrieved=0, return_type="list_of_dicts", ) history = {"params": [], "criterion": [], "runtime": []} for data in raw_res: if data["value"] is not None: params = tree_unflatten(params_treedef, data["params"], registry=registry) history["params"].append(params) history["criterion"].append(data["value"]) history["runtime"].append(data["timestamp"]) times = np.array(history["runtime"]) times -= times[0] history["runtime"] = times return history