#' @title Cross-Validation for my_grad_descent() Function
#' @description Does k-fold cross-validation for my_grad_descent(), returns out-of-sample accuracy (sum of residual squares).
#' @param form a formula object, e.g. y ~ x1 + x2
#' @param data a data frame
#' @param nfolds number of folds, default is 10
#' @param ... other arguments that can be passed to `my_grad_descent()`
#'
#' @import rsample
#' @import foreach
#'
#' @examples
#' data(iris)
#' fit <- cv_my_grad_descent(form = Sepal.Length ~ ., data = iris, nfolds = 5, lambda = 0.0001)
#' fit$coefficients
#' @export
cv_my_grad_descent <- function(form, data, nfolds = 10, ...){
folds <- vfold_cv(data, v = nfolds)
SSE <- foreach(fold = folds$splits, .combine = c) %do% {
fit <- my_grad_descent(form, data = analysis(fold), ...)
sum(as.vector(assessment(fold)[, as.character(form)[2]] -
as.vector(predict(fit, assessment(fold))))^2)
}
list(SSE = SSE, MSE = mean(SSE))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.