Shifted Beta-Geometric Model with Cohorts and Covariates#

The Shifted Beta-Geometric (sBG) model was first introduced in “How to Project Customer Retention” by Hardie & Fader in 2007. It is ideal for predicting customer behavior in business cases involving contract renewals or recurring subscriptions, and the original model has been expanded in PyMC-Marketing to support multidimensional cohorts and covariates. In this notebook we will reproduce the research results, then proceed to a comprehensive example with EDA and additional predictive methods.

Setup Notebook#

import arviz as az
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sb
import xarray as xr
from dateutil.relativedelta import relativedelta
from pymc_extras.prior import Prior

from pymc_marketing import clv

# Plotting configuration
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.constrained_layout.use"] = True

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
seed = sum(map(ord, "sBG Model"))
rng = np.random.default_rng(seed)

Load Data#

Data must be aggregrated in the following format for model fitting:

  • customer_id is an index of unique identifiers for each customer

  • recency indicates the most recent time period a customer was still active

  • T is the maximum observed time period for a given cohort

  • cohort indicates the cohort assignment for each customer

For active customers, recency is equal to T, and all customers in a given cohort share the same value for T. If a customer cancelled their contract and restarted at a later date, a new customer_id must be assigned for the restart.

Sample data is available in the PyMC-Marketing repo. To see the code used to generate this data, refer to generate_sbg_data() in scripts/clv_data_generation.py in the repo.

cohort_df = pd.read_csv(
    "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/refs/heads/main/data/sbg_cohorts.csv"
)
cohort_df
customer_id recency T cohort
0 1 1 8 highend
1 2 1 8 highend
2 3 1 8 highend
3 4 1 8 highend
4 5 1 8 highend
... ... ... ... ...
1995 1996 8 8 regular
1996 1997 8 8 regular
1997 1998 8 8 regular
1998 1999 8 8 regular
1999 2000 8 8 regular

2000 rows × 4 columns

This dataset was generated from the first 8 time periods in Table 1 of the research paper, which provides survival rates for two types of customers (“Regular” and “Highend”) over 13 time periods:

# Data from research paper
research_data = pd.DataFrame(
    {
        "regular": [
            100.0,
            63.1,
            46.8,
            38.2,
            32.6,
            28.9,
            26.2,
            24.1,
            22.3,
            20.7,
            19.4,
            18.3,
            17.3,
        ],
        "highend": [
            100.0,
            86.9,
            74.3,
            65.3,
            59.3,
            55.1,
            51.7,
            49.1,
            46.8,
            44.5,
            42.7,
            40.9,
            39.4,
        ],
    }
)
research_data
regular highend
0 100.0 100.0
1 63.1 86.9
2 46.8 74.3
3 38.2 65.3
4 32.6 59.3
5 28.9 55.1
6 26.2 51.7
7 24.1 49.1
8 22.3 46.8
9 20.7 44.5
10 19.4 42.7
11 18.3 40.9
12 17.3 39.4

This is also a useful format for model evaluation. In survival analysis parlance, customers with recency==T are “right-censored”. If we fit a model to the first \(8\) time periods, we can test predictions on censored data over the remaining \(5\).

# Utility function to aggregate model fit data for evaluation
def survival_rate_aggregation(customer_df: pd.DataFrame):
    """Aggregate customer-level sBG data into survival rates by cohort over time."""
    # Group by cohort to get total counts
    cohorts = customer_df["cohort"].unique()

    # Create a list to store results for each time period
    results = []

    # For each time period from 0 to T (8 in this case)
    for t in range(customer_df["T"].max()):
        row_data = {"T": t}

        for cohort in cohorts:
            cohort_data = customer_df[customer_df["cohort"] == cohort]
            total_customers = len(cohort_data)

            if t == 0:
                # At time 0, 100% retention
                retention_pct = 100.0
            else:
                # Count customers who survived at least to time t (recency >= t)
                survived = len(cohort_data[cohort_data["recency"] > t])
                retention_pct = (survived / total_customers) * 100

            row_data[cohort] = retention_pct

        results.append(row_data)

    return pd.DataFrame(results)


# Aggregate model fit data
df_actual = survival_rate_aggregation(cohort_df)

# Assign T column to research data and truncate to 8 periods
research_data["T"] = research_data.index
df_expected = research_data[["T", "highend", "regular"]].query("T<8").copy()

# Assert aggregated dataset is equivalent to the values in the research paper
pd.testing.assert_frame_equal(df_actual, df_expected)
plt.plot(research_data["regular"].values, marker="o", label="Regular Segment")
plt.plot(research_data["highend"].values, marker="o", label="Highend Segment")
plt.ylabel("% Customers Surviving")
plt.xlabel("# of Time Periods")
# Add vertical line separating train/test periods
plt.axvline(7, ls=":", color="k", label="Train/Test Split")
plt.legend()

# Highlight train/test regions
plt.axvspan(0, 7, alpha=0.05, color="blue", zorder=0)

# Label training and test periods
plt.text(4, 1, "Train Period", ha="center", fontsize=10, color="darkblue")
plt.text(10, 1, "Test Period", ha="center", fontsize=10, color="darkred")

plt.title("Survival Rates over Time by Customer Segment");

Because all customers began their contracts in the same time period, segment is a more appropriate term than cohort, but we can model both with the same functionality. Let’s proceed to modeling.

Model Fitting#

The sBG model has the following assumptions:

  1. Individual customer lifetime durations are characterized by the (shifted) Geometric distribution, with cancellation probability \(\theta\).

  2. Heterogeneity in \(\theta\) follows a Beta distribution with shape parameters \(\alpha\) and \(\beta\).

If we take the expectation across the distribution of \(\theta\), we can derive a likelihood function to estimate parameters \(\alpha\) and \(\beta\) for the customer population. For more details on the ShiftedBetaGeometric mixture distribution, please refer to the documentation.

The original frequentist model assumes a single cohort of customers who all started their contracts in the same time period. This requires fitting a separate model for each cohort. However, in PyMC-Marketing we can fit all cohorts in a single hierarchical Bayesian model!

Here are the parameter estimates from the research paper for a sBG model fit with the provided data using Maximum Likelihood Estimation (MLE):

# MLE estimates from the paper
mle_research_parameters = {
    # "cohort": [alpha, beta]
    "regular": [0.704, 1.182],
    "highend": [0.668, 3.806],
}

Reproduce Research Results with Cohorts#

Model Fitting with MAP#

The Bayesian equivalent of a frequentist MLE fit is Maximum a Posteriori (MAP) with “flat” priors. A flat prior can be used when the user is agnostic about the observed data, holding no prior beliefs or assumptions. Since \(\alpha\) and \(\beta\) must be positive values, let’s configure our Bayesian ShiftedBetaGeoModel with HalfFlat priors:

sbg_map = clv.ShiftedBetaGeoModel(
    data=cohort_df,
    model_config={
        "alpha": Prior("HalfFlat", dims="cohort"),
        "beta": Prior("HalfFlat", dims="cohort"),
    },
)
sbg_map.fit(method="map")
sbg_map.fit_summary()

alpha[highend]    0.668
alpha[regular]    0.704
beta[highend]     3.806
beta[regular]     1.182
Name: value, dtype: float64

MAP parameter estimates are identical to those in the research. We can also use these parameters recover the latent \(\theta\) dropout distributions and recreate Figure 6 in the research paper:

Visualize Latent Dropout Distributions#

# Extract alpha and beta from fit results
alpha = sbg_map.fit_result["alpha"]
beta = sbg_map.fit_result["beta"]

# Specifiy number of draws from latent theta distributions
n_samples = 4_000

cohorts = alpha.coords["cohort"].values
dropout_samples = np.array(
    [
        rng.beta(
            alpha.sel(cohort=c).values.item(),  # Use .item() to get scalar
            beta.sel(cohort=c).values.item(),  # Use .item() to get scalar
            size=n_samples,
        )
        for c in cohorts
    ]
).T  # Transpose to get (samples, cohorts) shape

# Create xarray DataArray with chain, draw, and cohort dimensions
dropout = xr.DataArray(
    dropout_samples[np.newaxis, :, :],
    dims=("chain", "draw", "cohort"),
    coords={
        "chain": [0],
        "draw": np.arange(n_samples),
        "cohort": cohorts,
    },
    name=r"f($\theta$)",
)

# Convert to InferenceData for plotting in ArviZ
dropout_idata = az.convert_to_inference_data(dropout)

axes = az.plot_forest(
    [dropout_idata.sel(cohort="highend"), dropout_idata.sel(cohort="regular")],
    model_names=["Highend Segment", "Regular Segment"],
    kind="ridgeplot",
    combined=True,
    colors=["C0", "C1"],
    hdi_prob=1,
    ridgeplot_alpha=0.5,
    ridgeplot_overlap=1e5,
    ridgeplot_truncate=True,
    ridgeplot_quantiles=[0.5],
    figsize=(5, 5),
)
axes[0].set_title("Dropout Heterogeneity by Cohort");

It is evident from the plots dropout probabilities skew lower for customers in the Highend segment, who have a median contract cancellation probability of about 10% for each renewal period. The Regular segment median exceeds 30% in comparison.

Models can be fit very quickly with MAP, but there are important caveats to consider when recreating the frequentist MLE approach:

  • Flat priors are slow to converge, and can be unstable in practice because no regularization is applied to model parameters during fitting.

  • MAP/MLE fits do not provide credibility intervals for predictions.

In the Bayesian paradigm, prior distributions are derived from business domain knowledge and experimentation. With priors we can regularize parameter estimates, reduce model fit times, and with MCMC sampling infer the posterior distributions, illustrating uncertainty in our parameter estimates as well as enabling prediction intervals.

Model Fitting with MCMC#

The default sampler in PyMC-Marketing is the No-U-Turn Sampler (NUTS), which samples from the posterior by exploring the gradients of the probability space. The default prior configuration for ShiftedBetaGeoModel works well for many use cases:

sbg_mcmc = clv.ShiftedBetaGeoModel(data=cohort_df)
sbg_mcmc.build_model()  # optional step to view model config and run prior predictive checks
sbg_mcmc.graphviz()
../../_images/5e1ca9bd7cdba9f0f88530eaf043b27c9bf14be0833a86d5dbd1928a191156d3.svg

The default kappa and phi priors are pooling distributions that improve the speed & reliability of model fits, but can be omitted by specifying custom alpha and beta priors in the model_config. Note the number of cohorts and customers are also listed, and the dropout \(\theta\) distribution is censored just like the dataset.

Let’s fit a MCMC model and visualize the results:

sbg_mcmc.fit(method="mcmc", random_seed=rng)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [phi, kappa]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.
arviz.InferenceData
    • <xarray.Dataset> Size: 264kB
      Dimensions:  (chain: 4, draw: 1000, cohort: 2)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * cohort   (cohort) <U7 56B 'highend' 'regular'
      Data variables:
          phi      (chain, draw, cohort) float64 64kB 0.1377 0.3894 ... 0.1613 0.3723
          kappa    (chain, draw, cohort) float64 64kB 5.257 1.867 ... 3.852 1.749
          alpha    (chain, draw, cohort) float64 64kB 0.7238 0.7268 ... 0.6213 0.6509
          beta     (chain, draw, cohort) float64 64kB 4.533 1.14 3.895 ... 3.231 1.098
      Attributes:
          created_at:                 2025-12-16T11:10:26.842039+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          sampling_time:              6.430373191833496
          tuning_steps:               1000

    • <xarray.Dataset> Size: 528kB
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables: (12/18)
          lp                     (chain, draw) float64 32kB -3.3e+03 ... -3.299e+03
          perf_counter_start     (chain, draw) float64 32kB 1.659e+05 ... 1.659e+05
          perf_counter_diff      (chain, draw) float64 32kB 0.001657 ... 0.003421
          acceptance_rate        (chain, draw) float64 32kB 0.9945 0.9736 ... 0.9995
          n_steps                (chain, draw) float64 32kB 3.0 7.0 3.0 ... 3.0 7.0
          divergences            (chain, draw) int64 32kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          ...                     ...
          step_size_bar          (chain, draw) float64 32kB 0.6497 0.6497 ... 0.5986
          largest_eigval         (chain, draw) float64 32kB nan nan nan ... nan nan
          index_in_trajectory    (chain, draw) int64 32kB -3 3 -2 3 -4 ... 3 -2 -1 2 5
          reached_max_treedepth  (chain, draw) bool 4kB False False ... False False
          energy                 (chain, draw) float64 32kB 3.301e+03 ... 3.301e+03
          smallest_eigval        (chain, draw) float64 32kB nan nan nan ... nan nan
      Attributes:
          created_at:                 2025-12-16T11:10:26.851747+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          sampling_time:              6.430373191833496
          tuning_steps:               1000

    • <xarray.Dataset> Size: 32kB
      Dimensions:      (customer_id: 2000)
      Coordinates:
        * customer_id  (customer_id) int64 16kB 1 2 3 4 5 ... 1996 1997 1998 1999 2000
      Data variables:
          dropout      (customer_id) float64 16kB 1.0 1.0 1.0 1.0 ... 8.0 8.0 8.0 8.0
      Attributes:
          created_at:                 2025-12-16T11:10:26.854134+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 80kB
      Dimensions:      (index: 2000)
      Coordinates:
        * index        (index) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          customer_id  (index) int64 16kB 1 2 3 4 5 6 ... 1996 1997 1998 1999 2000
          recency      (index) int64 16kB 1 1 1 1 1 1 1 1 1 1 ... 8 8 8 8 8 8 8 8 8 8
          T            (index) int64 16kB 8 8 8 8 8 8 8 8 8 8 ... 8 8 8 8 8 8 8 8 8 8
          cohort       (index) object 16kB 'highend' 'highend' ... 'regular' 'regular'

sbg_mcmc.fit_summary(var_names=["alpha", "beta"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[highend] 0.670 0.112 0.477 0.874 0.002 0.003 2305.0 2512.0 1.0
alpha[regular] 0.703 0.065 0.591 0.829 0.001 0.001 3235.0 3186.0 1.0
beta[highend] 3.821 0.867 2.313 5.370 0.020 0.024 1897.0 2300.0 1.0
beta[regular] 1.181 0.152 0.904 1.464 0.003 0.002 2480.0 2659.0 1.0
az.plot_trace(sbg_mcmc.idata, var_names=["alpha", "beta"]);

Fit summaries and trace plots look good. Let’s compare the fitted posterior distributions to the scalar parameter estimates from the research:

for var_name in mle_research_parameters.keys():
    az.plot_posterior(
        sbg_mcmc.idata.sel(cohort=var_name),
        var_names=["alpha", "beta"],
        ref_val=mle_research_parameters[var_name],
        label="MCMC",
    )
    plt.gcf().suptitle(
        f"{var_name.upper()} Cohort Parameter Estimates", fontsize=18, fontweight="bold"
    )

Fitted posterior mean values align with the MLE values described in the research paper! MCMC sampling also gives us useful information about the uncertainty of the fits. Note how the mean values are within the 94% HDI intervals but not perfectly centered, indicating the posteriors are asymetrical.

Model Evaluation with Cohorts#

Recall the model was fit to the first \(8\) time periods and the remaining \(5\) were withheld for testing. The survival_rate_aggregation() function can be used to summarize the full dataset for model evaluation.

# Aggregating model fit data to get the training period endpoint.
# (This would normally be ran on the full dataset of train/test data, which is already aggregated in this case.)
df_fit_eval = survival_rate_aggregation(cohort_df)

# get T values of test and training data
test_T = len(research_data)
train_T = len(df_fit_eval)

# Create T time period array to run predictions on both train and test time periods
T_eval_range = np.arange(train_T * -1, test_T - train_T, 1)
T_eval_range
array([-8, -7, -6, -5, -4, -3, -2, -1,  0,  1,  2,  3,  4])

We are also running predictions on training (shown as negative) time periods to get the full context.

Survival Function#

The sBG survival function is the probability customers within a given cohort are still active after a specified time period. It is called with ShiftedBetaGeoModel.expected_probability_alive():

expected_survival_rates = xr.concat(
    objs=[
        sbg_mcmc.expected_probability_alive(
            future_t=T,
        )
        for T in T_eval_range
    ],
    dim="T",
).transpose(..., "T")

Plotting credibility intervals for survival rates against observed data:

# plot predictions
for i, segment in enumerate(["regular", "highend"]):
    az.plot_hdi(
        range(test_T),
        expected_survival_rates.sel(cohort=segment).mean("cohort"),
        hdi_prob=0.94,
        color=f"C{i}",
        fill_kwargs={"alpha": 0.5, "label": f"{segment.capitalize()} Segment"},
    )

# plot observed
plt.plot(
    range(test_T),
    research_data["highend"] / 100,
    marker="o",
    color="k",
    label="Observed",
)
plt.plot(range(test_T), research_data["regular"] / 100, marker="o", color="k")

plt.ylabel("% Surviving Customers")
plt.xlabel("Time Periods")
plt.axvline(7, ls=":", color="k", label="Test data starting period")
plt.legend()
plt.suptitle("Survival Rates over Time by Segment", fontsize=18)
plt.title("94% HDI Intervals", fontsize=12);

Observed survival rates fall well within the 94% credibility intervals!

Retention Rate#

We can also predict the retention rate by cohort, which is defined as the proportion of customers active in period \(T-1\) who are still active in period \(T\):

# Run retention rate predictions
expected_retention_rates = xr.concat(
    objs=[
        sbg_mcmc.expected_retention_rate(
            future_t=T,
        )
        for T in T_eval_range[1:]  # omit starting time period (see below)
    ],
    dim="T",
).transpose(..., "T")

# Calculate observed retention rates by cohort.
# Initial start period does not have a retention rate, so retention array is 1 time period shorter than observed.
retention_rate_highend_obs = (
    research_data["highend"][1:].values / research_data["highend"][:-1].values
)
retention_rate_regular_obs = (
    research_data["regular"][1:].values / research_data["regular"][:-1].values
)
# Plot predictions
for i, segment in enumerate(["regular", "highend"]):
    az.plot_hdi(
        range(1, test_T),
        expected_retention_rates.sel(cohort=segment).mean("cohort"),
        hdi_prob=0.94,
        color=f"C{i}",
        fill_kwargs={"alpha": 0.5, "label": f"{segment.capitalize()} Segment"},
    )

# Plot observed
plt.plot(
    range(1, test_T),
    retention_rate_highend_obs,
    marker="o",
    color="k",
    label="Observed",
)
plt.plot(range(1, test_T), retention_rate_regular_obs, marker="o", color="k")

plt.ylabel("Retention Rate")
plt.xlabel("Time Periods")
plt.axvline(7, ls=":", color="k", label="Test data starting period")
plt.legend()
plt.suptitle("Retention Rates over Time by Segment", fontsize=18)
plt.title("94% HDI Intervals", fontsize=12);

Retention rate predictions fall within the 94% credibility intervals for the Regular customers. Highend customers are more erratic, but stabilize over time.

The plots also highlight an interesting implication from the model: Retention rates increase over time as high-risk customers drop out. This is a direct outcome of heterogeneity among customers. Heterogeneity within cohorts can be modeled further with the addition of covariates.

Reproduce Research Results with Covariates#

Two segments starting in the same time period can also be represented as a binary covariate. Let’s modify the fit data to do so:

# Create a covariate column to identify highend customers
cohort_df["highend_customer"] = np.where(cohort_df["cohort"] == "highend", 1, 0)
# Update cohort column to a single "population" cohort
covariate_df = cohort_df.assign(cohort="population")
covariate_df
customer_id recency T cohort highend_customer
0 1 1 8 population 1
1 2 1 8 population 1
2 3 1 8 population 1
3 4 1 8 population 1
4 5 1 8 population 1
... ... ... ... ... ...
1995 1996 8 8 population 0
1996 1997 8 8 population 0
1997 1998 8 8 population 0
1998 1999 8 8 population 0
1999 2000 8 8 population 0

2000 rows × 5 columns

Recall \(\alpha\) and \(\beta\) represent the shape parameters of the latent Beta dropout distribution for each cohort. To include time-invariant covariates in our model, we simply modify these parameters as follows:

\[\alpha = \alpha_0e^{-\gamma_1z}\]
\[\beta = \beta_0e^{-\gamma_2z}\]

Where \(\gamma_1\) and \(\gamma_2\) are coefficients capturing the effects of the \(z\) covariate arrays for each customer. Covariates can be one-hot encoded for discrete factors like segment or marketing channel, as well as continuous for variables like user ratings or contract renewal costs.

These additional parameters are automatically created when covariate column names are specified in the model_config:

sbg_covar = clv.ShiftedBetaGeoModel(
    data=covariate_df,
    model_config={
        "dropout_covariate_cols": ["highend_customer"],
    },
)
sbg_covar.build_model()
sbg_covar
Shifted Beta-Geometric
                      phi ~ Uniform(0, 1)
                    kappa ~ Pareto(1, 1)
dropout_coefficient_alpha ~ Normal(0, 1)
 dropout_coefficient_beta ~ Normal(0, 1)
              alpha_scale ~ Deterministic(f(kappa, phi))
               beta_scale ~ Deterministic(f(kappa, phi))
                    alpha ~ Deterministic(f(dropout_coefficient_alpha, kappa, phi))
                     beta ~ Deterministic(f(dropout_coefficient_beta, kappa, phi))
                  dropout ~ Censored(ShiftedBetaGeometric(alpha, beta), -inf, <constant>)

Model Fitting with DEMZ#

MCMC model fitting takes longer with covariates compared to cohorts. A gradient-free sampler like Adaptive Differential Evolution Metropolis (DEMetropolisZ) will converge faster than NUTS, but requires more samples.

sbg_covar.fit(
    fit_method="demz", tune=2000, draws=3000, random_seed=rng
)  # 'demz' requires more tune/draws for convergence

# After fitting, remove redundant samples to reduce model size and increase predictive method processing speed
sbg_covar = sbg_covar.thin_fit_result(keep_every=3)
Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [phi, kappa, dropout_coefficient_alpha, dropout_coefficient_beta]

Sampling 4 chains for 2_000 tune and 3_000 draw iterations (8_000 + 12_000 draws total) took 3 seconds.
az.plot_trace(
    sbg_covar.idata,
    var_names=[
        "alpha_scale",
        "beta_scale",
        "dropout_coefficient_alpha",
        "dropout_coefficient_beta",
    ],
);

alpha_scale and beta_scale are baseline posteriors for each cohort, and the dropout_coefficient posteriors represent covariate effects across the entire customer population.

In this simple example with a single cohort (“population”) and binary covariate variable (“highend customer”), we can see the mean impact of this covariate on alpha is near-zero, but the impact on beta is significantly higher.

Recall the mean of the Beta distribution:

\[E(\theta) = \frac{\alpha}{\alpha + \beta}\]

The negatively-valued posterior indicates beta will increase for each highend customer, and reduce the dropout expectation.

Cohort and covariate parameters are broadcasted together to create alpha and beta parameter pairs for every customer id:

sbg_covar.fit_summary(var_names=["alpha", "beta"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[1] 0.671 0.107 0.481 0.872 0.004 0.003 863.0 1106.0 1.01
alpha[2] 0.671 0.107 0.481 0.872 0.004 0.003 863.0 1106.0 1.01
alpha[3] 0.671 0.107 0.481 0.872 0.004 0.003 863.0 1106.0 1.01
alpha[4] 0.671 0.107 0.481 0.872 0.004 0.003 863.0 1106.0 1.01
alpha[5] 0.671 0.107 0.481 0.872 0.004 0.003 863.0 1106.0 1.01
... ... ... ... ... ... ... ... ... ...
beta[1996] 1.197 0.155 0.921 1.484 0.005 0.005 894.0 1096.0 1.00
beta[1997] 1.197 0.155 0.921 1.484 0.005 0.005 894.0 1096.0 1.00
beta[1998] 1.197 0.155 0.921 1.484 0.005 0.005 894.0 1096.0 1.00
beta[1999] 1.197 0.155 0.921 1.484 0.005 0.005 894.0 1096.0 1.00
beta[2000] 1.197 0.155 0.921 1.484 0.005 0.005 894.0 1096.0 1.00

4000 rows × 9 columns

Reproduce Research with Covariates#

To compare covariate alpha and beta estimates to the original research, we can use customer_id to extract parameters for known highend and regular customers:

covariate_customer_ids = {
    "highend": 1,
    "regular": 2000,
}

for cohort, customer_id in covariate_customer_ids.items():
    az.plot_posterior(
        sbg_covar.idata.sel(customer_id=customer_id),
        var_names=["alpha", "beta"],
        ref_val=mle_research_parameters[cohort],
        label="MCMC",
    )
    plt.gcf().suptitle(
        f"{cohort.upper()} Customer Parameter Estimates", fontsize=18, fontweight="bold"
    )

Parameter estimates are equivalent regardless if segments are specified by cohort or covariate!

End-to-End Example with Cohorts and Covariates#

Simulate Data#

Let’s expand the previous covariate dataframe to create a monthly cohort dataset that also includes the same covariates:

# Use covariate_df to generate 7 monthly cohorts
cohort_dfs = []

for month in range(1, 8):
    # Calculate observation period: January (month 1) has 8 periods, July (month 7) has 2 periods
    observation_periods = 9 - month  # 8 for month 1, 7 for month 2, ..., 2 for month 7
    month_cohort_name = f"2025-{month:02d}"

    # Copy the covariate data
    monthly_cohort_df = covariate_df.copy()

    # Truncate recency to the observation period
    # If a customer churned after the observation period, they appear alive at T
    monthly_cohort_df["recency"] = monthly_cohort_df["recency"].clip(
        upper=observation_periods
    )

    # Update T to the correct observation period for this cohort
    monthly_cohort_df["T"] = observation_periods

    # Add the time-based cohort (month they joined)
    monthly_cohort_df["cohort"] = month_cohort_name

    cohort_dfs.append(monthly_cohort_df)

# Combine all monthly cohorts
monthly_cohort_dataset = pd.concat(cohort_dfs, ignore_index=True)

# Recreate customer_id to be unique across all cohorts
monthly_cohort_dataset["customer_id"] = monthly_cohort_dataset.index + 1

# Reorder columns
monthly_cohort_dataset = monthly_cohort_dataset[
    [
        "customer_id",
        "recency",
        "T",
        "cohort",
        "highend_customer",
    ]
]

print(f"\nMonthly cohorts: {sorted(monthly_cohort_dataset['cohort'].unique())}")

monthly_cohort_dataset
Monthly cohorts: ['2025-01', '2025-02', '2025-03', '2025-04', '2025-05', '2025-06', '2025-07']
customer_id recency T cohort highend_customer
0 1 1 8 2025-01 1
1 2 1 8 2025-01 1
2 3 1 8 2025-01 1
3 4 1 8 2025-01 1
4 5 1 8 2025-01 1
... ... ... ... ... ...
13995 13996 2 2 2025-07 0
13996 13997 2 2 2025-07 0
13997 13998 2 2 2025-07 0
13998 13999 2 2 2025-07 0
13999 14000 2 2 2025-07 0

14000 rows × 5 columns

A cohort chart is useful to see how cohorts compare in size and retention, revealing seasonality and growth trends:

def pivot_sbg_cohort_data(customer_df: pd.DataFrame) -> pd.DataFrame:
    """Create a cohort chart from ShiftedBetaGeoModel modeling data.

    Transform modeling data into an upper triangular cohort chart
    with labeled time periods.

    Parameters
    ----------
    customer_df : pd.DataFrame
        DataFrame with columns: customer_id, recency, T, cohort
        Cohort should be in YYYY-MM format

    Returns
    -------
    pd.DataFrame
        Pivoted DataFrame with cohorts as index, calendar periods as columns

    """
    results = []

    for cohort in customer_df["cohort"].unique():
        cohort_df = customer_df[customer_df["cohort"] == cohort]
        max_t_cohort = cohort_df["T"].iloc[0]

        # Parse cohort date
        cohort_date = pd.to_datetime(cohort)

        # Calculate retention for each relative period (age)
        for age in range(max_t_cohort):
            # Calculate absolute calendar period (cohort_date + age months)
            period = cohort_date + relativedelta(months=age)
            period_str = period.strftime("%Y-%m")

            # Calculate number of surviving customers
            survived = (cohort_df["recency"] > age).sum().astype(np.int64)

            results.append(
                {
                    "cohort": cohort,
                    "period": period_str,
                    "cohort_age": age,
                    "retention": survived,
                }
            )

    # Pivot to get cohorts as rows, calendar periods as columns
    pivot_df = pd.DataFrame(results).pivot(
        index="cohort", columns="period", values="retention"
    )

    # Sort index (cohorts) and columns (periods) chronologically
    pivot_df = pivot_df.sort_index()
    pivot_df = pivot_df[sorted(pivot_df.columns)]

    return pivot_df


# Create pivoted data
cohort_pivot = pivot_sbg_cohort_data(monthly_cohort_dataset)

plt.rcParams["figure.constrained_layout.use"] = False

# Plot cohort chart as a heatmap
fig, ax = plt.subplots(figsize=(17, 9))

sb.heatmap(
    cohort_pivot,
    cmap="viridis_r",
    linewidths=0.2,
    linecolor="black",
    annot=True,
    fmt=",.0f",
    cbar_kws={"format": mtick.FuncFormatter(func=lambda y, _: f"{y:,.0f}")},
    ax=ax,
)

# Rotate y-axis labels to horizontal
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

ax.set_title("Cohort Customer Counts by Time Period")
ax.set_ylabel("Cohort")
ax.set_xlabel("Time Period")

plt.tight_layout()
plt.show()

In this example there are no external time-varying effects on retention because we simply shifted the survival curve one period forward for each cohort. In practice there will always be seasonality and events like holidays influencing retention. Since we are estimating a parameter set for each starting time period, we can control for time-varying factors!

# Base survival curve (same pattern for all cohorts)
base_survival = research_data[["regular", "highend"]].values.mean(axis=1) / 100

# Create figure
fig, ax = plt.subplots(figsize=(12, 7))

# Define cohort start dates
cohort_start_dates = sorted(monthly_cohort_dataset["cohort"].unique())

# Plot each cohort
for i, cohort_date in enumerate(cohort_start_dates):
    # Calculate how many periods this cohort has been observed
    # Later cohorts have fewer observed periods
    observed_periods = len(base_survival) - i

    # Get survival data for this cohort
    cohort_survival = base_survival[:observed_periods]

    # Create timeline for this cohort (absolute dates)
    cohort_start = pd.to_datetime(cohort_date)
    cohort_timeline = [
        cohort_start + relativedelta(months=j) for j in range(observed_periods)
    ]

    # Plot the survival curve
    ax.plot(
        cohort_timeline,
        cohort_survival,
        linewidth=2.5,
        marker="o",
        markersize=3,
        label=f"Cohort {i + 1} ({cohort_date})",
    )

# Add vertical line separating train/test periods
train_end_date = pd.to_datetime("2025-08")
ax.axvline(
    train_end_date,
    linestyle=":",
    color="red",
    linewidth=2,
    alpha=0.7,
    label="Train/Test Split",
)

# Add shaded regions for train and test periods
ax.axvspan(
    pd.to_datetime("2025-01"), train_end_date, alpha=0.05, color="blue", zorder=0
)
ax.text(
    pd.to_datetime("2025-04"),
    0.05,
    "Train Period",
    ha="center",
    fontsize=10,
    color="darkblue",
)

# Add test period label
test_end_date = pd.to_datetime("2026-06")
ax.text(
    pd.to_datetime("2025-10"),
    0.05,
    "Test Period",
    ha="center",
    fontsize=10,
    color="darkred",
)

# Formatting
ax.set_xlabel("Timeline", fontsize=12, fontweight="bold")
ax.set_ylabel("Survival Percentage %", fontsize=12, fontweight="bold")
ax.set_title("Cohort Survival Curves", fontsize=14, fontweight="bold", color="red")

# Set y-axis limits
ax.set_ylim(0.25, 1.1)

# Format y-axis as percentage
ax.set_yticks([0.25, 0.5, 0.75, 1.0])
ax.set_yticklabels(["0.25", "0.5", "0.75", "1"])

# Format x-axis dates
ax.tick_params(axis="x", rotation=45)
fig.autofmt_xdate()

# Add legend
ax.legend(loc="upper right", framealpha=0.9)

# Add grid
ax.grid(True, alpha=0.3, linestyle="--")

plt.tight_layout()
plt.show()

These survival rates are the average of the Highend and Regular customer segments, which are indicated by covariates.

Cohort + Covariate Model Fitting and Diagnostics#

DEMEtropolisZ works well for covariates, but covergence is more difficult when dealing with multiple cohorts. For large datasets with millions of customers and dozens of cohorts, MAP fits may be the more practical choice. Regardless of the fit method used, multidimensional models often require well-defined priors - priors that are too broad may diverge, and if too narrow can overfit. For more information on prior elicitation see the preliz library, as well as the Prior Configuration section of the Model Configuration Notebook.

For this last example we will be using the default model config with the external nutpie sampler, which is written in rust and can run on GPUs.

sbg_cohort = clv.ShiftedBetaGeoModel(
    data=monthly_cohort_dataset,
    model_config={
        "dropout_covariate_cols": ["highend_customer"],
    },
)
sbg_cohort.fit(method="mcmc", nuts_sampler="nutpie", random_seed=rng)

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for 2 minutes

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.45 15
2000 0 0.44 7
2000 0 0.42 15
2000 0 0.43 15
arviz.InferenceData
    • <xarray.Dataset> Size: 898MB
      Dimensions:                    (chain: 4, draw: 1000, phi_interval___dim_0: 7,
                                      kappa_interval___dim_0: 7,
                                      dropout_covariate: 1, cohort: 7,
                                      customer_id: 14000)
      Coordinates:
        * chain                      (chain) int64 32B 0 1 2 3
        * draw                       (draw) int64 8kB 0 1 2 3 4 ... 996 997 998 999
        * phi_interval___dim_0       (phi_interval___dim_0) int64 56B 0 1 2 3 4 5 6
        * kappa_interval___dim_0     (kappa_interval___dim_0) int64 56B 0 1 2 3 4 5 6
        * dropout_covariate          (dropout_covariate) object 8B 'highend_customer'
        * cohort                     (cohort) object 56B '2025-01' ... '2025-07'
        * customer_id                (customer_id) int64 112kB 1 2 3 ... 13999 14000
      Data variables:
          phi_interval__             (chain, draw, phi_interval___dim_0) float64 224kB ...
          kappa_interval__           (chain, draw, kappa_interval___dim_0) float64 224kB ...
          dropout_coefficient_alpha  (chain, draw, dropout_covariate) float64 32kB ...
          dropout_coefficient_beta   (chain, draw, dropout_covariate) float64 32kB ...
          phi                        (chain, draw, cohort) float64 224kB 0.3788 ......
          kappa                      (chain, draw, cohort) float64 224kB 1.828 ... ...
          alpha_scale                (chain, draw, cohort) float64 224kB 0.6925 ......
          beta_scale                 (chain, draw, cohort) float64 224kB 1.135 ... ...
          alpha                      (chain, draw, customer_id) float64 448MB 0.834...
          beta                       (chain, draw, customer_id) float64 448MB 4.792...
      Attributes:
          created_at:                 2025-12-16T10:36:01.660031+00:00
          arviz_version:              0.22.0
          inference_library:          nutpie
          inference_library_version:  0.15.2
          sampling_time:              108.42676997184753
          tuning_steps:               1000

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 3 3 3 3 4 4 ... 4 4 4 4 4 3
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 6 -6 4 -3 9 ... 7 6 -8 -6 6
          logp                  (chain, draw) float64 32kB -1.704e+04 ... -1.704e+04
          energy                (chain, draw) float64 32kB 1.705e+04 ... 1.705e+04
          diverging             (chain, draw) bool 4kB False False ... False False
          energy_error          (chain, draw) float64 32kB 0.2241 -0.1556 ... -0.3027
          step_size             (chain, draw) float64 32kB 0.4522 0.4522 ... 0.4336
          step_size_bar         (chain, draw) float64 32kB 0.4522 0.4522 ... 0.4336
          mean_tree_accept      (chain, draw) float64 32kB 0.8101 0.8465 ... 0.9595
          mean_tree_accept_sym  (chain, draw) float64 32kB 0.8918 0.8972 ... 0.9227
          n_steps               (chain, draw) uint64 32kB 7 7 15 7 15 ... 15 31 15 15
      Attributes:
          created_at:     2025-12-16T10:36:01.619159+00:00
          arviz_version:  0.22.0

    • <xarray.Dataset> Size: 224kB
      Dimensions:      (customer_id: 14000)
      Coordinates:
        * customer_id  (customer_id) int64 112kB 1 2 3 4 5 ... 13997 13998 13999 14000
      Data variables:
          dropout      (customer_id) float64 112kB 1.0 1.0 1.0 1.0 ... 2.0 2.0 2.0 2.0
      Attributes:
          created_at:                 2025-12-16T10:36:01.659331+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 224kB
      Dimensions:            (customer_id: 14000, dropout_covariate: 1)
      Coordinates:
        * customer_id        (customer_id) int64 112kB 1 2 3 4 ... 13998 13999 14000
        * dropout_covariate  (dropout_covariate) <U16 64B 'highend_customer'
      Data variables:
          dropout_data       (customer_id, dropout_covariate) float64 112kB 1.0 ......
      Attributes:
          created_at:                 2025-12-16T10:36:01.655943+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 672kB
      Dimensions:           (index: 14000)
      Coordinates:
        * index             (index) int64 112kB 0 1 2 3 4 ... 13996 13997 13998 13999
      Data variables:
          customer_id       (index) int64 112kB 1 2 3 4 5 ... 13997 13998 13999 14000
          recency           (index) int64 112kB 1 1 1 1 1 1 1 1 1 ... 2 2 2 2 2 2 2 2
          T                 (index) int64 112kB 8 8 8 8 8 8 8 8 8 ... 2 2 2 2 2 2 2 2
          cohort            (index) object 112kB '2025-01' '2025-01' ... '2025-07'
          highend_customer  (index) int64 112kB 1 1 1 1 1 1 1 1 1 ... 0 0 0 0 0 0 0 0

    • <xarray.Dataset> Size: 898MB
      Dimensions:                    (chain: 4, draw: 1000, phi_interval___dim_0: 7,
                                      kappa_interval___dim_0: 7,
                                      dropout_covariate: 1, cohort: 7,
                                      customer_id: 14000)
      Coordinates:
        * chain                      (chain) int64 32B 0 1 2 3
        * draw                       (draw) int64 8kB 0 1 2 3 4 ... 996 997 998 999
        * phi_interval___dim_0       (phi_interval___dim_0) int64 56B 0 1 2 3 4 5 6
        * kappa_interval___dim_0     (kappa_interval___dim_0) int64 56B 0 1 2 3 4 5 6
        * dropout_covariate          (dropout_covariate) object 8B 'highend_customer'
        * cohort                     (cohort) object 56B '2025-01' ... '2025-07'
        * customer_id                (customer_id) int64 112kB 1 2 3 ... 13999 14000
      Data variables:
          phi_interval__             (chain, draw, phi_interval___dim_0) float64 224kB ...
          kappa_interval__           (chain, draw, kappa_interval___dim_0) float64 224kB ...
          dropout_coefficient_alpha  (chain, draw, dropout_covariate) float64 32kB ...
          dropout_coefficient_beta   (chain, draw, dropout_covariate) float64 32kB ...
          phi                        (chain, draw, cohort) float64 224kB 0.3648 ......
          kappa                      (chain, draw, cohort) float64 224kB 2.659 ... ...
          alpha_scale                (chain, draw, cohort) float64 224kB 0.9699 ......
          beta_scale                 (chain, draw, cohort) float64 224kB 1.689 ... ...
          alpha                      (chain, draw, customer_id) float64 448MB 0.54 ...
          beta                       (chain, draw, customer_id) float64 448MB 2.945...
      Attributes:
          created_at:     2025-12-16T10:36:01.615801+00:00
          arviz_version:  0.22.0

    • <xarray.Dataset> Size: 336kB
      Dimensions:               (chain: 4, draw: 1000)
      Coordinates:
        * chain                 (chain) int64 32B 0 1 2 3
        * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
      Data variables:
          depth                 (chain, draw) uint64 32kB 2 0 2 3 3 2 ... 3 4 3 3 3 3
          maxdepth_reached      (chain, draw) bool 4kB False False ... False False
          index_in_trajectory   (chain, draw) int64 32kB 1 0 -2 5 -3 ... -1 -5 -3 1 4
          logp                  (chain, draw) float64 32kB -1.881e+04 ... -1.704e+04
          energy                (chain, draw) float64 32kB 1.885e+04 ... 1.704e+04
          diverging             (chain, draw) bool 4kB False True ... False False
          energy_error          (chain, draw) float64 32kB -0.297 0.0 ... -0.3335
          step_size             (chain, draw) float64 32kB 1.439 0.2431 ... 0.4336
          step_size_bar         (chain, draw) float64 32kB 1.439 0.4998 ... 0.4336
          mean_tree_accept      (chain, draw) float64 32kB 1.0 0.0 ... 0.7465 0.8491
          mean_tree_accept_sym  (chain, draw) float64 32kB 0.7342 0.0 ... 0.8206
          n_steps               (chain, draw) uint64 32kB 3 1 3 7 7 3 ... 31 7 7 7 7
      Attributes:
          created_at:     2025-12-16T10:36:01.622916+00:00
          arviz_version:  0.22.0

sbg_cohort.fit_summary(
    var_names=[
        "alpha_scale",
        "beta_scale",
        "dropout_coefficient_alpha",
        "dropout_coefficient_beta",
    ]
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha_scale[2025-01] 0.670 0.057 0.567 0.776 0.001 0.001 3615.0 2966.0 1.0
alpha_scale[2025-02] 0.700 0.066 0.576 0.824 0.001 0.001 3633.0 3542.0 1.0
alpha_scale[2025-03] 0.738 0.074 0.604 0.876 0.001 0.001 3809.0 2978.0 1.0
alpha_scale[2025-04] 0.795 0.094 0.627 0.966 0.001 0.002 4007.0 2671.0 1.0
alpha_scale[2025-05] 0.854 0.137 0.614 1.101 0.002 0.002 4185.0 3264.0 1.0
alpha_scale[2025-06] 1.007 0.299 0.586 1.545 0.005 0.009 4845.0 2831.0 1.0
alpha_scale[2025-07] 3.749 34.215 0.341 5.671 0.600 7.920 5479.0 2460.0 1.0
beta_scale[2025-01] 1.125 0.128 0.887 1.352 0.002 0.002 3390.0 2809.0 1.0
beta_scale[2025-02] 1.177 0.146 0.918 1.452 0.003 0.003 2910.0 2730.0 1.0
beta_scale[2025-03] 1.244 0.161 0.954 1.542 0.003 0.003 3331.0 2498.0 1.0
beta_scale[2025-04] 1.342 0.196 0.997 1.704 0.003 0.003 3574.0 2620.0 1.0
beta_scale[2025-05] 1.448 0.276 0.977 1.965 0.004 0.005 3706.0 3144.0 1.0
beta_scale[2025-06] 1.739 0.566 0.912 2.725 0.009 0.018 4453.0 2754.0 1.0
beta_scale[2025-07] 6.601 61.150 0.617 10.162 1.064 14.775 5644.0 2340.0 1.0
dropout_coefficient_alpha[highend_customer] -0.172 0.112 -0.368 0.051 0.003 0.002 1054.0 1416.0 1.0
dropout_coefficient_beta[highend_customer] -1.428 0.132 -1.670 -1.175 0.004 0.003 1030.0 1484.0 1.0

We can see from the fit summary the covariate posteriors are similar to the previous model, but cohort parameters shift towards larger values (and higher dropout probabilities) in later cohorts as observable training data decreases.

Survival and Retention Plot Diagnostics#

all_T = len(base_survival)
base_survival = research_data[["regular", "highend"]].values.mean(axis=1) / 100

for i, month in enumerate(cohort_start_dates):
    active_customers = monthly_cohort_dataset.query("recency==T").copy()
    single_cohort_df = active_customers[active_customers["cohort"] == month].copy()

    train_offset = i - 8
    offset_range = range(train_offset, all_T - i + train_offset)

    # Run predictions
    expected_survival_rates = xr.concat(
        objs=[
            sbg_cohort.expected_probability_alive(
                data=single_cohort_df,
                future_t=T,
            )
            for T in offset_range
        ],
        dim="T",
    ).transpose(..., "T")

    # Plot predictions per cohort
    az.plot_hdi(
        offset_range,
        expected_survival_rates.sel(cohort=month).mean("cohort"),
        hdi_prob=0.94,
        color=f"C{i}",
        fill_kwargs={"alpha": 0.5, "label": f"{month} Cohort"},
    )
    # Plot observed survival curve, shifting T index for each cohort
    plt.plot(offset_range, base_survival[: all_T - i], marker="o", color="k")

# Plot the observed survival curve one more time to get a single label for the legend
plt.plot(
    offset_range, base_survival[: all_T - i], marker="o", color="k", label="Observed"
)
plt.axvline(0, ls=":", color="k", label="Test data starting period")
plt.legend()

plt.ylabel("Survival Rate")
plt.xlabel("Number of Time Periods")
plt.suptitle("Survival Rates over Time by Cohort", fontsize=18)
plt.title("94% HDI Intervals", fontsize=12);

There is clear bias in the survival plots over time. However, when plotted only for the current time period, an interesting story emerges:

active_customers = monthly_cohort_dataset.query("recency==T").copy()

_, axes = plt.subplots(
    nrows=7,
    ncols=1,
    figsize=(12, 14),
    layout="constrained",
)

axes = axes.flatten()

for i, month in enumerate(cohort_start_dates):
    active_customers = monthly_cohort_dataset.query("recency==T").copy()
    active_monthly = active_customers[active_customers["cohort"] == month].copy()

    ax = axes[i]

    cohort_survival = sbg_cohort.expected_probability_alive(
        data=active_monthly,
        future_t=0,
    ).sel(cohort=month)

    az.plot_density(
        cohort_survival, hdi_prob=1, colors=f"C{i}", shade=0.3, bw=0.005, ax=ax
    )
    ax.axvline(x=base_survival[8 - i], color="k", linestyle="--", label="Observed")
    ax.set(
        title=f"{month} Cohort",
        xlim=[0.2, 0.8],
    )
plt.legend()
plt.gcf().suptitle(
    "Survival Rates per Cohort for the 2025-09 Time Period", fontsize=18
);

Survival estimates are bi-modal for all cohorts! This is due to earlier cohorts retaining an increasingly higher proportion of Highend customers over time. For later cohorts with shorter time durations, the observed survival rates shift towards Regular customers who cancel their contracts sooner. The survival curves would look much better if filtered by covariate segment.

Note the consistency in the highend customer densities. Information-sharing across cohorts is where hierarchical Bayesian models truly shine. The 2025-07 cohort would overfit with a dedicated Frequentist model, but when doing things the Bayesian way, we still have useful estimates for a cohort consisting of only 2 time periods and 25% non-censored customers!

A similar bi-modal pattern emerges when we look at retention rates:

active_customers = monthly_cohort_dataset.query("recency==T").copy()
retention_rate_agg = base_survival[1:] / base_survival[:-1]

_, axes = plt.subplots(
    nrows=7,
    ncols=1,
    figsize=(12, 14),
    layout="constrained",
)

axes = axes.flatten()

for i, month in enumerate(cohort_start_dates):
    active_customers = monthly_cohort_dataset.query("recency==T").copy()
    active_monthly = active_customers[active_customers["cohort"] == month].copy()

    ax = axes[i]
    active_monthly = active_customers[active_customers["cohort"] == month].copy()
    cohort_retention = sbg_cohort.expected_retention_rate(
        data=active_monthly,
        future_t=0,
    ).sel(cohort=month)

    az.plot_density(
        cohort_retention, hdi_prob=1, colors=f"C{i}", shade=0.3, bw=0.005, ax=ax
    )
    ax.axvline(x=retention_rate_agg[7 - i], color="k", linestyle="--", label="Observed")
    ax.set(
        title=f"{month} Cohort",
        xlim=[0.70, 0.95],
    )
plt.legend()
plt.gcf().suptitle(
    "Retention Rates per Cohort for the 2025-09 Time Period", fontsize=18
);

Survival rates are good snapshots of model performance, but retention rates for the current time period are more useful in practice because they give an indication of which customers to prioritize now for marketing efforts.

For the record, the survival curves are spot-on when the model is fit to the same dataset sans covariates. However, retention estimates would be unimodal, providing no insight for why observed rates drift further into the tail regions. This highlights the importance of customer heterogeneity as a modeling consideration, the advantages of doing so with covariates, and why both survival and retention rates should be inspected after fitting.

Discounted Residual Lifetime and Retention Elasticity#

These additional predictive methods were introduced in the follow-up research paper “Customer Base Valuation in a Contractual Setting: The Perils of Ignoring Heterogeneity” by Hardie & Fader in 2010.

Discounted Expected Residual Lifetime#

With ShiftedBetaGeoModel.expected_residual_lifetime(), we can estimate the average remaining number of time periods a cohort of customers will be active. A discount rate parameter is provided to calculate Net Present Value (NPV). It is always recommended to use a discount rate as an industry best practice.

Given how well the segment covariates explain the heterogeneity uncovered earlier, we will calculate DERL between segments separately with a 10% discount rate. Note this is very high discount rate - it is equivalent to saying purchases made \(9\) time periods from now have zero value!

discount_rate = 0.10

# filter pred data by covariate value
active_reg = active_customers.query("highend_customer==0").copy()
active_hi = active_customers.query("highend_customer==1").copy()

# run DERL predictions on both segments, and rename for plotting
derl_reg = sbg_cohort.expected_residual_lifetime(
    data=active_reg,
    discount_rate=discount_rate,
).rename("DERL")

derl_hi = sbg_cohort.expected_residual_lifetime(
    data=active_hi,
    discount_rate=discount_rate,
).rename("DERL")
ax = az.plot_forest(
    [derl_hi, derl_reg],
    model_names=["Highend Segment", "Regular Segment"],
    combined=True,
    figsize=(11.5, 5),
    colors=["C2", "C4"],
    ridgeplot_quantiles=[0.5],
)
ax[0].set_title("DERL by Cohort and Segment");

We can see the segment DERL estimates in the later cohorts are quite different and broadly distributed, but over time variance decreases and segments become more similar.

Discounted Expected Retention Elasticity#

With ShiftedBetaGeoModel.expected_retention_elasticity(), we can estimate the % increase in residual lifetime given a 1% increase in the retention rate. This is very useful for causal studies on customer lifetime durations. It is recommended to apply a discount rate for elasticity as well.

# run DERL predictions on both segments
elastic_reg = sbg_cohort.expected_retention_elasticity(
    data=active_reg,
    discount_rate=discount_rate,
)  # )

elastic_hi = sbg_cohort.expected_retention_elasticity(
    data=active_hi,
    discount_rate=discount_rate,
)
# Create figure
fig, axes = plt.subplots(3, 3, figsize=(11.5, 8), sharey=True)
axes = axes.flatten()

cohorts = sbg_cohort.cohorts

# Plot each cohort
for i, cohort in enumerate(cohorts):
    if i >= len(axes):
        break
    ax = axes[i]

    # Plot regular (left side) - lighter color
    az.plot_violin(
        elastic_reg.sel(cohort=cohort),
        side="left",
        ax=ax,
        shade=0.3,
        bw=0.1,
        shade_kwargs={"color": "#7FB3D5"},  # Light blue
        show=False,
    )

    # Plot highend (right side) - darker color
    az.plot_violin(
        elastic_hi.sel(cohort=cohort),
        side="right",
        ax=ax,
        shade=0.9,
        bw=0.1,
        shade_kwargs={"color": "#1A5276"},  # Dark blue
        show=False,
    )

    ax.set_title(cohort, fontweight="normal", fontsize=12)

    # Add y-label only to leftmost plots
    if i % 3 == 0:
        ax.set_ylabel("DERL", fontsize=12)
    else:
        ax.set_ylabel("")

# Add "Regular | Highend" suptitle only to top row
for i in range(min(3, len(cohorts))):
    axes[i].annotate(
        "Regular | Highend",
        xy=(0.5, 1.15),
        xycoords="axes fraction",
        ha="center",
        fontsize=14,
        fontweight="bold",
    )

# Hide unused subplots
for i in range(len(cohorts), len(axes)):
    axes[i].set_visible(False)

fig.tight_layout()
plt.show();

Retention elasticity is higher for earlier, long-running cohorts, but these retention rates also approach 95%. If retention were to be increased by just 5% for Regular customers in the 2025-07 cohort, their respective remaining lifetimes could be 25% longer!

The xarray outputs we’ve been working with for the sBG predictive methods can also be converted to dataframes for downstream handling:

# filter dataset to only active customers
pred_data = monthly_cohort_dataset.query("recency==T")

# predict retention rate and convert to dataframe
pred_cohort_retention = sbg_cohort.expected_retention_rate(pred_data, future_t=0).mean(
    ("chain", "draw")
)
pred_cohort_retention.to_dataframe(name="retention").reset_index()
cohort customer_id retention
0 2025-01 510 0.936442
1 2025-01 511 0.936442
2 2025-01 512 0.936442
3 2025-01 513 0.936442
4 2025-01 514 0.936442
... ... ... ...
7011 2025-07 13996 0.749574
7012 2025-07 13997 0.749574
7013 2025-07 13998 0.749574
7014 2025-07 13999 0.749574
7015 2025-07 14000 0.749574

7016 rows × 3 columns

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc,pytensor
Last updated: Tue Dec 16 2025

Python implementation: CPython
Python version       : 3.12.11
IPython version      : 9.4.0

pymc    : 5.25.1
pytensor: 2.31.7

pymc          : 5.25.1
pytensor      : 2.31.7
pandas        : 2.3.1
xarray        : 2025.7.1
pymc_extras   : 0.4.0
matplotlib    : 3.10.3
dateutil      : 2.9.0.post0
numpy         : 2.2.6
arviz         : 0.22.0
pymc_marketing: 0.17.0
seaborn       : 0.13.2

Watermark: 2.5.0