R/cv-my-grad-descent.R

Defines functions cv_my_grad_descent

Documented in cv_my_grad_descent

#' @title Cross-Validation for my_grad_descent() Function
#' @description Does k-fold cross-validation for my_grad_descent(), returns out-of-sample accuracy (sum of residual squares).
#' @param form a formula object, e.g. y ~ x1 + x2
#' @param data a data frame
#' @param nfolds number of folds, default is 10
#' @param ... other arguments that can be passed to `my_grad_descent()`
#'
#' @import rsample
#' @import foreach
#'
#' @examples
#' data(iris)
#' fit <- cv_my_grad_descent(form = Sepal.Length ~ ., data = iris, nfolds = 5, lambda = 0.0001)
#' fit$coefficients
#' @export
cv_my_grad_descent <- function(form, data, nfolds = 10, ...){

  folds <- vfold_cv(data, v = nfolds)
  SSE <- foreach(fold = folds$splits, .combine = c) %do% {
    fit <- my_grad_descent(form, data = analysis(fold), ...)
    sum(as.vector(assessment(fold)[, as.character(form)[2]] -
                as.vector(predict(fit, assessment(fold))))^2)
  }
  list(SSE = SSE, MSE = mean(SSE))
}
tqchen07/bis557 documentation built on Dec. 21, 2020, 3:06 a.m.