#' @importFrom spams spams.fistaTree spams.normalize
NULL
#' cross validation
#'
#' @description too much detail to type in
#'
#' @param X a matrix.
#' @param Y a vector
#' @param whichsubgroup a number
#' @param Lambda a number
#' @param W0 a vector
#' @param tree dont remember
#'
#' @return a value
#' @export
#'
#' @examples
#' # later
# note this function needs spams.tree!
CV.hier <- function(X, Y, whichsubgroup, Lambda, W0, tree){
index.shuffle <- sample(nrow(X), size = nrow(X), replace = FALSE)
# group index into k groups. it's alright if the number of rows is a multiple of 5
gr <- split(index.shuffle, factor(seq(1, 5)))
X.train <- X[-gr[[whichsubgroup]], ] # what is this gr???
X.test <- X[gr[[whichsubgroup]], ]
Y.train <- Y[-gr[[whichsubgroup]]]
Y.test <- Y[gr[[whichsubgroup]]]
# start training, using scaled X
# X.train.s <- scale(X.train) #this is for each subtraining set i
# since we have intercept, we have to scale and normalise.
meanMatrix <- matrix(rep(colMeans(X.train), nrow(X.train)), nrow(X.train), ncol(X.train), byrow = T) # every row same
X.train <- X.train - meanMatrix # centered
X.train <- spams.normalize(X.train) # rescale such that they have unit L2 norm. but this is different from unit variance
model.train <- spams.fistaTree(as.matrix(scale(Y.train)), X.train, W0 = W0, tree=tree, TRUE,
lambda1 = Lambda, ##### lambda
max_it = 200, L0 = 0.1, tol = 1e-5,
intercept = F, pos = F, compute_gram = T,
loss = 'square', regul = 'tree-l2') # extra attention!!!
betahat <- model.train[[1]]
# predict on test set
# X.test.s <- scale(X.test) # I think this is wrong!!
yhat <- X.test %*% betahat # + mean(Y.test) # use Y.test as estimated intercept
MSE <- MSE(yhat, Y.test)
return(MSE)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.