nn_prune_head.tabnet_fit | R Documentation |
Prune head_size
last layers of a tabnet network in order to
use the pruned module as a sequential embedding module.
## S3 method for class 'tabnet_fit'
nn_prune_head(x, head_size)
## S3 method for class 'tabnet_pretrain'
nn_prune_head(x, head_size)
x |
nn_network to prune |
head_size |
number of nn_layers to prune, should be less than 2 |
a tabnet network with the top nn_layer removed
data("ames", package = "modeldata")
x <- ames[,-which(names(ames) == "Sale_Price")]
y <- ames$Sale_Price
# pretrain a tabnet model on ames dataset
ames_pretrain <- tabnet_pretrain(x, y, epoch = 2, checkpoint_epochs = 1)
# prune classification head to get an embedding model
pruned_pretrain <- torch::nn_prune_head(ames_pretrain, 1)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.