MMMPlotSuite.prior_vs_posterior#

MMMPlotSuite.prior_vs_posterior(var, plot_dim='channel', alphabetical_sort=True, dims=None, figsize=None)[source]#

Plot the prior vs posterior distribution for a variable across a dimension.

Creates KDE plots showing the prior and posterior distributions with their respective means highlighted. Each subplot represents a value in the plot_dim (e.g., each channel). If additional dimensions are present, creates a grid of subplots for each combination.

For scalar variables (those without the specified plot_dim), a single subplot is created showing the overall prior vs posterior comparison. If the variable has other dimensions besides chain/draw, subplots are created for each combination of those dimensions.

Parameters:
varstr

The name of the variable to plot (e.g., ‘adstock_alpha’, ‘lam’).

plot_dimstr, optional

The dimension to create subplots over. Default is “channel”. Each value in this dimension will get its own subplot showing prior vs posterior comparison. If the variable does not have this dimension, it is treated as a scalar variable.

alphabetical_sortbool, optional

Whether to sort the plot_dim values alphabetically (True) or by the difference between the posterior and prior means (False), with the largest positive difference at the top. Default is True. Only applies when plot_dim exists in the variable.

dimsdict[str, str | int | list], optional

Dimension filters to apply. Example: {“geo”: “US”}. If provided, only the selected slice(s) will be plotted.

figsizetuple[float, float], optional

The size of the figure. If None, it will be calculated based on the number of subplots.

Returns:
figmatplotlib.figure.Figure

The Figure object containing the subplots.

axesnp.ndarray of matplotlib.axes.Axes

Array of Axes objects corresponding to each subplot.

Raises:
ValueError

If var is not found in both prior and posterior. If no prior or posterior data is found in idata.

Examples

Plot prior vs posterior distribution of an adstock parameter:

mmm.plot.prior_vs_posterior(var="adstock_alpha", plot_dim="channel")

Plot a scalar variable (no channel dimension):

mmm.plot.prior_vs_posterior(var="intercept")

Plot with dimension filtering:

mmm.plot.prior_vs_posterior(
    var="lam", plot_dim="channel", dims={"geo": "US"}
)

Sort by magnitude of update (largest posterior - prior difference first):

mmm.plot.prior_vs_posterior(
    var="adstock_alpha", plot_dim="channel", alphabetical_sort=False
)