R/SL.rpart.R

Defines functions predict.SL.rpart SL.rpart

Documented in predict.SL.rpart SL.rpart

# rpart {rpart}
SL.rpart <- function(Y, X, newX, family, obsWeights, cp = 0.01, minsplit = 20, xval = 0L, maxdepth = 30, minbucket = round(minsplit/3), ...) {
  .SL.require('rpart')
	if(family$family == "gaussian"){
		fit.rpart <- rpart::rpart(Y~., data = data.frame(Y, X), control = rpart::rpart.control(cp = cp, minsplit = minsplit, xval = xval, maxdepth = maxdepth, minbucket = minbucket), method = "anova", weights = obsWeights)
		pred <- predict(fit.rpart, newdata = newX)
	}
	if(family$family == "binomial") {
		fit.rpart <- rpart::rpart(Y ~ ., data = data.frame(Y, X), control = rpart::rpart.control(cp = cp, minsplit = minsplit, xval = xval, maxdepth = maxdepth, minbucket = minbucket), method = "class", weights = obsWeights)
		pred <- predict(fit.rpart, newdata = newX)[, 2]
	}
	fit <- list(object = fit.rpart)
	out <- list(pred = pred, fit = fit)
	class(out$fit) <- c("SL.rpart")
	return(out)
}

# 
predict.SL.rpart <- function(object, newdata, family, ...) {
	.SL.require('rpart')
	if(family$family=="gaussian") { 
	  pred <- predict(object$object, newdata = newdata)
  }
  if(family$family=="binomial") {
    pred <- predict(object$object, newdata = newdata)[, 2]
  }
	return(pred)
}

Try the SuperLearner package in your browser

Any scripts or data that you put into this service are public.

SuperLearner documentation built on May 29, 2024, 5:25 a.m.