MMM.sample_saturation_curve#
- MMM.sample_saturation_curve(max_value=FieldInfo(annotation=NoneType, required=False, default=1.0, description='Maximum value for curve (in scaled space).', metadata=[Gt(gt=0)]), num_points=FieldInfo(annotation=NoneType, required=False, default=100, description='Number of points.', metadata=[Gt(gt=0)]), num_samples=FieldInfo(annotation=NoneType, required=False, default=500, description='Number of posterior samples to use.', metadata=[Gt(gt=0)]), random_state=None, original_scale=FieldInfo(annotation=NoneType, required=False, default=True, description='Whether to return curve in original scale.'), idata=FieldInfo(annotation=NoneType, required=False, default=None, description='Optional InferenceData to sample from.'))[source]#
Sample saturation curves from posterior parameters.
This method samples the saturation transformation curves using posterior parameters from the fitted model. It allows visualization of the diminishing returns relationship between media spend and contribution.
- Parameters:
- max_value
float, optional Maximum value for the curve x-axis, in scaled space (consistent with model internals). By default 1.0. This represents the maximum spend level in scaled units. To convert from original scale, divide by channel_scale:
max_scaled = original_max / mmm.data.get_channel_scale().mean()- num_points
int, optional Number of points between 0 and max_value to evaluate the curve at. By default 100. Higher values give smoother curves but take longer.
- num_samples
intorNone, optional Number of posterior samples to use for generating curves. By default 500. Samples are drawn randomly from the full posterior (across all chains and draws). Using fewer samples speeds up computation and reduces memory usage while still capturing posterior uncertainty. If None, all posterior samples are used without subsampling.
- random_state
int,np.random.Generator, orNone, optional Random state for reproducible subsampling. Can be an integer seed, a numpy Generator instance, or None for non-reproducible sampling. Only used when num_samples is not None and less than total available samples.
- original_scalebool, optional
Whether to return curve y-values in original scale. If True (default), y-axis values (contribution) are multiplied by target_scale to convert from scaled to original units. If False, values remain in scaled space as used internally by the model. Note that x-axis values always remain in scaled space consistent with the max_value parameter.
- idata
az.InferenceDataorNone, optional Optional InferenceData to sample from. If None (default), uses self.idata. This allows sampling curves from different posterior distributions, such as from a different model or a subset of samples.
- max_value
- Returns:
xr.DataArraySampled saturation curves with dimensions: - Simple model: (x, channel, sample) - Panel model: (x, *custom_dims, channel, sample)
The “sample” dimension indexes the posterior samples used. The “x” coordinate represents spend levels in scaled space (consistent with max_value). Y-values are in original scale when original_scale=True, otherwise in scaled space.
- Raises:
ValueErrorIf called before model is fitted (idata doesn’t exist) and no idata provided
ValueErrorIf original_scale=True but scale factors not found in constant_data
Notes
The max_value parameter is always in scaled space, consistent with how the model operates internally. This matches the pattern of other MMM methods.
For panel models, curves are generated for each combination of custom dimensions (e.g., each country) and channel.
The returned array includes a “sample” dimension for uncertainty quantification. Use
.mean(dim='sample')for point estimates and.quantile()for credible intervals.Posterior samples are drawn randomly without replacement when num_samples is less than the total available samples, otherwise all samples are used.
Examples
Sample curves with default parameters (original scale):
>>> curves = mmm.sample_saturation_curve() >>> curves.dims ('sample', 'x', 'channel')
Sample curves using all posterior samples:
>>> curves_all = mmm.sample_saturation_curve(num_samples=None)
Sample curves in scaled space:
>>> curves_scaled = mmm.sample_saturation_curve(original_scale=False)
Sample curves with custom max value and reproducible sampling:
>>> channel_scale = mmm.data.get_channel_scale() >>> max_original = 10000 # $10,000 >>> max_scaled = max_original / float(channel_scale.mean()) >>> curves = mmm.sample_saturation_curve( ... max_value=max_scaled, num_points=200, num_samples=1000, random_state=42 ... )
Sample curves from a different InferenceData:
>>> external_idata = az.from_netcdf("other_model.nc") >>> curves = mmm.sample_saturation_curve(idata=external_idata)