cvfit: Cross-validation of a prediction model

View source: R/cvfit.R

cvfitR Documentation

Cross-validation of a prediction model

Description

cvfit is a generic function for implementing cross-validation of a prediction model. Argument fun defines the function implementing the model has to be cross-validated.

Usage


cvfit(X, Y, fun, segm, printcv = FALSE, ...)

Arguments

X

A n x p matrix or data frame of reference (= training) row observations.

Y

For quantive responses: A n x q matrix or data frame, or vector of length n, of reference (= training) responses. For qualitative responses: A vector of length n, or n x 1 matrix, of reference (= training) responses (class membership).

fun

A function defining the prediction model to cross-validate.

segm

A list of the test segments. Typically, output of function segmkf or segmts.

printcv

Logical. If TRUE (default), fitting information are printed.

...

Optionnal arguments to pass through function fun.

Value

A list of outputs, such as:

y

Responses for the test data.

fit

Predictions for the test data.

r

Residuals for the test data.

Examples


data(datcass)

Xr <- datcass$Xr
yr <- datcass$yr

Xr <- detrend(Xr)
headm(Xr)

################ Cross-validation of a PLSR model

n <- nrow(Xr)
segm <- segmkf(n = n, K = 5, typ = "random", nrep = 3)      # = Repeated K-fold CV
#segm <- segmts(n = n, m = 40, nrep = 30)                   # = Repeated Test-set CV
fm <- cvfit(
  Xr, yr,
  fun = plsr,
  ncomp = 20,
  segm = segm,
  printcv = TRUE
  )
names(fm)
headm(fm$y)
headm(fm$fit)
headm(fm$r)

z <- mse(fm, ~ ncomp)
headm(z)
z[z$msep == min(z$msep), ]
plotmse(z)
plotmse(z, nam = "r2", ylim = c(0, 1))
abline(h = 1, lty = 2, col = "grey")

z <- mse(fm, ~ ncomp + rep)
headm(z)
plotmse(z, group = z$rep)
plotmse(z, group = z$rep, legend = FALSE, col = "grey")

################ Example of CV of an ad'hoc predictive function (e.g. using locw)

fun <- function(Xr, Yr, Xu, Yu = NULL, k, ncomp, ncompdis) {
  
  z <- pls(Xr, Yr, Xu, ncomp = ncompdis)
  resn <- getknn(z$Tr, z$Tu, k = k, diss = "mahalanobis")
  fm <- locw(
    Xr, Yr,
    Xu, Yu,
    listnn = resn$listnn,
    listw = lapply(resn$listd, wdist, h = 2),
    fun = plsr,
    algo = pls_kernel,
    ncomp = ncomp,
    print = TRUE
    )

  list(y = fm$y, fit = fm$fit, r = fm$r)
  
  }

n <- nrow(Xr)
segm <- segmkf(n = n, K = 5, typ = "random", nrep = 1)
fm <- cvfit(
  Xr, yr,
  fun = fun,
  k = 50,
  ncomp = 20,
  ncompdis = 10,
  segm = segm,
  printcv = TRUE
  )
headm(fm$fit)

z <- mse(fm, ~ k + ncomp)
headm(z)
z[z$msep == min(z$msep), ]
plotmse(z, group = z$k)


mlesnoff/rnirs documentation built on April 24, 2023, 4:17 a.m.