R/RLearner_classif_ctree.R

#' @export
makeRLearner.classif.ctree = function() {
  makeRLearnerClassif(
    cl = "classif.ctree",
    package = "party",
    par.set = makeParamSet(
      makeDiscreteLearnerParam(id = "teststat", default = "quad", values = c("quad", "max")),
      makeDiscreteLearnerParam(id = "testtype", default = "Bonferroni", values = c("Bonferroni", "MonteCarlo", "Univariate", "Teststatistic")),
      makeNumericLearnerParam(id = "mincriterion", default = 0.95, lower = 0, upper = 1),
      makeIntegerLearnerParam(id = "minsplit", default = 20L, lower = 1L),
      makeIntegerLearnerParam(id = "minbucket", default = 7L, lower = 1L),
      makeLogicalLearnerParam(id = "stump", default = FALSE),
      makeIntegerLearnerParam(id = "nresample", default = 9999L, lower = 1L, requires = quote(testtype == "MonteCarlo")),
      makeIntegerLearnerParam(id = "maxsurrogate", default = 0L, lower = 0L),
      makeIntegerLearnerParam(id = "mtry", default = 0L, lower = 0L),
      makeLogicalLearnerParam(id = "savesplitstats", default = TRUE, tunable = FALSE),
      makeIntegerLearnerParam(id = "maxdepth", default = 0L, lower = 0L)
    ),
    properties = c("twoclass", "multiclass", "missings", "numerics", "factors", "ordered", "prob", "weights"),
    name = "Conditional Inference Trees",
    short.name = "ctree",
    note = "See `?ctree_control` for possible breakage for nominal features with missingness.",
    callees = c("ctree", "ctree_control")
  )
}

#' @export
trainLearner.classif.ctree = function(.learner, .task, .subset, .weights = NULL, teststat, testtype,
  mincriterion, minsplit, minbucket, stump, nresample, maxsurrogate, mtry,
  savesplitstats, maxdepth, ...) {

  ctrl = learnerArgsToControl(party::ctree_control, teststat, testtype, mincriterion, minsplit,
    minbucket, stump, nresample, maxsurrogate, mtry, savesplitstats, maxdepth)
  f = getTaskFormula(.task)
  party::ctree(f, data = getTaskData(.task, .subset), controls = ctrl, weights = .weights, ...)
}

#' @export
predictLearner.classif.ctree = function(.learner, .model, .newdata, ...) {
  if (.learner$predict.type == "prob") {
    m = .model$learner.model
    p = party::treeresponse(m, newdata = .newdata, ...)
    p = do.call(rbind, p)
    rownames(p) = NULL
    colnames(p) = m@responses@levels[[.model$task.desc$target]]
    return(p)
  } else {
    predict(.model$learner.model, newdata = .newdata, ...)
  }
}
Najah-lshanableh/R-data-mining2 documentation built on May 6, 2019, 10:11 a.m.