R/CVHier.R

Defines functions CV.hier

Documented in CV.hier

#' @importFrom spams spams.fistaTree spams.normalize
NULL

#' cross validation
#'
#' @description too much detail to type in
#'
#' @param X a matrix.
#' @param Y a vector
#' @param whichsubgroup a number
#' @param Lambda a number
#' @param W0 a vector
#' @param tree dont remember
#'
#' @return a value
#' @export
#'
#' @examples
#' # later


# note this function needs spams.tree!

CV.hier <- function(X, Y, whichsubgroup, Lambda, W0, tree){
  index.shuffle <- sample(nrow(X), size = nrow(X), replace = FALSE)

  # group index into k groups. it's alright if the number of rows is a multiple of 5
  gr <- split(index.shuffle, factor(seq(1, 5)))


  X.train <- X[-gr[[whichsubgroup]], ]   # what is this gr???
  X.test <- X[gr[[whichsubgroup]], ]
  Y.train <- Y[-gr[[whichsubgroup]]]
  Y.test <- Y[gr[[whichsubgroup]]]

  # start training, using scaled X
  # X.train.s <- scale(X.train)  #this is for each subtraining set i
  # since we have intercept, we have to scale and normalise.

  meanMatrix <- matrix(rep(colMeans(X.train), nrow(X.train)),  nrow(X.train), ncol(X.train), byrow = T) # every row same
  X.train <- X.train - meanMatrix   # centered
  X.train <- spams.normalize(X.train)  # rescale such that they have unit L2 norm. but this is different from unit variance

  model.train <- spams.fistaTree(as.matrix(scale(Y.train)), X.train, W0 = W0, tree=tree, TRUE,
                                 lambda1 = Lambda,   ##### lambda
                                 max_it = 200, L0 = 0.1, tol = 1e-5,
                                 intercept = F, pos = F, compute_gram = T,
                                 loss = 'square', regul = 'tree-l2')  # extra attention!!!


  betahat <- model.train[[1]]


  # predict on test set
  # X.test.s <- scale(X.test)    # I think this is wrong!!
  yhat <- X.test %*% betahat # + mean(Y.test)  # use Y.test as estimated intercept
  MSE <- MSE(yhat, Y.test)

  return(MSE)
}
yymmhaha/PackPaper1 documentation built on May 24, 2019, 8:55 a.m.