#' @title A function fitting the OLS model
#' @description This is a function fitting the OLS model using gradient descent that calculates the penalty based on the out-of-sample accuracy. Here we use cross validation to calculate the out-of-sample accracy.
#' @param formula a formula of linear model
#' @param data_frame a data_frame
#' @param nfolds Default is 10. the number of folds for cross validation
#' @param contrasts Default is NULL. a list of constasts for factor variables
#' @param lambda Default is 0.0001. The speed of gradient descent
#' @param tolerence Default is 1e-20. The minimum difference between the old ssr and the update ssr.
#' @param beta1 Default is 1. The initial value of beta.
#' @param max_itr Default is 1e6. The maximum number of iterations
#' @import foreach
#' @import rsample
#' @examples
#' data(iris)
#' gradient_descent_OLS_cv(Sepal.Length ~ ., iris,nfolds=10)
#' @export
gradient_descent_OLS_cv<-function(formula,
data_frame,
nfolds=10,
contrasts= NULL,
lambda=0.0001,
tolerence=1e-20,
beta1=1,
max_itr=1e6){
folds <- vfold_cv(data_frame, v = nfolds)
y_name=as.character(formula)[2]
SSE <- foreach(fold = folds$splits, .combine = c) %do% {
fit <- gradient_descent_OLS(formula, data_frame = analysis(fold), contrasts = contrasts,lambda=lambda,tolerence=tolerence, beta1=beta1,max_itr=max_itr)
X<- model.matrix(formula,assessment(fold),contrasts=contrasts)
sum(as.vector(assessment(fold)[, y_name]-as.vector(fit$coefficients%*%t(X)))^2)
}
list(SSE = SSE, MSE = mean(SSE))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.