R/learner_tabpfn_regr_tabpfn.R

Defines functions unmarshal_model.tabpfn_model_marshaled marshal_model.tabpfn_model

#' @title TabPFN Regression Learner
#' @author b-zhou
#' @name mlr_learners_regr.tabpfn
#'
#' @inherit mlr_learners_classif.tabpfn description
#'
#' @templateVar class LearnerRegrTabPFN
#' @template sections_tabpfn
#'
#' @section Custom mlr3 parameters:
#'
#' - `output_type` corresponds to the same argument of the `.predict()` method of the `TabPFNRegressor` class,
#'   but only supports the options `"mean"`, `"median"` and `"mode"`.
#'   The point predictions are stored as `$response` of the prediction object.
#'
#' - `categorical_feature_indices` uses R indexing instead of zero-based Python indexing.
#'
#' - `device` must be a string.
#'   If set to `"auto"`, the behavior is the same as original.
#'   Otherwise, the string is passed as argument to `torch.device()` to create a device.
#'
#' - `inference_precision` must be `"auto"` or `"autocast"`.
#'   Passing `torch.dtype` is currently not supported.
#'
#' - `inference_config` is currently not supported.
#'
#' @templateVar id regr.tabpfn
#' @template learner
#'
#' @inherit mlr_learners_classif.tabpfn references
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerRegrTabPFN = R6Class("LearnerRegrTabPFN",
  inherit = LearnerRegr,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      ps = ps(
        output_type = p_fct(
          c("mean", "median", "mode"),
          default = "mean",
          tags = "predict"
        ),
        n_estimators = p_int(lower = 1, default = 4, tags = "train"),
        categorical_features_indices = p_uty(tags = "train", custom_check = function(x) {
          # R indexing is used
          check_integerish(x, lower = 1, any.missing = FALSE, min.len = 1)
        }),
        softmax_temperature = p_dbl(default = 0.9, lower = 0, tags = "train"),
        balance_probabilities = p_lgl(default = FALSE, tags = "train"),
        average_before_softmax = p_lgl(default = FALSE, tags = "train"),
        model_path = p_uty(default = "auto", tags = "train", custom_check = check_string),
        device = p_uty(default = "auto", tags = "train", custom_check = check_string),
        ignore_pretraining_limits = p_lgl(default = FALSE, tags = "train"),
        inference_precision = p_fct(
          # torch.dtype option is currently not supported
          c("auto", "autocast"),
          default = "auto",
          tags = "train"
        ),
        fit_mode = p_fct(
          c("low_memory", "fit_preprocessors", "fit_with_cache"),
          default = "fit_preprocessors",
          tags = "train"
        ),
        memory_saving_mode = p_uty(default = "auto", tags = "train", custom_check = function(x) {
          if (identical(x, "auto") || test_flag(x) || test_number(x, lower = 0)) {
            TRUE
          } else {
            "Invalid value for memory_saving_mode. Must be 'auto', a TRUE/FALSE value, or a number > 0."
          }
        }),
        random_state = p_int(default = 0, tags = "train"),
        n_jobs = p_int(lower = 1, init = 1, special_vals = list(-1), tags = "train")
      )

      super$initialize(
        id = "regr.tabpfn",
        feature_types = c("integer", "numeric", "logical"),
        predict_types = c("response", "quantiles"),
        param_set = ps,
        properties = c("missings", "marshal"),
        label = "TabPFN Regressor",
        man = "mlr3extralearners::mlr_learners_regr.tabpfn"
      )

      self$quantile_response = 0.5
    },
    #' @description
    #' Marshal the learner's model.
    #' @param ... (any)\cr
    #'   Additional arguments passed to [`marshal_model()`].
    marshal = function(...) {
      mlr3::learner_marshal(.learner = self, ...)
    },
    #' @description
    #' Unmarshal the learner's model.
    #' @param ... (any)\cr
    #'   Additional arguments passed to [`unmarshal_model()`].
    unmarshal = function(...) {
      mlr3::learner_unmarshal(.learner = self, ...)
    }
  ),

  active = list(
    #' @field marshaled (`logical(1)`)\cr
    #' Whether the learner has been marshaled.
    marshaled = function() {
      mlr3::learner_marshaled(self)
    }
  ),

  private = list(
    .train = function(task) {
      reticulate::py_require("torch")
      reticulate::py_require("tabpfn")
      tabpfn = reticulate::import("tabpfn")

      pars = self$param_set$get_values(tags = "train")

      if (!is.null(pars$device) && pars$device != "auto") {
        torch = reticulate::import("torch")
        pars$device = torch$device(pars$device)
      }

      # x is an (n_samples, n_features) array
      x = as.matrix(task$data(cols = task$feature_names))
      # force NaN to make conversion work,
      # otherwise reticulate will not convert NAs in logical and integer columns to
      # np.nan properly
      x[is.na(x)] = NaN
      # y is an (n_samples,) array
      y = task$truth()

      # convert categorical_features_indices to python indexing
      categ_indices = pars$categorical_features_indices
      if (!is.null(categ_indices)) {
        if (max(categ_indices) > ncol(x)) {
          stop("categorical_features_indices must not exceed number of features")
        }
        pars$categorical_features_indices = as.integer(categ_indices - 1)
      }

      regressor = mlr3misc::invoke(tabpfn$TabPFNRegressor, .args = pars)
      x_py = reticulate::r_to_py(x)
      y_py = reticulate::r_to_py(y)
      fitted = mlr3misc::invoke(regressor$fit, X = x_py, y = y_py)

      structure(list(fitted = fitted), class = "tabpfn_model")
    },

    .predict = function(task) {
      model = self$model$fitted

      x = as.matrix(task$data(cols = task$feature_names))
      # NA -> NaN, same reason as in $.train
      x[is.na(x)] = NaN
      x_py = reticulate::r_to_py(x)

      if (self$predict_type == "response") {
        pars = self$param_set$get_values(tags = "predict")
        response = mlr3misc::invoke(model$predict, X = x_py, .args = pars)
        response = reticulate::py_to_r(response)
        list(response = response)
      } else {
        quantiles = mlr3misc::invoke(
          model$predict,
          X = x_py,
          output_type = "quantiles",
          # tabpfn requires a list object as quantiles, even in the case of a single quantile
          quantiles = as.list(self$quantiles)
        )
        quantiles = do.call(cbind, reticulate::py_to_r(quantiles))
        attr(quantiles, "probs") = self$quantiles
        attr(quantiles, "response") = self$quantile_response
        list(quantiles = quantiles)
      }
    }
  )
)

.extralrns_dict$add("regr.tabpfn", LearnerRegrTabPFN)


#' @export
marshal_model.tabpfn_model = function(model, inplace = FALSE, ...) {
  # pickle should be available in any python environment
  pickle = reticulate::import("pickle")
  # save model as bytes
  pickled = pickle$dumps(model$fitted)
  # ensure object is converted to R
  pickled = as.raw(pickled)

  structure(list(
    marshaled = pickled,
    packages = "mlr3extralearners"
  ), class = c("tabpfn_model_marshaled", "marshaled"))
}

#' @export
unmarshal_model.tabpfn_model_marshaled = function(model, inplace = FALSE, ...) {
  pickle = reticulate::import("pickle")
  # unpickle
  fitted = pickle$loads(reticulate::r_to_py(model$marshaled))
  structure(list(fitted = fitted), class = "tabpfn_model")
}
mlr-org/mlr3extralearners documentation built on June 11, 2025, 7:06 p.m.