R/tune_tree_interval.R

Defines functions tune.tree.interval

Documented in tune.tree.interval

#' mixlasso
#' @title Wrapper function for tree-lasso objects.
#' @description
#' Wrapper function for tree-lasso objects used by epsgo function. This function is mainly used within the function \code{epsgo}.
#' 
#' @importFrom parallel detectCores makeCluster	stopCluster
#' @importFrom foreach foreach %dopar%
#' @importFrom doParallel registerDoParallel
#' 
#' @param parms tuning parameter alpha for the tree-lasso object.
#' @param x,y \code{x} is a matrix where each row refers to a sample a each column refers to a gene; \code{y} is a factor which includes the class for each sample
#' @param lambda A user supplied lambda sequence. 
#' @param nfolds number of cross-validation's folds, default 5.
#' @param foldid an optional vector of values between 1 and nfold identifying what fold each observation is in. If supplied, nfold can be missing.
#' @param num.nonpen number of predictors forced to be estimated (i.e., nonpenalization).
#' @param seed random seed
#' @param intercept should  intercept(s) be fitted (default=\code{TRUE}) or set to zero (\code{FALSE}).
#' @param standardize.response standardization for the response variables. Default: \code{TRUE}.
#' @param p the number of predictors from different data source.
#' @param verbose print the middle search information, default is \code{TRUE}.
#' @param parallel  If \code{TRUE}, use parallel foreach to fit each lambda. If \code{c(TRUE,TRUE)}, use parallel foreach to fit each lambda and each fold. 
#' @param search.path save the visited points, default is \code{FALSE}.
#' @param threshold threshold for estimated coefficients of the tree-lasso models.
#' @export
tune.tree.interval<-function(parms, x, y,
                                lambda = NULL, 
                                nfolds = 5,
                                foldid=NULL,
                                tree.parm=tree.parm,
                                num.nonpen = 0,
                                seed=12345, 
                                intercept=TRUE,
                                standardize.response=FALSE,
                                p=NULL,
                                verbose=FALSE,
                                parallel=FALSE,
                                search.path=FALSE,
                                threshold=threshold,
                                maxiter=maxiter,
                                ...){
  
  # 1. decode the parameters ############################################################
  
  parms<-round(parms,digits=2)
  if (verbose) print(paste("IPF=",paste(as.character(parms),collapse=":")))

  if(standardize.response) y<-scale(y)[,]
  set.seed(seed)
  if(is.null(foldid)) foldid<-sample(rep(seq(nfolds),length=dim(x)[1]))
  
  #=========
  # using augmented data
  #=========
  if(is.null(lambda)) stop("No given lambda sequence!")
  lambda <- matrix(lambda,ncol=1)
  x2 <- x
  for(i in 1:(length(p)-1)) x2[,num.nonpen+(cumsum(p)[i]+1):cumsum(p)[i+1]] <- x2[,num.nonpen+(cumsum(p)[i]+1):cumsum(p)[i+1]]/parms[i]
 
  #tree.parm <- tree.parms(y=y)
  cv5 <- function(xx,la){
    sum((y[foldid==xx,] - cbind(rep(ifelse(intercept,1,NULL),sum(foldid==xx)),x2[foldid==xx,]) %*% tree.lasso(x=x2[!foldid==xx,], y=y[!foldid==xx,],lambda=la,tree.parm=tree.parm,num.nonpen=num.nonpen,intercept=intercept,threshold=threshold,maxiter=maxiter)$Beta)^2)/(prod(dim(y[foldid==xx,])))
  }
  la.seq <- function(la) {mean(sapply(1:max(foldid), cv5, la))}
  
  if(sum(parallel)){
    if(sum(parallel)>1){
      la.xx <- cbind(rep(lambda,times=max(foldid)), rep(1:max(foldid),each=length(lambda)))
      cores <- length(lambda)*max(foldid)
      cvm0 <- numeric(cores)
      cl <- makeCluster(cores)
      #clusterEvalQ(cl, library(mixLasso))
      registerDoParallel(cl)
      cvm0[1:cores] <- foreach(i = 1:cores, .combine=c, .packages= c('base','Matrix')) %dopar%{
        cv5(la.xx[i,2], la.xx[i,1])
      }
      stopCluster(cl)
      cvm0 <- rowMeans(matrix(cvm0, ncol=max(foldid)))
    }else{
      cvm0 <- numeric(length(lambda))
      nCores <- min(length(lambda),16,detectCores()-1)
      cl <- makeCluster(nCores)
      registerDoParallel(cl)
      cvm0[1:min(length(lambda),nCores)] <- foreach(i = 1:min(length(lambda),nCores), .combine=c, .packages= c('base','Matrix')) %dopar%{
        la.seq(lambda[i])
      }
      stopCluster(cl)
      
      if(length(lambda)>nCores){
        cvm0[(nCores+1):length(lambda)]<-apply(matrix(lambda[(nCores+1):length(lambda)],ncol=1), 1, la.seq)
      }
    }
    
  }
  if(sum(parallel)<1){
    cvm0<-apply(lambda, 1, la.seq)
  }
  #=========
  
  opt.lambda<-lambda[which.min(cvm0)]
  # q.val= mean cross-validated error over the folds
  q.val<-min(cvm0)
  
  if(!search.path){
    ret<-list(q.val=q.val, model=list(lambda=opt.lambda, ipf=parms, nfolds=nfolds))
  }else{
    ret<-list(q.val=q.val, model=list(lambda=opt.lambda, ipf=parms, nfolds=nfolds, search.cvm=c(lambda,cvm0)))
  }
  return(ret)
}
zhizuio/mixlasso documentation built on March 21, 2022, 1:07 a.m.