Calibrarion of Ordinal Posterior Predictions

Author

Teemu Säilynoja

Published

February 4, 2023

Modified

June 14, 2023

Imports & options
library("bayesplot")
library("cmdstanr")
library("ggplot2")
library("khroma")
library("quartoExtra")


# Source for the modified reliability plot
source("../../code/helpers.R")

good_theme <- bayesplot::theme_default(base_family = "Sans") + theme(
  axis.text = element_text(colour = "#666666", size = 12),
  axis.ticks = element_line(colour = "#666666"),
  title = element_text(colour = "#666666", size = 16),
  plot.subtitle = element_text(colour = "#666666", size = 14),
  legend.text = element_text(colour = "#666666", size = 12),
  legend.title = element_text(colour = "#666666", size = 14),
  axis.line = element_line(colour = "#666666"))

theme_set(good_theme)
bayesplot_theme_set(good_theme)
color_scheme_set(scheme = c(unname(colour("vibrant")(7)[c(3,2,5,4,1,6)])))

scale_colour_discrete = scale_colour_vibrant
scale_fill_discrete = scale_fill_vibrant


# darkmode_theme_set(
#     dark = ggthemes::theme_stata(scheme = "s1rcolor"),
#     light = ggthemes::theme_stata(scheme = "s1color")
# )


SEED <- 236543
set.seed(SEED)
SAVE_FITS = TRUE

This notebook highlights posterior predictive visualizations when the posterior predictive distribution is ordinal.

As shown below, the ordinal nature of the predictions allows us to use the cumulative posterior predictive mass function to assess the calibration of the posterior.

Data set

Below, I use synthetic data generated from observations of \(K\) Gaussians with varying means.

\[\begin{align} x_n &\sim \mathcal N\!\!\left(k, 0.5^2\right), &\text{for } n \in\{1,\dots,N\}\\ k &\sim \text{Categorical}(\theta_k),&\\ \theta_k &= \frac 1 K, &\text{for } k \in \{1, \dots, K\}. \end{align}\]

Data generation
K <- 5
N <- 1500
sigma <- .5
c <- sample(1:K, N, replace = T)
x <- rnorm(N, c, sigma)
standata_gmm <- list(K = K,
                     N = N,
                     x = x,
                     y = c,
                     sigma = sigma)
Code
ggplot(data.frame(standata_gmm)) +
  geom_density(aes(x = x,
                   colour = as.factor(y),
                   fill = as.factor(y),
                   group = as.factor(y)), alpha = .5) +
  legend_none() +
  xlab("") + ylab("") + ggtitle(paste("The data", sep = ""))

Model

The data is fit with two models, both structured to first normalize the data and then fit a Gaussian Mixture Model with K = 5` mixture components with known variances. In the first model, when the data is normalized, I omit to scale the variance of the mixture components accordingly.

Read model code
gmm <- cmdstan_model("../../code/stan-models/gmm_classifier.stan")
gmm
data {
  int<lower=2> K;                   // Number of classes
  int<lower=0> N;                   // Total number of observations
  array[N] int<lower=1, upper=K> y; // Target classes
  vector[N] x;                      // Observed predictor values
  real<lower=0> sigma;              // User supplied standard deviation
  int correct_sigma;                // How to handle sigma, see below.
}

transformed data{
  vector[N] x_st;
  real Sigma;

  // Standardize data
  x_st = (x - mean(x)) / sd(x);

  // Maybe remember to scale sigma accordingly
  if (correct_sigma == 1) {
    Sigma = sigma / sd(x);
  } else {
    Sigma = sigma;
  }
}

parameters {
  // Inferred means.
  ordered[K] c;
  simplex[K] p_c;
}

model {
  // Prior
  c ~ normal(0,1);
  p_c ~ dirichlet(rep_vector(1,K));

  // Likelihood
  for (n in 1:N) {
    target += normal_lpdf(x_st[n] | c[y[n]], Sigma);
  }
}

generated quantities {
  // Posterior predictive sample
  vector[N] yrep;
  // For each observation, posterior predictive mass of classes
  array[N] vector[K] ppm;

  for (n in 1:N) {
    for (k in 1:K) {
      ppm[n, k] = normal_lpdf(x_st[n] | c[k], Sigma);
    }
    ppm[n, ] = softmax(ppm[n, ]);
    yrep[n] = categorical_rng(ppm[n, ]);
  }
}
run inference using Stan
fit_1 <- tryCatch(
  expr = {readRDS(paste("../../code/stan-models/fits/gmm_classifier_1_",SEED,".RDS", sep=""))},
  error = function(e) {
    fit <- gmm$sample(data = c(standata_gmm, list(correct_sigma = 0)),
                      parallel_chains = 4,
                      refresh = 0,
                      seed = SEED,
                      show_messages = F)
    if (SAVE_FITS) {fit$save_object(
      paste("../../code/stan-models/fits/gmm_classifier_1_",SEED,".RDS", sep=""))}
    return(fit)},
  finally = {message("Finished model 1.")})

fit_2 <- tryCatch(
  expr = readRDS(paste("../../code/stan-models/fits/gmm_classifier_2_",SEED,".RDS", sep="")),
  error = function(e) {
    fit <- gmm$sample(data = c(standata_gmm, list(correct_sigma = 1)),
               parallel_chains = 4,
               refresh = 0,
               seed = SEED,
               show_messages = F)
    if (SAVE_FITS) {fit$save_object(
      paste("../../code/stan-models/fits/gmm_classifier_2_",SEED,".RDS", sep=""))}
    return(fit)},
  finally = {message("Finished model 2.")})

For the reliability diagrams, we compute the mean posterior predictive masses for each observation.

Code
p_1 <- matrix(colMeans(fit_1$draws(variables = "ppm", format = "matrix")), ncol = K)
p_2 <- matrix(colMeans(fit_2$draws(variables = "ppm", format = "matrix")), ncol = K)

PPC

The ppc bar plots don’t show any major differences between the models.

Code
ppc_bars(y = as.numeric(c),
         yrep = fit_1$draws(variables = "yrep",format = "matrix")) +
  ggtitle("Model 1") +
  theme(legend.position = "bottom")

Code
ppc_bars(y = as.numeric(c),
         yrep = fit_2$draws(variables = "yrep", format = "matrix")) +
  ggtitle("Model 2") +
  theme(legend.position = "bottom") 

When we inspect the reliability diagrams for the cumulative classification probabilities, we observe considerable under confidence with the first model. This is due to the omission to scale the mixture component variance when normalizing the data. :::{layout-ncol=“2”}

Code
for (k in 1:(K - 1)) {
  p1 <- plot_dotted_reliabilitydiag(
    y = as.numeric(c <= k),
    x = if (k != 1) pmin(1, rowSums(p_1[, 1:k])) else p_1[, k],
    quantiles = K * N / 100,
    dot_scale = .5) +
    ggtitle(paste("Model 1: P(y <= ", k, ")", sep=""),
            subtitle = "1 dot = 100 observations")
  
  print(p1)
}

Code
for (k in 1:(K - 1)) {
  p2 <- plot_dotted_reliabilitydiag(
    y = as.numeric(c <= k),
    x = if (k != 1) pmin(1, rowSums(p_2[, 1:k])) else p_2[, k],
    quantiles = K * N / 100,
    dot_scale = .5) +
    ggtitle(paste("Model 2: P(y <= ", k, ")", sep=""),
            subtitle = "1 dot = 100 observations")
  
  print(p2)
  }

:::