EEP-01: Pytrees#


Janos Gabler




Standards Track





This EEP explains how we will use pytrees to allow for more flexible specification of parameters for optimization or differentiation, more convenient ways of writing moment functions for msm estimation and more. The actual code to work with pytrees is implemented in Pybaum, developed by @janosg and @tobiasraabe.

Backwards compatibility#

All changes are fully backwards compatible.


Estimagic has many functions that require user written functions as inputs. Examples are:

  • criterion functions and their derivatives for optimization

  • functions of which numerical derivatives are taken

  • functions that calculate simulated moments

  • functions that calculate bootstrap statistics

In all cases, there are some restrictions on possible inputs and outputs of the user written functions. For example, parameters for numerical optimization need to be provided as pandas.DataFrame with a "value" column. Simulated moments and bootstrap statistics need to be returned as a pandas.Series, etc.

Pytrees allow to relax many of those restrictions on interfaces of user provided functions. This is not only more convenient for users, but sometimes also allows to reduce overhead because the user can choose optimal data structures for their problem.

Background: What is a pytree#

Pytree is a term used in TensorFlow and JAX to refer to a tree-like structure built out of container-like Python objects with arbitrary levels of nesting.

What is a container can be re-defined for each application. By default, lists, tuples and dicts are considered containers and everything else is a leaf. Then the following are examples of pytrees:

[1, "a", np.arange(3)]  # 3 leaves

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

np.arange(5)  # 1 leaf

What makes pytrees so powerful are the operations defined for them. The most important ones are:

  • tree_flatten: Convert any pytree into a flat list of leaves + metadata

  • tree_unflatten: The inverse of tree_flatten

  • tree_map: Apply a function to all leaves in a pytree

  • leaf_names: Generate a list of names for all leaves in a pytree

The above examples of pytrees would look as follows when flattened (with a default definition of containers):

[1, "a", np.arange(3)]

[1, 2, 3, 4, 5]


By adding numpy arrays to the registry of container like objects, each of the three examples above would have five leafs. The flattened versions would look as follows:

[1, "a", 0, 1, 2]

[1, 2, 3, 4, 5]

[0, 1, 2, 3, 4]

Needless to say, it is possible to register anything as container. For example, we would add pandas.Series and pandas.DataFrame (with varying definitions, depending on the application).

Difference between pytrees in JAX and estimagic#

Most JAX functions only work with Pytrees of arrays and scalars, i.e. pytrees where container types are dicts, lists and tuples and all leaves are arrays or scalars. We will just call them pytrees of arrays because scalars are converted to arrays by JAX.

There are two ways to look at such pytrees:

  1. As pytree of arrays -> tree_flatten produces a list of arrays

  2. As pytree of numbers -> tree_flatten produces a list of numbers

The only difference between the two perspectives is that for the second one, arrays have been registered as container types that can be flattened. In JAX the term ravel instead of flatten is sometimes used to make clear that the second perspective is meant.

Estimagic functions work with slightly more general pytrees. On top of arrays, they can also contain scalars, pandas.Series and pandas.DataFrames.

Again, there are two possible ways to look at such pytrees:

  1. As pytree of arrays, numbers, Series and DataFrames -> tree_flatten produces a list of arrays numbers, Series and DataFrames.

  2. As pytree of numbers -> tree_flatten produces a list of numbers

Again, the difference between the two is which objects are registered as container types and the rules to flatten and unflatten them are defined.

While numpy arrays, scalars and pandas.Series have only one natural way of defining the flattening rules, this becomes more complex for DataFrames due to the way params DataFrames were used in estimagic before.

We define the following rules: If a DataFrame contains a column called "value", we interpret them as classical estimagic DataFrame and only consider the entries in the "value" column when flattening the DataFrame into a list of numbers. If there is no column "value", all numeric columns of the DataFrame are considered.

Note that internally, we will sometimes define flattening rules such that only some other columnn, e.g. only "lower_bound" is considered. However we never look at more than one column of a classical estimagic params DataFrame at a time.

To distinguish between the different pytrees we use the terms JAX-pytree and estimagic-pytree.

Optimization with pytrees#

In this section we look at possible ways to specify optimizations when parameters and some outputs of criterion functions can be estimagic-pytrees.

As an example we use a hypothetical criterion function with pytree inputs and outputs to describe how a user can optimize it. We also give a rough intuition what happens behind the scenes.

The criterion function#

Consider a criterion function that takes parameters in the following format:

params = {
    "delta": 0.95,
    "utility": pd.DataFrame(
        [[0.5, 0]] * 3, index=["a", "b", "c"], columns=["value", "lower_bound"]
    "probs": np.array([[0.8, 0.2], [0.3, 0.7]]),

The criterion function returns a dictionary of the form:

    "value": 1.1,
    "contributions": {"a": np.array([0.36, 0.25]), "b": 0.49},
    "root_contributions": {"a": np.array([0.6, 0.5]), "b": 0.7},

Run an optimization#

from estimagic import minimize


The internal optimizer (in this case the lbfgsb algorithm from scipy) will see a wrapped version of crit. That version takes a 1d numpy array as its only argument and returns a scalar float (the "value" entry of the result of crit). Numerical derivatives are also taken on that function.

If instead a derivative based least squares optimizer like "scipy_ls_dogbox" had been used, the internal optimizer would see a modified version of crit that takes a 1d numpy array and returns a 1d numpy array (the flattened version of the "root_contributions" entry of the result of crit).

The optimization output#

The following entries of the output of minimize are affected by the change:

  • "solution_params": A pytree with the same structure as params

  • "solution_criterion": The output dictionary of crit evaluated solution params

  • solution_derivative: Maybe we should not even have this entry.


We need to discuss if and in which form we want to have a solution derivative entry. In it’s current form it is useless if constraints are used. This gets worse when we allow for pytrees and translating this into a meaningful shape might be very difficult.

Add bounds#

Bounds on parameters that are inside a DataFrame with "value" column can simply be specified as before. For all others, there are separate lower_bounds and upper_bounds arguments in maximize and minimize.

lower_bounds and upper_bounds are pytrees of the same structure as params or a subtree that preserves enough structure to match all bounds. For example:

    lower_bounds={"delta": 0},
    upper_bounds={"delta": 1},

This would add bounds for delta, keep the bounds on all "utility" parameters, and leave the "probs" parameters unbounded.

Add a constraint#

Currently, parameters to which a constraint is applied are selected via a "loc" or "query" entry in the constraints dictionary.

This keeps working as long as params are specified as a single DataFrame containing a "value" column. If a more general pytree is used we need a “selector” entry instead. The value of that entry is a callable that takes the pytree and returns selected parameters.

The selector function may return the parameters in the form of an estimagic-pytree. Should order play a role for the constraints (e.g., increasing) the constraint will be applied to the flattened version of the pytree returned by the selector function. However, in the case that order matters, we advise users to return one-dimensional arrays (explicit is better than implicit).

As an example, let’s add probability constraints for each row of "probs":

constraints = [
    {"selector": lambda params: params["probs"][0], "type": "probability"},
    {"selector": lambda params: params["probs"][1], "type": "probability"},


The required changes to support this are relatively simple. This is because most functions that deal with constraints already work with a 1d array of parameters and the "loc" and "query" entries of constraints are internally translated to positions in that array very early on.

Derivatives during optimization#

If numerical derivatives are used, they are already taken on a modified function that maps from 1d numpy array to scalars or 1d numpy arrays. Allowing for estimagic-pytrees in parameters and criterion outputs will not pose any difficulties here.

Closed form derivatives need to have the following interface: They expect params in the exact same format as the criterion function as first argument. They return a derivative in the same format as our numerical derivative functions or JAXs autodiff functions when applied to the criterion function.

Numerical derivatives with pytrees#

Problem: Higher dimensional extensions of pytrees#

The derivative of a function that maps from a 1d array to a 1d array (usually called Jacobian) is a 2d matrix. If the 1d arrays are replaced by pytrees, we need a two dimensional extension of the pytrees. Below we will look at how JAX does this and why we cannot simply copy that solution.

The JAX solution#

Let’s look at an example. We first define a function in terms of 1d arrays and then in terms of pytrees and look at a JAX calculated jacobian in both cases:

def square(x):
    return x**2

x = jnp.array([1, 2, 3, 4, 5, 6.0])

DeviceArray([[ 2.,  0.,  0.,  0.,  0.,  0],
             [ 0.,  4.,  0.,  0.,  0.,  0],
             [ 0.,  0.,  6.,  0.,  0.,  0],
             [ 0.,  0.,  0.,  8.,  0.,  0],
             [ 0.,  0.,  0.,  0., 10.,  0],
             [ 0.,  0.,  0.,  0.,  0., 12]], dtype=float32)
def tree_square(x):
    out = {
        "c": x["a"] ** 2,
        "d": x["b"].flatten() ** 2,

    return out

tree_x = {"a": jnp.array([1, 2.0]), "b": jnp.array([[3, 4], [5, 6.0]])}


Instead of showing the entire results, let’s just look at the resulting tree structure and array shapes:

    "c": {
        "a": (2, 2),
        "b": (2, 2, 2),
    "d": {
        "a": (4, 2),
        "b": (4, 2, 2),

The outputs for hessians have even deeper nesting and three dimensional arrays inside the nested dictionary. Similarly, we would get higher dimensional arrays if one of the original pytrees had already contained a 2d array.

Extending the JAX solution to estimagic-pytrees#

JAX pytrees can only contain arrays, whereas estimagic-pytrees may contain scalars, pandas.Series and pandas.DataFrames (with or without "value" column). Unfortunately, this poses non-trivial challenges for numerical derivatives because those data types have no natural extension in arbtirary dimensions.

Our solution needs to fulfill two requirements:

1. Compatible with JAX in the sense than whenever a derivative can be calculated with JAX it can also be calculated with estimagic and the result has the same structure. 2. Compatible with the rest of estimagic in the sense that any function that can be optimized can also be differentiated. In the special case of differentiating with respect to a DataFrame it also needs to be backwards compatible.

A solution that achieves this is to treat Series and DataFrames with "value" columns as 1d arrays and other DataFrames as 2d arrays, then proceed as in JAX and finally try to preserve as much index and column information as possible.

This leads to very natural results in the typical usecases with flat dicts of Series or params DataFrames both as inputs and outputs and is backwards compatible with everything that is supported already.

However, similar to JAX, not everything that is supported will also be a good idea. Predicting where a pandas Object is preserved and where it will be replaced by an array might be hard for very nested pytrees. However, these rules are mainly defined to avoid hard limitations that have to be checked and documented. Users will learn to avoid too much complexity by avoiding complex pytrees as inputs and outputs at the same time.

To see this in action, let’s look at an example. We repeat the example from the JAX interface above with the following changes:

  1. The 1d numpy array in x[“a”] is replaced by a DataFrame with "value" column

  2. The “d” entry in the output becomes a Series instead of a 1d numpy array.

def pd_tree_square(x):
    out = {
        "c": x["a"]["value"] ** 2,
        "d": pd.Series(x["b"].flatten() ** 2, index=list("jklm")),

    return out

pd_tree_x = {
    "a": pd.DataFrame(data=[[1], [2]], index=["alpha", "beta"], columns=["value"]),
    "b": np.array([[3, 4], [5, 6]]),

        "alpha"    1
        "beta"     4
        dtype: int64,
        "j"        9
        "k"       16
        "l"       25
        "m"       36
        dtype: int64,

The resulting shapes of the jacobian will be the same as before. For all arrays with only two dimensions we can preserve some information from the Series and DataFrame indices. On the higher dimensional ones, this will be lost.

    "c": {
        "a": (2, 2),  # df with columns ["alpha", "beta"], index ["alpha", "beta"]
        "b": (2, 2, 2),  # numpy array without label information
    "d": {
        "a": (4, 2),  # columns ["alpha", "beta"], index [0, 1, 2, 3]
        "b": (4, 2, 2),  # numpy array without label information

To get more intuition for the structure of the result, let’s add a few labels to the very first jacobian:





















































The indices [“j”, “k”, “l”, “m”] unfortunately never made it into the result because they were only applied to elements that already came from a 2d array and thus always have a 3d Jacobian, i.e. the result entry ["c"][b"] is a reshaped version of the upper right 2 by 4 array and the result entry ["d"]["b"] is a reshaped version of the lower right 4 by 4 array.


We use the following terminology to describe the implementation:

  • input_tree: The pytree containing parameters, i.e. inputs to the function that is differentiated.

  • output_tree: The pytree that is returned by the function being differentiated

  • derivative_tree: The pytree we want to generate, i.e. the pytree that would be returned by JAX jacobian.

  • flat_derivative: The matrix version of the derivative_tree

To simply reproduce the JAX behavior with pytrees of arrays, we could proceed in the following steps:

  • Create a modified function that maps from 1d array to 1d array

  • Calculate flat_derivative by taking numerical derivatives just as before

  • Calculate the shapes of all arrays in derivative_tree by concatenating the shapes of the cartesian product of flattend output_tree and input_tree

  • Calculate the 2d versions of those arrays by taking the product over elements in the shape tuple before concatenating.

  • Create a list of lists containing all arrays that will be in derivative_tree. The values are taken from flat_derivative, using the previously calculated shapes.

  • call tree_unflatten on the inner lists with the treedef corresponding to input_tree.

  • call tree_unflatten on the result of that with the treedef corresponding to output_tree.

To implement the extension to estimagic pytrees we would probably do exactly the same but have a bit more preparation and post-processing to do.

General aspects of pytrees in estimation functions#

Estimation summaries#

Currently, estimation summaries are DataFrames. The estimated parameters are in the "value" column. There are other columns with standard errors, p-values, significance stars and confidence intervals.

This is another form of higher dimensional extension of pytrees, where we need to add additional columns. There are two ways in which estimation summaries could be presented. I suggest we offer both. The first is more geared towards generating estimation tables and serving as actual summary to be looked at in a jupyter notebook. It is also backwards compatible and should thus be the default. The second is more geared towards further calculations. There will be utility functions to convert between the two.

Both formats will be explained using the params pytree from the optimization example (reproduced here for convenience):

Format 1: Everything becomes a DataFrame#

In this approach we do the following conversions:

  1. numpy arrays are flattened and converted to DataFrames with one column called "value". The index contains the original positions of elements.

  2. pandas.Series are converted to DataFrames. The index remains unchanged. The column is called "value".

  3. scalars become DataFrames with one row with index 0 and one column called "value".

  4. DataFrames without "value" column are stacked into a DataFrame with just one column called "value".

  5. DataFrames with "value" column are reduced to that column.

After these transformations, all numbers of the original pytree are stored in DataFrames with "value" column. Additional columns with standard errors and the like can then simply be assigned as before.

For more intuition, let’s see how this would look in an example. For simplicity we only add a column with stars and ommit standard errors, p-values and confidence intervals. We use the same example as in the optimization section:

params = {
    "delta": 0.95,
    "utility": pd.DataFrame(
        [[0.5, 0]] * 3, index=["a", "b", "c"], columns=["value", "lower_bound"]
    "probs": np.array([[0.8, 0.2], [0.3, 0.7]]),
          value stars
    0     0.95   ***,
          value stars
    a     0.5    **
    b     0.5    **
    c     0.5    **,
          value stars
    0 0   0.8   ***
      1   0.2     *
    1 0   0.3    **
      1   0.7   ***,

Format 2: Dictionary of pytrees#

The second solution is a dictionary of pytrees the keys are the columns of the current summary but probably in plural, i.e. “values”, “standard_errors”, “p-values”, …;

Each value is a pytree with the exact same structure as params. If this pytree contains DataFrames with "value" column, only that column is updated. i.e. standard errors would be accessed via summary["standard_errors"]["my_df"]["value"].

Representation of covariance matrices#

A covariance matrix is a two dimensional extension of a params pytree. We could theoretically handle it exactly the same way as Jacobians. However, this would not be useful for statistical tests and visualization if it contains more than 2 dimensional arrays (as the Jacobian example does).

We thus propose to have two possible formats in which covariance matrices can be returned:

  1. The pytree variant described in the above Jacobian example. This will be useful to look at sub-matrices of the full covariance matrix as long as the params pytree only contains one dimensional arrays, Series and DataFrames with "value" columns.

  2. A DataFrame containing the covariance matrix of the flattened parameter vector. The index and columns of the DataFrames can be constructed from the leaf_names function in pybaum. We could also triviall add a function there that constructs an index that is easier to work with for selecting elements and let the user choose between the two versions.

The function that maps from the flat version (which would be calculated internally) to the pytree version is the same as we need for numerical derivatives. The inverse of that function is probably not too difficult to implement and can also be useful for derivatives.


Everything that can be used as params in optimization and differentiation can also be used as params in estimation. The registries used in pytree functions are identical.

ML specific aspects of pytrees#

The output of the log likelihood functions is a dictionary with the entries:

  • "value": a scalar float

  • "contributions": a 1d numpy array or pandas.Series

Moreover, there can be arbitrary additional entries.

The only change is that "contributions" can now be any estimagic pytree.

MSM specific aspects of pytrees#

Valid formats of empirical and simulated moments#

There are three types of moments in MSM estimation:

  • empirical moments

  • The output of simulate_moments

  • The output of calculate_moments, needed to get a moments covariance matrix.

We propose that moments can be stored as any valid estimagic pytree but of course all three types of moments have to be aligned, i.e. be stored in a tree of the same structure. We will raise an error if the trees do not have the same structure.

This is a generalization of an interface that has already proven useful in respy, sid and other applications. In the future, the project specific implementations of flatten and unflatten functions could simply be deleted.

Representation of the weighting matrix and moments_cov#

The weighting matrix for MSM estimation is represented as a DataFrame in the same way as the flat representation of the covariance matrices. Of course, the conversion functions that work for covariance matrices would also work here, but it is highly unlikely that a different representation of a weighting matrix is ever needed.

Note that the user does not have to construct this weighting matrix manually. They can generate them using get_moments_cov and get_weighting_matrix, so they do not need any knowledge of how the flattening works.

Pepresentation of sensitivity measures#

Sensitivity measures are similar to covariance matrices in the sense that they require a two dimensional extension of pytrees. The only difference is that for covariance matrices the two pytrees the same (namely the params) and for sensitivity measures they are different (one is params, the other moments).

We therefore suggest to use the same solution, i.e. to offer a flat representation in form of a DataFrame, a pytree representation and functions to convert between the two.

Compatibility with estimation tables#

Estimation tables are constructed from estimation summaries. This continues to work for summaries where everything has been converted to DataFrames. Users will select individual DataFrames from a pytree of DataFrames, possibly concatenate or filter them and pass them to the estimation table function.

Compatibility with plotting functions#

The following functions are affected:

  • plot_univariate_effects

  • convergence_plot

  • lollipop_plot

  • derivative_plot

Most of them can be adjusted easily to the proposed changes. On all others we will simply raise errors and provide tutorials to work around the limitations.

Compatibility with Dashboard#

The main challenge for the dashboard is that pytrees have no natural multi-column extension and thus it becomes harder to specify a group or name column. However, these features have not been used very much anyways.

We propose to write a better automatic grouping and naming function for pytrees. That way it is simply not necessary to provide group and name columns and most of the users will get a better dashboard experience.

Rules of thumb for both should be:

  1. Only parameters where the start values have a similar magnitude can be in the same group, i.e. displayed in one lineplot.

  2. Parameters that are close to each other in the tree (i.e. have a common beginning in their leaf_name should be in the same group.

  3. The plot title should subsume the commen parts of the tree-structure (i.e. name we get from pybaum.leaf_names.

  4. Most line plots should have approximately 5 lines, none should have more than 8.

Advanced options for functions that work with pytrees#

There are two argument to tree_flatten and other pytree functions that determine which entries in a pytree are considered a leaf and which a container as well as how containers are flattened. 1. registry and 2. is_leaf. See the documentation of pybaum for details.

To allow for absolute flexibility, each function that works with pytrees needs to allow a user to pass in a registry and an is_leaf argument. If a function works with multiple pytrees (e.g. in estimate_msm the params are a pytree and emprirical_moments are a pytree) it needs to allow users to pass in multiple registries and is_leaf functions (e.g. params_registry, params_is_leaf and moments_registry, moments_is_leaf.

However, we need only as many registries as there are different pytrees. For example since simulated_moments and empirical_moments always need to be pytrees with the same structure, they do not need separate registries and is_leaf functions.

Compatibility with JAX autodiff#

While we allow for pytrees of arrays, numbers and DataFrames, JAX only allows pytrees of arrays and numbers for automatic differentiation.

If you want to use automatic differentiation with estimagic you will thus have to restrict yourself in the way you specify parameters.


Implementing all of the proposed changes will take a time. Since all of them are fully backwards compatible, we do not have to finish all of them at once, even if that means that some areas of estimagic alread support pytrees and others do not.

We suggest the following work packages and tentative order in which we tackle them



Basic tree functions in pybaum


Conversion functions between flat and nested 2d representations of jacobians, covariance matrices, etc.

@janosg @timmens

Better name and group functions for dashboard

Pytree support for optimization


Pytree support for first_derivative


Pytree support for second_derivative


Pytree support in ML estimation

Pytree support in Bootstrap

Pytree support in MSM estimation