R/SL.randomForest.R

Defines functions predict.SL.randomForest SL.randomForest

Documented in predict.SL.randomForest SL.randomForest

# randomForest{randomForest}

SL.randomForest <- function(Y, X, newX, family, mtry = ifelse(family$family == "gaussian",
                            max(floor(ncol(X)/3), 1), floor(sqrt(ncol(X)))), ntree = 1000,
                            nodesize = ifelse(family$family == "gaussian", 5, 1),
                            maxnodes = NULL,
			    importance = FALSE, ...) {
	.SL.require('randomForest')
	if (family$family == "gaussian") {
		fit.rf <- randomForest::randomForest(Y ~ ., data = X, ntree = ntree, xtest = newX, keep.forest = TRUE, mtry = mtry, nodesize = nodesize, maxnodes = maxnodes, importance = importance)
		pred <- fit.rf$test$predicted
		fit <- list(object = fit.rf)
	}
	if (family$family == "binomial") {
		fit.rf <- randomForest::randomForest(y = as.factor(Y), x = X, ntree = ntree, xtest = newX, keep.forest = TRUE, mtry = mtry, nodesize = nodesize, maxnodes = maxnodes, importance = importance)
		pred <- fit.rf$test$votes[, 2]
		fit <- list(object = fit.rf)
	}
	out <- list(pred = pred, fit = fit)
	class(out$fit) <- c("SL.randomForest")
	return(out)
}

predict.SL.randomForest <- function(object, newdata, family, ...) {
	.SL.require('randomForest')
	if (family$family == "gaussian") {
		pred <- predict(object$object, newdata = newdata, type = 'response')
	}
	if (family$family == "binomial") {
		pred <- predict(object$object, newdata = newdata, type = 'vote')[,2]
	}
	pred
}

Try the SuperLearner package in your browser

Any scripts or data that you put into this service are public.

SuperLearner documentation built on May 29, 2024, 5:25 a.m.