R/cv_error.R

Defines functions cv_error

Documented in cv_error

#' Cross-validation error
#'
#' Calculates the k-fold cross-validation error
#'
#'
#' @param X design matrix
#' @param y output variable
#' @param k number of folds
#'
#' @return cross-validation error
#'
cv_error <- function(X,y,k){
  # create folds
  n <- dim(X)[1]
  folds <- sample(1:k,size = n,replace = TRUE)
  # calculate the error for the ith fold
  cv_fold <-function(i,X,y){
    assign("folds", value = folds,
           envir = parent.frame())
    test_X <- X[which(folds == i),]
    test_y <- y[which(folds == i)]
    train_X <- X[which(folds != i),]
    train_y <- y[which(folds != i)]
    model <- ls(train_X, train_y) #fit model
    coef <- model$coef  #get coefficients
    pred <- test_X%*%coef   #make predictions
    return(rmse(pred,test_y))
  }
  # calculate the error for all folds
  E <- sapply(seq(k), FUN = cv_fold, X, y)
  return(mean(E))
}
andreabecsek/portfolio documentation built on Jan. 2, 2020, 2:56 a.m.