MMMPlotSuite.prior_vs_posterior#
- MMMPlotSuite.prior_vs_posterior(var, plot_dim='channel', alphabetical_sort=True, dims=None, figsize=None)[source]#
Plot the prior vs posterior distribution for a variable across a dimension.
Creates KDE plots showing the prior and posterior distributions with their respective means highlighted. Each subplot represents a value in the plot_dim (e.g., each channel). If additional dimensions are present, creates a grid of subplots for each combination.
For scalar variables (those without the specified plot_dim), a single subplot is created showing the overall prior vs posterior comparison. If the variable has other dimensions besides chain/draw, subplots are created for each combination of those dimensions.
- Parameters:
- var
str The name of the variable to plot (e.g., ‘adstock_alpha’, ‘lam’).
- plot_dim
str, optional The dimension to create subplots over. Default is “channel”. Each value in this dimension will get its own subplot showing prior vs posterior comparison. If the variable does not have this dimension, it is treated as a scalar variable.
- alphabetical_sortbool, optional
Whether to sort the plot_dim values alphabetically (True) or by the difference between the posterior and prior means (False), with the largest positive difference at the top. Default is True. Only applies when plot_dim exists in the variable.
- dims
dict[str,str|int|list], optional Dimension filters to apply. Example: {“geo”: “US”}. If provided, only the selected slice(s) will be plotted.
- figsize
tuple[float,float], optional The size of the figure. If None, it will be calculated based on the number of subplots.
- var
- Returns:
- fig
matplotlib.figure.Figure The Figure object containing the subplots.
- axes
np.ndarrayofmatplotlib.axes.Axes Array of Axes objects corresponding to each subplot.
- fig
- Raises:
ValueErrorIf
varis not found in both prior and posterior. If no prior or posterior data is found in idata.
Examples
Plot prior vs posterior distribution of an adstock parameter:
mmm.plot.prior_vs_posterior(var="adstock_alpha", plot_dim="channel")
Plot a scalar variable (no channel dimension):
mmm.plot.prior_vs_posterior(var="intercept")
Plot with dimension filtering:
mmm.plot.prior_vs_posterior( var="lam", plot_dim="channel", dims={"geo": "US"} )
Sort by magnitude of update (largest posterior - prior difference first):
mmm.plot.prior_vs_posterior( var="adstock_alpha", plot_dim="channel", alphabetical_sort=False )