Source code for pymc_marketing.mmm.additive_effect

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Additive effects for the multidimensional Marketing Mix Model.

Example of a custom additive effect
--------

1. Custom negative-effect component (added as a MuEffect)

.. code-block:: python

    import numpy as np
    import pandas as pd
    import pymc as pm
    from pymc_extras.prior import create_dim_handler

    # A simple custom effect that penalizes certain dates/segments with a
    # negative-only coefficient. This is not a "control" in the MMM sense, so
    # give it a different name/prefix to avoid clashing with built-in controls.
    class PenaltyEffect:
        '''Example MuEffect that applies a negative coefficient to a user-specified pattern.
        '''

        def __init__(self, name: str, penalty_provider):
            self.name = name
            self.penalty_provider = penalty_provider

        def create_data(self, mmm):
            # Produce penalty values aligned with model dates (and optional extra dims)
            dates = pd.to_datetime(mmm.model.coords["date"])
            penalty = self.penalty_provider(dates)
            pm.Data(f"{self.name}_penalty", penalty, dims=("date", *mmm.dims))

        def create_effect(self, mmm):
            model = mmm.model
            penalty = model[f"{self.name}_penalty"]  # dims: (date, *mmm.dims)

            # Negative-only coefficient per extra dims, broadcast over date
            coef = pm.TruncatedNormal(f"{self.name}_coef", mu=-0.5, sigma=-0.05, lower=-1.0, upper=0.0, dims=mmm.dims)

            dim_handler = create_dim_handler(("date", *mmm.dims))
            effect = pm.Deterministic(
                f"{self.name}_effect_contribution",
                dim_handler(coef, mmm.dims) * penalty,
                dims=("date", *mmm.dims),
            )
            return effect  # Must have dims ("date", *mmm.dims)

        def set_data(self, mmm, model, X):
            # Update to future dates during posterior predictive
            dates = pd.to_datetime(model.coords["date"])
            penalty = self.penalty_provider(dates)
            pm.set_data({f"{self.name}_penalty": penalty}, model=model)

    Usage
    -----
    # Example weekend penalty (Sat/Sun = 1, else 0), applied per geo if present
    weekend_penalty = PenaltyEffect(
        name="brand_penalty",
        penalty_provider=lambda dates: pd.Series(dates)
        .dt.dayofweek.isin([5, 6])
        .astype(float)
        .to_numpy()[:, None]  # if mmm.dims == ("geo",), broadcast over geo
    )

    # Build your MMM as usual (with channels, etc.), then add the effect before build/fit:
    # mmm = MMM(...)
    # mmm.mu_effects.append(weekend_penalty)
    # mmm.build_model(X, y)
    # mmm.fit(X, y, ...)
    # At prediction time, the effect updates itself via set_data.

How it works
------------
- Mu effects follow a simple protocol: ``create_data(mmm)``, ``create_effect(mmm)``,
  and ``set_data(mmm, model, X)``.
- During ``MMM.build_model(...)``, each effect’s ``create_data`` is called first to
  introduce any needed ``pm.Data``. Then ``create_effect`` must return a tensor with
  dims ("date", *mmm.dims) that is added additively to the model mean.
- During posterior predictive, ``set_data`` is called with the cloned PyMC model
  and the new coordinates; update any ``pm.Data`` you created using ``pm.set_data``.

Tips for custom components
--------------------------
- Use unique variable prefixes to avoid name clashes with built-in pieces like
  controls. Do not call your component "control"; choose a distinct name/prefix.
- Follow the patterns used by the provided effects in this module (e.g.,
  `FourierEffect`, `LinearTrendEffect`, `EventAdditiveEffect`):
  - In `create_data`, derive and register any required inputs into the model.
  - In `create_effect`, construct PyTensor expressions and return a contribution
    with dims ("date", *mmm.dims). If you need broadcasting, use
    `pymc_extras.prior.create_dim_handler` as shown above.
  - In `set_data`, update the data variables when dates/dims change.
"""

from abc import ABC, abstractmethod
from typing import Annotated, Any, Protocol

import pandas as pd
import pymc as pm
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, PlainValidator, WithJsonSchema
from pymc_extras.prior import create_dim_handler
from pytensor import tensor as pt

from pymc_marketing.mmm.events import EventEffect, days_from_reference
from pymc_marketing.mmm.fourier import FourierBase
from pymc_marketing.mmm.linear_trend import LinearTrend
from pymc_marketing.mmm.utils import create_index


[docs] class Model(Protocol): """Protocol MMM.""" @property def dims(self) -> tuple[str, ...]: """The additional dimensions of the MMM target.""" @property def model(self) -> pm.Model: """The PyMC model."""
[docs] class MuEffect(ABC, BaseModel): """Abstract base class for arbitrary additive mu effects. All mu_effects must inherit from this Pydantic BaseModel to ensure proper serialization and deserialization when saving/loading MMM models. """
[docs] @abstractmethod def create_data(self, mmm: Model) -> None: """Create the required data in the model."""
[docs] @abstractmethod def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the additive effect in the model."""
[docs] @abstractmethod def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions."""
[docs] class FourierEffect(MuEffect): """Fourier seasonality additive effect for MMM.""" fourier: InstanceOf[FourierBase] date_dim_name: str = Field("date")
[docs] def create_data(self, mmm: Model) -> None: """Create the required data in the model. Parameters ---------- mmm : MMM The MMM model instance """ model = mmm.model # Get dates from model coordinates dates = pd.to_datetime(model.coords[self.date_dim_name]) # Add weekday data to the model pm.Data( f"{self.fourier.prefix}_day", self.fourier._get_days_in_period(dates).to_numpy(), dims=self.date_dim_name, )
[docs] def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the Fourier effect in the model. Parameters ---------- mmm : MMM The MMM model instance Returns ------- pt.TensorVariable The Fourier effect """ model = mmm.model # Apply the Fourier transformation to data day_data = model[f"{self.fourier.prefix}_day"] # Store the unsummed basis components (including the internal fourier mode dim) # so users can inspect individual sine/cos contributions if desired. def create_deterministic(x: pt.TensorVariable) -> None: pm.Deterministic( f"{self.fourier.prefix}_components", x, dims=(self.date_dim_name, *self.fourier.prior.dims), ) # Call apply to create the components deterministic (unsummed basis * betas) _ = self.fourier.apply(day_data, result_callback=create_deterministic) # Retrieve the components deterministic just created components_var = model[f"{self.fourier.prefix}_components"] component_dims = model.named_vars_to_dims[components_var.name] # Identify axis of the fourier prefix dimension and collapse it prefix_axis = component_dims.index(self.fourier.prefix) collapsed = components_var.sum(axis=prefix_axis) # Determine final dims order consistent with MMM dims dims = tuple(dim for dim in mmm.dims if dim in self.fourier.prior.dims) fourier_dims = (self.date_dim_name, *dims) fourier_contribution = pm.Deterministic( f"{self.fourier.prefix}_contribution", collapsed, dims=fourier_dims, ) # Broadcast to full MMM dims ordering dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims)) return dim_handler(fourier_contribution, fourier_dims)
[docs] def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions. Parameters ---------- mmm : MMM The MMM model instance model : pm.Model The PyMC model X : xr.Dataset The dataset for prediction """ # Get dates from the new dataset new_dates = pd.to_datetime(model.coords[self.date_dim_name]) # Update the data new_data = { f"{self.fourier.prefix}_day": self.fourier._get_days_in_period( new_dates ).to_numpy() } pm.set_data(new_data=new_data, model=model)
_Timestamp = Annotated[ pd.Timestamp, PlainValidator(lambda x: pd.Timestamp(x)), WithJsonSchema({"type": "date-time"}), ]
[docs] class LinearTrendEffect(MuEffect): """Wrapper for LinearTrend to use with MMM's MuEffect protocol. This class adapts the LinearTrend component to be used as an additive effect in the MMM model. Parameters ---------- trend : LinearTrend The LinearTrend instance to wrap. prefix : str The prefix to use for variables in the model. date_dim_name : str The name of the date dimension in the model. Examples -------- Out of sample predictions: .. note:: No new changepoints are used for the out of sample predictions. The trend effect is linearly extrapolated from the last changepoint. .. plot:: :include-source: True :context: reset import pandas as pd import numpy as np import matplotlib.pyplot as plt import pymc as pm from pymc_marketing.mmm.linear_trend import LinearTrend from pymc_marketing.mmm.additive_effect import LinearTrendEffect seed = sum(map(ord, "LinearTrend out of sample")) rng = np.random.default_rng(seed) class MockMMM: pass dates = pd.date_range("2025-01-01", periods=52, freq="W") coords = {"date": dates} model = pm.Model(coords=coords) mock_mmm = MockMMM() mock_mmm.dims = () mock_mmm.model = model effect = LinearTrendEffect( trend=LinearTrend(n_changepoints=8), prefix="trend", ) with mock_mmm.model: effect.create_data(mock_mmm) pm.Deterministic( "effect", effect.create_effect(mock_mmm), dims="date", ) idata = pm.sample_prior_predictive(random_seed=rng) idata["posterior"] = idata.prior n_new = 10 + 1 new_dates = pd.date_range( dates.max(), periods=n_new, freq="W", ) with mock_mmm.model: mock_mmm.model.set_dim("date", n_new, new_dates) effect.set_data(mock_mmm, mock_mmm.model, None) pm.sample_posterior_predictive( idata, var_names=["effect"], random_seed=rng, extend_inferencedata=True, ) draw = rng.choice(range(idata.posterior.sizes["draw"])) sel = dict(chain=0, draw=draw) before = idata.posterior.effect.sel(sel).to_series() after = idata.posterior_predictive.effect.sel(sel).to_series() ax = before.plot(color="C0") after.plot(color="C0", linestyle="dashed", ax=ax) plt.show() """ trend: InstanceOf[LinearTrend] prefix: str date_dim_name: str = Field("date") linear_trend_first_date: _Timestamp | None = Field(None, init=False)
[docs] def create_data(self, mmm: Model) -> None: """Create the required data in the model. Parameters ---------- mmm : MMM The MMM model instance. """ model: pm.Model = mmm.model # Create time index data (normalized between 0 and 1) dates = pd.to_datetime(model.coords[self.date_dim_name]) self.linear_trend_first_date = dates[0] t = (dates - self.linear_trend_first_date).days.astype(float) pm.Data(f"{self.prefix}_t", t, dims=self.date_dim_name)
[docs] def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the trend effect in the model. Parameters ---------- mmm : MMM The MMM model instance. Returns ------- pt.TensorVariable The trend effect in the model. """ model: pm.Model = mmm.model # Get the time data t = model[f"{self.prefix}_t"] t_max = t.max().eval() t = t / t_max if t_max > 0 else t # Apply the trend trend_effect = self.trend.apply(t) # Create deterministic for the trend effect trend_dims = (self.date_dim_name, *self.trend.dims) # type: ignore trend_non_broadcastable_dims = ( self.date_dim_name, *self.trend.non_broadcastable_dims, ) trend_effect = pm.Deterministic( f"{self.prefix}_effect_contribution", trend_effect[create_index(trend_dims, trend_non_broadcastable_dims)], dims=trend_non_broadcastable_dims, ) # Return the trend effect dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims)) return dim_handler(trend_effect, trend_non_broadcastable_dims)
[docs] def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions. Parameters ---------- mmm : MMM The MMM model instance. model : pm.Model The PyMC model. X : xr.Dataset The dataset for prediction. """ # Create normalized time index for new data new_dates = pd.to_datetime(model.coords[self.date_dim_name]) t = (new_dates - self.linear_trend_first_date).days.astype(float) # Update the data pm.set_data({f"{self.prefix}_t": t}, model=model)
[docs] class EventAdditiveEffect(MuEffect): """Event effect class for the MMM. Parameters ---------- df_events : pd.DataFrame The DataFrame containing the event data. * `name`: name of the event. Used as the model coordinates. * `start_date`: start date of the event * `end_date`: end date of the event prefix : str The prefix to use for the event effect and associated variables. effect : EventEffect The event effect to apply. reference_date : str The arbitrary reference date to calculate distance from events in days. Default is "2025-01-01". date_dim_name : str The name of the date dimension in the model. Default is "date". """ df_events: InstanceOf[pd.DataFrame] prefix: str effect: EventEffect reference_date: str = "2025-01-01" date_dim_name: str = "date"
[docs] def model_post_init(self, context: Any, /) -> None: """Post initialization of the model.""" if missing_columns := set(["start_date", "end_date", "name"]).difference( self.df_events.columns ): raise ValueError(f"Columns {missing_columns} are missing in df_events.") self.effect.basis.prefix = self.prefix
@property def start_dates(self) -> pd.Series: """The start dates of the events.""" return pd.to_datetime(self.df_events["start_date"]) @property def end_dates(self) -> pd.Series: """The end dates of the events.""" return pd.to_datetime(self.df_events["end_date"])
[docs] def create_data(self, mmm: Model) -> None: """Create the required data in the model. Parameters ---------- mmm : MMM The MMM model instance. """ model: pm.Model = mmm.model model_dates = pd.to_datetime(model.coords[self.date_dim_name]) model.add_coord(self.prefix, self.df_events["name"].to_numpy()) if "days" not in model: pm.Data( "days", days_from_reference(model_dates, self.reference_date), dims=self.date_dim_name, ) pm.Data( f"{self.prefix}_start_diff", days_from_reference(self.start_dates, self.reference_date), dims=self.prefix, ) pm.Data( f"{self.prefix}_end_diff", days_from_reference(self.end_dates, self.reference_date), dims=self.prefix, )
[docs] def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the event effect in the model. Parameters ---------- mmm : MMM The MMM model instance. Returns ------- pt.TensorVariable The average event effect in the model. """ model: pm.Model = mmm.model start_ref = model["days"][:, None] - model[f"{self.prefix}_start_diff"] end_ref = model["days"][:, None] - model[f"{self.prefix}_end_diff"] def create_basis_matrix(start_ref, end_ref): return pt.where( (start_ref >= 0) & (end_ref <= 0), 0, pt.where(pt.abs(start_ref) < pt.abs(end_ref), start_ref, end_ref), ) X = create_basis_matrix(start_ref, end_ref) event_effect = self.effect.apply(X, name=self.prefix) total_effect = pm.Deterministic( f"{self.prefix}_total_effect", event_effect.sum(axis=1), dims=self.date_dim_name, ) dim_handler = create_dim_handler((self.date_dim_name, *mmm.dims)) return dim_handler(total_effect, self.date_dim_name)
[docs] def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions.""" new_dates = pd.to_datetime(model.coords[self.date_dim_name]) new_data = { "days": days_from_reference(new_dates, self.reference_date), } pm.set_data(new_data=new_data, model=model)