library(tabnet)
library(dplyr)
library(purrr)
library(rsample)
library(yardstick)
library(ggplot2)
library(patchwork)
set.seed(202402)
Interprestability is a lightweight evolution of Tabnet network design that provides, among other, a stability score of the interpretation mask provided through the tabnet_explain()
function.
In this vignette, we will try to improve the workflow on Ames dataset debuted in the Training a Tabnet model from missing-values dataset vignette.
Interprestability score associated with tabnet_explain()
results, will help us to select more stable models.
Interprestability score is a metric for the stability of mask between models: score over 0.9 relates very-high stability, between 0.7 and 0.9 is high stability, between 0.5 and 0.7 is moderate and between 0.3 and 0.5 low stability of the interpretation on the model.
The {tabnet} implementation compares the explainability parameters between the last 5 model checkpoints. So it is up to you to make those last 5 checkpoints a good proxy of the model.
Let’s experiment those on a pretraining scenario on the ames
dataset :
We will work here with the ames_missing
dataset, transformation of the ames
dataset done in vignette Training a Tabnet model from missing-values dataset.
data("ames_missing", package = "tabnet")
initial_split(ames_missing, strata = Sale_Price, prop = 0.8)
ames_split <- training(ames_split)
ames_train <- testing(ames_split) ames_test <-
map_dbl(ames_train %>% select_if(is.factor), ~max(log(nlevels(.x)) %>% floor, 1))
cat_emb_dim <-
tabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
tabnet_config <-early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
valid_split = 0.2)
list(tabnet_fit(Sale_Price ~., data = ames_train,
train_tabnet <-epochs = 100, checkpoint_epoch = 101,
config = tabnet_config, learn_rate = 5e-2))
The difference of interpretabnet models is - the presence of a 3 layers MLP as adaptation layer in between the tabnet steps. - the encoder layer use a Multibranch Weighted Linear-Unit implemented as nn_mb_wlu()
- the mask type switch to entmax
You don’t have to know them, as the handy interpretabnet_config()
function comes as a replacement of tabnet_config()
to switch them all at once.
For the training loop, you will realize that Interpretabnet models requires much more epochs to converge due to the MLP network.
interpretabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
interpretabnet_config <-early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
valid_split = 0.2, learn_rate = 5e-2, lr_scheduler = "step",
lr_decay = .7, step_size = 5)
c(train_tabnet,
train_tabnet <-map(1:3, ~tabnet_fit(Sale_Price ~., data = ames_train,
epochs = 150, checkpoint_epoch = 151,
config = interpretabnet_config),
.progress = TRUE)
)#> ■■■■■■■■■■■ 33% | ETA: 4m[K
67% | ETA: 1m[K
■■■■■■■■■■■■■■■■■■■■■
[K
autoplot(train_tabnet[[1]]) +
autoplot(train_tabnet[[2]]) +
autoplot(train_tabnet[[3]]) +
autoplot(train_tabnet[[4]]) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")
With a small learning-rate, we will extend the fitted model for 5 epochs in order to measure the Inteprestability.
map(train_tabnet, ~tabnet_fit(Sale_Price ~., data = ames_train,
models_checkpointed <-tabnet_model = .x, epochs = 6, valid_split = 0.2,
checkpoint_epoch = 1, learn_rate = 1e-2),
.progress = TRUE)
We can now have a closer look at their training convergence plot:
autoplot(models_checkpointed[[1]]) +
autoplot(models_checkpointed[[2]]) +
autoplot(models_checkpointed[[3]]) +
autoplot(models_checkpointed[[4]]) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")
map(models_checkpointed, tabnet_explain, new_data = ames_train)
explain_lst <- map_dbl(explain_lst, "interprestability")
interprestability <-
interprestability#> [1] 0.9947170 0.9948593 0.9942322 0.9847693
autoplot(explain_lst[[1]], quantile = 0.99) +
autoplot(explain_lst[[2]], quantile = 0.99) +
autoplot(explain_lst[[3]], quantile = 0.99) +
autoplot(explain_lst[[4]], quantile = 0.99) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")