#' k-fold cross-validation for tsmvr
#'
#' Calculates the mean and standard deviation of errors
#' over a tsmvr k-fold cross-validation experiment.
#' The error for each base model is the normalized squared
#' error between the true response and the predicted response
#' on a given cross-validation set.
#'
#' @param X design matrix (n-by-p)
#' @param Y response matrix (n-by-q)
#' @param s1 regressor matrix sparsity (positive integer)
#' @param s2 covariance matrix sparsity (positive integer)
#' @param k number of k-folds (integer greater than 1)
#' @param pars list of algorithm parameters; output of \code{set_parameters}
#' @param quiet (logical)
#' @param seed set random seed (integer)
#'
#' @return A list of the mean and standard deviation of the errors
#' across the \code{K} folds.
#'
#' @note See also \code{\link{squared_error}},
#' \code{\link{k_folds}},
#' \code{\link{tsmvr_solve}}, and
#' \code{\link{set_parameters}}.
#'
#' @importFrom stats sd
#'
# #' @export
tsmvr_cv <- function(X, Y, s1, s2, pars, quiet = F, seed = NULL) {
stopifnot(
is.numeric(X), is.matrix(X),
is.numeric(Y), is.matrix(Y),
dim(X)[1] == dim(Y)[1],
is.numeric(s1), s1%%1 == 0, s1 > 0, s1 <= dim(X)[2] * dim(Y)[2],
is.numeric(s2), s2%%1 == 0, s2 > 0, s2 <= (dim(Y)[2])^2,
is.list(pars),
is.logical(quiet),
is.null(seed) || is.numeric(seed)
)
error <- rep(0, pars$k)
# Print header.
if (!quiet) {
cat("Solver mode ", pars$B_type, "-", pars$Omega_type, " with eta1 = ", pars$eta1, sep = "")
if (pars$Omega_type == 'min') cat(".\n", sep = '')
else cat(" and eta2 = ", pars$eta2, ".\n", sep = '')
cat("Fold\terror\t\ttime (s)\n")
}
# Compute folds.
fold_list <- k_folds(nrow(X), pars$k, seed)
# For each fold, solve and compute result.
tic <- Sys.time()
pars$quiet = T
for (i in 1:pars$k) {
X_train <- X[fold_list$train[[i]], ]
Y_train <- Y[fold_list$train[[i]], ]
X_val <- X[fold_list$val[[i]], ]
Y_val <- Y[fold_list$val[[i]], ]
B_hat <- tsmvr_solve(
X = X_train, Y = Y_train, s1 = s1, s2 = s2, pars = pars
)$B_hat
Y_pred <- X_val %*% B_hat
A = Y_val - Y_pred
error[i] <- sum(A*A)
toc <- (Sys.time() - tic)
# Print fold result to screen.
if (!quiet) {
cat(i, "\t", error[i], "\t", round(toc, 3), "\n", sep = "")
}
}
toc <- (Sys.time() - tic)
error_mean <- mean(error)
error_sd <- sd(error)
# Print final result to screen.
if (!quiet) {
cat("Fold mean = ", error_mean, "\n", sep = "")
cat("Fold sd = ", error_sd, "\n", sep = "")
}
# Compute results.
return(list(
error_mean = error_mean, error_sd = error_sd,
num_folds = pars$k, time = toc
))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.