R/ecpc.R

Defines functions .getWeightLeaves visualiseGroupweights visualiseGroupsetweights visualiseGroupset .prmsIGMode1 .IWLS .matmult3d hierarchicalLasso obtainHierarchy splitMedian cv.ecpc simDat .mlestlin createGroupset produceFolds postSelect ecpc

Documented in createGroupset cv.ecpc ecpc hierarchicalLasso obtainHierarchy postSelect produceFolds simDat splitMedian visualiseGroupset visualiseGroupsetweights visualiseGroupweights

###ecpc method to learn from co-data:
#Fit a generalised linear (linear, logistic) or Cox model, penalised with adaptive multi-group penalties.
#The method combines empirical Bayes estimation for the group hyperparameters with an extra level of shrinkage
#to be able to handle various co-data, including overlapping groups, hierarchical groups and continuous co-data.

ecpc <- function(Y,X,
                 Z=NULL,paraPen=NULL,paraCon=NULL,intrcpt.bam=TRUE,bam.method="ML",
                 groupsets=NULL,groupsets.grouplvl=NULL,hypershrinkage=NULL, #co-data former
                 unpen=NULL,intrcpt=TRUE,model=c("linear", "logistic", "cox"),
                 postselection="elnet,dense",maxsel=10,
                 lambda=NULL,fold=10,sigmasq=NaN,w=NULL,
                 nsplits=100,weights=TRUE,profplotRSS=FALSE,
                 Y2=NULL,X2=NULL,compare=TRUE,
                 mu=FALSE,normalise=FALSE,silent=FALSE,
                 datablocks=NULL,est_beta_method=c("glmnet","multiridge")#,standardise_Y=FALSE
                 #nIt=1,betaold=NaN
                 ){
  #-1. Description input --------------------------------------------------------------------------
  #
  #Data and co-data:
  # Y: nx1 vector with response data
  # X: nxp matrix with observed data
  # Z: pxG matrix with co-data on the p covariates 
  # groupsets: list of m elements, each element one co-data group set
  #            with each group set a list of groups containing the indices of covariates in that group 
  # groupsets.grouplvl: (optional) hierarchical groups define a group set on group level.
  #            list of m elements (corresponding to groupsets), with NULL if there is no structure on group level, or 
  #            with a list of groups containing the indices of groups of covariates in that group
  #
  #Model:
  # hypershrinkage: vector of m strings indicating shrinkage type used in level of extra shrinkage for each of the m group sets, 
  #                either in the the form "type" 
  #                or "type1,type2", in which type1 is used to select groups, and type 2 to estimate selected group parameters
  # unpen: vector with indices of covariates that should not be penalised (TD: adapt .mlestlin)
  # intrcpt: TRUE/FALSE to use/not use intercept (always unpenalised)
  # model: type of response Y (linear, logistic or Cox)
  # postselection: TRUE/FALSE if parsimonious model is/is not needed, or string with type of posterior selection method;
  #               "elnet", or "DSS" for corresponding post-selection method used
  # maxsel: vector with maximum number of penalised covariates to be selected (additional to all unpenalised covariates)
  #
  #Global hyperparameter estimation:
  #NOTE: tau^2_{global}=sigma^2/lambda (linear) or 1/lambda (logistic,Cox)
  # lambda: -numeric value for lambda to be used to compute initial beta on which EB estimates are based, or
  #         -"ML" or "CV" if (approximate) ML criterium or cross-validation needs to be used to estimate overall lambda
  # fold: number of folds used in inner CV if tau^_{global}^2 is estimated with CV
  # sigmasq: (linear model) given variance of Y~N(X*beta,sd=sqrt(sigmasq)), else estimated
  # w: (optional) mx1 vector to fix group set weights at given value.
  #
  #Local hyperparameter estimation:
  # nsplits: number of splits used in RSS criterion in the extra level of shrinkage
  # weights: TRUE/FALSE to use weights in ridge hypershrinkage to correct for group size
  #
  #Optional settings:
  # Y2,X2: (optional) independent data set for performance check 
  # compare: -TRUE if to grridge to be compared with glmnet, with same lambda. 
  #          -if "CV" or "ML", (possibly) different lambda used in glmnet (lambda "ML", "CV" specifies which method is used for grridge lambda)
  #          -for logistic/cox: if "MoM", CV approximation as initial value + MoM for moment iteration on whole group
  # silent: set to TRUE to suppress output messages
  #
  #Experimental settings:
  # mu: TRUE/FALSE to include/exclude group prior means (default FALSE)
  # normalise: TRUE if group variances should be normalised to sum to 1 (default FALSE)
  # nIt: number of times ecpc is iterated (default 1)
  # betaold: if beta known for similar, previous study, can be used as weights for group mean beta_old*mu (default unknown)

  nIt=1;
  betaold=NaN
  gammaForm=FALSE #co-data with bam
  minlam <- 0
  if(!all(est_beta_method %in% c("glmnet", "multiridge"))){
    warning("Estimation method for betas should be either glmnet or multiridge, set to multiridge")
    est_beta_method <- "multiridge"
  }
  if(length(est_beta_method)>1){
    est_beta_method <- "multiridge"
  }

  #-1.1 Check variable input is as expected--------------------------------------------------
  #check response Y, missings not allowed
  assert(checkVector(Y, any.missing=FALSE), checkMatrix(Y, any.missing=FALSE), 
         checkArray(Y, max.d=2, any.missing=FALSE))
  assert(checkLogical(Y), checkFactor(Y, n.levels=2), checkNumeric(Y))
  
  #check observed data X, missings not allowed
  assert(checkMatrix(X, any.missing=FALSE), checkArray(X, max.d=2, any.missing=FALSE))
  assertNumeric(X)
  
  #check whether dimensions Y and X match
  if(checkVector(Y)){
    if(length(Y)!=dim(X)[1]) stop("Length of vector Y should equal number of rows in X")
  }else{
    if(dim(Y)[1]!=dim(X)[1]) stop("Number of rows in Y and X should be the same")
  }
  
  #check co-data provided in either Z or groupsets and check format
  if(!is.null(Z)&!is.null(groupsets)){
    stop("Provide co-data either in Z or in groupsets, both not possible")
  }else if(is.null(Z)&is.null(groupsets)){
    print("No co-data provided. Regular ridge is computed corresponding to an 
          intercept only co-data model.")
    groupsets <- list(list(1:p))
  }else if(!is.null(Z)){
    #check if Z is provided as list of (sparse) matrices/arrays
    assertList(Z, types=c("vector", "matrix", "array", "dgCMatrix"), min.len=1, any.missing=FALSE)
    for(g in 1:length(Z)){
      assert(checkNumeric(Z[[g]], any.missing=FALSE), class(Z[[g]])[1]=="dgCMatrix")
      if(is.vector(Z[[g]])) Z[[g]] <- matrix(Z[[g]],length(Z[[g]]),1)
      #check if number of rows of Z match number of columns X
      if(dim(Z[[g]])[1]!=dim(X)[2]){
        stop(paste("The number of rows of co-data matrix",g, "should equal the 
                   number of columns of X, including (missing) co-data values for 
                   unpenalised variables."))
      } 
      #check number of co-data variables < variables
      if(dim(Z[[g]])[2] > dim(X)[2]){
        stop("Number of co-data variables in Z should be smaller than number of 
             variables in X")
      } 
      
      #check paraPen 
      assertList(paraPen, types="list", null.ok = TRUE) #should be NULL or list
      if(!is.null(paraPen)){
        #check if names match Zi, i=1,2,..
        assertSubset(names(paraPen), paste("Z", 1:length(Z), sep=""), empty.ok=FALSE)
        for(g in 1:length(Z)){
          nameZg <- paste("Z",g,sep="")
          if(nameZg%in%names(paraPen)){
            assertList(paraPen[[nameZg]]) #should be named list
            #paraPen for mgcv may obtain L, rank, sp, Si with i a number 1,2,3,..
            assertSubset(names(paraPen[[nameZg]]), 
                         c("L", "rank", "sp", paste("S",1:length(paraPen[[nameZg]]), sep="")))
            #elements Si, i=1,2,.. should be matrices and match number of columns of Zg
            for(nameSg in paste("S",1:length(paraPen[[nameZg]]), sep="")){
              if(nameSg%in%names(paraPen[[nameZg]])){
                #check whether Sg is a matrix or 2-dimensional array
                assert(checkMatrix(paraPen[[nameZg]][[nameSg]]),
                       checkArray(paraPen[[nameZg]][[nameSg]], d=2))
                #check square matrix and dimension match co-data matrix
                if(dim(Z[[g]])[2]!=dim(paraPen[[nameZg]][[nameSg]])[1] |
                   dim(Z[[g]])[2]!=dim(paraPen[[nameZg]][[nameSg]])[2]){
                  stop(paste("Dimensions of the square penalty matrix",nameSg,
                             "should be equal to the number of columns in 
                             co-data matrix",g))
                }
              }
            }
          }
        }
      } 
      
      #check paraCon
      assertList(paraCon, types="list", null.ok = TRUE) #should be NULL or list
      if(!is.null(paraCon)){
        #check if names match Zi, i=1,2,..
        assertSubset(names(paraCon), paste("Z", 1:length(Z), sep=""), empty.ok=FALSE)
        for(g in 1:length(Z)){
          nameZg <- paste("Z",g,sep="")
          if(nameZg%in%names(paraCon)){
            assertList(paraCon[[nameZg]], types=c("vector", "matrix", "array")) #should be named list
            #paraCon may obtain elements ("M.ineq" and "b.ineq") and/or ("M.eq" and "b.eq")
            namesparaCon <- names(paraCon[[nameZg]])
            assertSubset(namesparaCon, c("M.ineq", "b.ineq", "M.eq", "b.eq"))
            if( ("M.ineq"%in%namesparaCon)&!("b.ineq"%in%namesparaCon) |
                !("M.ineq"%in%namesparaCon)&("b.ineq"%in%namesparaCon)){
              stop("Neither/both M.ineq and b.ineq should be provided in paraCon")
            }
            if( ("M.eq"%in%namesparaCon)&!("b.eq"%in%namesparaCon) |
                !("M.eq"%in%namesparaCon)&("b.eq"%in%namesparaCon)){
              stop("Neither/both M.eq and b.eq should be provided in paraCon")
            }
          }
        }
      }
      
      #check intercept term used for Z; intrcpt.bam
      assertLogical(intrcpt.bam)
      
      #check type of method used for Z; bam.method
      assertSubset(bam.method, c("GCV.Cp", "GACV.Cp", "REML", "P-REML", "ML", "P-ML", "fREML"))
    }
  }else{ #!is.null(groupsets)
    #check groupsets is a list of lists of vectors of integers (covariate indices)
    assertList(groupsets, types=c("list"), min.len=1)
    for(g in 1:length(groupsets)){
      assertList(groupsets[[g]], types="integerish", null.ok = TRUE)
    }
    
    #check groupsets.grouplvl (NULL or list)
    assertList(groupsets.grouplvl, types=c("list","null"), null.ok = TRUE)
    if(length(groupsets.grouplvl)>0){
      #length should equal the number of group sets
      if(length(groupsets.grouplvl)!=length(groupsets)){
        stop("groupsets.grouplvl should be either NULL, or a list with at least one 
           element not equal to NULL, with the length of the list
           matching the length of the list provided in groupsets")
      }
      #elements in the group set on the group level should contain integers
      for(g in 1:length(groupsets.grouplvl)){
        assertList(groupsets.grouplvl[[g]], types="integerish", null.ok = TRUE)
      }
    }
    
    #check hypershrinkage
    checkVector(hypershrinkage, null.ok = TRUE)
    if(is.null(hypershrinkage)){
      hypershrinkage<-rep("ridge", length(groupsets))
    }
    if(length(hypershrinkage)>0){
      if(length(hypershrinkage)!=length(groupsets)){
        stop("Number of elements in hypershrinkage should match that of groupsets")
      }
      for(g in 1:length(hypershrinkage)){
        assertString(hypershrinkage[g]) #should be string
        assertSubset(unlist(strsplit(hypershrinkage[g], split=',')),
                     c("none","ridge","lasso","hierLasso")) #should be combination of these types
      }
    }
  }
  
  #check unpen
  assertIntegerish(unpen, null.ok=TRUE)
  
  #check intrcpt
  assertLogical(intrcpt, len = 1)
  
  #check model
  assert(checkSubset(model, c("linear", "logistic", "cox")),
         class(model)[1]=="family")
  
  #check postselection
  assertScalar(postselection)
  assert(postselection==FALSE,
         checkSubset(postselection,c( "elnet,dense", "elnet,sparse", 
                     "BRmarginal,dense", "BRmarginal,sparse", "DSS")))
  
  #check maxsel
  assertIntegerish(maxsel, lower=1, upper=dim(X)[2]-1) #must be integers and fewer than number of variables
  
  #check lambda
  assertScalar(lambda, null.ok=TRUE, na.ok = TRUE)
  assert(checkNumeric(lambda, null.ok=TRUE),
         checkString(lambda))
  if(testString(lambda)){
    assert(grepl("ML",lambda), grepl("CV", lambda))
  }
  
  #check fold
  assertIntegerish(fold, lower=2, upper = dim(X)[1])
  
  #check sigmasq
  assertNumeric(sigmasq, lower=0, len=1, null.ok=TRUE)
  
  #check w
  assertNumeric(w, lower=0, len=ifelse(!is.null(Z), length(Z), length(groupsets)), 
                null.ok = TRUE)
  
  #check nsplits
  assertIntegerish(nsplits, lower=1)
  
  #check weights
  assertLogical(weights, len=1)
  
  #check profplotRSS
  assertLogical(profplotRSS, len=1)
  
  #check test response Y2
  assert(checkVector(Y2, null.ok=TRUE), checkMatrix(Y2, null.ok=TRUE), 
         checkArray(Y2, max.d=2, null.ok=TRUE))
  assert(checkLogical(Y2, null.ok=TRUE), checkFactor(Y2, n.levels=2, null.ok=TRUE), 
         checkNumeric(Y2, null.ok=TRUE))
  
  #check test observed data X2
  assert(checkMatrix(X2, null.ok=TRUE), checkArray(X2, max.d=2, null.ok=TRUE))
  assertNumeric(X2, null.ok=TRUE)
  
  #check whether dimensions Y2 and X2 match
  if(!is.null(Y2)){
    if(checkVector(Y2)){
      if(length(Y2)!=dim(X2)[1]) stop("Length of vector Y2 should equal number of rows in X2")
    }else{
      if(dim(Y2)[1]!=dim(X2)[1]) stop("Number of rows in Y2 and X2 should be the same")
    }
  }
  #check whether number of variables in test data and training data match
  if(!is.null(X2)){
    if(dim(X2)[2]!=dim(X)[2]){
      stop("Number of columns in test data X2 should match that of training data X")
    }
  }
  
  #check compare
  assertScalar(compare)
  assert(checkLogical(compare),
         checkString(compare))
  if(testString(compare)){
    assert(grepl("ML",compare), grepl("CV", compare))
  }
  
  #check mu
  assertLogical(mu, len=1)
  
  #check normalise
  assertLogical(normalise, len=1)
  
  #check silent
  assertLogical(silent, len=1)
  
  #check datablocks
  assertList(datablocks, types="integerish", null.ok=TRUE)
  
  #check est_beta_method
  assertSubset(est_beta_method, c("glmnet", "multiridge"))
  
  #Save input colnames/rownames to return in output
  colnamesX <- colnames(X)
  if(!is.null(Z)){
    colnamesZ <- names(Z)
    codataNames <- unlist(lapply(1:length(Z),function(x){rep(paste("Z",x,sep=""),dim(Z[[x]])[2])}))
    codataSource <- unlist(lapply(1:length(Z),function(x){rep(x,dim(Z[[x]])[2])}))
    if(!is.null(names(Z))){
      codataNames <- unlist(lapply(1:length(Z),function(x){rep(names(Z)[x],dim(Z[[x]])[2])}))
    }
    codatavarNames <- unlist(lapply(Z,function(x){
      if(is.null(colnames(x))) return(1:dim(x)[2])
      colnames(x)
    }))
    namesZ <- paste(codataNames,codatavarNames,sep=".")
  }else{
    colnamesZ <- names(groupsets)
    codataNames <- unlist(lapply(1:length(groupsets),function(x){rep(paste("Z",x,sep=""),length(groupsets[[x]]))}))
    codataSource <- unlist(lapply(1:length(groupsets),function(x){rep(x,length(groupsets[[x]]))}))
    if(!is.null(names(groupsets))){
      codataNames <- unlist(lapply(1:length(groupsets),function(x){rep(names(groupsets)[x],length(groupsets[[x]]))}))
    }
    codatavarNames <- unlist(lapply(groupsets,function(x){
      if(is.null(colnames(x))) return(1:length(x))
      colnames(x)
    }))
    namesZ <- paste(codataNames,codatavarNames,sep=".")
  }

  
  #-2. Set-up variables ---------------------------------------------------------------------------
  n <- dim(X)[1] #number of samples
  p <- dim(X)[2] #number of covariates 
  
  if(!is.null(X2)) n2<-dim(X2)[1] #number of samples in independent data set x2 if given
  multi <- FALSE; if(!is.null(datablocks)) multi <- TRUE #use multiple global tau, one for each data block
  if(multi==FALSE) datablocks <- list((1:p)[!((1:p)%in%unpen)])

  cont_codata <- FALSE; if(!is.null(Z)) cont_codata <- TRUE
  
  if(class(model)[1]=="family"){
    #use glmnet package and cross-validation to compute initial global lambda en beta estimates
    fml <- model
    model <- "family"
    est_beta_method <- "glmnet"
    multi <- FALSE
    if(!is.numeric(lambda)) lambda <- "CV_glmnet"
  } 
  if(length(model)>1){
    if(all(is.element(Y,c(0,1))) || is.factor(Y)){
      model <- "logistic" 
    } else if(all(is.numeric(Y)) & !(is.matrix(Y) && dim(Y)[2]>1)){
      model <- "linear"
    }else{
      model <- "cox"
    }
  }
  levelsY<-NaN
  if(is.null(lambda)||is.nan(lambda)) lambda <- ifelse(model=="linear","ML","CV")
  if(model=="logistic"){
    levelsY<-cbind(c(0,1),c(0,1))
    if(lambda=="ML"){
      lambda <- "CV"
      if(!silent) print("For logistic model, use CV for overall tau.")
    }
    if(!all(is.element(Y,c(0,1)))){
    oldLevelsY<-levels(Y)
    levels(Y)<-c("0","1")
    Y<-as.numeric(Y)-1
    levelsY<-cbind(oldLevelsY,c(0,1))
    colnames(levelsY)<-c("Old level names","New level names")
    if(!is.null(Y2)){
      levels(Y2)<-c("0","1")
      Y2<-as.numeric(Y2)-1
    }
    #if(!silent) print("Y is put in 0/1 format, see levelsY in output for new names")
    print("Y is put in 0/1 format:")
    print(levelsY)
    }
  }
  if(model=='cox'){
    intrcpt <- FALSE
    if(length(lambda)==0) lambda <- "CV"
    if(lambda=="ML"){
      if(!silent) print("For cox model, no ML approximation for overall tau available. Use CV instead.")
      lambda <- "CV"
    }
    if(class(Y)[1]!="Surv"){
      Y <- survival::Surv(Y)
    }
  }
  switch(model,
         'linear'={
           Y <- c(Y) #make numeric vector 
           fml <- 'gaussian'
           sd_y <- sqrt(var(Y)*(n-1)/n)[1]
           # if(standardise_Y){
           #   Y <- Y/sd_y
           #   sd_y_former <- sd_y
           #   sd_y <- 1
           # }
         },
         'logistic'={
           Y <- c(Y)
           fml <- 'binomial'
           sd_y <- 1 #do not standardise y in logistic setting
           #sd_y_former <- sd_y
         },
         'cox'={
           fml <- 'cox'
           sd_y <- 1 #do not standardise y in cox regression setting
           #sd_y_former <- sd_y
         },
         "family"={
           sd_y <- 1
           if(fml$family%in%c("gaussian")) sd_y <- sqrt(var(Y)*(n-1)/n)[1]
         }
  )
  mutrgt<-0
  if(mu==FALSE){mu <- 0}else{
    if(cont_codata){
      warning("Co-data provided in Z instead of groupsets. This option does not 
               yet support inclusion of prior means. Prior means set to 0.")
      mu <- 0 
    }else{
      mu <- NaN
    }
  } 
  tauglobal<-NaN
  tausq<-NaN
  hyperlambdas<-c(NaN,NaN)
  
  # check whether unpenalised covariates are not in partition
  # and set penalty.factor of unpenalised covariates to 0 for glmnet
  penfctr <- rep(1,p) #factor=1 for penalised covariates
  if(length(unpen)>0){
    penfctr[unpen] <- 0 #factor=0 for unpenalised covariates
  }
  
  #-2.1 Variables describing groups and partition(s) =================================================
  if(!cont_codata){ #settings when co-data is provided in groupsets
    #remove any unpenalised covariates from group sets
    if(length(unpen)>0){
      if(any(unlist(groupsets)%in%unpen)){
        warning("Unpenalised covariates removed from group set")
        for(i in 1:length(groupsets)){
          for(j in 1:length(groupsets[[i]])){
            if(all(groupsets[[i]][[j]]%in%unpen)){
              groupsets[[i]][[j]] <- NULL #remove whole group if all covariates unpenalised
            }else{
              groupsets[[i]][[j]] <- groupsets[[i]][[j]][!(groupsets[[i]][[j]]%in%unpen)]
            }
          }
        }
      }
    }
    
    G <- sapply(groupsets,length) #1xm vector with G_i, number of groups in partition i
    m <- length(G) #number of partitions
    if(any(grepl("hierLasso",hypershrinkage))){
      if(length(groupsets.grouplvl)==0){
        stop("Group set on group level for hierarchical groups is missing")
      }
    }
    indGrpsGlobal <- list(1:G[1]) #global group index in case we have multiple partitions
    if(m>1){
      for(i in 2:m){
        indGrpsGlobal[[i]] <- (sum(G[1:(i-1)])+1):sum(G[1:i])
      }
    }
    Kg <- lapply(groupsets,function(x)(sapply(x,length))) #m-list with G_i vector of group sizes in partition i
    #ind1<-ind
    
    #ind <- (matrix(1,G,1)%*%ind)==(1:G)#sparse matrix with ij element TRUE if jth element in group i, otherwise FALSE
    i<-unlist(sapply(1:sum(G),function(x){rep(x,unlist(Kg)[x])}))
    j<-unlist(unlist(groupsets))
    ind <- Matrix::sparseMatrix(i,j,x=1) #sparse matrix with ij element 1 if jth element in group i (global index), otherwise 0
    
    Ik <- lapply(1:m,function(i){
      x<-rep(0,sum(G))
      x[(sum(G[1:i-1])+1):sum(G[1:i])]<-1
      as.vector(x%*%ind)}) #list for each partition with px1 vector with number of groups beta_k is in
    #sparse matrix with ij element 1/Ij if beta_j in group i
    
    #make co-data matrix Z (Zt transpose of Z as in paper, with co-data matrices stacked for multiple groupsets)
    Zt<-ind; 
    if(G[1]>1){
      Zt[1:G[1],]<-Matrix::t(Matrix::t(ind[1:G[1],])/apply(ind[1:G[1],],2,sum))
    }
    if(m>1){
      for(i in 2:m){
        if(G[i]>1){
          Zt[indGrpsGlobal[[i]],]<-Matrix::t(Matrix::t(ind[indGrpsGlobal[[i]],])/
                                               apply(ind[indGrpsGlobal[[i]],],2,sum))
        }
      }
    }
    if(dim(Zt)[2]<p) Zt <- cbind(Zt,matrix(rep(NaN,(p-dim(Zt)[2])*sum(G)),c(sum(G),p-dim(Zt)[2])))
    if(length(G)==1 && G==1){
      PenGrps <- matrix(sum(Zt^2),c(1,1))
    }else{
      PenGrps <- as.matrix(Zt[,!((1:p)%in%unpen)]%*%Matrix::t(Zt[,!((1:p)%in%unpen)])) #penalty matrix groups
    } 
    
  }else{ #settings when co-data is provided in list Z
    m <- length(Z)
    names(Z) <- paste("Z",1:m,sep="")
    for(i in 1:m){
      if(length(unpen)>0) Z[[i]][unpen,] <- NaN
    }
    G <- sapply(Z,function(x)dim(x)[2]) #1xm vector with G_i, number of variables in co-data source i
    
    indGrpsGlobal <- list(1:G[1]) #global group index in case we have multiple partitions
    if(m>1){
      for(i in 2:m){
        indGrpsGlobal[[i]] <- (sum(G[1:(i-1)])+1):sum(G[1:i])
      }
    }
    Zt <- t(Z[[1]])
    if(m>1){
      for(i in 2:m){
        Zt <- rbind(Zt,Matrix::t(Z[[i]]))
      }
    }
    Kg <- list(apply(Zt,1,function(x)(sum(!is.na(x))))) #m-list with G_i vector of group sizes in partition i
    
    if(is.null(hypershrinkage)){
      hypershrinkage <- rep("none",m)
      for(i in 1:m){
        bool.Pen <- names(Z)[i]%in%names(paraPen) #boolean indicating whether or not co-data weights should be penalised
        bool.Con <- names(Z)[i]%in%names(paraCon) #boolean indicating whether or not co-data weights should be constrained
        tempMat <- cbind(c("none","ridge"), c("none+constraints","ridge+constraints"))
        hypershrinkage[i] <- tempMat[1+bool.Pen, 1+bool.Con]
      }
      if(all(!grepl("constraints",hypershrinkage)) & sum(G)>1) hypershrinkage <- "mgcv"
    }else if(!all(hypershrinkage%in%c("none","ridge","mgcv","none+constraints","ridge+constraints"))){
      stop("Hypershrinkage should be one of none, ridge, none+constraints, ridge+constraints or mgcv.
              For co-data provided as matrix, hypershrinkage can set automatically.")
    }
    normalise <- FALSE
    
    Ik <- lapply(1:m,function(x) rep(1,p))
    PenGrps <- as.matrix(Zt[,!((1:p)%in%unpen)]%*%Matrix::t(Zt[,!((1:p)%in%unpen)])) #penalty matrix groups
  } 
  

  #-2.2 Weight variables for extra shrinkage on group parameters =====================================
  # Compute weights and corresponding weight matrix
  #Note: for logistic regression, another, different weight matrix W is defined below
  if(weights){
    weights <- unlist(Kg)
  }else{
    weights <- rep(1,sum(G))
  }
  if(length(weights)==1){Wminhalf<-1/sqrt(weights)}
  else{Wminhalf <- diag(1/sqrt(weights))} #W^{-0.5} (element-wise), with W diagonal matrix with weights for prior parameters

  #-3. The ecpc method possibly using extra shrinkage ---------------------------------------------
  #-3.1 Set-up variables =======================================================================================
  #-3.1.1 Variables adapted for intercept ####################################################################
  # copy X: add column with ones for intercept if included in the model
  if(intrcpt){
    Xc <- cbind(X,rep(1,n))
    unpen<-c(unpen,p+1) #add index of last column for intercept to unpenalised covariates 
  } else{
    Xc <- X
  }
  if(model%in%c("logistic","cox","family")){
    Xcinit<-Xc
  }
  
  #-3.1.2 Variables used in initialisation of beta (and afterwards) #########################################
  muhat <- array(mu,c(sum(G),nIt+1)); #group means, optional: fixed at mu if given
  gammatilde <- array(tausq,c(sum(G),nIt+1)) #group variances before truncating negative tau to 0, optional: fixed at tausq if given (tausq 1 value for all groups)
  gamma <- array(max(0,tausq),c(sum(G),nIt+1)) #group variances truncated at 0 for negative values, optional: fixed at tausq if given
  gamma0 <- 0
  gamma0tilde <- 0
  colnames(muhat)<-paste("Itr",0:nIt,sep="")
  colnames(gammatilde)<-paste("Itr",0:nIt,sep="")
  colnames(gamma)<-paste("Itr",0:nIt,sep="")
  tempRow <- unlist(lapply(1:length(G),function(x){paste("Z",x,".",1:G[x],sep="")}))
  rownames(muhat)<-tempRow;  rownames(gammatilde)<-tempRow;  rownames(gamma)<-tempRow
  
  
  weightsMu <- array(NaN,c(sum(G),nIt+1))
  if(is.nan(mu)){
    partWeightsMu <- array(1,c(m,nIt+1))
    partWeightsMuG<- array(1,c(sum(G),nIt+1))
  }else{
    partWeightsMu <- array(NaN,c(m,nIt+1))
    partWeightsMuG<- array(NaN,c(sum(G),nIt+1))
  }
  if(is.nan(tausq)){
    partWeightsTau <-array(1,c(m,nIt+1))
    partWeightsTauG<-array(1,c(sum(G),nIt+1))
  }else{
    partWeightsTau <-array(NaN,c(m,nIt+1))
    partWeightsTauG<-array(NaN,c(sum(G),nIt+1))
  }
  
  
  #-3.1.3 Variables used in iterations #######################################################################
  if(!is.null(X2)){
    if(model=="cox") YpredGR <- array(NaN,c(n2,nIt+1))
    else YpredGR <- array(NaN,c(n2,nIt+1))
    MSEecpc<-rep(NaN,nIt+1)
  } else { 
      YpredGR<-NaN
      MSEecpc<-NaN
      if(!is.nan(compare)){
        Ypredridge<-NaN
        MSEridge<-NaN
      }
  }
  ind0 <- c(); #keep track of index of groups which have 0 variance
  indnot0 <- 1:sum(G) #and keep track of index of groups which have variance > 0
  lambdashat<-array(NaN,c(2,nIt+1,m)) #hyperpenalties for extra level of shrinkage
  colnames(lambdashat)<-paste("Itr",0:nIt,sep="")
  rownames(lambdashat)<-c("PenMu","PenTau")
  lambdashat[1,,]<-hyperlambdas[1]; lambdashat[2,,]<-hyperlambdas[2] #optional: fix extra hyperpenalty if given 
  
  #-3.2 Initial tau and beta ========================================================================================
  if(!silent) print(paste("Estimate global tau^2 (equiv. global ridge penalty lambda)"))
  intrcptGLM <- intrcpt

  #inital tau given
  if(!is.nan(tausq)){
    lambda <- 1/tausq
    if(model=="linear" | (model=="family"&!is.nan(sigmasq))) lambda <- sigmasq/tausq
    if(!is.nan(compare) & compare!=FALSE){ #compare not false
      lambdaridge <- 1/tausq
    }
  }
  if(is.numeric(lambda) & compare!=FALSE) lambdaridge <- lambda
  
  datablockNo <- rep(1,p) #in case only one data type
  if(multi!=FALSE){
    if(!is.null(datablocks)){
      datablockNo <- c(unlist(lapply(1:length(datablocks),function(x){rep(x,length(datablocks[[x]]))}))) #p-dimensional vector with datablock number
    }else{
      datablocks <- list(1:p)
    }
    datablockNo[(1:p)%in%unpen]<-NA
    
    #compute multi-lambda; from multiridge package demo:
    #datablocks: list with each element a data type containing indices of covariates with that data type 
    Xbl <- multiridge::createXblocks(lapply(datablocks,function(ind) X[,intersect(ind,ind[!(ind%in%unpen)])]))
    XXbl <- multiridge::createXXblocks(lapply(datablocks,function(ind) X[,intersect(ind,ind[!(ind%in%unpen)])]))
    
    #Find initial lambda: fast CV per data block, separately using SVD. CV is done using the penalized package
    if(!is.numeric(lambda)){
      if(sum((1:p)%in%unpen)>0){
        capture.output({cvperblock <- multiridge::fastCV2(Xbl,Y=Y,kfold=fold,fixedfolds = FALSE,
                                                          X1=X[,(1:p)%in%unpen],intercept=intrcpt)})
      }else{
        capture.output({cvperblock <- multiridge::fastCV2(Xbl,Y=Y,kfold=fold,fixedfolds = FALSE,
                                                          intercept=intrcpt)})
      }
      lambdas <- cvperblock$lambdas
      lambdas[lambdas==Inf] <- 10^6
      
      #Find joint lambdas:
      if(length(lambdas)>1){
        leftout <- multiridge::CVfolds(Y=Y,kfold=fold,nrepeat=3,fixedfolds = FALSE) #Create (repeated) CV-splits of the data
        if(sum((1:p)%in%unpen)>0){
          capture.output({jointlambdas <- multiridge::optLambdasWrap(penaltiesinit=lambdas, XXblocks=XXbl,Y=Y,folds=leftout,
                                     X1=X[,(1:p)%in%unpen],intercept=intrcpt,
                                     score=ifelse(model == "linear", "mse", "loglik"),model=model)})
        }else{
          capture.output({jointlambdas <- multiridge::optLambdasWrap(penaltiesinit=lambdas, XXblocks=XXbl,Y=Y,folds=leftout,
                                     intercept=intrcpt,
                                     score=ifelse(model == "linear", "mse", "loglik"),model=model)})
        }
        
        lambda <- jointlambdas$optpen
      }else{
        lambda <- lambdas
      }
      
    }else{
      if(length(lambda)==1) lambdas <- rep(lambda, max(datablockNo))
    }
    
    lambdap <- rep(0,p)
    lambdap[!((1:p)%in%unpen)] <- lambda[datablockNo[!((1:p)%in%unpen)]]

    sigmahat <- 1 #sigma not in model for logistic: set to 1
    if(model=="linear"){
      if(!is.nan(sigmasq)) sigmahat <- sigmasq
      else{
        XtDinvX <- multiridge::SigmaFromBlocks(XXblocks = XXbl,lambda)
        if(length(unpen)>0){
          Xunpen <- X[,(1:p)%in%unpen]
          if(intrcpt) Xunpen <- cbind(Xunpen,rep(1,n))
          if(intrcpt && length(unpen)==1){
            betaunpenML <- sum(Y)/n
          }else{
            temp <- solve(XtDinvX+diag(rep(1,n)),Xunpen)
            betaunpenML <- solve(t(Xunpen)%*%temp , t(temp)%*%Y)
          }
          sigmahat <- c(t(Y-Xunpen%*%betaunpenML)%*%solve(XtDinvX+diag(rep(1,n)),Y-Xunpen%*%betaunpenML)/n)
        }else{
          sigmahat <- c(t(Y)%*%solve(XtDinvX+diag(rep(1,n)),Y)/n)
        }
        
        #sigmahat <- 1/n * (y-X[unpen] %*% solve(X%*%Lambda^{-1}%*%t(X) + diag(1,n))
        # lambdaoverall <- exp(mean(log(lambdap[lambdap!=0])))
        # Xacc <- X
        # Xacc[,!((1:p)%in%unpen)] <- as.matrix(X[,!((1:p)%in%unpen)] %*% 
        #                                         sparseMatrix(i=1:sum(!((1:p)%in%unpen)),j=1:sum(!((1:p)%in%unpen)),
        #                                                  x=c(1/sqrt(lambdap[lambdap!=0]/lambdaoverall))))
        # #Use ML for sigma estimation and/or initial lambda (tausq) estimate and/or mu
        # par <- .mlestlin(Y,Xacc,lambda=lambdaoverall,sigmasq=NaN,mu=0,tausq=NaN) #use maximum marginal likelihood
        # sigmahat <- par[2] #sigma could be optimised with CV in the end if not known
      }
    }
    muhat[,1] <- 0
    gamma[,1] <- rep(1,sum(G))
    tauglobal<- sigmahat/lambda #set global group variance
    mutrgt <- 0 #TD: with offset
    
    #Compute betas
    XXT <- multiridge::SigmaFromBlocks(XXbl,penalties=lambda) #create nxn Sigma matrix = sum_b [lambda_b)^{-1} X_b %*% t(X_b)]
    if(model!="cox"){
      if(sum((1:p)%in%unpen)>0){
        fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
      }else{
        fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt) #Fit. fit$etas contains the n linear predictors
      }
    }else{
      if(sum((1:p)%in%unpen)>0){
        fit <- multiridge::IWLSCoxridge(XXT,Y=Y, model=model,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
      }else{
        fit <- multiridge::IWLSCoxridge(XXT,Y=Y) #Fit. fit$etas contains the n linear predictors
      }
    }
    
    betas <- multiridge::betasout(fit, Xblocks=Xbl, penalties=lambda) #Find betas.
    intrcptinit <- c(betas[[1]][1]) #intercept
    betasinit <- rep(0,p) 
    betasinit[(1:p)%in%unpen] <- betas[[1]][-1] #unpenalised variables
    for(i in 1:length(datablocks)){
      betasinit[datablocks[[i]][!(datablocks[[i]]%in%unpen)]] <- betas[[1+i]]
    }  
    muinitp <- rep(0,p) #TD: with offset for mu
    
    rm(betas)
    #compare multiridge
    if(compare!=FALSE){
      lambdaridge <- lambda
    }
  }else{
    switch(model,
           'linear'={
             #Use Cross-validation to compute initial lambda (tausq)
             if((!is.nan(compare) & grepl("CV",compare)) | grepl("CV",lambda)){
               #use glmnet to do CV; computationally more expensive but other optimising criteria possible
               if(grepl("glmnet",lambda)){ 
                 lambdaGLM<-glmnet::cv.glmnet(X,Y,nfolds=fold,alpha=0,family=fml,
                                              standardize = FALSE,intercept=intrcpt,
                                              penalty.factor=penfctr,keep=TRUE) #alpha=0 for ridge
               }#else if(grepl("penalized",lambda)){ #use penalized to do CV
               #   Xsvd <- svd(X[,penfctr!=0])
               #   XF <- X[,penfctr!=0]%*% Xsvd$v
               #   Xunpen <- cbind(X[,penfctr==0]) #if empty vector, no unpenalised and no intercept
               #   if(intrcpt){
               #     Xunpen <- cbind(X[,penfctr==0],rep(1,n))
               #   }
               #   if(length(unpen)==0){
               #     ol1 <- penalized::optL2(Y,penalized=XF,fold=fold,trace=FALSE)
               #   }else{
               #     ol1 <- penalized::optL2(Y,penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE)
               #   }
               #   
               #   #ol2 <- penalized::optL2(Y,penalized=X[,penfctr!=0],unpenalized=Xunpen, fold=ol1$fold ) #gives same result, but the first is much faster for large p
               #   itr2<-1
               #   while((ol1$lambda>10^12 | ol1$lambda<10^-5 ) & itr2 < 10){
               #     if(length(unpen)==0){
               #       ol1 <- penalized::optL2(Y,penalized=XF,fold=fold,trace=FALSE)
               #     }else{
               #       ol1 <- penalized::optL2(Y,penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE)
               #     }
               #     itr2 <- itr2 + 1
               #   } 
               #   if(itr2==10 & ol1$lambda>10^12){
               #     if(ol1$lambda>10^10){
               #       ol1$lambda <- 10^12
               #       warning("Cross-validated global penalty lambda was >10^12 and set to 10^12")
               #     }
               #     if(ol1$lambda<10^-5){
               #       ol1$lambda <- 1
               #       warning("Cross-validated global penalty lambda was <10-5 and set to 1")
               #     }
               #   }
               #}
               else{ #do CV with fastCV2 from multiridge package
                 if(length(setdiff(unpen,p+1))==0){
                   Xbl <- X%*%t(X)
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                                        fixedfolds=FALSE,model=model,kfold=fold)})
                 }else{
                   Xbl <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                    fixedfolds=FALSE,model=model,X1=X[,penfctr==0],kfold=fold)})
                 }
               }
               
               if((!is.nan(compare) & grepl("CV",compare)) | (!is.nan(compare) & compare==TRUE)){
                 # lambdaridge<-lambdaGLM$lambda.min/sd_y*n #using glmnet
                 if(grepl("glmnet",lambda)) lambdaridge <- lambdaGLM$lambda.min/sd_y*n #fitted lambda
                 #else if(grepl("penalized",lambda)) lambdaridge <- ol1$lambda
                 else lambdaridge <- fastCVfit$lambdas
               } 
               if(grepl("CV",lambda)){
                 if(grepl("glmnet",lambda)) lambda <- lambdaGLM$lambda.min/sd_y*n #using glmnet
                 #else if(grepl("penalized",lambda)) lambda <- ol1$lambda #using penalized
                 else lambda <- fastCVfit$lambdas
                 sigmahat <- sigmasq
                 muhat[,1] <- mutrgt
                 gamma[,1] <- 1/lambda
                 tauglobal<- 1/lambda
               } 
             }
             #Use ML for sigma estimation and/or initial lambda (tausq) estimate and/or mu
             if(is.nan(sigmasq) | (!is.nan(compare) & grepl("ML",compare)) | grepl("ML",lambda) | is.nan(mutrgt)){
               #Estimate sigma^2, lambda and initial estimate for tau^2 (all betas in one group), mu=0 by default
               if(grepl("ML",lambda)){ lambda <- NaN}
               Xrowsum <- apply(X,1,sum)
               
               XXt <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
               
               Xunpen <- NULL #if empty vector, no unpenalised and no intercept
               if(sum(penfctr==0)>0){
                 Xunpen <- X[,penfctr==0]
               }
               par <- .mlestlin(Y=Y,XXt=XXt,Xrowsum=Xrowsum,
                                intrcpt=FALSE,Xunpen=NULL,  #Ignore unpenalised/intercept for initialisation
                                lambda=lambda,sigmasq=sigmasq,mu=mutrgt,tausq=tausq) #use maximum marginal likelihood

               lambda <- par[1] 
               sigmahat <- par[2] #sigma could be optimised with CV in the end if not known
               muhat[,1] <- par[3] 
               gamma[,1] <- par[4]
               
               tauglobal<- par[4] #set target group variance (overall variance if all covariates in one group)
               mutrgt <- par[3] #set target group mean (overall mean if all covariates in one group), 0 by default
               
               if((!is.nan(compare) & grepl("ML",compare)) | (!is.nan(compare) & compare==TRUE)) lambdaridge<- par[1]
             }
             
             #initial estimate for beta
             lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
             lambdap[(1:p)%in%unpen] <- 0
             
             if(cont_codata){ 
               muinitp <- rep(0,p)
             }else{
               muinitp <- as.vector(c(muhat[,1])%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p (0 for unpenalised covariates) 
               muinitp[(1:p)%in%unpen] <- 0
             }
             if(est_beta_method=="glmnet"){
               glmGRtrgt <- glmnet::glmnet(X,Y,alpha=0,
                                           #lambda = lambda/n*sd_y,
                                           family=fml,
                                           offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], 
                                           intercept = intrcpt, standardize = FALSE,
                                           penalty.factor=penfctr)
               #minlam <- min(glmGRtrgt$lambda)*n/sd_y
               if(lambda < minlam){
                 warning("Estimated lambda value found too small, set to minimum to for better numerical performance")
                 lambda <- minlam
                 lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
                 lambdap[(1:p)%in%unpen] <- 0
                 
                 Xunpen <- NULL #if empty vector, no unpenalised and no intercept
                 if(sum(penfctr==0)>0){
                   Xunpen <- X[,penfctr==0]
                 }
                 
                 #re-estimate sigma and tau_global for new lambda value
                 Xrowsum <- apply(X,1,sum)
                 XXt <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
                 par <- .mlestlin(Y=Y,XXt=XXt,Xrowsum=Xrowsum,
                                  intrcpt=intrcpt,Xunpen=Xunpen, #TD: adapt for intercept and Xunpen
                                  lambda=lambda,sigmasq=NaN,mu=mutrgt,tausq=tausq) #use maximum marginal likelihood
                 sigmahat <- par[2] #sigma could be optimised with CV in the end if not known
                 gamma[,1] <- par[4]
                 tauglobal<- par[4] #set target group variance (overall variance if all covariates in one group)
               }
               #betasinit <- as.vector(glmGRtrgt$beta)
               betasinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10, exact=TRUE,
                                 x=X,y=Y,
                                 family=fml,
                                 offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], 
                                 intercept = intrcpt,
                                 penalty.factor=penfctr)[-1]
               betasinit[!((1:p)%in%unpen)] <- betasinit[!((1:p)%in%unpen)] + muinitp[!((1:p)%in%unpen)]
               #intrcptinit <- glmGRtrgt$a0
               intrcptinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10, exact=TRUE,
                                   x=X,y=Y,
                                   family=fml,
                                   offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], 
                                   intercept = intrcpt,
                                   penalty.factor=penfctr)[1]
             }else{ #use multiridge package
               XXbl <- list(X[,penfctr!=0]%*%t(X[,penfctr!=0]))
               #Compute betas
               XXT <- multiridge::SigmaFromBlocks(XXbl,penalties=lambda) #create nxn Sigma matrix = sum_b [lambda_b)^{-1} X_b %*% t(X_b)]
               if(sum((1:p)%in%unpen)>0){
                 fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
               }else{
                 fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt) #Fit. fit$etas contains the n linear predictors
               }
               
               betas <- multiridge::betasout(fit, Xblocks=list(X[,penfctr!=0]), penalties=lambda) #Find betas.
               intrcptinit <- c(betas[[1]][1]) #intercept
               betasinit <- rep(0,p) 
               betasinit[(1:p)%in%unpen] <- betas[[1]][-1] #unpenalised variables
               betasinit[!((1:p)%in%unpen)] <- betas[[2]]
               rm(betas)
             }
             
           },
           'logistic'={
             #Use Cross-validation to compute initial lambda (tausq)
             if((!is.nan(compare) & grepl("CV",compare)) | grepl("CV",lambda)){
               #use glmnet to do CV; computationally more expensive but other optimising criteria possible
               if(grepl("glmnet",lambda)){ 
                 lambdaGLM<-glmnet::cv.glmnet(X,Y,nfolds=fold,alpha=0,family=fml,
                                      standardize = FALSE,intercept=intrcpt,
                                      penalty.factor=penfctr,keep=TRUE) #alpha=0 for ridge
               }#else if(grepl("penalized",lambda)){ #use penalized to do CV
                 # Xsvd <- svd(X[,penfctr!=0])
                 # XF <- X[,penfctr!=0]%*% Xsvd$v
                 # if(intrcpt){
                 #   Xunpen <- cbind(X[,penfctr==0],rep(1,n))
                 # }else{
                 #   Xunpen <- cbind(X[,penfctr==0]) #if empty vector, no unpenalised and no intercept
                 # }
                 # 
                 # ol1 <- penalized::optL2(Y,penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE,minlambda2=10^-6)
                 # #ol2 <- penalized::optL2(Y,penalized=X[,penfctr!=0],unpenalized=Xunpen, fold=ol1$fold ) #gives same result, but the first is much faster for large p
                 # itr2<-1
                 # while((ol1$lambda>10^12 | ol1$lambda<10^-5 ) & itr2 < 10){
                 #   ol1 <- penalized::optL2(Y,penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE,minlambda2=10^-6)
                 #   itr2 <- itr2 + 1
                 #   # if(ol1$lambda>10^12){
                 #   #   ol1$lambda <- 10^2
                 #   #   warning("Cross-validated global penalty lambda was >10^12 and set to 100")
                 #   # }
                 # } 
                 # if(itr2==10 & (ol1$lambda>10^10 | ol1$lambda<10^-5 )){
                 #   if(ol1$lambda>10^10){
                 #     ol1$lambda <- 10^12
                 #     warning("Cross-validated global penalty lambda was >10^12 and set to 10^12")
                 #   }
                 #   if(ol1$lambda<10^-5){
                 #     ol1$lambda <- 1
                 #     warning("Cross-validated global penalty lambda was <10-5 and set to 1")
                 #   }
                 # }
                 # 
               #}
               else{ #do CV with fastCV2 from multiridge package
                 if(length(setdiff(unpen,p+1))==0){
                   Xbl <- X%*%t(X)
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                                    fixedfolds=FALSE,model=model,kfold=fold)})
                 }else{
                   Xbl <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                                    fixedfolds=FALSE,model=model,X1=X[,penfctr==0],kfold=fold)})
                 }
               }
               
               if((!is.nan(compare) & grepl("CV",compare)) | (!is.nan(compare) & compare==TRUE)){
                 if(grepl("glmnet",lambda)) lambdaridge <- lambdaGLM$lambda.min/sd_y*n #fitted lambda
                 #else if(grepl("penalized",lambda)) lambdaridge <- ol1$lambda
                 else lambdaridge <- fastCVfit$lambdas
               } 
               if(grepl("CV",lambda)){
                 if(grepl("glmnet",lambda)) lambda <- lambdaGLM$lambda.min/sd_y*n #using glmnet
                 #else if(grepl("penalized",lambda)) lambda <- ol1$lambda #using penalized
                 else lambda <- fastCVfit$lambdas
               } 
               #print(lambda)
             }
             gamma[,1] <- 1/lambda
             tauglobal <- 1/lambda
             sigmahat <- 1 #sigma not in model for logistic: set to 1
             muhat[,1] <- mu #use initial mean 0 in logistic setting
             mutrgt <- mutrgt #default: 0
             
             #initial estimate for beta
             lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
             lambdap[(1:p)%in%unpen] <- 0
             
             if(cont_codata){ 
               muinitp <- rep(0,p)
             }else{
               muinitp <- as.vector(c(muhat[,1])%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p (0 for unpenalised covariates) 
               muinitp[(1:p)%in%unpen] <- 0
             }
             if(est_beta_method=="glmnet"){
               glmGRtrgt <- glmnet::glmnet(X,Y,alpha=0,
                                           #lambda = lambda/n*sd_y,
                                           family=fml,
                                           offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt, standardize = FALSE,
                                           penalty.factor=penfctr)
               #minlam <- min(glmGRtrgt$lambda)*n/sd_y
               if(lambda < minlam){
                 warning("Estimated lambda value found too small, set to minimum for better numerical performance")
                 lambda <- minlam
                 lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
                 lambdap[(1:p)%in%unpen] <- 0
                 
                 #re-estimate tau_global for new lambda value
                 gamma[,1] <- 1/lambda
                 tauglobal <- 1/lambda
               }
               
               #betasinit <- as.vector(glmGRtrgt$beta)
               betasinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10, exact=TRUE,
                                 x=X,y=Y,
                                 family=fml,
                                 offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt,
                                 penalty.factor=penfctr)[-1]
               betasinit[!((1:p)%in%unpen)] <- betasinit[!((1:p)%in%unpen)] + muinitp[!((1:p)%in%unpen)]
               #intrcptinit <- glmGRtrgt$a0
               intrcptinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10,exact=TRUE,
                                   x=X,y=Y,
                                   family=fml,
                                   offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt,
                                   penalty.factor=penfctr)[1]
             }else{ #use multiridge package
               XXbl <- list(X[,penfctr!=0]%*%t(X[,penfctr!=0]))
               #Compute betas
               XXT <- multiridge::SigmaFromBlocks(XXbl,penalties=lambda) #create nxn Sigma matrix = sum_b [lambda_b)^{-1} X_b %*% t(X_b)]
               if(sum((1:p)%in%unpen)>0){
                 fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
               }else{
                 fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt) #Fit. fit$etas contains the n linear predictors
               }
               
               betas <- multiridge::betasout(fit, Xblocks=list(X[,penfctr!=0]), penalties=lambda) #Find betas.
               intrcptinit <- c(betas[[1]][1]) #intercept
               betasinit <- rep(0,p) 
               betasinit[(1:p)%in%unpen] <- betas[[1]][-1] #unpenalised variables
               betasinit[!((1:p)%in%unpen)] <- betas[[2]]
               rm(betas)
             }
             
           },
           'cox'={
             #Cross-validation lambda
             if((!is.nan(compare) & grepl("CV",compare)) | grepl("CV",lambda)){
               if(grepl("glmnet",lambda)){ 
                 lambdaGLM<-glmnet::cv.glmnet(X,as.matrix(Y),nfolds=fold,alpha=0,family=fml,
                                      standardize = FALSE,
                                      penalty.factor=penfctr,keep=TRUE) #alpha=0 for ridge
               }#else if(grepl("penalized",lambda)){ #use penalized to do CV
               #   Xsvd <- svd(X[,penfctr!=0])
               #   XF <- X[,penfctr!=0]%*% Xsvd$v
               #   Xunpen <- cbind(X[,penfctr==0]) #if empty vector, no unpenalised and no intercept
               #   
               #   ol1 <- penalized::optL2(survival::Surv(Y[,1],Y[,2]),penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE,minlambda2=10^-6)
               #   #ol2 <- penalized::optL2(Y,penalized=X[,penfctr!=0],unpenalized=Xunpen, fold=ol1$fold ) #gives same result, but the first is much faster for large p
               #   itr2<-1
               #   while((ol1$lambda>10^10 | ol1$lambda<10^-5 ) & itr2 < 10){
               #     ol1 <- penalized::optL2(survival::Surv(Y[,1],Y[,2]),penalized=XF, unpenalized =Xunpen,fold=fold,trace=FALSE,minlambda2=10^-6)
               #     itr2 <- itr2 + 1
               #     # if(ol1$lambda>10^12){
               #     #   ol1$lambda <- 10^2
               #     #   warning("Cross-validated global penalty lambda was >10^12 and set to 100")
               #     # }
               #   } 
               #   if(itr2==10 & (ol1$lambda>10^10 | ol1$lambda<10^-5 )){
               #     if(ol1$lambda>10^10){
               #       ol1$lambda <- 10^12
               #       warning("Cross-validated global penalty lambda was >10^12 and set to 10^12")
               #     }
               #     if(ol1$lambda<10^-5){
               #       ol1$lambda <- 1
               #       warning("Cross-validated global penalty lambda was <10-5 and set to 1")
               #     }
               #   }
               # }
               else{ #do CV with fastCV2 from multiridge package
                 if(length(setdiff(unpen,p+1))==0){
                   Xbl <- X%*%t(X)
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,
                                        fixedfolds=FALSE,model=model,kfold=fold)})
                 }else{
                   Xbl <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,
                                        fixedfolds=FALSE,model=model,X1=X[,penfctr==0],kfold=fold)})
                 }
               }
               
               if((!is.nan(compare) & grepl("CV",compare))| (!is.nan(compare) & compare==TRUE)){
                 if(grepl("glmnet",lambda)) lambdaridge <- lambdaGLM$lambda.min/sd_y*n/2 #fitted lambda
                 #else if(grepl("penalized",lambda)) lambdaridge <- ol1$lambda
                 else lambdaridge <- fastCVfit$lambdas
               } 
               if(grepl("CV",lambda)){
                 if(grepl("glmnet",lambda)) lambda <- lambdaGLM$lambda.min/sd_y*n/2 #fitted lambda
                 #else if(grepl("penalized",lambda)) lambda <- ol1$lambda
                 else lambda <- fastCVfit$lambdas
               } 
             }
             sigmahat <- 1 #sigma not in model for cox: set to 1
             muhat[,1] <- 0 #use initial mean 0 in cox setting
             gamma[,1] <- 1/lambda
             mutrgt <- 0
             tauglobal <- 1/lambda
             
             #initial estimate for beta
             lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
             lambdap[(1:p)%in%unpen] <- 0
             if(cont_codata){ 
               muinitp <- rep(0,p)
             }else{
               muinitp <- as.vector(c(muhat[,1])%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p (0 for unpenalised covariates) 
               muinitp[(1:p)%in%unpen] <- 0
             }
             if(est_beta_method=="glmnet"){
               glmGRtrgt <- glmnet::glmnet(X,Y,alpha=0,
                                           #lambda = 2*lambda/n*sd_y,
                                           family=fml,
                                           offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], standardize = FALSE,
                                           penalty.factor=penfctr)
               #minlam <- min(glmGRtrgt$lambda)*n/sd_y/2
               if(lambda < minlam){
                 warning("Estimated lambda value found too small, set to minimum to for better numerical performance")
                 lambda <- minlam
                 lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
                 lambdap[(1:p)%in%unpen] <- 0
                 
                 #re-estimate tau_global for new lambda value
                 gamma[,1] <- 1/lambda
                 tauglobal <- 1/lambda
               }
               
               intrcptinit <- NULL #NULL for Cox
               #betasinit <- as.vector(glmGRtrgt$beta)
               betasinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10,exact=TRUE,
                                 x=X,y=Y,
                                 family=fml,
                                 offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)],
                                 penalty.factor=penfctr)
               betasinit[!((1:p)%in%unpen)] <- betasinit[!((1:p)%in%unpen)] + muinitp[!((1:p)%in%unpen)]
               #intrcptinit <- glmGRtrgt$a0
             }else{
               XXbl <- list(X[,penfctr!=0]%*%t(X[,penfctr!=0]))
               #Compute betas
               XXT <- multiridge::SigmaFromBlocks(XXbl,penalties=lambda) #create nxn Sigma matrix = sum_b [lambda_b)^{-1} X_b %*% t(X_b)]
               if(sum((1:p)%in%unpen)>0){
                 fit <- multiridge::IWLSCoxridge(XXT,Y=Y, model=model,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
               }else{
                 fit <- multiridge::IWLSCoxridge(XXT,Y=Y) #Fit. fit$etas contains the n linear predictors
               }
               
               betas <- multiridge::betasout(fit, Xblocks=list(X[,penfctr!=0]), penalties=lambda) #Find betas.
               intrcptinit <- c(betas[[1]][1]) #intercept
               betasinit <- rep(0,p) 
               betasinit[(1:p)%in%unpen] <- betas[[1]][-1] #unpenalised variables
               betasinit[!((1:p)%in%unpen)] <- betas[[2]]
               rm(betas)
             }
           },
           'family'={
             #Use Cross-validation to compute initial lambda (tausq)
             if((!is.nan(compare) & grepl("CV",compare)) | grepl("CV",lambda)){
               #use glmnet to do CV; computationally more expensive but other optimising criteria possible
               if(grepl("glmnet",lambda)){ 
                 lambdaGLM<-glmnet::cv.glmnet(X,Y,nfolds=fold,alpha=0,family=fml,
                                              standardize = FALSE,intercept=intrcpt,
                                              penalty.factor=penfctr,keep=TRUE) #alpha=0 for ridge
               }
               else{ #do CV with fastCV2 from multiridge package
                 if(length(setdiff(unpen,p+1))==0){
                   Xbl <- X%*%t(X)
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                                                    fixedfolds=FALSE,model=model,kfold=fold)})
                 }else{
                   Xbl <- X[,penfctr!=0]%*%t(X[,penfctr!=0])
                   capture.output({fastCVfit <- multiridge::fastCV2(XXblocks=list(Xbl),Y=Y,intercept=intrcpt,
                                                                    fixedfolds=FALSE,model=model,X1=X[,penfctr==0],kfold=fold)})
                 }
               }
               
               if((!is.nan(compare) & grepl("CV",compare)) | (!is.nan(compare) & compare==TRUE)){
                 if(grepl("glmnet",lambda)) lambdaridge <- lambdaGLM$lambda.min/sd_y*n #fitted lambda
                 #else if(grepl("penalized",lambda)) lambdaridge <- ol1$lambda
                 else lambdaridge <- fastCVfit$lambdas
               } 
               if(grepl("CV",lambda)){
                 if(grepl("glmnet",lambda)) lambda <- lambdaGLM$lambda.min/sd_y*n #using glmnet
                 #else if(grepl("penalized",lambda)) lambda <- ol1$lambda #using penalized
                 else lambda <- fastCVfit$lambdas
               } 
               #print(lambda)
             }
             gamma[,1] <- 1/lambda
             tauglobal <- 1/lambda
             sigmahat <- 1 #sigma not in model for logistic: set to 1
             muhat[,1] <- mu #use initial mean 0 in logistic setting
             mutrgt <- mutrgt #default: 0
             
             #initial estimate for beta
             lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
             lambdap[(1:p)%in%unpen] <- 0
             
             if(cont_codata){ 
               muinitp <- rep(0,p)
             }else{
               muinitp <- as.vector(c(muhat[,1])%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p (0 for unpenalised covariates) 
               muinitp[(1:p)%in%unpen] <- 0
             }
             if(est_beta_method=="glmnet"){
               glmGRtrgt <- glmnet::glmnet(X,Y,alpha=0,
                                           #lambda = lambda/n*sd_y,
                                           family=fml,
                                           offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt, standardize = FALSE,
                                           penalty.factor=penfctr)
               #minlam <- min(glmGRtrgt$lambda)*n/sd_y
               if(lambda < minlam){
                 warning("Estimated lambda value found too small, set to minimum for better numerical performance")
                 lambda <- minlam
                 lambdap <- rep(lambda,p) #px1 vector with penalty for each beta_k, k=1,..,p
                 lambdap[(1:p)%in%unpen] <- 0
                 
                 #re-estimate tau_global for new lambda value
                 gamma[,1] <- 1/lambda
                 tauglobal <- 1/lambda
               }
               
               #betasinit <- as.vector(glmGRtrgt$beta)
               betasinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10, exact=TRUE,
                                 x=X,y=Y,
                                 family=fml,
                                 offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt,
                                 penalty.factor=penfctr)[-1]
               betasinit[!((1:p)%in%unpen)] <- betasinit[!((1:p)%in%unpen)] + muinitp[!((1:p)%in%unpen)]
               #intrcptinit <- glmGRtrgt$a0
               intrcptinit <- coef(glmGRtrgt,s=lambda/n*sd_y,thresh = 10^-10,exact=TRUE,
                                   x=X,y=Y,
                                   family=fml,
                                   offset = X[,!((1:p)%in%unpen)] %*% muinitp[!((1:p)%in%unpen)], intercept = intrcpt,
                                   penalty.factor=penfctr)[1]
             }#multiridge package not yet possible for general glm families
             # }else{ #use multiridge package
             #   XXbl <- list(X[,penfctr!=0]%*%t(X[,penfctr!=0]))
             #   #Compute betas
             #   XXT <- multiridge::SigmaFromBlocks(XXbl,penalties=lambda) #create nxn Sigma matrix = sum_b [lambda_b)^{-1} X_b %*% t(X_b)]
             #   if(sum((1:p)%in%unpen)>0){
             #     fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt,X1=X[,(1:p)%in%unpen]) #Fit. fit$etas contains the n linear predictors
             #   }else{
             #     fit <- multiridge::IWLSridge(XXT,Y=Y, model=model,intercept=intrcpt) #Fit. fit$etas contains the n linear predictors
             #   }
             #   
             #   betas <- multiridge::betasout(fit, Xblocks=list(X[,penfctr!=0]), penalties=lambda) #Find betas.
             #   intrcptinit <- c(betas[[1]][1]) #intercept
             #   betasinit <- rep(0,p) 
             #   betasinit[(1:p)%in%unpen] <- betas[[1]][-1] #unpenalised variables
             #   betasinit[!((1:p)%in%unpen)] <- betas[[2]]
             #   rm(betas)
             # }
             
           }
    )
  }
  
  
  #-3.3 Start iterations (usually just one iteration) ========================================================================================
  Itr<-1
  while(Itr<=nIt){
    #-3.3.1 Compute penalty matrix and weight matrix for logistic #############################################
    #copy penalty parameter matrix ridge: add 0 for unpenalised intercept if included
    if(intrcpt | intrcptGLM){
      #Deltac <- diag(c(lambdap,0))
      Deltac <- Matrix::sparseMatrix(i=1:(length(lambdap)+1),j=1:(length(lambdap)+1),x=c(lambdap,0))
      if(model=="logistic"){
        #Deltac<-2*Deltac
        #reweight Xc for logistic model
        expminXb<-exp(-Xcinit%*%c(betasinit,intrcptinit))
        Pinit<-1/(1+expminXb)
        W<-diag(c(sqrt(Pinit*(1-Pinit))))
        Xc<-W%*%Xcinit
      }else if(model=="family"){
        lp <- Xcinit%*%c(betasinit,intrcptinit) #linear predictor
        meansY <- fml$linkinv(lp) #mu=E(Y)
        W <- diag(c(sqrt(fml$variance(meansY))))
        Xc<-W%*%Xcinit
      }
    }else{
      #Deltac <- diag(c(lambdap))
      Deltac <- Matrix::sparseMatrix(i=1:length(lambdap),j=1:length(lambdap),x=c(lambdap))
      if(model=="logistic"){
        #Deltac<-2*Deltac
        #reweight Xc for logistic model
        expminXb<-exp(-Xcinit%*%c(betasinit))
        Pinit<-1/(1+expminXb)
        W<-diag(c(sqrt(Pinit*(1-Pinit))))
        Xc<-W%*%Xcinit
      }else if(model=="cox"){
        #Deltac<-2*Deltac
        #reweight Xc for cox model
        expXb<-exp(Xcinit%*%c(betasinit))
        h0 <- sapply(1:length(Y[,1]),function(i){Y[i,2]/sum(expXb[Y[,1]>=Y[i,1]])})#updated baseline hazard in censored times for left out samples
        H0 <- sapply(Y[,1],function(Ti){sum(h0[Y[,1]<=Ti])})
        
        W <- diag(c(sqrt(H0*expXb)))
        Xc<-W%*%Xcinit
      }else if(model=="family"){
        lp <- Xcinit%*%c(betasinit) #linear predictor
        meansY <- fml$linkinv(lp) #mu=E(Y)
        W <- diag(c(sqrt(fml$variance(meansY)))) #square root of variance matrix
        Xc<-W%*%Xcinit
      }
    }
    if(model%in%c("logistic","cox","family") && all(W==0)){
      #browser()
      if(!silent) print("Overfitting: only 0 in weight matrix W")
      if(!silent) print(paste("Iterating stopped after",Itr-1,"iterations",sep=" "))
      break;
    }
    
    #NOTE: in glmnet not yet unpenalised covariates other than intercept
    
    #-3.3.2 Compute matrices needed for MoM ###################################################################
    # XtXD <- t(Xc)%*%Xc+Deltac
    # XtXDinv <- solve(XtXD)
    # L<-XtXDinv %*% t(Xc)
    pen <- (1:dim(Xc)[2])[!(1:dim(Xc)[2]%in%unpen)] #index covariates to be penalised
    if(p>n){
      if(length(unpen)>0){
        P1<-diag(1,n) - Xc[,unpen]%*%solve(t(Xc[,unpen])%*%Xc[,unpen],t(Xc[,unpen])) #compute orthogonal projection matrix
        eigP1<-eigen(P1,symmetric = TRUE)
        if(!all(round(eigP1$values,digits=0)%in%c(0,1))){
          warning("Check unpenalised covariates")
        }
        eigP1$values<-pmax(0,eigP1$values) #set eigenvalues that are small negative due to numerical errors to 0
        CP1 <- eigP1$vectors %*% diag(sqrt(eigP1$values))
        Xpen <- as.matrix(t(CP1)%*%Xc[,pen]%*%Matrix::sparseMatrix(i=1:length(pen),j=1:length(pen),x=Matrix::diag(Deltac)[pen]^(-0.5)))
      } else{
        pen <- 1:p
        CP1<- diag(1,n)
        Xpen <- as.matrix(Xc[,pen]%*%Matrix::sparseMatrix(i=1:length(pen),j=1:length(pen),x=Matrix::diag(Deltac)[pen]^(-0.5)))
      }
      svdX<-svd(Xpen) #X=UDV^T=RV^T
      svdXR<-svdX$u%*%diag(svdX$d) #R=UD
      L2 <- as.matrix(Matrix::sparseMatrix(i=1:length(pen),j=1:length(pen),
                 x=Matrix::diag(Deltac)[pen]^(-0.5))%*%svdX$v%*%solve(t(svdXR)%*%
                                                    svdXR+diag(1,n),t(svdXR)%*%t(CP1)))
      L<-array(0,c(p+intrcpt,n))
      L[pen,]<-L2 #compute only elements corresponding to penalised covariates
      
      R<-Xc
      V2<-sigmahat*apply(L2,1,function(x){sum(x^2)})
      #V2<-apply(L2,1,function(x){sum(x^2)})
      V<-rep(NaN,p+intrcpt)
      V[pen]<-V2 #variance beta ridge estimator
      zeroV <- which(V==0)
      
      # #should be same as:
      # XtXD <- t(Xc)%*%Xc+Deltac
      # XtXDinv <- solve(XtXD) #inverting pxp matrix really slow, use SVD instead
      # L1<-XtXDinv %*% t(Xc)
      # L1 <- solve(XtXD,t(Xc))
      # R<-Xc
      # V1<-sigmahat*apply(L,1,function(x){sum(x^2)})
      # #same as: V <- sigmahat*diag(L%*%R %*% XtXDinv) 
    }else{ #n>p
      XtXD <- Matrix::as.matrix(t(Xc)%*%Xc+Deltac)
      XtXDinv <- solve(XtXD) #inverting pxp matrix
      L<-XtXDinv %*% t(Xc)
      R<-Xc
      #C<-L%*%R
      V<-sigmahat*apply(L,1,function(x){sum(x^2)})
      zeroV <- which(V==0)
      #same as: V3 <- sigmahat*diag(L%*%R %*% XtXDinv)
    }
 
    #-3.3.3 Update group parameters ###########################################################################
    if(nIt>1){
      if(!silent) print(paste("Compute group penalty estimates, iteration ",Itr,"out of maximal ",nIt," iterations."))
    }
    ### Function Method of Moments to compute group weights for (possibly) multiple parameters 
    MoM <- function(Partitions,hypershrinkage=NaN,groupsets.grouplvl=NaN,
                    fixWeightsMu=NaN,fixWeightsTau=NaN,pars){
      #Partitions: vector with index of partitions
      #fixWeightsMu,fixWeightsTau: when fixed group weights for different partitions/co-data are given, 
      #                            the MoM-function will calculate partition/co-data weights (without extra shrinkage)
      #extract parts of global variables for local copy
      if(length(Partitions)<m | m==1){
        if(cont_codata){
          Zt <- Zt[unlist(indGrpsGlobal[Partitions]),,drop=FALSE]
        }else{
          Zt <- Zt[unlist(indGrpsGlobal[Partitions]),,drop=FALSE]
        }
        
        if(missing(pars)){
          ind0 <- which(unlist(indGrpsGlobal[Partitions])%in%ind0)
          indnot0 <- which(unlist(indGrpsGlobal[Partitions])%in%indnot0)
        }else{
          indnot0 <- pars[[1]]
          ind0 <- pars[[2]]
        }
        if(length(dim(Wminhalf))>1){
          Wminhalf <- Wminhalf[unlist(indGrpsGlobal[Partitions]),unlist(indGrpsGlobal[Partitions]),drop=FALSE]
          # initgamma <- as.vector(ind[unlist(indGrpsGlobal[Partitions]),]%*%betasinit^2 / 
          #                           apply(ind[unlist(indGrpsGlobal[Partitions]),],1,sum))
          initgamma <- rep(1,length(indnot0))
        }else initgamma <- 1#tauglobal  
        G <- G[Partitions] #number of groups in these partitions
        PenGrps <- PenGrps[unlist(indGrpsGlobal[Partitions]),unlist(indGrpsGlobal[Partitions]),drop=FALSE]
        eqfun <- function(gamma,b,A,lam)  return(sum(t(Zt[indnot0,,drop=FALSE])%*%gamma)/length(pen) ) #equality constraint for average prior variance
      }
      #keep local copies of variables to return
      muhat <- muhat[unlist(indGrpsGlobal[Partitions]),Itr]
      gammatilde <- rep(0,sum(G))
      gamma <- gamma[unlist(indGrpsGlobal[Partitions]),Itr]
      lambdashat<-lambdashat[,Itr+1,Partitions] 
      
      #if two shrinkage penalties are given, first is used to select groups, second to shrink estimates
      if(!is.nan(hypershrinkage)){
        temp <- strsplit(hypershrinkage,",") 
        hypershrinkage <- temp[[1]][1]  
        ExtraShrinkage2 <- temp[[1]][-1]
        if(!cont_codata){
          if(grepl("none",hypershrinkage)){
            if(length(Partitions)==1){
              if(!silent) print(paste("Group set ",Partitions,": estimate group weights, hypershrinkage type: ",hypershrinkage,sep=""))
            }
          }else{
            if(!silent) print(paste("Group set ",Partitions,": estimate hyperlambda for ",hypershrinkage," hypershrinkage",sep=""))
          }
        }else{
          if(grepl("none",hypershrinkage)){
            if(length(Partitions)==1){
              if(!silent) print(paste("Co-data matrix ",Partitions,
                                      ": estimate weights, hypershrinkage type: ",
                                      hypershrinkage,sep=""))
            }
          }else if(hypershrinkage=="mgcv"){
            if(!silent) print(paste("Estimate co-data weights and (if included) hyperpenalties with mgcv",sep=""))
          }else{
            if(length(Partitions)==1){
              if(!silent) print(paste("Co-data matrix ",Partitions,": estimate hyperlambda for ",hypershrinkage," hypershrinkage",sep=""))
            }
          }
        }
        
      }
      if(!cont_codata){ #Set up MoM equations in fast GxG linear system in case of group sets
        if(length(G)==1 && G==1 && !cont_codata){
          lambdashat <- c(0,0)
          muhat <- mutrgt
          weightsMu <- NaN
          if(is.nan(tausq)){
            gamma <- 1
            gammatilde <- gamma
          }
        }else if(!cont_codata & !grepl("none",hypershrinkage) & !all(G==1) & length(indnot0)>1){ 
          #-3.3.3|1 With extra shrinkage -----------------------------------------------------------------
          # Use splits to penalise for too many groups
          # Minimise RSS over lambda1, lambda2 to find optimal penalties for shrinkage on group level
          # Splits group randomly in half for nsplits times, INDin: one half of the split
          INDin <- lapply(1:m,function(prt){ 
            if(!(prt%in%Partitions)){return(NULL)}else{
              replicate(nsplits,lapply(groupsets[[prt]],function(x){sample(x,floor(length(x)/2),replace=FALSE)}),simplify=FALSE)  
            }
          }) #keep list of m elements such that index same as groupsets
          INDout <- lapply(1:m,function(i){ #for each partition
            if(!(i%in%Partitions)){return(NULL)}else{
              lapply(INDin[[i]],function(indin){ #for each split
                lapply(1:length(groupsets[[i]]),function(x){groupsets[[i]][[x]][!(groupsets[[i]][[x]]%in%indin[[x]])]})})
            }
          })
          #INDin[[Partitions]] <- lapply(groupsets[Partitions],function(prt){
          #  replicate(nsplits,lapply(prt,function(x){sample(x,floor(length(x)/2),replace=FALSE)}),simplify=FALSE)
          #})
          #INDout <- lapply(Partitions,function(i){ #for each partition
          #  lapply(INDin[[i]],function(indin){ #for each split
          #    lapply(1:G[i],function(x){groupsets[[i]][[x]][!(groupsets[[i]][[x]]%in%indin[[x]])]})})
          #})
          
          #-3.3.3|1.1 EB estimate group means ============================================================
          muhatp <-as.vector(rep(mu,sum(G))%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p
          muhatp[(1:p)%in%unpen] <- 0
          
          weightsMu <- rep(NaN,sum(G))
          if(is.nan(mu)){
            if(is.nan(lambdashat[1])){
              #-3.3.3|1.1.1 Compute linear system for whole partition ####################################
              A.mu <- matrix(unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    #compute row with gamma_{xy}
                    x<-groupsets[[i]][[j]]
                    unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){sum(L[x,]%*%t(t(R[,y])/Ik[[prt]][y]))/Kg[[i]][j]})}))
                  }, simplify="array")
                })
              ),c(sum(G),sum(G)),byrow=TRUE) #reshape to matrix of size sum(G)xsum(G)
              Bmu <- unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    x<-groupsets[[i]][[j]]
                    sum(betasinit[x]-muinitp[x]+L[x,]%*%(R[,pen]%*%muinitp[pen]))/Kg[[i]][j]
                  })
                })
              )
              
              sdA.mu <- c(apply(A.mu,2,function(x){sd(x,na.rm=TRUE)}))
              A.mu<-A.mu%*% diag(1/sdA.mu) #normalise columns
              
              #-3.3.3|1.1.2 For each split, compute linear system ########################################
              mutrgtG <- mutrgt
              if(length(mutrgt)==1){ mutrgtG <- rep(mutrgt,sum(G))}
              mutrgtG<-diag(sdA.mu)%*%mutrgtG
              
              #in-part
              A.muin <- lapply(1:nsplits,function(split){
                matrix(unlist(
                  lapply(Partitions,function(i){ #for each partition
                    sapply(1:length(Kg[[i]]),function(j){ #for each group
                      #compute row with gamma_{xy}
                      x<-INDin[[i]][[split]][[j]]
                      #compute row with gamma_{xy}
                      unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){sum(L[x,]%*%t(t(R[,y])/Ik[[prt]][y]))/Kg[[i]][j]})}))
                    }, simplify="array")
                  })
                ),c(sum(G),sum(G)),byrow=TRUE) %*% diag(1/sdA.mu) #reshape to matrix of size sum(G)xsum(G)
              })
              #rhs vector
              Bmuin <- lapply(1:nsplits,function(split){unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    x<-INDin[[i]][[split]][[j]]
                    sum(betasinit[x]-muinitp[x]+L[x,]%*%(R[,pen]%*%muinitp[pen]))/Kg[[i]][j]
                  })
                })
              )
              })
              
              #weight matrix
              A.muinAcc <- lapply(1:nsplits,function(i){
                A.muinAcc <- A.muin[[i]] %*% Wminhalf #weight matrix 
              })
              
              #out-part: use A.mu_{out}=A.mu-A.mu_{in}, B_{out}=B-B_{in}
              A.muout <- lapply(1:nsplits,function(split){
                A.mu-A.muin[[split]]
              })
              Bmuout <- lapply(1:nsplits,function(split){
                Bmu-Bmuin[[split]]
              })
              
              
              #-3.3.3|1.1.3 Define function RSSlambdamu, ################################################
              # using the extra shrinkage penalty function corresponding to parameter hypershrinkage
              rangelambda1 <- c(-100,100)
              switch(hypershrinkage,
                     "ridge"={
                       #standard deviation needed for glmnet
                       sd_Bmuin<- lapply(1:nsplits,function(i){
                         if(length(ind0)>0){
                           sd_Bmuin <- sqrt(var(Bmuin[[i]][indnot0]- 
                                                  as.matrix(A.muin[[i]][indnot0,ind0],c(length(c(indnot0,ind0))))%*%muhat[ind0] -
                                                  A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                         }else{
                           sd_Bmuin <- sqrt(var(Bmuin[[i]][indnot0]- 
                                                  A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                         }
                       })
                       RSSlambdamu <- function(lambda1){
                         #Ridge estimates for given lambda
                         lambda1<-exp(lambda1)
                         
                         ### Estimate group means in-part for given lambda1
                         muhatin <- lapply(1:nsplits,function(i){
                           #ridge estimate for group means
                           muhatin <- rep(NaN,sum(G))
                           muhatin[ind0]<-muhat[ind0] #groups with variance 0 keep same prior parameters
                           if(length(ind0)>0){
                             glmMuin <- glmnet::glmnet(A.muinAcc[[i]][indnot0,indnot0],Bmuin[[i]][indnot0]- 
                                                         as.matrix(A.muinAcc[[i]][indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0],
                                                       alpha=0,
                                                       #lambda = 2*lambda1/length(indnot0)*sd_Bmuin[[i]],
                                                       family="gaussian",
                                                       offset = A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                           }else{
                             glmMuin <- glmnet::glmnet(A.muinAcc[[i]][indnot0,indnot0],Bmuin[[i]][indnot0],
                                                       alpha=0,
                                                       #lambda = 2*lambda1/length(indnot0)*sd_Bmuin[[i]],
                                                       family="gaussian",
                                                       offset = A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                           }
                           #muhatin[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(glmMuin$beta) + mutrgtG[indnot0]
                           muhatin[indnot0] <- Wminhalf[indnot0,indnot0] %*% 
                             coef(glmMuin,s=2*lambda1/length(indnot0)*sd_Bmuin[[i]])[-1] + mutrgtG[indnot0]
                           return(muhatin)
                         }
                         ) #group estimate for mu_in
                         
                         ### Compute RSS on left-out part
                         A.muoutmuin <- lapply(1:nsplits,function(split){A.muout[[split]][indnot0,]%*%muhatin[[split]]})
                         RSSmu <- sum(sapply(1:nsplits,function(split){sum((A.muoutmuin[[split]]-Bmuout[[split]][indnot0])^2)/nsplits}))
                         return(RSSmu)
                       }
                     },
                     "lasso"={
                       ### Fit glmnet for global range of lambda
                       fitMu <- lapply(1:nsplits,function(i){
                         if(length(ind0)>0){
                           glmMuin <- glmnet::glmnet(A.muinAcc[[i]][indnot0,indnot0],Bmuin[[i]][indnot0]- 
                                                       as.matrix(A.muinAcc[[i]][indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0],
                                                     alpha=1,family="gaussian",
                                                     offset = A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE,
                                                     thresh = 1e-10)
                         }else{
                           glmMuin <- glmnet::glmnet(A.muinAcc[[i]][indnot0,indnot0],Bmuin[[i]][indnot0],
                                                     alpha=1,family="gaussian",
                                                     offset = A.muin[[i]][indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE,
                                                     thresh = 1e-10)
                         }
                       })
                       
                       RSSlambdamu <- function(lambda1){
                         #Ridge estimates for given lambda
                         lambda1<-exp(lambda1)
                         
                         ### Estimate group means in-part for given lambda1
                         muhatin <- lapply(1:nsplits,function(i){
                           #ridge estimate for group means
                           muhatin <- rep(NaN,sum(G))
                           muhatin[ind0]<-muhat[ind0] #groups with variance 0 keep same prior parameters
                           coefMu<- coef(fitMu[[i]], s = lambda1, exact = FALSE)[-1,]
                           muhatin[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(coefMu) + mutrgtG[indnot0]
                           return(muhatin)
                         }
                         ) #group estimate for mu_in
                         
                         ### Compute RSS on left-out part
                         A.muoutmuin <- lapply(1:nsplits,function(split){A.muout[[split]][indnot0,]%*%muhatin[[split]]})
                         RSSmu <- sum(sapply(1:nsplits,function(split){sum((A.muoutmuin[[split]]-Bmuout[[split]][indnot0])^2)/nsplits}))
                         return(RSSmu)
                       }
                     },
                     "hierLasso"={
                       #TD: acc or not?
                       #Hierarchical overlapping group estimates for given lambda
                       #no target for mu (shrunk to 0)
                       #A.muxtnd <- lapply(A.muinAcc,function(X){return(X[,unlist(groupsets.grouplvl)])}) #extend matrix such to create artifical non-overlapping groups
                       A.muxtnd <- lapply(A.muin,function(X){return(X[,unlist(groupsets.grouplvl)])}) #extend matrix such to create artifical non-overlapping groups
                       #create new group indices for Axtnd
                       Kg2 <- c(1,sapply(groupsets.grouplvl,length)) #group sizes on group level (1 added to easily compute hier. group numbers)
                       G2 <- length(Kg2)-1
                       groupxtnd <- lapply(2:length(Kg2),function(i){sum(Kg2[1:(i-1)]):(sum(Kg2[1:i])-1)}) #list of indices in each group
                       groupxtnd2 <- unlist(sapply(1:G2,function(x){rep(x,Kg2[x+1])})) #vector with group number
                       
                       ### Fit gglasso for global range of lambda
                       fit1<-lapply(1:nsplits,function(i){
                         gglasso::gglasso(x=A.muxtnd[[i]],y=Bmuin[[i]],group = groupxtnd2, loss="ls", 
                                          intercept = FALSE, pf = rep(1,G2))
                       })
                       rangelambda1 <- log(range(sapply(fit1,function(i){range(i$lambda)})))
                       
                       RSSlambdamu <- function(lambda1){
                         lambda1<-exp(lambda1)
                         
                         ### Estimate prior gammas for given lambda2 (and mutrgt=0)
                         muhatin <- lapply(1:nsplits,function(i){
                           vtilde <- coef(fit1[[i]],s=lambda1)[-1]
                           v<-lapply(groupxtnd,function(g){
                             x<-rep(0,G)
                             x[unlist(groupsets.grouplvl)[g]]<-x[unlist(groupsets.grouplvl)[g]]+vtilde[g]
                             return(x)
                           })
                           muhatin <- Wminhalf %*% c(apply(array(unlist(v),c(G,G2)),1,sum))
                           return(muhatin)
                         })
                         
                         ### Compute MSE on left-out part
                         A.muoutmuin <- lapply(1:nsplits,function(split){A.muout[[split]]%*%muhatin[[split]]})
                         RSSmu <- sum(sapply(1:nsplits,function(split){sum((A.muoutmuin[[split]]-Bmuout[[split]])^2)/nsplits}))
                         return(RSSmu)
                         
                         # lamb<-seq(exp(rangelambda1[1]),exp(rangelambda1[2]),diff(exp(rangelambda1))/200)
                         # RSS<-sapply(log(lamb),RSSlambdamu)
                         # plot(lamb,RSS)
                       }
                     }
              )
              
              #First find optimal lambda_1
              lambda1<- optimise(RSSlambdamu,rangelambda1)
              lambdashat[1] <- exp(lambda1$minimum)
            }
            
            #-3.3.3|1.1.4 Compute group mean estimates for optimised hyperpenalty lambda #################
            if(lambdashat[1]==0){
              #groups with zero group variance already in muhat
              if(length(ind0)>0){ #only update groups with positive group variance
                muhat[indnot0] <- solve(A.mu[indnot0,indnot0],Bmu[indnot0]-
                                          as.matrix(A.mu[indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0])
              }else{
                muhat[indnot0] <- solve(A.mu[indnot0,indnot0],Bmu[indnot0])
              }
              muhat[indnot0] <- diag(1/sdA.mu[indnot0]) %*% muhat[indnot0] #restore sd columns mu
            }else{
              #-3.3.3|1.1.5 Compute mu for given hyperpenalty  ###########################################
              switch(hypershrinkage,
                     "ridge"={
                       A.muAcc <- A.mu %*% Wminhalf
                       if(length(ind0)>0){
                         sd_Bmu <- sqrt(var(Bmu[indnot0] - as.matrix(A.mu[indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0]
                                            -as.matrix(A.mu[indnot0,indnot0],c(length(indnot0),length(ind0))) %*% mutrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                         #ridge estimate for group means
                         glmMu <- glmnet::glmnet(A.muAcc[indnot0,indnot0],Bmu[indnot0]-
                                                   as.matrix(A.muAcc[indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0],alpha=0,
                                                 #lambda = 2*lambdashat[1]/length(indnot0)*sd_Bmu,
                                                 family="gaussian",
                                                 offset = A.mu[indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                       }else{
                         sd_Bmu <- sqrt(var(Bmu[indnot0]
                                            -as.matrix(A.mu[indnot0,indnot0],c(length(indnot0),length(ind0))) %*% mutrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                         #ridge estimate for group means
                         glmMu <- glmnet::glmnet(A.muAcc[indnot0,indnot0],Bmu[indnot0],alpha=0,
                                                 #lambda = 2*lambdashat[1]/length(indnot0)*sd_Bmu,
                                                 family="gaussian",
                                                 offset = A.mu[indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                       }
                       #groups with variance 0 keep same prior parameters, update other groups
                       #muhat[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(glmMu$beta) + mutrgtG[indnot0]
                       muhat[indnot0] <- Wminhalf[indnot0,indnot0] %*% 
                         coef(glmMu, s=2*lambdashat[1]/length(indnot0)*sd_Bmu)[-1] + mutrgtG[indnot0]
                       muhat[indnot0] <- diag(1/sdA.mu[indnot0]) %*% muhat[indnot0] #restore sd columns A.mu
                       
                     },
                     "lasso"={
                       A.muAcc <- A.mu %*% Wminhalf
                       #ridge estimate for group means
                       if(length(ind0)>0){
                         glmMu <- glmnet::glmnet(A.muAcc[indnot0,indnot0],Bmu[indnot0]-
                                                   as.matrix(A.muAcc[indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0],
                                                 alpha=1,family="gaussian",
                                                 offset = A.mu[indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                       }else{
                         glmMu <- glmnet::glmnet(A.muAcc[indnot0,indnot0],Bmu[indnot0],
                                                 alpha=1,family="gaussian",
                                                 offset = A.mu[indnot0,indnot0] %*% mutrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                       }
                       coefMu <- coef(glmMu,s=lambdashat[1])
                       #groups with variance 0 keep same prior parameters, update other groups
                       muhat[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(coefMu[-1,]) + mutrgtG[indnot0]
                       muhat[indnot0] <- diag(1/sdA.mu[indnot0]) %*% muhat[indnot0] #restore sd columns A.mu
                     },
                     "hierLasso"={
                       #Hierarchical overlapping group estimates for given lambda
                       #no target for mu (shrunk to 0)
                       #A.muAcc <- A.mu %*% Wminhalf
                       #A.muxtnd <- A.muAcc[,unlist(groupsets.grouplvl)] #extend matrix such to create artifical non-overlapping groups
                       A.muxtnd <- A.mu[,unlist(groupsets.grouplvl)] #extend matrix such to create artifical non-overlapping groups
                       #create new group indices for Axtnd
                       Kg2 <- c(1,sapply(groupsets.grouplvl,length)) #group sizes on group level (1 added to easily compute hier. group numbers)
                       G2 <- length(Kg2)-1
                       groupxtnd <- lapply(2:length(Kg2),function(i){sum(Kg2[1:(i-1)]):(sum(Kg2[1:i])-1)}) #list of indices in each group
                       groupxtnd2 <- unlist(sapply(1:G2,function(x){rep(x,Kg2[x+1])})) #vector with group number
                       
                       #Hierarchical group lasso estimate for group variances
                       fit1<-gglasso::gglasso(x=A.muxtnd,y=Bmu,group = groupxtnd2, loss="ls", 
                                              intercept = FALSE, pf = rep(1,G2),lambda=lambdashat[1])
                       vtilde <- coef(fit1,s=lambdashat[1])[-1]
                       v<-lapply(groupxtnd,function(g){
                         x<-rep(0,G)
                         x[unlist(groupsets.grouplvl)[g]]<-x[unlist(groupsets.grouplvl)[g]]+vtilde[g]
                         return(x)
                       })
                       muhat <- Wminhalf %*% c(apply(array(unlist(v),c(G,G2)),1,sum))
                       muhat <- diag(1/sdA.mu) %*% muhat #restore sd columns A
                     })
            }
            
            weightsMu <- muhat*p/sum(as.vector(c(muhat)%*%Zt))
            muhatp <-as.vector(c(muhat)%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p
            muhatp[(1:p)%in%unpen] <- 0
            # if(normalise){ #TRUE by default
            #   C<-mutrgt*p/sum(muhatp)
            #   muhat[,Itr+1]<-muhat[,Itr+1]*C
            #   muhatp <-as.vector(c(muhat[,Itr+1])%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p
            # }
            
            #should be same as:
            #muhat2 <- solve(t(A.mu)%*%A.mu+lambdashat[1]*diag(weights),t(A.mu)%*%Bmu+lambdashat[1]*diag(weights)%*%rep(mutrgt,G))
            #muhat2 <- solve(t(A.mu)%*%diag(c(Kg))%*%A.mu+lambdashat[1]*diag(1,G),t(A.mu)%*%diag(c(Kg))%*%Bmu+lambdashat[1]*diag(1,G)%*%rep(mutrgt,G))
          }
          
          #-3.3.3|1.2 EB estimate group variances =========================================================
          gamma <- rep(1,sum(G))
          if(is.nan(tausq)){
            if(is.nan(lambdashat[2])){
              #-3.3.3|1.2.1 Compute linear system for whole partition #####################################
              Btau <- unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    if(j%in%ind0) return(NaN)
                    #compute row with gamma_{xy}
                    x<-groupsets[[i]][[j]]
                    x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                    #sum(pmax(0,(betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%
                    #      (muhatp[pen]-muinitp[pen])))^2)/V[x]-1),na.rm=TRUE)/Kg[[i]][j]
                    sum((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%
                                                               (muhatp[pen]-muinitp[pen])))^2)/V[x]-1,na.rm=TRUE)/Kg[[i]][j]
                  })
                })
              )
              A <- matrix(unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    if(j%in%ind0) return(rep(NaN,sum(G)))
                    #compute row with gamma_{xy}
                    x<-groupsets[[i]][[j]]
                    x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                    #compute row with gamma_{xy}
                    unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){
                      y<-setdiff(y,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                      sum(t(c(1/V[x])*L[x,])%*%L[x,]*(R[,y]%*%(t(R[,y])/c(Ik[[prt]][y])*c(tauglobal[datablockNo[y]]))),na.rm=TRUE)/Kg[[i]][j]
                    })}))
                  }, simplify="array")
                })
              ),c(sum(G),sum(G)),byrow=TRUE) #reshape to matrix of size sum(G)xsum(G)
              
              constA <- 1 #mean(diag(A),na.rm=TRUE)
              Btau <- Btau/constA
              A <- A/constA
              
              #if(Itr==2) browser()
              #-3.3.3|1.2.2 For each split, compute linear system #########################################
              gammatrgtG <- rep(1,sum(G))
              gammatrgtG[ind0]<-0 
              
              #in-part
              flag <- TRUE; itr2 <- 1
              indNewSplits <- 1:nsplits; 
              Btauin <- list(); Btauout <- list()
              while(flag & itr2 <= 50){
                Btauin[indNewSplits] <- lapply(indNewSplits,function(split){unlist(
                  lapply(Partitions,function(i){ #for each partition
                    sapply(1:length(Kg[[i]]),function(j){ #for each group
                      if(j%in%ind0) return(NaN)
                      #compute row with gamma_{xy}
                      x<-INDin[[i]][[split]][[j]]
                      x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                      #sum(pmax(0,(betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1),na.rm=TRUE)/Kg[[i]][j]
                      sum((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1,na.rm=TRUE)/Kg[[i]][j]
                    })
                  })
                )/constA
                })
                
                #check split: at least two elements of Btauin (of selected groups) should be larger than 0
                checkSplit <- sapply(Btauin,function(b){sum(b[!is.nan(b)]!=0)>=2 }) 
                if(all(checkSplit)){ #all splits are fine
                  flag <- FALSE
                }else{ 
                  itr2 <- itr2 + 1
                  indNewSplits <- which(!checkSplit) #index of splits that have to be resampled
                  #resample split
                  INDin[[Partitions]][indNewSplits] <- replicate(length(indNewSplits),lapply(groupsets[[Partitions]],
                                                                                             function(x){sample(x,floor(length(x)/2),replace=FALSE)}),simplify=FALSE)  
                  INDout[[Partitions]][indNewSplits] <- lapply(INDin[[Partitions]][indNewSplits],function(indin){ #for each split
                    lapply(1:length(groupsets[[Partitions]]),
                           function(x){groupsets[[Partitions]][[x]][!(groupsets[[Partitions]][[x]]%in%indin[[x]])]})})
                }
              }
              if(itr2==51) warning("Check splits")
              
              Ain <- lapply(1:nsplits,function(split){
                matrix(unlist(
                  lapply(Partitions,function(i){ #for each partition
                    sapply(1:length(Kg[[i]]),function(j){ #for each group
                      if(j%in%ind0) return(rep(NaN,sum(G)))
                      #compute row with gamma_{xy}
                      x<-INDin[[i]][[split]][[j]]
                      x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                      #compute row with gamma_{xy}
                      unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){
                        y<-setdiff(y,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                        sum(t(c(1/V[x])*L[x,])%*%L[x,]*(R[,y]%*%(t(R[,y])/c(Ik[[prt]][y])*c(tauglobal[datablockNo[y]]))),na.rm=TRUE)/Kg[[i]][j]
                      })}))
                    }, simplify="array")
                  })
                ),c(sum(G),sum(G)),byrow=TRUE)/constA #reshape to matrix of size sum(G)xsum(G)
              })
              #weight matrix
              AinAcc <- lapply(1:nsplits,function(i){
                AinAcc <- Ain[[i]] %*% Wminhalf #weight matrix 
              })
              
              Btauout <- lapply(1:nsplits,function(split){unlist(
                lapply(Partitions,function(i){ #for each partition
                  sapply(1:length(Kg[[i]]),function(j){ #for each group
                    if(j%in%ind0) return(NaN)
                    #compute row with gamma_{xy}
                    x<-INDout[[i]][[split]][[j]]
                    x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                    #sum(pmax(0,(betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1),na.rm=TRUE)/Kg[[i]][j]
                    sum((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1,na.rm=TRUE)/Kg[[i]][j]
                  })
                })
              )/constA
              })
              # Btauout <- lapply(1:nsplits,function(split){
              #   Btau - Btauin[[split]]
              # })
              Aout <- lapply(1:nsplits,function(split){
                A - Ain[[split]]
              })
              
              #-3.3.3|1.2.3 Define function RSSlambdatau, #################################################
              # using the extra shrinkage penalty function corresponding to parameter hypershrinkage
              rangelambda2 <- c(10^-5,10^6)
              switch(hypershrinkage,
                     "ridge"={
                       gammatrgtG[indnot0] <- 1
                       meanWhalf <- mean(diag(Wminhalf)^-1)
                       #standard deviation needed for glmnet
                       sd_Btauin<- lapply(1:nsplits,function(i){
                         sd_Btauin <- sqrt(var(Btauin[[i]][indnot0] - Ain[[i]][indnot0,indnot0] %*% gammatrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                       })
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         sd_Btau <- sqrt(var(Btau[indnot0] - A[indnot0,indnot0] %*% gammatrgtG[indnot0])*(length(indnot0)-1)/length(indnot0))[1]
                         
                         gammas <- rep(0,G)
                         Aacc <- A %*% (Wminhalf * meanWhalf)
                         #ridge estimate for group variances
                         glmTau <- glmnet::glmnet(Aacc[indnot0,indnot0],Btau[indnot0],alpha=0,
                                                  family="gaussian",
                                                  offset = Aacc[indnot0,indnot0] %*% gammatrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                         coefTau <- coef(glmTau,s=lambda2,exact=TRUE,
                                         x=Aacc[indnot0,indnot0],y=Btau[indnot0],
                                         offset = Aacc[indnot0,indnot0] %*% gammatrgtG[indnot0])[-1,]
                         gammas[indnot0] <- (Wminhalf[indnot0,indnot0]*meanWhalf) %*% as.vector(coefTau) + gammatrgtG[indnot0] 
                         #gammas <- pmax(gammas,0) #truncate at 0
                         return(gammas)
                       }
                       
                       ### Fit glmnet lasso for global range of lambda
                       fitTau <- lapply(1:nsplits,function(i){
                         glmTauin <- glmnet::glmnet(AinAcc[[i]][indnot0,indnot0],Btauin[[i]][indnot0],
                                                    alpha=0,family="gaussian",
                                                    offset = Ain[[i]][indnot0,indnot0] %*% gammatrgtG[indnot0],
                                                    intercept = FALSE, standardize = FALSE,
                                                    thresh=1e-6)
                       })
                       rangelambda2 <- range(sapply(fitTau,function(x)range(x$lambda)))
                       rangelambda2[1] <- rangelambda2[1]/100
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         #Ridge estimates for given lambda
                         ### Estimate prior gammas for given lambda2 and mutrgt
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           #ridge estimate for group variances
                           # glmTauin <- glmnet::glmnet(AinAcc[[i]][indnot0,indnot0],Btauin[[i]][indnot0],alpha=0,
                           #                    lambda = 2*lambda2/length(indnot0)*sd_Btauin[[i]],family="gaussian",
                           #                    offset = Ain[[i]][indnot0,indnot0] %*% gammatrgtG[indnot0], intercept = FALSE, standardize = FALSE)
                           #gammain[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(glmTauin$beta) + gammatrgtG[indnot0] 
                           coefTau <- coef(fitTau[[i]],s=lambda2,exact=TRUE,
                                           x=AinAcc[[i]][indnot0,indnot0],y=Btauin[[i]][indnot0],
                                           offset = Ain[[i]][indnot0,indnot0] %*% gammatrgtG[indnot0])[-1,]
                           gammain[indnot0] <- Wminhalf[indnot0,indnot0] %*% as.vector(coefTau) + gammatrgtG[indnot0] 
                           #gammain <- pmax(gammain,0)
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "ridgeGAM"={
                       gammatrgtG[indnot0] <- 1
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         
                         #ridge estimate for group variances with fused penalty for overlapping groups
                         dat<-list(Y=Btau[indnot0],X=A[indnot0,indnot0])
                         gamTau <- mgcv::gam(Y~ 0 + X, data=dat,
                                             family="gaussian",offset = A[indnot0,indnot0] %*% gammatrgtG[indnot0],
                                             paraPen=list(X=list(S1=PenGrps,sp=lambda2)))
                         gammas[indnot0] <- gamTau$coefficients + gammatrgtG[indnot0]
                         
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         
                         ### Estimate prior gammas for given lambda2 and mutrgt
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           #ridge estimate for group variances
                           dat<-list(Y=Btauin[[i]][indnot0],X=Ain[[i]][indnot0,indnot0])
                           gamTauin <- mgcv::gam(Y~ 0 + X, data=dat,
                                                 family="gaussian",offset = Ain[[i]][indnot0,indnot0] %*% gammatrgtG[indnot0],
                                                 paraPen=list(X=list(S1=PenGrps,sp=lambda2)))
                           gammain[indnot0] <- gamTauin$coefficients + gammatrgtG[indnot0]
                           
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "lasso"={#TD: adapt lasso&hierLasso, check other hyperpenalties on inclusion function gammas
                       meanWhalf <- mean(diag(Wminhalf)^-1)
                       ### Fit glmnet lasso for global range of lambda
                       fitTau <- lapply(1:nsplits,function(i){
                         glmTauin <- glmnet::glmnet(AinAcc[[i]][indnot0,indnot0]*meanWhalf,Btauin[[i]][indnot0],
                                                    alpha=1,family="gaussian",
                                                    intercept = FALSE, standardize = FALSE,
                                                    thresh=1e-6)
                       })
                       rangelambda2 <- range(sapply(fitTau,function(x)range(x$lambda)))
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         Aacc <- A %*% (Wminhalf * meanWhalf)
                         glmTau <- glmnet::glmnet(Aacc[indnot0,indnot0],Btau[indnot0],
                                                  alpha=1,family="gaussian",
                                                  intercept = FALSE, standardize = FALSE)
                         coefTau <- coef(glmTau,s=lambda2,exact=TRUE,
                                         x=Aacc[indnot0,indnot0],y=Btau[indnot0])
                         gammas[indnot0] <- (Wminhalf[indnot0,indnot0]*meanWhalf) %*% as.vector(coefTau[-1,])
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         #Ridge estimates for given lambda
                         ### Estimate prior gammas for given lambda2 and mutrgt
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           #ridge estimate for group variances
                           coefTau<- coef(fitTau[[i]], s = lambda2, exact = TRUE,
                                          x=AinAcc[[i]][indnot0,indnot0]*meanWhalf,y=Btauin[[i]][indnot0])[-1,]
                           gammain[indnot0] <- (Wminhalf[indnot0,indnot0]*meanWhalf) %*% as.vector(coefTau)
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "hierLasso"={
                       maxit_gglasso <- 1e+04
                       #Hierarchical overlapping group estimates for given lambda
                       #no target for tau (shrunk to 0)
                       #remove groups that are already set to 0
                       if(length(groupsets.grouplvl)!=length(indnot0)){
                         INDgrps2 <- lapply(groupsets.grouplvl[indnot0],function(x){x[x%in%indnot0]})
                       }else{
                         INDgrps2 <- groupsets.grouplvl
                       }
                       
                       #Axtnd <- lapply(AinAcc,function(A){return(A[indnot0,unlist(INDgrps2),drop=FALSE])}) #extend matrix such to create artifical non-overlapping groups
                       Axtnd <- lapply(Ain,function(A){return(A[indnot0,unlist(INDgrps2),drop=FALSE])}) #extend matrix such to create artifical non-overlapping groups
                       #create new group indices for Axtnd
                       Kg2 <- c(1,sapply(INDgrps2,length)) #group sizes on group level (1 added to easily compute hier. group numbers)
                       G2 <- length(Kg2)-1
                       groupxtnd <- lapply(2:length(Kg2),function(i){sum(Kg2[1:(i-1)]):(sum(Kg2[1:i])-1)}) #list of indices in each group
                       groupxtnd2 <- unlist(sapply(1:G2,function(x){rep(x,Kg2[x+1])})) #vector with group number
                       
                       ### Fit gglasso for global range of lambda
                       fit2<-lapply(1:nsplits,function(i){
                         capture.output({temp <- gglasso::gglasso(x=Axtnd[[i]],y=Btauin[[i]][indnot0],
                                                  group = groupxtnd2, loss="ls",
                                                  intercept = FALSE, pf = rep(1,G2),maxit = 1e+04)})
                         # temp <- gglasso::gglasso(x=Axtnd[[i]],y=Btauin[[i]][indnot0],
                         #                                          group = groupxtnd2, loss="ls",
                         #                                          intercept = FALSE, pf = rep(1,G2),maxit = maxit_gglasso)
                         return(temp)
                       })
                       rangelambda2 <- range(sapply(fit2,function(i){range(i$lambda)}), na.rm=TRUE)
                       
                       #Find grid to search optimal lambda over
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         Axtnd <- A[indnot0,unlist(INDgrps2),drop=FALSE] #extend matrix such to create artifical non-overlapping groups
                         
                         #create new group indices for Axtnd
                         Kg2 <- c(1,sapply(INDgrps2,length)) #group sizes on group level (1 added to easily compute hier. group numbers)
                         G2 <- length(Kg2)-1
                         groupxtnd <- lapply(2:length(Kg2),function(i){sum(Kg2[1:(i-1)]):(sum(Kg2[1:i])-1)}) #list of indices in each group
                         groupxtnd2 <- unlist(sapply(1:G2,function(x){rep(x,Kg2[x+1])})) #vector with group number
                         
                         #Hierarchical group lasso estimate for group variances
                         fit2<-gglasso::gglasso(x=Axtnd,y=Btau[indnot0],group = groupxtnd2, loss="ls",
                                                intercept = FALSE, pf = rep(1,G2),maxit = maxit_gglasso)
                         gamma <- rep(0,sum(G))
                         vtilde <- try(coef(fit2,s=lambda2)[-1],silent=TRUE)
                         if(class(vtilde)[1]=="try-error") return(gamma) #return 0 vector
                         v<-lapply(groupxtnd,function(g){
                           x<-rep(0,sum(G))
                           x[unlist(INDgrps2)[g]]<-x[unlist(INDgrps2)[g]]+vtilde[g]
                           return(x)
                         })
                         #gammatilde[indnot0] <- Wminhalf[indnot0,indnot0] %*% c(apply(array(unlist(v),c(sum(G),G2)),1,sum))[indnot0]
                         gammas[indnot0] <- c(apply(array(unlist(v),c(sum(G),G2)),1,sum))[indnot0]
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         ### Estimate prior gammas for given lambda2 (and mutrgt=0)
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(0,sum(G))
                           vtilde <- try(coef(fit2[[i]],s=lambda2)[-1],silent=TRUE)
                           if(class(vtilde)[1]=="try-error") return(gammain) #return 0 vector
                           v<-lapply(groupxtnd,function(g){
                             x<-rep(0,sum(G))
                             x[unlist(INDgrps2)[g]]<-x[unlist(INDgrps2)[g]]+vtilde[g]
                             return(x)
                           })
                           #gammain[indnot0] <- Wminhalf[indnot0,indnot0] %*% c(apply(array(unlist(v),c(sum(G),G2)),1,sum))[indnot0]
                           gammain[indnot0] <- c(apply(array(unlist(v),c(sum(G),G2)),1,sum))[indnot0]
                           gammain[gammain<0] <- 0
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,indnot0]%*%gammain[[split]][indnot0]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "ridge+positive"={
                       meanWhalf <- mean(diag(Wminhalf)^-1)
                       trgt <- 1/diag(Wminhalf)/meanWhalf
                       initgamma <- diag(Wminhalf)^(-1)/meanWhalf
                       
                       #define function for MSE penalised with ridge prior with target 1
                       penMSE <- function(gamma,b,A,lam){ 
                         return(sum((b-A%*%gamma)^2) + lam*sum((gamma-trgt[indnot0])^2)) }
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         Aacc <- A %*% (Wminhalf * meanWhalf)
                         initSelected <- initgamma[indnot0]
                         fitTau <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btau[indnot0],
                                                 A=Aacc[indnot0,indnot0], lam=lambda2,
                                                 LB = rep(0,length(indnot0)),control=list(trace=0))
                         gammas[indnot0] <- (Wminhalf[indnot0,indnot0] *meanWhalf) %*%as.vector(fitTau$pars)
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         #Ridge estimates for given lambda
                         ### Estimate prior gammas for given lambda2 and mutrgt
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           
                           Ainacc <- Ain[[i]]%*%(Wminhalf * meanWhalf)
                           fitTauin <- Rsolnp::solnp(par = initgamma[indnot0], fun=penMSE, b=Btauin[[i]][indnot0],
                                                     A=Ainacc[indnot0,indnot0],lam=lambda2,
                                                     LB = rep(0,length(indnot0)),control=list(trace=0))
                           gammain[indnot0] <- (Wminhalf[indnot0,indnot0] * meanWhalf) %*% as.vector(fitTauin$pars)
                           return(gammain)
                         })
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "ridgeGAM+positive"={
                       trgt <- 1
                       
                       #define function for MSE penalised with ridge prior with target 1
                       penMSE <- function(gamma,b,A,lam){ 
                         return(sum((b-A%*%gamma)^2) + 
                                  lam*sum((gamma-trgt)%*%PenGrps%*%(gamma-trgt))) }
                       
                       gammas <- function(lambda2){
                         #gammas <- rep(0,sum(G))
                         initSelected <- initgamma[indnot0]
                         fitTau <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btau[indnot0], 
                                                 A=A[indnot0,indnot0], lam=lambda2,
                                                 LB = rep(0,length(indnot0)),control=list(trace=0))
                         gammain[indnot0] <- as.vector(fitTau$pars)
                         return(gammain)
                       }
                       
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         #Ridge estimates for given lambda
                         ### Estimate prior gammas for given lambda2 and mutrgt
                         gammain <- lapply(1:nsplits,function(i){
                           gammain <- rep(0,sum(G))
                           
                           initSelected <- initgamma[indnot0]
                           fitTauin <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btauin[[i]][indnot0], 
                                                     A=Ain[[i]][indnot0,indnot0], lam=lambda2,
                                                     LB = rep(0,length(indnot0)),control=list(trace=0))
                           gammain[indnot0] <- as.vector(fitTauin$pars)
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "invgamma+mean1"={
                       meanWhalf <- mean(diag(Wminhalf)^-1)
                       
                       #define function for MSE penalised with inverse gamma prior with mean 1
                       initgamma <- diag(Wminhalf)^(-1)/meanWhalf
                       #MSE penalised by inverse gamma penalty
                       penMSE <- function(gamma,b,A,lam){ 
                         Kg <- diag(Wminhalf)^(-2) #group sizes
                         alphaIG <- pmax(1,2 + (lam-1/min(Kg))*Kg) #alpha in range [1,infty)
                         betaIG <- pmax(0,sqrt(Kg)* (1+(lam-1/min(Kg))* Kg) / meanWhalf) #beta in range [0,infty)
                         
                         minlogLikeInvGamma <- (alphaIG[indnot0] + 1)*log(gamma) + betaIG[indnot0]/gamma
                         return(sum((b-A%*%gamma)^2) + sum(minlogLikeInvGamma) ) }
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         Aacc <- A %*% (Wminhalf * meanWhalf)
                         initSelected <- initgamma[indnot0]
                         fitTau <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btau[indnot0],
                                                 A=Aacc[indnot0,indnot0], lam=lambda2,
                                                 LB = rep(0,length(indnot0)),control=list(trace=0))
                         gammas[indnot0] <- (Wminhalf[indnot0,indnot0] *meanWhalf) %*%as.vector(fitTau$pars)
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         ### Estimate prior gammas for given lambda2
                         gammain <- lapply(1:nsplits,function(i){
                           Ainacc <- Ain[[i]]%*%(Wminhalf * meanWhalf)
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           initSelected <- initgamma[indnot0]
                           fitTauin <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btauin[[i]][indnot0],
                                                     A=Ainacc[indnot0,indnot0],lam=lambda2,
                                                     LB = rep(0,length(indnot0)),control=list(trace=0))
                           gammain[indnot0] <- (Wminhalf[indnot0,indnot0] *meanWhalf)%*%as.vector(fitTauin$pars)
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "invgamma+mode1"={
                       meanWhalf <- mean(diag(Wminhalf)^-1)
                       
                       #define function for MSE penalised with inverse gamma prior with mean 1
                       initgamma <- diag(Wminhalf)^(-1)/meanWhalf
                       
                       #MSE penalised by inverse gamma penalty
                       penMSE <- function(gamma,b,A,lam){ 
                         Kg <- diag(Wminhalf)^(-2) #group sizes
                         prmsIG<-.prmsIGMode1(lam,Kg)
                         alphaIG <- prmsIG[[1]]
                         betaIG<-prmsIG[[2]] * diag(Wminhalf)^(-1)/meanWhalf
                         
                         minlogLikeInvGamma <- (alphaIG[indnot0] + 1)*log(gamma) + betaIG[indnot0]/gamma
                         return(sum((b-A%*%gamma)^2) + sum(minlogLikeInvGamma) ) }
                       
                       #function to compute tau for linear system given a hyperpenalty lambda2
                       gammas <- function(lambda2){
                         gammas <- rep(0,G)
                         Aacc <- A %*% (Wminhalf * meanWhalf)
                         initSelected <- initgamma[indnot0]
                         fitTau <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btau[indnot0],
                                                 A=Aacc[indnot0,indnot0], lam=lambda2,
                                                 LB = rep(0,length(indnot0)),control=list(trace=0))
                         gammas[indnot0] <- (Wminhalf[indnot0,indnot0] *meanWhalf) %*%as.vector(fitTau$pars)
                         return(gammas)
                       }
                       
                       #function to compute the Residual Sum of Squares on the splits given lambda2
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         ### Estimate prior gammas for given lambda2
                         gammain <- lapply(1:nsplits,function(i){
                           Ainacc <- Ain[[i]]%*%(Wminhalf * meanWhalf)
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           initSelected <- initgamma[indnot0]
                           fitTauin <- Rsolnp::solnp(par = initSelected, fun=penMSE, b=Btauin[[i]][indnot0],
                                                     A=Ainacc[indnot0,indnot0],lam=lambda2,
                                                     LB = rep(0,length(indnot0)),control=list(trace=0))
                           gammain[indnot0] <- (Wminhalf[indnot0,indnot0] *meanWhalf)%*%as.vector(fitTauin$pars)
                           return(gammain)
                         })
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     },
                     "gamma+positive"={
                       #define function for MSE penalised with gamma prior with mean 1
                       spike<-0.001
                       penMSE <- function(gamma,b,A,lam){ 
                         #logLikeInvGamma <- sapply(gamma,function(x){logdinvgamma(x,alp=lam,bet=lam-1)})
                         logLikeGamma <- sapply(gamma,function(x){
                           #return(lam*log(lam)-log(gamma(lam))+(lam-1)*log(x)-lam*x)
                           if(x==0) return(log(spike))
                           return(log(1-spike)+(lam-1)*log(x)-lam*x)
                         })
                         return(sum((b-A%*%gamma)^2) - sum(logLikeGamma) ) }
                       #eqfun2 <- function(gamma,b,A,lam)  return(sum(t(Zt[indnot0,])%*%(Wminhalf[indnot0,indnot0]%*%gamma))/length(pen) ) #equality constraint for average prior variance
                       
                       RSSlambdatau <- function(lambda2){
                         lambda2 <- exp(lambda2)
                         #if(lambda2<100) lambda2<-log(exp(lambda2)+1)
                         
                         browser()
                         ### Estimate prior gammas for given lambda2
                         #compute gamma for first split
                         i<-1
                         #Ainacc <- Ain[[i]]%*%Wminhalf
                         gammain1 <- rep(NaN,sum(G))
                         gammain1[ind0] <- 0
                         fitTauin1 <- Rsolnp::solnp(par = initgamma, fun=penMSE, b=Btauin[[i]][indnot0], 
                                                    A=Ain[[i]][indnot0,indnot0],lam=lambda2,
                                                    LB = rep(0,G), eqfun=eqfun, eqB = 1,control=list(trace=0))
                         #gammain1[indnot0] <- Wminhalf%*%as.vector(fitTauin1$pars)
                         gammain1[indnot0] <- as.vector(fitTauin1$pars)
                         
                         gammain <- lapply(2:nsplits,function(i){
                           #Ainacc <- Ain[[i]]%*%Wminhalf
                           gammain <- rep(NaN,sum(G))
                           gammain[ind0] <- 0
                           fitTauin <- Rsolnp::solnp(par = initgamma, fun=penMSE, b=Btauin[[i]][indnot0], 
                                                     A=Ain[[i]][indnot0,indnot0],lam=lambda2,
                                                     LB = rep(0,G), eqfun=eqfun, eqB = 1,control=list(trace=0))
                           #gammain[indnot0] <- Wminhalf%*%as.vector(fitTauin$pars)
                           gammain[indnot0] <- as.vector(fitTauin$pars)
                           
                           return(gammain)
                         })
                         gammain <- c(list(gammain1),gammain)
                         
                         ### Compute MSE on left-out part
                         Aouttauin <- lapply(1:nsplits,function(split){Aout[[split]][indnot0,]%*%gammain[[split]]})
                         RSStau <- sum(sapply(1:nsplits,function(split){sum((Aouttauin[[split]]-Btauout[[split]][indnot0])^2)/nsplits}))
                         return(RSStau)
                       }
                     }
              )
              
              #find optimal lambda_2 given muhat
              
              tic<-proc.time()[[3]]
              lambda2 <- optim(mean(log(rangelambda2)),RSSlambdatau,method="Brent",
                               lower = log(rangelambda2[1]),upper = log(rangelambda2[2]))
              lambdashat[2] <- exp(lambda2$par)
              #exp(lambda2$par)
              toc <- proc.time()[[3]]-tic
              # lambda2 <- optimise(RSSlambdatau,rangelambda2) #regular optimiser can get stuck in flat region
              # lambdashat[2] <- lambda2$minimum
              #browser()
              
              if(profplotRSS){ #profile plot lambda vs RSS
                lambdas <- 10^seq(-5,6,length.out=30)
                if(all(lambdas>rangelambda2[2] | lambdas<rangelambda2[1])){
                  lambdas <- 10^seq(log(rangelambda2)[1],log(rangelambda2)[2],length.out=30)
                } 
                FRSS <- sapply(log(lambdas),RSSlambdatau)
                profPlot <- plot(log10(lambdas),FRSS,xlab="hyperlambda (log10-scale)",ylab="RSS",
                                 main=paste("Group set ",Partitions,", ",hypershrinkage," hypershrinkage",sep=""))
                abline(v=log10(lambdashat[2]),col="red")
                abline(v=log10(rangelambda2[1]),col="blue",lty=2)
                abline(v=log10(rangelambda2[2]),col="blue",lty=2)
                if(!silent) print(paste("Estimated hyperlambda: ",lambdashat[2],sep=""))
              }
              
              # #first find range for lambda
              # tic <- proc.time()[[3]]
              # minTau <- gammas(10^-9) #minimally penalised tau
              # maxTau <- gammas(10^9) #maximally penalised tau
              # lb <- 10^-8
              # ub <- 10^8
              # diff <- (minTau-maxTau)^2*10^-2 #1 percent relative difference
              # while(all(abs(gammas(lb)[indnot0]-minTau[indnot0])<diff[indnot0]) & lb<ub){
              #   lb <- lb*10
              # }
              # while(all(abs(gammas(ub)[indnot0]-maxTau[indnot0])<diff[indnot0]) & ub>lb){
              #   ub <- ub/10
              # }
              # rangelambda2 <- c(lb/10,ub*10) #take values just outside the range
              # 
              # #then fit for range of lambda and take minimizer
              # if(hypershrinkage=="ridge"){
              #   lambdas <- 10^seq(log10(rangelambda2[1]),log10(rangelambda2[2]),length.out=100)
              # }else{
              #   lambdas <- 10^seq(log10(rangelambda2[1]),log10(rangelambda2[2]),length.out=30)
              # }
              # FRSS<-sapply(log(lambdas),RSSlambdatau)
              # minFRSS <- which.min(FRSS)
              # if(minFRSS==1) minFRSS <- rev(1:length(lambdas))[which.min(rev(FRSS))] #take least extreme lambda with same RSS
              # lambdashat[2] <- lambdas[minFRSS]
              # if(profplotRSS){ #profile plot lambda vs RSS
              #   profPlot <- plot(log10(lambdas),FRSS,xlab="hyperlambda (log10-scale)",ylab="RSS",
              #                    main=paste("Group set ",Partitions,", ",hypershrinkage," hypershrinkage",sep=""))
              #   abline(v=log10(lambdas[minFRSS]),col="red")
              #   if(!silent) print(paste("Estimated hyperlambda: ",lambdashat[2],sep=""))
              # }    
              # toc <- proc.time()[[3]]-tic
              
            }
            
            #-3.3.3|1.2.4 Compute group variance estimates for optimised hyperpenalty lambda ##############
            if(length(ExtraShrinkage2)==0){
              if(!silent) print(paste("Estimate group weights of group set ",Partitions,sep=""))
            }
            if(lambdashat[2]==0){
              gammatilde[indnot0] <- solve(A[indnot0,indnot0],Btau[indnot0])
              gamma <- pmax(0,gammatilde) #set negative tau to 0
            }else{
              gammatilde <- gammas(lambdashat[2])
              gamma <- pmax(0,gammatilde)
              
              if(length(ExtraShrinkage2)>0){
                if(!silent) print(paste("Select groups of group set ",Partitions,sep=""))
                if(all(gammatilde==0)){ #none selected
                  gamma <- rep(0,G)
                  gamma[indnot0] <- 1
                }else if(sum(gammatilde[indnot0]!=0)==1){ #just one selected
                  gamma <- gammatilde
                  Cnorm <- p/sum(c(gamma)%*%Zt)
                  gamma <- gamma*Cnorm
                }else{
                  #2.
                  output <- MoM(Partitions,hypershrinkage=ExtraShrinkage2,
                                pars=list(indnot0=which(gammatilde!=0),ind0=which(gammatilde==0)))
                  return(output)
                }
              }
              
            }
            if(normalise){
              Cnorm <- p/sum(c(gamma)%*%Zt)
              gammatilde <- gammatilde*Cnorm
              gamma <- gamma*Cnorm
            }
            
            if(any(is.nan(gamma))){warning("NaN in group variance");browser()}
          }
        }else{ 
          #-3.3.3|2 Without extra shrinkage---------------------------------------------------------------
          lambdashat <- c(0,0) 
          
          #-3.3.3|2.1 EB estimate group means ============================================================
          muhatp <-as.vector(rep(mu,sum(G))%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p
          muhatp[(1:p)%in%unpen] <- 0
          weightsMu <- rep(NaN,sum(G))
          if(!is.nan(mu)){
            muhat<-rep(mu,length(muhat))
          }else{
            if(all(is.nan(betaold))){
              betaold<-rep(1,p) #used as weights
            }else{
              #normalise=FALSE #make sure tau not scaled back to target
            }
            A.mu <- matrix(unlist(
              lapply(Partitions,function(i){ #for each partition
                sapply(1:length(Kg[[i]]),function(j){ #for each group
                  #compute row with gamma_{xy}
                  x<-groupsets[[i]][[j]]
                  unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){sum(L[x,]%*%t(t(R[,y])/Ik[[prt]][y])%*%diag(betaold[y]))/Kg[[i]][j]})}))
                }, simplify="array")
              })
            ),c(sum(G),sum(G)),byrow=TRUE) #reshape to matrix of size sum(G)xsum(G)
            Bmu <- unlist(
              lapply(Partitions,function(i){ #for each partition
                sapply(1:length(Kg[[i]]),function(j){ #for each group
                  x<-groupsets[[i]][[j]]
                  sum(betasinit[x]-muinitp[x]+L[x,]%*%(R[,pen]%*%muinitp[pen]))/Kg[[i]][j]
                })
              })
            )  
            if(any(is.nan(fixWeightsMu))){ #compute group means for specific partition
              #correct for fixed group means corresponding to groups with variance 0
              if(length(ind0)>0){
                muhat[indnot0] <- solve(A.mu[indnot0,indnot0],Bmu[indnot0]- 
                                          as.matrix(A.mu[indnot0,ind0],c(length(indnot0),length(ind0)))%*%muhat[ind0])
              }else{
                muhat[indnot0] <- solve(A.mu[indnot0,indnot0],Bmu[indnot0])
              }
              #muhat[ind0,Itr+1] <- muhat[ind0,Itr] #means of groups with variance 0 stay the same
              muhatp <-as.vector(c(muhat)%*%Zt)*betaold #px1 vector with estimated prior mean for beta_k, k=1,..,p
              muhatp[(1:p)%in%unpen] <- 0
              weightsMu <- muhat*p/sum(as.vector(c(muhat)%*%Zt))
            }else{ #compute partition weights/co-data weights
              weightsPart <- sqrt(G[indGrpsGlobal[Partitions]])
              weightMatrixMu <- matrix(rep(0,sum(G)*length(G)),sum(G),length(G))
              for(i in 1:length(G)){
                weightMatrixMu[indGrpsGlobal[[Partitions[i]]],i] <- fixWeightsMu[indGrpsGlobal[[Partitions[i]]]]
              }
              if(!all(round(fixWeightsMu,10)==1)){ #all partitions shrunk to overall mu
                weightsMu <- rep(1/length(Partitions),length(Partitions)) #partition/co-data weights
                muhat<-weightMatrixMu%*%weightsMu #group weights multiplied with partition/co-data weights
              }else{
                A.mutilde <- A.mu%*%weightMatrixMu%*%diag(weightsPart)
                muhat <- solve(t(A.mutilde)%*%A.mutilde,t(A.mutilde)%*%c(Bmu)) / weightsPart
                muhat<- pmax(0,muhat)
                weightsMu <- muhat/sum(muhat) #partition/co-data weights
                muhat<-weightMatrixMu%*%weightsMu #group weights multiplied with partition/co-data weights
              }
            }
            
            
            # if(normalise){ #TRUE by default
            #   C<-mutrgt*p/sum(muhatp)
            #   muhat[,Itr+1]<-muhat[,Itr+1]*C
            #   muhatp <-as.vector(c(muhat[,Itr+1])%*%Zt)*betaold #px1 vector with estimated prior mean for beta_k, k=1,..,p
            # }
            ## Should be same as:
            # A.mu <- ind %*% C %*% diag(betaold) %*% t(ind) /c(Kg)
            # Bmu <- ind %*% (betasinit - Cacc%*%rep(muinit,p)) /c(Kg) #betasinit depend on initial mutrgt
            # muhat <- solve(A.mu,Bmu)
          }
          
          #-3.3.3|2.2 EB estimate group variances ========================================================
          if(!is.nan(tausq)){
            gamma <- rep(1,length(gamma))
          }else{
            Btau <- unlist(
              lapply(Partitions,function(i){ #for each partition
                sapply(1:length(Kg[[i]]),function(j){ #for each group
                  #compute row with gamma_{xy}
                  x<-groupsets[[i]][[j]]
                  x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                  #sum(pmax(0,(betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1),na.rm=TRUE)/Kg[[i]][j]
                  sum((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,pen]%*%(muhatp[pen]-muinitp[pen])))^2)/V[x]-1,na.rm=TRUE)/Kg[[i]][j]
                })
              })
            )
            A <- matrix(unlist(
              lapply(Partitions,function(i){ #for each partition
                sapply(1:length(Kg[[i]]),function(j){ #for each group
                  #compute row with gamma_{xy}
                  x<-groupsets[[i]][[j]]
                  x<-setdiff(x,zeroV) #ad-hoc fix: remove covariates with 0 variance (will be set to 0 anyways)
                  #compute row with gamma_{xy}
                  unlist(sapply(Partitions,function(prt){sapply(groupsets[[prt]],function(y){
                    y<-setdiff(y,zeroV)
                    sum(t(c(1/V[x])*L[x,])%*%L[x,]*(R[,y]%*%(t(R[,y])/c(Ik[[prt]][y])*c(tauglobal[datablockNo[y]]))),na.rm=TRUE)/Kg[[i]][j]
                  })}))
                }, simplify="array")
              })
            ),c(sum(G),sum(G)),byrow=TRUE) #reshape to matrix of size sum(G)xsum(G)
            
            if(any(is.nan(fixWeightsTau))){
              if(!cont_codata){
                if(grepl("positive",hypershrinkage)){
                  # penMSE <- function(gamma,b,A,lam) return(sum((b-A%*%gamma)^2)) 
                  # #Aacc <- A%*%Wminhalf
                  # gamma <- rep(0,G)
                  # fitTau <- Rsolnp::solnp(par = rep(1,length(indnot0)), fun=penMSE, b=Btau[indnot0],
                  #                         A=A[indnot0,indnot0],
                  #                         LB = rep(0,length(indnot0)),control=list(trace=0))
                  # gamma[indnot0] <- as.vector(fitTau$pars)
                  # gammatilde <- gamma
                  
                  gamma <- rep(0,G)
                  fitTau <- nnls::nnls(A[indnot0,indnot0],Btau[indnot0])
                  gamma[indnot0] <- as.vector(fitTau$x)
                  gammatilde <- gamma
                }else{
                  gamma <- rep(0,G)
                  gammatilde <- solve(t(A[indnot0,indnot0])%*%A[indnot0,indnot0],
                                      t(A[indnot0,indnot0])%*%Btau[indnot0])
                  gamma <- pmax(0,gammatilde)
                  
                  if(normalise){
                    Cnorm <- p/sum(c(gamma)%*%Zt[,pen,drop=FALSE])
                    gamma<-gamma*Cnorm
                  }
                }
              }else{
                if(grepl("positive",hypershrinkage)){
                  # penMSE <- function(gamma,b,A,lam) return(sum((b-A%*%gamma)^2))
                  # #Aacc <- A%*%Wminhalf
                  # 
                  # gamma <- rep(0,G)
                  # fitTau <- Rsolnp::solnp(par = rep(1,length(indnot0)), fun=penMSE, b=Btau,
                  #                         A=A[,indnot0,drop=FALSE],
                  #                         LB = rep(0,length(indnot0)),control=list(trace=0))
                  # gamma[indnot0] <- as.vector(fitTau$pars)
                  # gammatilde <- gamma
                  
                  gamma <- rep(0,G)
                  fitTau <- nnls::nnls(A[,indnot0,drop=FALSE],Btau)
                  gamma[indnot0] <- as.vector(fitTau$x)
                  gammatilde <- gamma
                }else{
                  gamma <- rep(0,G)
                  gammatilde <- solve(t(A[,indnot0,drop=FALSE])%*%A[,indnot0,drop=FALSE],
                                      t(A[,indnot0,drop=FALSE])%*%Btau)
                  gamma <- gammatilde
                  #gamma <- pmax(0,gammatilde)
                  # if(normalise){
                  #   Cnorm <- p/sum(c(gamma)%*%Zt[,pen])
                  #   gamma<-gamma*Cnorm
                  # }
                }
              }
              
              if(any(is.nan(gamma))){warning("NaN in group variance")}
            }else{ #compute partition weights/co-data weights
              if(!silent) print("Estimate group set weights")
              weightsPart <- sqrt(G[Partitions])
              weightMatrixTau <- matrix(rep(0,sum(G)*length(G)),sum(G),length(G))
              for(i in 1:length(G)){
                weightMatrixTau[indGrpsGlobal[[Partitions[i]]],i] <- fixWeightsTau[indGrpsGlobal[[Partitions[i]]]]
              }
              if(all(round(fixWeightsTau,10)==1)){ #all partitions shrunk to overall mu
                gamma <- rep(1/length(Partitions),length(Partitions)) #partition/co-data weights
              }else{
                if(any(partWeightsTau[,Itr]==0)){
                  set0 <- unlist(indGrpsGlobal[which(partWeightsTau[,Itr]==0)])
                  ind0 <- union(ind0,set0)
                  indnot0 <- setdiff(indnot0,set0)
                }
                
                Atilde <- A[indnot0,indnot0]%*%weightMatrixTau[indnot0,partWeightsTau[,Itr]!=0] 
                
                #Three options to solve for partition weights (use only one):
                #Solve for tau and truncate negative values to 0
                #browser()
                cosangle<-t(as.matrix(t(Zt[,pen,drop=FALSE])%*%weightMatrixTau))%*%as.matrix(t(Zt[,pen,drop=FALSE])%*%weightMatrixTau)
                cosangle<-abs(t(cosangle/sqrt(diag(cosangle)))/sqrt(diag(cosangle)))
                cosangle <- cosangle-diag(rep(1,m))
                if(any(cosangle>0.999)){
                  indAngle <- which(cosangle>0.999,arr.ind=TRUE)
                  indAngle <- indAngle[indAngle[,1]<indAngle[,2],,drop=FALSE]
                  print(paste("Estimated group weights for group sets",indAngle[,1],
                              "and",indAngle[,2],"found to be similar"))
                  print("Switch to constrained optimisation such that group set weights >0 for stability")
                  
                }
                
                gammatilde<-rep(0,m)
                temp<-try(solve(t(Atilde)%*%Atilde,t(Atilde)%*%c(Btau[indnot0])),silent=TRUE)
                if(class(temp)[1]=="try-error" | any(cosangle>0.999)){
                  # #Solve for tau>=0 with convex optimisation package
                  D<-length(G)
                  w <- CVXR::Variable(D)
                  objective <- CVXR::Minimize(sum((Atilde%*%w-c(Btau[indnot0]))^2))
                  constraint1 <- diag(rep(1,D))%*%w >= 0
                  problem <- CVXR::Problem(objective,constraints = list(constraint1))
                  result <- solve(problem)
                  gamma <- c(result$getValue(w))
                  gammatilde <- gamma
                  gamma <- pmax(0,gammatilde) #correct round-off errors: partition/co-data weights
                  if(all(is.na(gamma))){
                    #infeasible CVXR problem; remove one of similar group sets and give equal weight
                    indremove <- unique(indAngle[,2]) #index of group sets to be removed
                    indmatch <- sapply(indremove,function(k){ #index to map removed back to matches
                      indmatch <- indAngle[which(indAngle[,2]==k),1]
                      indmatch <- setdiff(indmatch,indremove)[1] #take first in case multiple matches
                      return(indmatch)
                    })
                    temp<-try(solve(t(Atilde[,-indremove])%*%Atilde[,-indremove],
                                    t(Atilde[,-indremove])%*%c(Btau[indnot0])),silent=TRUE)
                    if(class(temp)[1]=="try-error"){
                      # #Solve for tau>=0 with convex optimisation package
                      D<-length(G)-length(indremove)
                      w <- CVXR::Variable(D)
                      objective <- CVXR::Minimize(sum((Atilde[,-indremove]%*%w-c(Btau[indnot0]))^2))
                      constraint1 <- diag(rep(1,D))%*%w >= 0
                      problem <- CVXR::Problem(objective,constraints = list(constraint1))
                      result <- solve(problem)
                      gamma <- rep(0,length(G))
                      gamma[-indremove] <- c(result$getValue(w))
                      gamma[indremove] <- gamma[indmatch]
                      gammatilde <- gamma
                      gamma <- pmax(0,gammatilde) #correct round-off
                    }else{
                      gammatilde<-rep(0,m)
                      gammatilde[-indremove] <- temp
                      gamma[indremove] <- gamma[indmatch]
                      gamma <- pmax(0,gammatilde)
                    }
                  }
                }else{
                  gammatilde[partWeightsTau[,Itr]!=0] <- temp
                  gamma <- pmax(0,gammatilde)
                  #temp<-optim(gamma,function(x){sum((Atilde%*%x-c(Btau[indnot0]))^2)},lower=rep(0,length(gamma)),method="L-BFGS-B")
                  #gamma<-temp$par 
                }
                
                if(0){
                  #solve for w\in[0,1] with convex optimisation package when Atilde is computationally singular
                  #library(CVXR)
                  D<-length(G)
                  w <- CVXR::Variable(D)
                  objective <- CVXR::Minimize(sum((Atilde%*%w-c(Btau[indnot0]))^2))
                  constraint1 <- diag(rep(1,D))%*%w >= 0
                  constraint2 <-  matrix(rep(1,D), nrow = 1)%*%w ==1
                  problem <- CVXR::Problem(objective,constraints = list(constraint1, constraint2))
                  result <- CVXR::solve(problem)
                  gammatilde <- c(result$getValue(w))
                  gamma <- pmax(0,gammatilde) #correct round-off errors: partition/co-data weights
                }
                
                if(normalise){
                  gammatilde <- gammatilde/sum(gamma)
                  gamma <- gamma/sum(gamma)
                }
              }
            }
          }
        }
      }else{ #Set up MoM equations using pxG co-data matrix 
        #-3.3.3|2 With co-data provided in Z matrix-----------------------------------------------------
        #-3.3.3|2.1 EB estimate group means ============================================================
        #Not yet supported
        muhatp <-as.vector(rep(mu,sum(G))%*%Zt) #px1 vector with estimated prior mean for beta_k, k=1,..,p
        muhatp[(1:p)%in%unpen] <- 0
        weightsMu <- rep(NaN,sum(G))
        if(!is.nan(mu)){
          muhat<-rep(mu,length(muhat))
        }else{
          if(cont_codata) stop("Not implemented for prior means, provide co-data
                                in groupsets")
        }
        
        #-3.3.3|2.2 EB estimate group variances ========================================================
        if(!is.nan(tausq)){
          gamma <- rep(1,length(gamma))
        }else{
          #-3.3.3|1.2.1 Compute linear system for whole partition #####################################
          x<-pen
          x<-setdiff(x,zeroV)
          Btau <- ((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,x]%*%(muhatp[x]-muinitp[x])))^2)/V[x]-1)
          #Btau <- pmax(0,((betasinit[x]^2-(muinitp[x]+L[x,]%*%(R[,x]%*%(muhatp[x]-muinitp[x])))^2)/V[x]-1))
 
          Ln2 <- try(t(apply(t(c(1/V[x])*L[x,,drop=FALSE]), 2, rep, n) *
                         apply(L[x,,drop=FALSE], 1, rep, each=n)), silent=TRUE) #pxn^2 matrix
          Rn2 <- try(matrix(R[,x,drop=FALSE]%*%(t(apply(R[,x,drop=FALSE] , 2,