PPC Visualizations for Categorical Data

Palmer Penguins

Calibration plots for the easy tasks of identifying penguin species in the Palmer Penguins data set.
Author

Teemu Säilynoja

Published

January 30, 2023

Modified

June 14, 2023

Code
library("bayesplot")
library("cmdstanr")
library("ggplot2")
library("khroma")
library("quartoExtra")


# Source for the modified reliability plot
source("../../code/helpers.R")

theme_set(good_theme)
bayesplot_theme_set(good_theme)
color_scheme_set(scheme = c(unname(colour("vibrant")(7)[c(3,7,4,1,2,5)])))

scale_colour_discrete = scale_colour_vibrant
scale_fill_discrete = scale_fill_vibrant

source("../../code/helpers.R")

SAVE_MODEL = TRUE

Calibration plots for the easy tasks of identifying penguin species in the Palmer Penguins data set.

The data

Code
if (FALSE) {
  data("iris")
  X <- dplyr::select(na.omit(iris), -c("Species"))
  y <- as.numeric(iris$Species)
} else {
  library(palmerpenguins)
  data("penguins")
  X <- na.omit(penguins)[, c(3,4,5,6)]
  y <- as.factor(na.omit(penguins)$species)
}
Code
ggplot(X, aes(x = bill_length_mm, y = bill_depth_mm, colour = y)) +
  geom_point() +
  xlab("Bill length (mm)") +
  ylab("Bill depth (mm)") +
  labs(colour = "Species") +
  legend_move(position = "top")

The model

For the set of \(N\) observations of \(K\) features \(x_{k,n}\) and their corresponding species \(y_{n} \ in \{1,\dots,D\}\), we perform the classification using the categorical logit classifier, that is using the following likelihood \[ p(y_n = s_j \mid \mathbf{w}) = \frac{\exp\left(\beta_{n,j}\right)}{\sum_{d=1}^D\exp\left(\beta_{n,j}\right)}, \] where \[ \beta_{n,j} = w_{0,j} + \sum_{k=1}^Kw_{k,j}x_{k,n} \]

Code
model <- cmdstan_model(stan_file = "../../code/stan-models/penguins_glm.stan")
model

  data {
    int N; // number of observations
    int D; // number of features
    int N_classes; // number of classes
    matrix [N, D] X; // observation data
    array[N] int <lower = 1, upper = N_classes> y; // target values {1,..., N_classes}
  }

  transformed data {
    // Normalize the data
    matrix[D + 1, N] X_stn;
    X_stn[D + 1, ] = rep_row_vector(1, N);
    for (d in 1:D) {
      X_stn[d,] = to_row_vector((X[, d] - mean(X[, d])) / sd(X[, d]));
    }
  }

  parameters {
    // Matrix of the linear weights.
    matrix[N_classes, D + 1] W;
  }

  transformed parameters {
    // Compute the linear transformations
    matrix[N_classes, N] logits;
    for (c in 1:N_classes) {
      logits[c, ] =  W[c, ] * X_stn;
    }
  }

  model {
    // Evaluate prior density
    for (d in 1:(D + 1)) {
      for (c in 1:N_classes) {
        target += normal_lpdf(W[c, d] | 0, 1);
      }
    }
    // Evaluate likelihood
    for (n in 1:N) {
      target += categorical_logit_lpmf(y[n] | logits[,n]);
    }
  }

  generated quantities {
    vector[N] yrep; //posterior predictive samples
    matrix[N,N_classes] lpd; //log posterior predictive densities

    for (n in 1:N) {
      yrep[n] = categorical_logit_rng(logits[,n]);
    }
    for (n in 1:N) {
      for (c in 1:N_classes) {
        lpd[n, c] = categorical_logit_lpmf(c | logits[,n]);
      }
    }
  }

After loading the model, we pass the data and run the inference using HMC. As a result, we obtain 4000 samples from the posterior as well as posterior predictive samples and posterior predictive densities.

Code
fit <- model$sample(
  data = list(N = nrow(X),
              D = ncol(X),
              N_classes = length(unique(y)),
              X = X,
              y = as.numeric(y)),
  parallel_chains = 4,
  refresh = 0)
Running MCMC with 4 parallel chains...

Chain 4 finished in 6.0 seconds.
Chain 3 finished in 6.1 seconds.
Chain 2 finished in 6.3 seconds.
Chain 1 finished in 6.5 seconds.

All 4 chains finished successfully.
Mean chain execution time: 6.2 seconds.
Total execution time: 6.7 seconds.

The posterior distributions of the weights show reasonable separation between classes (the first index). The dimension with most overlap seems to be the intercept term, \(W[j,5]\) below.

Code
mcmc_areas(fit$draws(variables = "W"))

The calibration

The common approach of plotting a bar char of the observations overlaid with posterior means and 95% confidence intervals only gives a crude idea of the calibration of the model predictions.

Code
ppc_bars(as.numeric(y), fit$draws(variables = "yrep", format = "matrix")) +
  scale_x_continuous(breaks = 1:3, labels = levels(y))

From the following three reliability diagrams, one can see, that although sometimes underconfident in the predicted probabilities, the model is capable of very clearly separating the three species.

Code
plot_dotted_reliabilitydiag(x = colMeans(exp(fit$draws(variables = paste(paste("lpd[", 1:nrow(X), sep=""), ",1]", sep=""), format = "matrix"))), y = as.numeric(y == levels(y)[1]), quantiles = 20) + labs(title = paste("Calibration:", levels(y)[1], "vs. Others"))

Code
plot_dotted_reliabilitydiag(x = exp(colMeans(fit$draws(variables = paste(paste("lpd[", 1:nrow(X), sep=""), ",2]", sep=""), format = "matrix"))), y = as.numeric(y == levels(y)[2]), quantiles = 20) + labs(title = paste("Calibration:", levels(y)[2], "vs. Others"))

Code
plot_dotted_reliabilitydiag(x = exp(colMeans(fit$draws(variables = paste(paste("lpd[", 1:nrow(X), sep=""), ",3]", sep=""), format = "matrix"))), y = as.numeric(y == levels(y)[3]), quantiles = 20)  + labs(title = paste("Calibration:", levels(y)[3], "vs. Others"))