##` ## https://github.com/vspinu/polymode/issues/147#issuecomment-399745611 R.v.maj <- as.numeric(R.version$major) R.v.min.1 <- as.numeric(strsplit(R.version$minor, "\\.")[[1]][1]) if(R.v.maj < 2 || (R.v.maj == 2 && R.v.min.1 < 15)) stop("requires R >= 2.15", call.=FALSE) suppressPackageStartupMessages(library(knitr)) opts_chunk$set(echo=TRUE, warning=TRUE, message=TRUE, cache=FALSE, fig.align="center") opts_knit$set(progress=TRUE, verbose=TRUE)
Comparing models based on their predictions via cross-validation is useful in many fields.
The caret package helps doing exactly that.
However, even though it natively incorporates many models, even more exist.
Moreover, for the models without tuning parameters, it is not straightforward to know how to use caret
.
This document aims at providing some examples.
This document will also require external packages to be available:
suppressPackageStartupMessages(library(caret)) suppressPackageStartupMessages(library(varbvs)) stopifnot(compareVersion("2.5.7", as.character(packageVersion("varbvs"))) != 1) suppressPackageStartupMessages(library(BGLR)) suppressPackageStartupMessages(library(coda)) suppressPackageStartupMessages(library(rrBLUP)) suppressPackageStartupMessages(library(doParallel)) cl <- makePSOCKcluster(min(max(1, detectCores() - 1), 20)) registerDoParallel(cl) suppressPackageStartupMessages(library(rutilstimflutre)) stopifnot(compareVersion("0.169.2", as.character(packageVersion("rutilstimflutre"))) != 1)
This R chunk is used to assess how much time it takes to execute the R code in this document until the end:
t0 <- proc.time()
data("mtcars") trCtl <- trainControl(method="repeatedcv", number=10, repeats=5) set.seed(1234) # ensure that the same resamples are used fit <- train(form=mpg ~ hp, data=mtcars, method="lm", trControl=trCtl) fit names(fit) fit$finalModel
set.seed(1234) n <- 10^2 x <- rnorm(n) mu <- 20 beta <- 0.5 epsilon <- rnorm(n) y <- mu + beta * x + epsilon dat <- data.frame(x=x, y=y) plot(x, y) abline(lm(y ~ x, data=dat), col="red")
lm
lmModelInfo <- getModelInfo(model="lm", regex=FALSE)[[1]] names(lmModelInfo) lmModelInfo$label lmModelInfo$library lmModelInfo$loop lmModelInfo$type lmModelInfo$parameters lmModelInfo$grid lmModelInfo$fit lmModelInfo$predict lmModelInfo$prob lmModelInfo$predictors lmModelInfo$tags lmModelInfo$varImp lmModelInfo$sort
lm
Create inputs for the "method" of interest (https://topepo.github.io/caret/using-your-own-model-in-train.html#model-components):
caretFitLm <- function(x, y, wts, param, lev, last, weights, classProbs, ...) { dat <- if(is.data.frame(x)) x else as.data.frame(x) dat$.outcome <- y if(! is.null(wts)){ if(param$intercept) out <- lm(.outcome ~ ., data = dat, weights = wts, ...) else out <- lm(.outcome ~ 0 + ., data = dat, weights = wts, ...) } else { # wts == NULL if(param$intercept) out <- lm(.outcome ~ ., data = dat, ...) else out <- lm(.outcome ~ 0 + ., data = dat, ...) } out } caretPredictLm <- function(modelFit, newdata, submodels=NULL) { if(! is.data.frame(newdata)) newdata <- as.data.frame(newdata) predict(modelFit, newdata) } # check: # fitTmp <- caretFitLm(x=dat$x, y=dat$y) # predTmp <- caretPredictLm(fitTmp, dat$x[1:10]) caretGridLm <- function(x, y, len=NULL, search="grid") { data.frame(intercept=TRUE) } caretParamLm <- data.frame(parameter="intercept", class="logical", label="intercept") caretMethLm <- list(library="stats", type="Regression", parameters=caretParamLm, grid=caretGridLm, fit=caretFitLm, predict=caretPredictLm, prob=NULL, sort=NULL)
Create inputs for the cross-validation (https://topepo.github.io/caret/model-training-and-tuning.html#the-traincontrol-function):
caretSummary <- function(data, lev=NULL, model=NULL){ data <- data[order(data$obs, decreasing=TRUE),] # sort best -> worse nb.inds <- nrow(data) idx.best50p <- floor(0.50 * nb.inds) idx.best25p <- floor(0.25 * nb.inds) coefOls <- as.numeric(coef(lm(pred ~ obs, data=data))) c(rmse=sqrt(mean((data$pred - data$obs)^2)), corP=cor(data$obs, data$pred, method="pearson"), corS=cor(data$obs, data$pred, method="spearman"), corP.best50p=cor(data$obs[1:idx.best50p], data$pred[1:idx.best50p], method="pearson"), corS.best50p=cor(data$obs[1:idx.best50p], data$pred[1:idx.best50p], method="spearman"), corP.best25p=cor(data$obs[1:idx.best25p], data$pred[1:idx.best25p], method="pearson"), corS.best25p=cor(data$obs[1:idx.best25p], data$pred[1:idx.best25p], method="spearman"), reg.intercept=coefOls[1], reg.slope=coefOls[2]) } caretTrainCtlLm <- trainControl(method="repeatedcv", number=5, repeats=5, summaryFunction=caretSummary)
The cross-validation can now be performed:
set.seed(1234) # ensure that the same resamples are used fit2 <- train(form=y ~ x, data=dat, method=caretMethLm, trControl=caretTrainCtlLm, metric="rmse") fit2 names(fit2) fit2$results dim(fit2$resample) head(fit2$resample) mean(fit2$resample$rmse) sd(fit2$resample$rmse)
glmnetModelInfo <- getModelInfo(model="glmnet", regex=FALSE)[[1]] names(glmnetModelInfo) glmnetModelInfo$label glmnetModelInfo$library glmnetModelInfo$type glmnetModelInfo$parameters glmnetModelInfo$grid glmnetModelInfo$loop glmnetModelInfo$fit glmnetModelInfo$predict glmnetModelInfo$prob glmnetModelInfo$predictors glmnetModelInfo$varImp glmnetModelInfo$levels glmnetModelInfo$tags glmnetModelInfo$sort glmnetModelInfo$trim
set.seed(1234) X <- simulGenosDose(nb.genos=10^3, nb.snps=10^3) truth <- simulBvsr(Q=1, mu=0, X=X, pi=0.2, pve=0.75, sigma.a2=1) sum(truth$gamma) colnames(X)[truth$gamma == 1] dat <- truth$dat
varbvs
Make the inputs:
caretMethVarbvs <- list(library="varbvs", type="Regression", parameters=data.frame(parameter="intercept", class="logical", label="intercept"), grid=caretGridVarbvs, fit=caretFitVarbvs, predict=caretPredictVarbvs, prob=NULL, sort=NULL) caretTrainCtlVarbvs <- trainControl(method="repeatedcv", number=5, repeats=5, summaryFunction=caretSummary, allowParallel=TRUE) ## TODO: use the "seeds" argument
Run cross-validation:
set.seed(1234) # ensure that the same resamples are used system.time( fit3 <- train(x=X, y=dat$response1, method=caretMethVarbvs, trControl=caretTrainCtlVarbvs, metric="rmse", Z=NULL, family="gaussian", verbose=FALSE)) fit3 fit3$results
BGLR
Make the inputs:
# check: # nb.iters=50 * 10^3; burn.in=5 * 10^3; thin=5 # fitTmp <- caretFitBglr(x=X, y=dat$response1, # bglrEtaModel="BayesC", task.id="test-BGLR_", # nb.iters=nb.iters, burn.in=burn.in, thin=thin, # keep.samples=TRUE) # plot(fitTmp$post.samples) # raftery.diag(fitTmp$post.samples, q=0.5, r=0.05, s=0.9) # geweke.diag(fitTmp$post.samples) # should not be too far from [-2;2] # summary(fitTmp$post.samples) # predTmp <- caretPredictBglr(fitTmp, X[1:10,]) caretMethBglr <- list(library=c("BGLR", "coda"), type="Regression", parameters=data.frame(parameter="intercept", class="logical", label="intercept"), grid=caretGridBglr, fit=caretFitBglr, predict=caretPredictBglr, prob=NULL, sort=NULL) caretTrainCtlBglr <- trainControl(method="repeatedcv", number=5, repeats=5, summaryFunction=caretSummary, allowParallel=TRUE) ## TODO: use the "seeds" argument
Run cross-validation:
set.seed(1234) # ensure that the same resamples are used system.time( fit4 <- train(x=X, y=dat$response1, method=caretMethBglr, trControl=caretTrainCtlBglr, metric="rmse", ETA=list(list(X=X, model="BayesC")), saveAt="bglr4caret_", response_type="gaussian", # nIter=50 * 10^2, burnIn=5 * 10^2, thin=5, verbose=FALSE, # debug nIter=50 * 10^3, burnIn=5 * 10^3, thin=5, verbose=FALSE, keep.samples=FALSE)) fit4 fit4$results
rrBLUP
Make the inputs:
caretMethRrblup <- list(library=c("rrBLUP"), type="Regression", parameters=data.frame(parameter="intercept", class="logical", label="intercept"), grid=caretGridRrblup, fit=caretFitRrblup, predict=caretPredictRrblup, prob=NULL, sort=NULL) caretTrainCtlRrblup <- trainControl(method="repeatedcv", number=5, repeats=5, summaryFunction=caretSummary, allowParallel=TRUE) ## TODO: use the "seeds" argument
Run cross-validation:
set.seed(1234) # ensure that the same resamples are used system.time( fit5 <- train(x=X, y=dat$response1, method=caretMethRrblup, trControl=caretTrainCtlRrblup, metric="rmse")) fit5 fit5$results
stopCluster(cl)
t1 <- proc.time(); t1 - t0 print(sessionInfo(), locale=FALSE)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.