#' @title TabNet for Regression
#'
#' @aliases mlr_learners_regr.TabNet
#' @format [R6::R6Class] inheriting from [mlr3::LearnerRegr].
#'
#' @description
#' A [mlr3::LearnerRegr] for a TabNet classification model using the pytorch-tabnet Python package.
#'
#' @references
#' Arik, Sercan O.; Pfister, Tomas (2019).
#' TabNet: Attentive Interpretable Tabular Learning
#' arXiv:1908.07442
#' \url{https://arxiv.org/abs/1908.07442}
#'
#' @export
LearnerRegrTabNet <- R6Class("LearnerRegrTabNet",
inherit = LearnerRegr, # Adapt the name to your learner. For regression learners inherit = LearnerRegr.
public = list(
initialize = function() {
ps <- ParamSet$new( # parameter set using the paradox package
params = list(
ParamInt$new(id = "ntree", default = 500L, lower = 1L, tags = c("train", "predict")),
ParamInt$new(id = "mtry", lower = 1L, tags = "train"),
ParamLgl$new(id = "replace", default = TRUE, tags = "train"),
ParamUty$new(id = "classwt", default = NULL, tags = "train"), # lower = 0
ParamUty$new(id = "cutoff", tags = "train"),
ParamUty$new(id = "strata", tags = "train"),
ParamUty$new(id = "sampsize", tags = "train"),
ParamInt$new(id = "nodesize", default = 1L, lower = 1L, tags = "train"),
ParamInt$new(id = "maxnodes", lower = 1L, tags = "train"),
ParamFct$new(id = "importance", default = "none", levels = c("accuracy", "gini", "none"), tag = "train"), # importance is a logical value in the randomForest package.
ParamLgl$new(id = "localImp", default = FALSE, tags = "train"),
ParamLgl$new(id = "proximity", default = FALSE, tags = "train"),
ParamLgl$new(id = "oob.prox", tags = "train"),
ParamLgl$new(id = "norm.votes", default = TRUE, tags = "train"),
ParamLgl$new(id = "do.trace", default = FALSE, tags = "train"),
ParamLgl$new(id = "keep.forest", default = TRUE, tags = "train"),
ParamLgl$new(id = "keep.inbag", default = FALSE, tags = "train")
)
)
ps$values <- list(importance = "none") # Change the defaults. We set this here, because the default is FALSE in the randomForest package.
super$initialize(
# see the mlr3book for a description: https://mlr3book.mlr-org.com/extending-mlr3.html
id = "classif.randomForest",
packages = "randomForest",
feature_types = c("numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps,
properties = c("weights", "twoclass", "multiclass", "importance", "oob_error")
)
},
train_internal = function(task) {
pars <- self$param_set$get_values(tags = "train")
# Setting the importance value to logical
if (pars[["importance"]] != "none") {
pars[["importance"]] <- TRUE
} else {
pars[["importance"]] <- FALSE
}
# Get formula, data, classwt, cutoff for the randomForest
f <- task$formula() # the formula is available in the task
data <- task$data() # the data is avail
levs <- levels(data[[task$target_names]])
n <- length(levs)
if (!"cutoff" %in% names(pars)) {
cutoff <- rep(1 / n, n)
}
if ("classwt" %in% names(pars)) {
classwt <- pars[["classwt"]]
if (is.numeric(classwt) && length(classwt) == n && is.null(names(classwt))) {
names(classwt) <- levs
}
} else {
classwt <- NULL
}
if (is.numeric(cutoff) && length(cutoff) == n && is.null(names(cutoff))) {
names(cutoff) <- levs
}
invoke(randomForest::randomForest, formula = f, data = data, classwt = classwt, cutoff = cutoff, .args = pars) # use the mlr3misc::invoke function (it's similar to do.call())
},
predict_internal = function(task) {
pars <- self$param_set$get_values(tags = "predict") # get parameters with tag "predict"
newdata <- task$data(cols = task$feature_names) # get newdata
type <- ifelse(self$predict_type == "response", "response", "prob") # this is for the randomForest package
p <- invoke(predict, self$model,
newdata = newdata,
type = type, .args = pars
)
# Return a prediction object with PredictionClassif$new() or PredictionRegr$new()
if (self$predict_type == "response") {
PredictionRegr$new(task = task, response = p)
} else {
PredictionRegr$new(task = task, prob = p)
}
},
# Add method for importance, if learner supports that.
# It must return a sorted (decreasing) numerical, named vector.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
imp <- data.frame(self$model$importance)
pars <- self$param_set$get_values()
if (pars[["importance"]] == "accuracy") {
x <- setNames(imp[["MeanDecreaseAccuracy"]], rownames(imp))
return(sort(x, decreasing = TRUE))
}
if (pars[["importance"]] == "gini") {
x <- setNames(imp[["MeanDecreaseGini"]], rownames(imp))
return(sort(x, decreasing = TRUE))
}
if (pars[["importance"]] == "none") {
return(message("importance was set to 'none'. No importance available."))
}
},
# Add method for oob_error, if learner supports that.
oob_error = function() {
mean(self$model$err.rate[, 1])
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.