from functools import partial
import numpy as np
import pandas as pd
import plotly.express as px
from plotly import graph_objects as go
from pybaum import tree_just_flatten
from estimagic.batch_evaluators import process_batch_evaluator
from estimagic.config import DEFAULT_N_CORES, PLOTLY_TEMPLATE
from estimagic.parameters.conversion import get_converter
from estimagic.parameters.tree_registry import get_registry
from estimagic.visualization.plotting_utilities import combine_plots, get_layout_kwargs
[docs]def slice_plot(
func,
params,
lower_bounds=None,
upper_bounds=None,
func_kwargs=None,
selector=None,
n_cores=DEFAULT_N_CORES,
n_gridpoints=20,
plots_per_row=2,
param_names=None,
share_y=True,
expand_yrange=0.02,
share_x=False,
color="#497ea7",
template=PLOTLY_TEMPLATE,
title=None,
return_dict=False,
make_subplot_kwargs=None,
batch_evaluator="joblib",
):
"""Plot criterion along coordinates at given and random values.
Generates plots for each parameter and optionally combines them into a figure
with subplots.
Args:
criterion (callable): criterion function that takes params and returns a
scalar value or dictionary with the entry "value".
params (pytree): A pytree with parameters.
lower_bounds (pytree): A pytree with same structure as params. Must be
specified and finite for all parameters unless params is a DataFrame
containing with "lower_bound" column.
upper_bounds (pytree): A pytree with same structure as params. Must be
specified and finite for all parameters unless params is a DataFrame
containing with "lower_bound" column.
selector (callable): Function that takes params and returns a subset
of params for which we actually want to generate the plot.
n_cores (int): Number of cores.
n_gridpoins (int): Number of gridpoints on which the criterion function is
evaluated. This is the number per plotted line.
plots_per_row (int): Number of plots per row.
param_names (dict or NoneType): Dictionary mapping old parameter names
to new ones.
share_y (bool): If True, the individual plots share the scale on the
yaxis and plots in one row actually share the y axis.
share_x (bool): If True, set the same range of x axis for all plots and share
the x axis for all plots in one column.
expand_y (float): The ration by which to expand the range of the (shared) y
axis, such that the axis is not cropped at exactly max of Criterion Value.
color: The line color.
template (str): The template for the figure. Default is "plotly_white".
layout_kwargs (dict or NoneType): Dictionary of key word arguments used to
update layout of plotly Figure object. If None, the default kwargs defined
in the function will be used.
title (str): The figure title.
return_dict (bool): If True, return dictionary with individual plots of each
parameter, else, ombine individual plots into a figure with subplots.
make_subplot_kwargs (dict or NoneType): Dictionary of keyword arguments used
to instantiate plotly Figure with multiple subplots. Is used to define
properties such as, for example, the spacing between subplots (governed by
'horizontal_spacing' and 'vertical_spacing'). If None, default arguments
defined in the function are used.
batch_evaluator (str or callable): See :ref:`batch_evaluators`.
Returns:
out (dict or plotly.Figure): Returns either dictionary with individual slice
plots for each parameter or a plotly Figure combining the individual plots.
"""
layout_kwargs = None
if title is not None:
title_kwargs = {"text": title}
else:
title_kwargs = None
if func_kwargs is not None:
func = partial(func, **func_kwargs)
func_eval = func(params)
converter, internal_params = get_converter(
params=params,
constraints=None,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
func_eval=func_eval,
primary_key="value",
scaling=False,
scaling_options=None,
)
n_params = len(internal_params.values)
selected = np.arange(n_params, dtype=int)
if selector is not None:
helper = converter.params_from_internal(selected)
registry = get_registry(extended=True)
selected = np.array(
tree_just_flatten(selector(helper), registry=registry), dtype=int
)
if not np.isfinite(internal_params.lower_bounds[selected]).all():
raise ValueError("All selected parameters must have finite lower bounds.")
if not np.isfinite(internal_params.upper_bounds[selected]).all():
raise ValueError("All selected parameters must have finite upper bounds.")
evaluation_points, metadata = [], []
for pos in selected:
lb = internal_params.lower_bounds[pos]
ub = internal_params.upper_bounds[pos]
grid = np.linspace(lb, ub, n_gridpoints)
name = internal_params.names[pos]
for param_value in grid:
if param_value != internal_params.values[pos]:
meta = {
"name": name,
"Parameter Value": param_value,
}
x = internal_params.values.copy()
x[pos] = param_value
point = converter.params_from_internal(x)
evaluation_points.append(point)
metadata.append(meta)
batch_evaluator = process_batch_evaluator(batch_evaluator)
func_values = batch_evaluator(
func=func,
arguments=evaluation_points,
error_handling="continue",
n_cores=n_cores,
)
# add NaNs where an evaluation failed
func_values = [
converter.func_to_internal(val) if not isinstance(val, str) else np.nan
for val in func_values
]
func_values += [converter.func_to_internal(func_eval)] * len(selected)
for pos in selected:
meta = {
"name": internal_params.names[pos],
"Parameter Value": internal_params.values[pos],
}
metadata.append(meta)
plot_data = pd.DataFrame(metadata)
plot_data["Function Value"] = func_values
if param_names is not None:
plot_data["name"] = plot_data["name"].replace(param_names)
lb = plot_data["Function Value"].min()
ub = plot_data["Function Value"].max()
y_range = ub - lb
yaxis_ub = ub + y_range * expand_yrange
yaxis_lb = lb - y_range * expand_yrange
layout_kwargs = get_layout_kwargs(
layout_kwargs,
None,
title_kwargs,
template,
False,
)
plots_dict = {}
for pos in selected:
par_name = internal_params.names[pos]
if param_names is not None and par_name in param_names:
par_name = param_names[par_name]
df = plot_data[plot_data["name"] == par_name].sort_values("Parameter Value")
subfig = px.line(
df,
y="Function Value",
x="Parameter Value",
color_discrete_sequence=[color],
)
subfig.add_trace(
go.Scatter(
x=[internal_params.values[pos]],
y=[converter.func_to_internal(func_eval)],
marker={"color": color},
)
)
subfig.update_layout(**layout_kwargs)
subfig.update_xaxes(title={"text": par_name})
subfig.update_yaxes(title={"text": "Function Value"})
if share_y is True:
subfig.update_yaxes(range=[yaxis_lb, yaxis_ub])
plots_dict[par_name] = subfig
if return_dict:
out = plots_dict
else:
plots = list(plots_dict.values())
out = combine_plots(
plots=plots,
plots_per_row=plots_per_row,
sharex=share_x,
sharey=share_y,
share_yrange_all=share_y,
share_xrange_all=share_x,
expand_yrange=expand_yrange,
make_subplot_kwargs=make_subplot_kwargs,
showlegend=False,
template=template,
clean_legend=True,
layout_kwargs=layout_kwargs,
legend_kwargs={},
title_kwargs=title_kwargs,
)
return out