Interpretabnet

library(tabnet)
library(dplyr)
library(purrr)
library(rsample)
library(yardstick)
library(ggplot2)
library(patchwork)
set.seed(202402)

Interprestability

Interprestability is a lightweight evolution of Tabnet network design that provides, among other, a stability score of the interpretation mask provided through the tabnet_explain() function.

In this vignette, we will try to improve the workflow on Ames dataset debuted in the Training a Tabnet model from missing-values dataset vignette.

Interprestability score associated with tabnet_explain() results, will help us to select more stable models.

Interprestability score is a metric for the stability of mask between models: score over 0.9 relates very-high stability, between 0.7 and 0.9 is high stability, between 0.5 and 0.7 is moderate and between 0.3 and 0.5 low stability of the interpretation on the model.

The {tabnet} implementation compares the explainability parameters between the last 5 model checkpoints. So it is up to you to make those last 5 checkpoints a good proxy of the model.

Let’s experiment those on a pretraining scenario on the ames dataset :

Interprestability on ames_missing

We will work here with the ames_missing dataset, transformation of the ames dataset done in vignette Training a Tabnet model from missing-values dataset.

data("ames_missing", package = "tabnet")
ames_split <- initial_split(ames_missing, strata = Sale_Price, prop = 0.8)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)

A classical TabNet models

cat_emb_dim <- map_dbl(ames_train %>% select_if(is.factor), ~max(log(nlevels(.x)) %>% floor, 1))

tabnet_config <- tabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
                                early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
                                valid_split = 0.2)

train_tabnet <- list(tabnet_fit(Sale_Price ~., data = ames_train,
                                     epochs = 100, checkpoint_epoch = 101,
                                     config = tabnet_config, learn_rate = 5e-2))

3 InterpreTabNet models

The difference of interpretabnet models is - the presence of a 3 layers MLP as adaptation layer in between the tabnet steps. - the encoder layer use a Multibranch Weighted Linear-Unit implemented as nn_mb_wlu() - the mask type switch to entmax

You don’t have to know them, as the handy interpretabnet_config() function comes as a replacement of tabnet_config() to switch them all at once.

For the training loop, you will realize that Interpretabnet models requires much more epochs to converge due to the MLP network.

interpretabnet_config <- interpretabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
                                early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
                                valid_split = 0.2, learn_rate = 5e-2, lr_scheduler = "step", 
                                lr_decay = .7, step_size = 5)


train_tabnet <- c(train_tabnet,
                  map(1:3, ~tabnet_fit(Sale_Price ~., data = ames_train,
                                     epochs = 150, checkpoint_epoch = 151,
                                     config = interpretabnet_config),
                    .progress = TRUE)
)
#> ■■■■■■■■■■■ 33% | ETA: 4m

 ■■■■■■■■■■■■■■■■■■■■■ 67% | ETA: 1m



Initial diagnostic of model training

autoplot(train_tabnet[[1]]) + 
  autoplot(train_tabnet[[2]]) + 
  autoplot(train_tabnet[[3]]) + 
  autoplot(train_tabnet[[4]]) + 
  plot_layout(axes = "collect", guides = "collect") &
  theme(legend.position = "bottom")

Adding 5 checkpoints to each model training

With a small learning-rate, we will extend the fitted model for 5 epochs in order to measure the Inteprestability.


models_checkpointed <- map(train_tabnet, ~tabnet_fit(Sale_Price ~., data = ames_train,
                                     tabnet_model = .x, epochs = 6, valid_split = 0.2,
                                     checkpoint_epoch = 1, learn_rate = 1e-2),
                           .progress = TRUE)

We can now have a closer look at their training convergence plot:

autoplot(models_checkpointed[[1]]) + 
  autoplot(models_checkpointed[[2]]) + 
  autoplot(models_checkpointed[[3]]) + 
  autoplot(models_checkpointed[[4]]) + 
  plot_layout(axes = "collect", guides = "collect") &
  theme(legend.position = "bottom")

evolution of the Interpretabnet score along the checkpoints

explain_lst <- map(models_checkpointed, tabnet_explain, new_data = ames_train)
interprestability <- map_dbl(explain_lst, "interprestability")
interprestability
#> [1] 0.9947170 0.9948593 0.9942322 0.9847693

plot the 4 different models


autoplot(explain_lst[[1]], quantile = 0.99) + 
  autoplot(explain_lst[[2]], quantile = 0.99) + 
  autoplot(explain_lst[[3]], quantile = 0.99) + 
  autoplot(explain_lst[[4]], quantile = 0.99) + 
  plot_layout(axes = "collect", guides = "collect") &
  theme(legend.position = "bottom")