Source code for estimagic.visualization.history_plots

from pathlib import Path

import numpy as np
import plotly.graph_objects as go
from estimagic.config import PLOTLY_TEMPLATE
from estimagic.logging.read_log import OptimizeLogReader
from estimagic.logging.read_log import read_optimization_problem_table
from estimagic.optimization.history_tools import get_history_arrays
from estimagic.optimization.optimize_result import OptimizeResult
from estimagic.parameters.tree_registry import get_registry
from pybaum import leaf_names
from pybaum import tree_flatten
from pybaum import tree_just_flatten
from pybaum import tree_unflatten


[docs]def criterion_plot( res, max_evaluations=None, template=PLOTLY_TEMPLATE, highlight_color="#497ea7", base_color="#bab0ac", monotone=False, ): """Plot the criterion history of an optimization. Args: res (OptimizeResult): An optimization result with collected history. max_evaluations (int): Clip the criterion history after that many entries. template (str): A plotly template. highlight_color (str): Hex code of the line color. base_color (str): Hex code of the line color for local optimizations in a multistart optimization. monotone (bool): If True, the criterion plot becomes monotone in the sense that only that at each iteration the current best criterion value is displayed. """ if isinstance(res, OptimizeResult): if res.history is None: raise ValueError( "Criterion_plot requires a optimize_result with history. " "Enable history collection by setting collect_history=True " "when calling maximize or minimize." ) is_multistart = res.multistart_info is not None main_history = res.history direction = res.direction elif isinstance(res, (str, Path)): reader = OptimizeLogReader(res) main_history = reader.read_history() _problem_table = read_optimization_problem_table(res) direction = _problem_table["direction"].tolist()[-1] # For now we ignore multistart when getting history from a log is_multistart = False else: raise ValueError("res must be an OptimizeResult or a path to a log file.") key = "monotone_criterion" if monotone else "criterion" fig = go.Figure() if is_multistart: scatter_kws = { "connectgaps": True, "showlegend": False, } line_kws = { "color": base_color, } for i, opt in enumerate(res.multistart_info["local_optima"]): history = get_history_arrays(opt.history, direction)[key] if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] trace = go.Scatter( x=np.arange(len(history)), y=history, mode="lines", name=str(i), line=line_kws, **scatter_kws, ) fig.add_trace(trace) history = get_history_arrays(main_history, direction)[key] if max_evaluations is not None and len(main_history) > max_evaluations: history = history[:max_evaluations] scatter_kws = { "connectgaps": True, "showlegend": is_multistart, } line_kws = { "color": highlight_color, } trace = go.Scatter( x=np.arange(len(history)), y=history, mode="lines", name="best result", line=line_kws, **scatter_kws, ) fig.add_trace(trace) fig.update_layout( template=template, xaxis_title_text="No. of criterion evaluations", yaxis_title_text="Criterion value", ) return fig
[docs]def params_plot( res, selector=None, max_evaluations=None, template=PLOTLY_TEMPLATE, ): """Plot the params history of an optimization. Args: res (OptimizeResult): An optimization result with collected history. selector (callable): A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations (int): Clip the criterion history after that many entries. template (str): A plotly template. """ if isinstance(res, OptimizeResult): if res.history is None: raise ValueError( "params_plot requires a optimize_result with history. " "Enable history collection by setting collect_history=True " "when calling maximize or minimize." ) history = res.history["params"] start_params = res.start_params elif isinstance(res, (str, Path)): reader = OptimizeLogReader(res) start_params = reader.read_start_params() history = reader.read_history()["params"] else: raise ValueError("res must be an OptimizeResult or a path to a log file.") fig = go.Figure() registry = get_registry(extended=True) hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T names = leaf_names(start_params, registry=registry) if selector is not None: flat, treedef = tree_flatten(start_params, registry=registry) helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) selected = np.array(tree_just_flatten(selector(helper), registry=registry)) names = [names[i] for i in selected] hist_arr = hist_arr[selected] for name, data in zip(names, hist_arr): if max_evaluations is not None and len(data) > max_evaluations: data = data[:max_evaluations] trace = go.Scatter( x=np.arange(len(data)), y=data, mode="lines", name=name, ) fig.add_trace(trace) fig.update_layout( template=template, xaxis_title_text="No. of criterion evaluations", yaxis_title_text="Parameter value", ) return fig