##`
## https://github.com/vspinu/polymode/issues/147#issuecomment-399745611
R.v.maj <- as.numeric(R.version$major)
R.v.min.1 <- as.numeric(strsplit(R.version$minor, "\\.")[[1]][1])
if(R.v.maj < 2 || (R.v.maj == 2 && R.v.min.1 < 15))
  stop("requires R >= 2.15", call.=FALSE)

suppressPackageStartupMessages(library(knitr))
opts_chunk$set(echo=TRUE, warning=TRUE, message=TRUE, cache=FALSE, fig.align="center")
opts_knit$set(progress=TRUE, verbose=TRUE)

Overview

Comparing models based on their predictions via cross-validation is useful in many fields. The caret package helps doing exactly that. However, even though it natively incorporates many models, even more exist. Moreover, for the models without tuning parameters, it is not straightforward to know how to use caret. This document aims at providing some examples.

This document will also require external packages to be available:

suppressPackageStartupMessages(library(caret))
suppressPackageStartupMessages(library(varbvs))
stopifnot(compareVersion("2.5.7",
                         as.character(packageVersion("varbvs")))
          != 1)
suppressPackageStartupMessages(library(BGLR))
suppressPackageStartupMessages(library(coda))
suppressPackageStartupMessages(library(rrBLUP))
suppressPackageStartupMessages(library(doParallel))
cl <- makePSOCKcluster(min(max(1, detectCores() - 1), 20))
registerDoParallel(cl)
suppressPackageStartupMessages(library(rutilstimflutre))
stopifnot(compareVersion("0.169.2",
                         as.character(packageVersion("rutilstimflutre")))
          != 1)

This R chunk is used to assess how much time it takes to execute the R code in this document until the end:

t0 <- proc.time()

Native example

data("mtcars")
trCtl <- trainControl(method="repeatedcv", number=10, repeats=5)
set.seed(1234) # ensure that the same resamples are used
fit <- train(form=mpg ~ hp, data=mtcars, method="lm", trControl=trCtl)
fit
names(fit)
fit$finalModel

Simple linear regression

Data

set.seed(1234)
n <- 10^2
x <- rnorm(n)
mu <- 20
beta <- 0.5
epsilon <- rnorm(n)
y <- mu + beta * x + epsilon
dat <- data.frame(x=x, y=y)
plot(x, y)
abline(lm(y ~ x, data=dat), col="red")

Native lm

lmModelInfo <- getModelInfo(model="lm", regex=FALSE)[[1]]
names(lmModelInfo)
lmModelInfo$label
lmModelInfo$library
lmModelInfo$loop
lmModelInfo$type
lmModelInfo$parameters
lmModelInfo$grid
lmModelInfo$fit
lmModelInfo$predict
lmModelInfo$prob
lmModelInfo$predictors
lmModelInfo$tags
lmModelInfo$varImp
lmModelInfo$sort

Custom lm

Create inputs for the "method" of interest (https://topepo.github.io/caret/using-your-own-model-in-train.html#model-components):

caretFitLm <- function(x, y, wts, param, lev, last, weights, classProbs, ...) {
  dat <- if(is.data.frame(x)) x else as.data.frame(x)
  dat$.outcome <- y
  if(! is.null(wts)){
    if(param$intercept)
      out <- lm(.outcome ~ ., data = dat, weights = wts, ...)
    else
      out <- lm(.outcome ~ 0 + ., data = dat, weights = wts, ...)
  } else { # wts == NULL
    if(param$intercept)
      out <- lm(.outcome ~ ., data = dat, ...)
    else
      out <- lm(.outcome ~ 0 + ., data = dat, ...)
  }
  out
}
caretPredictLm <- function(modelFit, newdata, submodels=NULL) {
  if(! is.data.frame(newdata)) newdata <- as.data.frame(newdata)
  predict(modelFit, newdata)
}
# check:
# fitTmp <- caretFitLm(x=dat$x, y=dat$y)
# predTmp <- caretPredictLm(fitTmp, dat$x[1:10])
caretGridLm <- function(x, y, len=NULL, search="grid") {
  data.frame(intercept=TRUE)
}
caretParamLm <- data.frame(parameter="intercept", class="logical", label="intercept")
caretMethLm <- list(library="stats",
                    type="Regression",
                    parameters=caretParamLm,
                    grid=caretGridLm,
                    fit=caretFitLm,
                    predict=caretPredictLm,
                    prob=NULL,
                    sort=NULL)

Create inputs for the cross-validation (https://topepo.github.io/caret/model-training-and-tuning.html#the-traincontrol-function):

caretSummary <- function(data, lev=NULL, model=NULL){
  data <- data[order(data$obs, decreasing=TRUE),] # sort best -> worse
  nb.inds <- nrow(data)
  idx.best50p <- floor(0.50 * nb.inds)
  idx.best25p <- floor(0.25 * nb.inds)
  coefOls <- as.numeric(coef(lm(pred ~ obs, data=data)))
  c(rmse=sqrt(mean((data$pred - data$obs)^2)),
    corP=cor(data$obs, data$pred, method="pearson"),
    corS=cor(data$obs, data$pred, method="spearman"),
    corP.best50p=cor(data$obs[1:idx.best50p],
                     data$pred[1:idx.best50p],
                     method="pearson"),
    corS.best50p=cor(data$obs[1:idx.best50p],
                     data$pred[1:idx.best50p],
                     method="spearman"),
    corP.best25p=cor(data$obs[1:idx.best25p],
                     data$pred[1:idx.best25p],
                     method="pearson"),
    corS.best25p=cor(data$obs[1:idx.best25p],
                     data$pred[1:idx.best25p],
                     method="spearman"),
    reg.intercept=coefOls[1],
    reg.slope=coefOls[2])
}
caretTrainCtlLm <- trainControl(method="repeatedcv", number=5, repeats=5,
                                summaryFunction=caretSummary)

The cross-validation can now be performed:

set.seed(1234) # ensure that the same resamples are used
fit2 <- train(form=y ~ x, data=dat, method=caretMethLm, trControl=caretTrainCtlLm, metric="rmse")
fit2
names(fit2)
fit2$results
dim(fit2$resample)
head(fit2$resample)
mean(fit2$resample$rmse)
sd(fit2$resample$rmse)

Variable selection

glmnetModelInfo <- getModelInfo(model="glmnet", regex=FALSE)[[1]]
names(glmnetModelInfo)
glmnetModelInfo$label
glmnetModelInfo$library
glmnetModelInfo$type
glmnetModelInfo$parameters
glmnetModelInfo$grid
glmnetModelInfo$loop
glmnetModelInfo$fit
glmnetModelInfo$predict
glmnetModelInfo$prob
glmnetModelInfo$predictors
glmnetModelInfo$varImp
glmnetModelInfo$levels
glmnetModelInfo$tags
glmnetModelInfo$sort
glmnetModelInfo$trim

Data

set.seed(1234)
X <- simulGenosDose(nb.genos=10^3, nb.snps=10^3)
truth <- simulBvsr(Q=1, mu=0, X=X, pi=0.2, pve=0.75, sigma.a2=1)
sum(truth$gamma)
colnames(X)[truth$gamma == 1]
dat <- truth$dat

Custom varbvs

Make the inputs:

caretMethVarbvs <- list(library="varbvs",
                        type="Regression",
                        parameters=data.frame(parameter="intercept",
                                              class="logical",
                                              label="intercept"),
                        grid=caretGridVarbvs,
                        fit=caretFitVarbvs,
                        predict=caretPredictVarbvs,
                        prob=NULL,
                        sort=NULL)
caretTrainCtlVarbvs <- trainControl(method="repeatedcv", number=5, repeats=5,
                                    summaryFunction=caretSummary,
                                    allowParallel=TRUE)
## TODO: use the "seeds" argument

Run cross-validation:

set.seed(1234) # ensure that the same resamples are used
system.time(
  fit3 <- train(x=X, y=dat$response1,
                method=caretMethVarbvs,
                trControl=caretTrainCtlVarbvs,
                metric="rmse",
                Z=NULL, family="gaussian", verbose=FALSE))
fit3
fit3$results

Custom BGLR

Make the inputs:

# check:
# nb.iters=50 * 10^3; burn.in=5 * 10^3; thin=5
# fitTmp <- caretFitBglr(x=X, y=dat$response1,
#                        bglrEtaModel="BayesC", task.id="test-BGLR_",
#                        nb.iters=nb.iters, burn.in=burn.in, thin=thin,
#                        keep.samples=TRUE)
# plot(fitTmp$post.samples)
# raftery.diag(fitTmp$post.samples, q=0.5, r=0.05, s=0.9)
# geweke.diag(fitTmp$post.samples) # should not be too far from [-2;2]
# summary(fitTmp$post.samples)
# predTmp <- caretPredictBglr(fitTmp, X[1:10,])
caretMethBglr <- list(library=c("BGLR", "coda"),
                      type="Regression",
                      parameters=data.frame(parameter="intercept",
                                            class="logical",
                                            label="intercept"),
                      grid=caretGridBglr,
                      fit=caretFitBglr,
                      predict=caretPredictBglr,
                      prob=NULL,
                      sort=NULL)
caretTrainCtlBglr <- trainControl(method="repeatedcv", number=5, repeats=5,
                                  summaryFunction=caretSummary,
                                  allowParallel=TRUE)
## TODO: use the "seeds" argument

Run cross-validation:

set.seed(1234) # ensure that the same resamples are used
system.time(
  fit4 <- train(x=X, y=dat$response1,
                method=caretMethBglr,
                trControl=caretTrainCtlBglr,
                metric="rmse",
                ETA=list(list(X=X, model="BayesC")),
                saveAt="bglr4caret_",
                response_type="gaussian",
                # nIter=50 * 10^2, burnIn=5 * 10^2, thin=5, verbose=FALSE, # debug
                nIter=50 * 10^3, burnIn=5 * 10^3, thin=5, verbose=FALSE,
                keep.samples=FALSE))
fit4
fit4$results

Ridge regression

Custom rrBLUP

Make the inputs:

caretMethRrblup <- list(library=c("rrBLUP"),
                        type="Regression",
                        parameters=data.frame(parameter="intercept",
                                              class="logical",
                                              label="intercept"),
                        grid=caretGridRrblup,
                        fit=caretFitRrblup,
                        predict=caretPredictRrblup,
                        prob=NULL,
                        sort=NULL)
caretTrainCtlRrblup <- trainControl(method="repeatedcv", number=5, repeats=5,
                                    summaryFunction=caretSummary,
                                    allowParallel=TRUE)
## TODO: use the "seeds" argument

Run cross-validation:

set.seed(1234) # ensure that the same resamples are used
system.time(
    fit5 <- train(x=X, y=dat$response1,
                  method=caretMethRrblup,
                  trControl=caretTrainCtlRrblup,
                  metric="rmse"))
fit5
fit5$results

Cleaning

stopCluster(cl)

Appendix

t1 <- proc.time(); t1 - t0
print(sessionInfo(), locale=FALSE)


timflutre/rutilstimflutre documentation built on Feb. 7, 2024, 8:17 a.m.