R/LearnerSurvSVM.R

#' @title Survival Support Vector Machine Learner
#'
#' @name mlr_learners_surv.svm
#'
#' @description
#' A [mlr3proba::LearnerSurv] implementing svm from package
#'   \CRANpkg{survivalsvm}.
#' Calls [survivalsvm::survivalsvm()].
#'
#' @details
#' Four possible SVMs can be implemented, dependent on the `type` parameter. These correspond
#' to predicting the survival time via regression (`regression`), predicting a continuous rank
#' (`vanbelle1`, `vanbelle2`), or a hybrid of the two (`hybrid`).
#' Whichever `type` is chosen determines how the `crank` predict type is calculated,
#' but in any case all can be considered a valid continuous ranking.
#'
#' @templateVar id surv.svm
#' @template section_dictionary_learner
#'
#' @references
#' Belle VV, Pelckmans K, Huffel SV, Suykens JAK (2010).
#' “Improved performance on high-dimensional survival data by application of Survival-SVM.”
#' Bioinformatics, 27(1), 87–94.
#' doi: 10.1093/bioinformatics/btq617.
#'
#' Belle VV, Pelckmans K, Huffel SV, Suykens JA (2011).
#' “Support vector methods for survival analysis: a comparison between ranking and regression
#' approaches."
#' Artificial Intelligence in Medicine, 53(2), 107–118.
#' doi: 10.1016/j.artmed.2011.06.006.
#'
#' Shivaswamy, P. K., Chu, W., & Jansche, M. (2007).
#' A support vector approach to censored targets.
#' In Proceedings - IEEE International Conference on Data Mining, ICDM (pp. 655–660).
#' https://doi.org/10.1109/ICDM.2007.93
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerSurvSVM = R6Class("LearnerSurvSVM",
  inherit = LearnerSurv,

  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      ps = ParamSet$new(
        params = list(
          ParamFct$new(
            id = "type", default = "regression",
            levels = c("regression", "vanbelle1", "vanbelle2", "hybrid"),
            tags = "train"),
          ParamFct$new(
            id = "diff.meth", levels = c("makediff1", "makediff2", "makediff3"),
            tags = c("train")),
          ParamUty$new(id = "gamma.mu", tags = c("train", "required")),
          ParamFct$new(
            id = "opt.meth", default = "quadprog", levels = c("quadprog", "ipop"),
            tags = "train"),
          ParamFct$new(
            id = "kernel", default = "lin_kernel",
            levels = c("lin_kernel", "add_kernel", "rbf_kernel", "poly_kernel"),
            tags = "train"),
          ParamUty$new(id = "kernel.pars", tags = "train"),
          ParamInt$new(id = "sgf.sv", default = 5L, lower = 0L, tags = "train"),
          ParamInt$new(id = "sigf", default = 7L, lower = 0L, tags = "train"),
          ParamInt$new(id = "maxiter", default = 20L, lower = 0L, tags = "train"),
          ParamDbl$new(id = "margin", default = 0.05, lower = 0, tags = "train"),
          ParamDbl$new(id = "bound", default = 10, lower = 0, tags = "train"),
          ParamDbl$new(id = "eig.tol", default = 1e-06, lower = 0, tags = "train"),
          ParamDbl$new(id = "conv.tol", default = 1e-07, lower = 0, tags = "train"),
          ParamDbl$new(id = "posd.tol", default = 1e-08, lower = 0, tags = "train")
        )
      )

      ps$add_dep("diff.meth", "type", CondAnyOf$new(c("vanbelle1", "vanbelle2", "hybrid")))

      super$initialize(
        id = "surv.svm",
        packages = "survivalsvm",
        feature_types = c("integer", "numeric"),
        predict_types = c("crank", "response"),
        param_set = ps,
        man = "mlr3learners.survivalsvm::mlr_learners_surv.svm"
      )
    }
  ),

  private = list(
    .train = function(task) {
      with_package("survivalsvm", {
        mlr3misc::invoke(survivalsvm::survivalsvm,
          formula = task$formula(),
          data = task$data(),
          .args = self$param_set$get_values(tags = "train"))
      })
    },

    .predict = function(task) {
      fit = predict(self$model, newdata = task$data(cols = task$feature_names))
      crank = distr = lp = response = NULL

      if (is.null(self$param_set$values$type)) {
        crank = response = as.numeric(fit$predicted)
      } else if (self$param_set$values$type == "regression") {
        crank = response = as.numeric(fit$predicted)
      } else if (self$param_set$values$type == "vanbelle1") {
        crank = as.numeric(fit$predicted)
      } else if (self$param_set$values$type == "vanbelle2") {
        crank = as.numeric(fit$predicted)
      } else if (self$param_set$values$type == "hybrid") {
        crank = response = as.numeric(fit$predicted)
      }

      mlr3proba::PredictionSurv$new(task = task,
                                    crank = crank,
                                    response = response,
                                    lp = lp,
                                    distr = distr)
    }
  )
)
mlr3learners/mlr3learners.survivalsvm documentation built on Aug. 2, 2020, 9:40 a.m.