R/operation_on_susiF_obj.R

Defines functions which_dummy_cs.susiF which_dummy_cs update_residual_variance.susiF update_residual_variance update_pi.susiF update_pi update_KL.susiF update_KL update_cal_credible_band.susiF update_cal_credible_band update_cal_fit_func.susiF update_cal_fit_func update_cal_indf.susiF update_cal_indf update_cal_cs.susiF update_cal_cs update_susiF_obj.susiF update_susiF_obj update_alpha_hist.susiF update_alpha_hist update_alpha.susiF update_alpha test_stop_cond.susiF test_stop_cond rename_format_output out_prep.susiF out_prep name_cs.susiF name_cs merge_effect.susiF merge_effect init_susiF_obj greedy_backfit.susiF greedy_backfit get_alpha.susiF get_alpha get_post_F2.susiF get_post_F2 get_post_F.susiF get_post_F get_G_prior.susiF get_G_prior get_ER2 get_lBF.susiF get_lBF get_pi.susiF get_pi get_fitted_effect.susiF get_fitted_effect expand_susiF_obj estimate_residual_variance.susiF estimate_residual_variance discard_cs.susiF discard_cs check_cs.susiF check_cs change_fit cal_partial_resid.susiF cal_partial_resid

Documented in cal_partial_resid cal_partial_resid.susiF change_fit check_cs check_cs.susiF discard_cs discard_cs.susiF estimate_residual_variance estimate_residual_variance.susiF get_ER2 get_fitted_effect get_fitted_effect.susiF get_G_prior get_G_prior.susiF greedy_backfit greedy_backfit.susiF init_susiF_obj merge_effect merge_effect.susiF out_prep out_prep.susiF test_stop_cond test_stop_cond.susiF update_alpha update_alpha_hist update_alpha_hist.susiF update_alpha.susiF update_cal_cs update_cal_cs.susiF update_cal_fit_func update_cal_fit_func.susiF update_KL update_KL.susiF update_residual_variance update_residual_variance.susiF which_dummy_cs which_dummy_cs.susiF

################################## Operations on susiF object ############################


#' @title Compute partial residual for effect l
#
#' @param   obj a susiF  or EBmvFR object  
#
#' @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
#' @param X matrix of covariates
#
#' @param D matrix of wavelet D coefficients from the original input data (Y)
#
#' @param C vector of wavelet scaling coefficient from the original input data (Y)
#
#' @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
#' @return a matrix of size N by size J of partial residuals
#
#' @export
#' 
#' @keywords internal
cal_partial_resid  <- function(  obj, l, X, D, C,  indx_lst,... )

  UseMethod("cal_partial_resid")

#' @rdname cal_partial_resid
#
#' @method cal_partial_resid susiF
#
#' @export cal_partial_resid.susiF
#
#' @export
#' 
#' @keywords internal


cal_partial_resid.susiF  <- function(  obj, l, X, D, C,  indx_lst,... )
{

  L <- obj$L

  if (L > 1){
    id_L <- (1:L)[ - ( (l%%L) +1) ]#Computing residuals R_{l+1} by removing all the effect except effect l+1

    if(inherits(get_G_prior(obj),"mixture_normal_per_scale" ))
    {
      update_D  <-  D - Reduce("+", lapply  ( id_L, function(l) X%*%sweep(obj$fitted_wc[[l]][,-indx_lst[[length(indx_lst)]]],
                                                                          1,
                                                                          obj$alpha[[l]],
                                                                          "*"
                                                                          )
                                              )
                                )
      update_C  <-  C  - Reduce("+", lapply  ( id_L, function(l) X%*%(obj$fitted_wc[[l]][,indx_lst[[length(indx_lst)]]]*obj$alpha[[l]]
                                                                          )
                                               )
                                )
      update_Y  <- cbind(  update_D, update_C)

    }
    if(inherits(get_G_prior(obj),"mixture_normal" ))
    {




      update_D  <-  D - Reduce("+", lapply  ( id_L, function(l) X%*%sweep(obj$fitted_wc[[l]][,-dim(obj$fitted_wc[[l]])[2]],
                                                                          1,
                                                                          obj$alpha[[l]],
                                                                          "*"
                                                                          )
                                              )
                              )
      update_C  <-  C  - Reduce("+", lapply  ( id_L, function(l) X%*%(obj$fitted_wc[[l]][,dim(obj$fitted_wc[[l]])[2]]*obj$alpha[[l]]

                                                                          )
                                              )
                                )

          update_Y  <- cbind(  update_D, update_C)
    }
  }else{
    id_L <- 1

    if(inherits(get_G_prior(obj),"mixture_normal_per_scale" ))
    {
      update_D  <-  D - Reduce("+", lapply  ( id_L, function(l) X%*%sweep(obj$fitted_wc[[l]][,-indx_lst[[length(indx_lst)]]],
                                                                          1,
                                                                          obj$alpha[[l]],
                                                                          "*"
                                                                        )
                                            )
                                )
        update_C  <-  C  - Reduce("+", lapply  ( id_L, function(l) X%*%(obj$fitted_wc[[l]][,indx_lst[[length(indx_lst)]]]*obj$alpha[[l]]
                                                                         )
                                            )
                               )

      update_Y  <- cbind(  update_D, update_C)
    }
    if(inherits(get_G_prior(obj),"mixture_normal" ))
    {




      update_D  <-  D - Reduce("+", lapply  ( id_L, function(l) X%*%sweep(obj$fitted_wc[[l]][,-dim(obj$fitted_wc[[l]])[2]],
                                                                          1,
                                                                          obj$alpha[[l]],
                                                                          "*"
                                                                   )
                                           )
                               )
      update_C  <-  C  - Reduce("+", lapply  ( id_L, function(l) X%*%(obj$fitted_wc[[l]][,dim(obj$fitted_wc[[l]])[2]]*obj$alpha[[l]]

                                               )
                                        )
                                )

      update_Y  <- cbind(  update_D, update_C)
    }
  }

  return(update_Y)
}


#' @title Change postprocessing used in susiF object
#' @param   obj a fitted susiF  object  
#
#' @param Y functional phenotype used in the original fit
#'
#' @param X matrix of  covariates  used in the original fit
#' @param to name of the type of postprocessing you wish to change the susiF  object  to be fitted with
#' current option "TI" or "hMM"
#' @param verbose If \code{verbose = TRUE}, the algorithm's progress,
#' and a summary of the optimization settings are printed on the
#' console.
#' @param filter_cs logical, if TRUE filters the credible set (removing low-purity)
#' cs and cs with estimated prior equal to 0). Set as TRUE by default.
#' @param max_scale numeric, define the maximum of wavelet coefficients used in the analysis (2^max_scale).
#'        Set 10 true by default.
#' @param filter.number see documentation of wd from wavethresh package
#'
#' @param family see documentation of wd from wavethresh package
#' @export

change_fit= function( obj,
                      Y,
                      X, 
                      to="TI",
                      verbose=TRUE,
                      max_scale=10,
                      filter_cs =TRUE,
                      filter.number = 10,
                      family =  "DaubLeAsymm" 
){
  
  
  post_processing=to
  if (is.null( obj$pos))
  {
    pos <- 1:dim(Y)[2]
  }else{
    pos=obj$pos
  }
  
  names_colX <-  colnames(X)  
  tidx <- which(apply(X,2,var)==0)
  if( length(tidx)>0){
    warning(paste("Some of the columns of X are constants, we removed" ,length(tidx), "columns"))
    X <- X[,-tidx]
  }
  map_data <- remap_data(Y=Y,
                         pos= pos,
                         verbose=vebose,
                         max_scale=max_scale)
  
  outing_grid <- map_data$outing_grid
  Y           <- map_data$Y
  X <- colScale(X)
  
  indx_lst <-  gen_wavelet_indx(log2(length( outing_grid)))
  # centering input
  #Y0 <-  colScale(Y , scale=FALSE)
  Y  <- colScale(Y )
  out <- out_prep(     obj            = obj, 
                       Y             =   sweep(Y  , 2, attr(Y , "scaled:scale"),  "*"),
                       X             = X,
                       indx_lst      = indx_lst,
                       filter_cs     = filter_cs,
                       outing_grid   = outing_grid,
                       filter.number = filter.number,
                       family        = family,
                       post_processing=  post_processing,
                       tidx          = tidx,
                       names_colX    = names_colX,
                       pos           = pos
  )
  return( out)
  
  
}


#' @title Check purity credible sets
#
#' @param obj a susif object defined by init_susiF_obj function
#' @param min_purity minimal purity within a CS
#' @param X matrix of covariates
#
#' @return a obj without "dummy" credible s
#
#' @export
#' @keywords internal

check_cs <- function(obj, min_purity=0.5,X,...)
  UseMethod("check_cs")

#' @rdname check_cs
#
#' @method check_cs susiF
#
#' @export check_cs.susiF
#
#' @export
#' @keywords internal


check_cs.susiF <- function(obj, min_purity=0.5,X,  ...)
{
  dummy.cs <- which_dummy_cs(obj, min_purity=min_purity,X)

    if( length(dummy.cs)==0)
    {
      return(obj)
    }else{
      
      obj <- discard_cs( obj,cs=dummy.cs, out_prep= TRUE)
      return(obj)
    }


}


#' @title Discard credible sets
#
#' @param obj a susif object defined by init_susiF_obj function
#
#' @param cs vector of integer containing the credible sets to discard
#
#' @param out_prep logical, if set to true perform cleaning for final output
#
#' @return a obj without "dummy" credible sets
#
#' @export
#' @keywords internal

discard_cs <- function(obj, cs,out_prep,...)
  UseMethod("discard_cs")

#' @rdname discard_cs
#
#' @method discard_cs susiF
#
#' @export discard_cs.susiF
#
#' @export
#' @keywords internal

discard_cs.susiF <- function(obj, cs, out_prep=FALSE,  ...)
{ 
  
 
    if( length(cs)==obj$L){
        cs <- cs[-1]
        if(length(cs)==0){
          return(obj)
        }
    }
  
  if(1 %in% cs ){
    obj$alpha [[1]]      <-  rep(1 / length(obj$alpha[[1]]),
                                 length(obj$alpha[1]))
    obj$lBF [[1]]        <-  0*obj$lBF [[1]] 
    obj$fitted_wc [[1]]  <-  0*obj$fitted_wc[[1]]
    obj$fitted_wc2 [[1]] <-  0*obj$fitted_wc2[[1]]
    obj$cs               <-  1: length(obj$alpha[[1]])
    obj$fitted_func[[1]] <-  0*obj$fitted_func [[1]]
    } 
  if ( length(cs)==1 & 1 %in% cs ){
    return(obj)
  }else{
    if( 1 %in% cs ){
      cs= cs[-which(cs==1)]
    }
    if( length(cs)>0){
      
      obj$alpha       <-  obj$alpha[ -cs]
      obj$lBF         <-  obj$lBF[ -cs]
      obj$fitted_wc   <-  obj$fitted_wc[ -cs]
      obj$fitted_wc2  <-  obj$fitted_wc2[ -cs]
      obj$cs          <-  obj$cs[ -cs]
      if(out_prep){
        obj$fitted_func <-  obj$fitted_func[ -cs]
      }else{
        obj$greedy_backfit_update <- TRUE
        obj$KL                    <- obj$KL[ -cs]
        obj$ELBO                  <- -Inf
      }
      
      obj$est_sd      <-  obj$est_sd[ -cs]
      obj$est_pi      <-  obj$est_pi[ -cs]
      obj$cred_band   <-  obj$cred_band[ -cs]
      obj$lfsr_wc     <-  obj$lfsr_wc [ -cs]
      obj$L           <-  obj$L -length(cs)
    }
 
  }

  return(obj)
}

#' @title Update residual variance
#
#' @param  obj a susiF object defined by init_susiF_obj function
#
#' @param Y wavelet transformed  functional phenotype, matrix of size N by size J.
#
#' @param X matrix of size N by p
#
#
#' @return estimated residual variance
#
#' @export
#' @keywords internal
estimate_residual_variance <- function( obj,Y,X,... )
  UseMethod("estimate_residual_variance")

#' @rdname estimate_residual_variance
#
#' @method estimate_residual_variance susiF
#
#' @export estimate_residual_variance.susiF
#
#' @export
#
estimate_residual_variance.susiF <- function( obj,Y,X,... )
{

  out <-  (1/(prod(dim(Y))))*get_ER2 (  obj,Y, X  )
  return(out)
}

# @title Expand obj by adding L_extra effect
#
# @param obj a obj
#
# @param L_extra numeric a number of effect to add
#
# @return a obj a L_extra effect. Note the the number of effect of the obj cannot exceed the number the user upper bound
# @export
# @keywords internal

expand_susiF_obj <- function(obj,L_extra)
{
  L_extra <- ifelse (  obj$L_max - (obj$L+L_extra) <0 ,#check if we are adding more effect that maximum specified by user
                       abs(obj$L_max -(obj$L)),
                       L_extra

  )
  if( L_extra==0){
    return(obj)
  }else{
    L_old <- obj$L
    L_new <- obj$L+L_extra
    obj$L <- ifelse(L_new<(obj$P+1),L_new,obj$P)

    for ( l in (L_old+1):obj$L )
    {
      obj$fitted_wc[[l]]        <-  0*obj$fitted_wc[[1]]
      obj$fitted_wc2[[l]]       <-  0*obj$fitted_wc2[[1]] +1
      obj$alpha [[l]]           <-  rep(0, length(obj$alpha [[1]]))
      obj$cs[[l]]               <-  list()
      obj$est_pi [[l]]          <-  obj$est_pi[[1]]
      obj$est_sd [[l]]          <-  obj$est_sd[[1]]
      obj$lBF[[l]]              <-  rep(NA, length( obj$lBF[[1]]))
      obj$cred_band[[l]]        <-  matrix(0, ncol = ncol(obj$cred_band[[1]] ), nrow = 2)
      obj$KL                    <-  rep(NA,obj$L)
      obj$ELBO                  <-  c()
    }
    obj$n_expand <- obj$n_expand+1
    obj$greedy_backfit_update <- TRUE
    return(obj)
  }

}

#' Access fitted effect
#'
#' Retrieves the estimated effect from a fitted model.
#' @title Access fitted effect l
#' @param obj A fitted object.
#' @param l Effect of interest.
#' @param cred_band Logical, default set to `FALSE`. If `TRUE`, also returns the credible band.
#' @param alpha Numerical, defined as 1 - conf_level   set to obtain  0.99 confidence level by default.
#' @param ... Other arguments.
#'
#' @return The fitted effect.
#'
#' @export
get_fitted_effect <- function(obj, l,cred_band,alpha ,  ...) {
  UseMethod("get_fitted_effect")
}

#' @rdname get_fitted_effect
#' @method get_fitted_effect susiF
#'
#' @export
get_fitted_effect.susiF <- function(obj,
                                    l=1,
                                    cred_band = FALSE,
                                    alpha = 0.01, 
                                    ...) {
  
  if(cred_band){
    
    if(   !is.null(obj$fitted_var[[l]]) ){
      out_function =obj$fitted_func[[l]] 
      coeff= qnorm(1- alpha /2)
      
      cred_band_out = 0*obj$cred_band[[l]]
      cred_band_out[1, ]= out_function +  coeff* sqrt(obj$fitted_var[[l]])
      cred_band_out[2, ]= out_function -  coeff* sqrt(obj$fitted_var[[l]])
      
      return( list(effect= out_function,
                   cred_band=  cred_band_out))
    }else{
      warning("credible band option not available for post processing = HMM or none")
      return( unlist(obj$fitted_func[[l]]))
      
    }
    
    
  }else{
    return( unlist(obj$fitted_func[[l]]))
  }
  
  
}

#
# @title Access susiF mixture proportion of effect l
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
# @return a vector of  proportion
#
# @export
#
get_pi  <- function(obj, l,...)
  UseMethod("get_pi")

# @rdname get_pi
#
# @method get_pi susiF
#
# @export get_pi.susiF
#

#' @export
#' @keywords internal


get_pi.susiF <- function(obj, l,...)
{

  if( l >  length(obj$est_pi))
  {
    stop("Error trying to access mixture proportion")
  }
  if( l < 1)
  {
    stop("Error l should be larger ")
  }
  out <- obj$est_pi[[l]]
  return(out)
}





# @title Access susiF log Bayes factors of effect l
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
# @return a vector of log Bayes Factors
#
# @export
#
get_lBF  <- function(obj, l,...)
  UseMethod("get_lBF")

# @rdname get_lBF
#
# @method get_lBF susiF
#
# @export get_lBF.susiF
#
#' @export
#' @keywords internal

get_lBF.susiF <- function(obj, l,...)
{

  if( l >   obj$L)
  {
    stop("Error trying to access mixture proportion")
  }
  if( l < 1)
  {
    stop("Error l should be larger ")
  }
  out <- obj$lBF[[l]]
  return(out)
}










#' @title Compute Epected sum of square
#
#' @param obj a susiF object defined by init_susiF_obj function
#
#' @param Y wavelet transformed  functional phenotype, matrix of size N by size J.
#
#' @param X matrix of size N by p
#
#
#' @return estimated residual variance
#' @export
get_ER2 <- function( obj,Y,X,... )
  UseMethod("get_ER2")


#' @rdname get_ER2
#
#' @method get_ER2 susiF
#
#' @export get_ER2.susiF
#
#' @export
#' @keywords internal

get_ER2.susiF = function (  obj,Y, X,  ...) {


  obj <- obj
  postF <- get_post_F(obj )# J by N matrix
  #Xr_L = t(X%*% postF)
  postF2 <- get_post_F2(obj ) # Posterior second moment.

  return(sum(t((Y - X%*%postF ))%*%(Y - X%*%postF ) )  -sum(t(postF)%*%postF) + sum(    postF2))
}



#
#' @title Access susiF internal prior
#
#' @param obj a susiF object defined by init_susiF_obj function
#
#' @return G_prior object
#
#' @export
#' @keywords internal
#
get_G_prior  <- function(obj,...)
  UseMethod("get_G_prior")



#' @rdname get_G_prior
#
#' @method get_G_prior susiF
#
#' @export get_G_prior.susiF
#
#' @export
#' @keywords internal
get_G_prior.susiF <- function(obj,...)
{
  out <- obj$G_prior
  return(out)
}





# @title Compute posterior mean of the fitted effect
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param l , optional effect to update
#
# @return  A J by T matrix of posterior wavelet coefficient,
# \item{if l  missing}{return sum of the  effect  posterior mean }
# \item{if l not missing}{return effect specific posterior mean}
get_post_F <- function(obj,l,...)
  UseMethod("get_post_F")

# @rdname get_post_F
#
# @method get_post_F susiF
#
# @export get_post_F.susiF
#
#' @export
#' @keywords internal

get_post_F.susiF <- function(obj,l,...)
{
  if(missing(l))
  {
    out <-  Reduce("+",lapply(1:obj$L, FUN=function(l) obj$alpha[[l]] * obj$fitted_wc[[l]]))
  }else{
    out <-   obj$alpha[[l]] * obj$fitted_wc[[l]]
  }

  return(out)
}



# @title Compute posterior second moment
#
# @param obj a susiF object defined by init_susiF_obj function
# @param l , optional effect to update
#
# @return  A J by T matrix of posterior wavelet coefficient,
# \item{if l  missing}{return sum of the effects  posterior second moment }
# \item{if l not missing}{return effect specific posterior second moment}

get_post_F2 <- function(obj,l,...)
  UseMethod("get_post_F2")

# @rdname get_post_F2
#
# @method get_post_F2 susiF
#
# @export get_post_F2.susiF
#
#' @export
#' @keywords internal

get_post_F2.susiF <- function(obj, l,...)
{
  if(missing(l))
  {
    out <-  Reduce("+",lapply(1:obj$L, FUN=function(l) obj$alpha[[l]] *(obj$fitted_wc2[[l]]+ obj$fitted_wc [[l]]^2)))
  }else{
    out <-   obj$alpha[[l]] *( obj$fitted_wc2[[l]]+ obj$fitted_wc[[l]]^2)
  }

  return(out)
}



# @title Update alpha  susiF mixture proportion of effect l
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
#
# @return susiF object
#
# @export
#
#
get_alpha  <-  function(obj, l,...  )
  UseMethod("get_alpha")

# @rdname get_alpha
#
# @method get_alpha susiF
#
# @export get_alpha.susiF
#
#' @export
#' @keywords internal

get_alpha.susiF <-  function(obj, l,...  )
{
  out <- obj$alpha[[l]]
  return( out)
}

#' @title Update  susiF via greedy search or backfit
#
#' @param obj a susiF object defined by init_susiF_obj function
#' @param X matrix of size n by p contains the covariates
#' @param min_purity minimum purity for estimated credible sets
#' @param verbose If \code{verbose = TRUE}, the algorithm's progress,
# and a summary of the optimization settings, are printed to the
# console.
#' @param cov_lev the desired level of converage
#' @return susiF object
#
#' @export
#' @keywords internal
#
#
greedy_backfit  <-  function(obj,
                             verbose,
                             cov_lev,
                             X,
                             min_purity,...  )
  UseMethod("greedy_backfit")

#' @rdname greedy_backfit
#
#' @method greedy_backfit susiF
#
#' @export greedy_backfit.susiF
#
#' @export
#' @keywords internal
greedy_backfit.susiF <-  function(obj,
                                  verbose,
                                  cov_lev,
                                  X,
                                  min_purity,...  )
{


  obj <- update_alpha_hist(obj)
  if(!(obj$greedy)&!(obj$backfit))
  {
    return(obj)
  }
  obj <- update_cal_cs(obj,
                             cov_lev=cov_lev)


  dummy.cs <-  which_dummy_cs(obj,
                              min_purity = min_purity,
                              median_crit=TRUE,
                              X=X)


  if(obj$backfit & (length(dummy.cs)>0)){

    obj$greedy <- FALSE
    if(length(dummy.cs)== obj$L){
      dummy.cs <- dummy.cs[-1]
      obj$backfit <- FALSE
    }
    if( length(dummy.cs)==0  )
    {
      obj$backfit <- FALSE
    }else{
      temp_L <- obj$L


      obj <- discard_cs(obj,
                              cs= dummy.cs,
                              out_prep= FALSE
      )

      if( length(obj$cs)>1){
        A <- cal_cor_cs(obj, X)$cs_cor
        tl <- which(A>0.99, arr.ind = TRUE)
        tl <-  tl[- which( tl[,1]==tl[,2]),]

        if ( dim(tl)[1]==0){

        }else{

          tl <-  tl[which(tl[,1] < tl[,2]),]
          obj <- merge_effect(obj, tl)

        }
      }
      if(verbose){
        print( paste( "Discarding ",(temp_L- obj$L), " effects"))
      }
    }
    return(obj)

  }##Conditions for stopping greedy search
  if(  (obj$L>obj$L_max))
  {

    obj$greedy <- FALSE



    obj <- discard_cs(obj,
                            cs= (obj$L_max+1):obj$L,
                            out_prep= FALSE
    )




    if( length(obj$cs)>1){
      A <- cal_cor_cs(obj, X)$cs_cor
      tl <- which(A>0.99, arr.ind = TRUE)
      tl <-  tl[- which( tl[,1]==tl[,2]),]

      if ( dim(tl)[1]==0){

      }else{

        tl <-  tl[which(tl[,1] < tl[,2]),]
        obj <- merge_effect(obj, tl)

      }
    }
    if(verbose){
      print( paste( "Discarding ",(obj$L_max- obj$L), " effects"))
      print( "Greedy search and backfitting done")
    }

  }

  if( length(dummy.cs)==0& !( obj$greedy))
  {
    obj$backfit <- FALSE
  }

  if(!(obj$greedy )&!(obj$backfit ) ){

    if( length(obj$cs)>1){
      A <- cal_cor_cs(obj, X)$cs_cor
      tl <- which(A>0.99, arr.ind = TRUE)
      tl <-  tl[- which( tl[,1]==tl[,2]),]

      if ( dim(tl)[1]==0){

      }else{

        tl <-  tl[which(tl[,1] < tl[,2]),]
        obj <- merge_effect(obj, tl)

      }
    }

    if(verbose){
      print( paste( "Discarding ",(obj$L_max- obj$L), " effects"))
     print( "Greedy search and backfitting done")
    }
    obj <- update_alpha_hist(obj,discard = TRUE)
    obj$greedy_backfit_update <- FALSE

    return(obj)
  }
  if(obj$greedy & (length(dummy.cs)==0)){

    tt <- obj$L_max -obj$L
    temp <- min( ifelse(tt>0,tt,0 ) , 7)

    if(temp==0){
      if( length(obj$cs)>1){
        A <- cal_cor_cs(obj, X)$cs_cor
        tl <- which(A>0.99, arr.ind = TRUE)
        tl <-  tl[- which( tl[,1]==tl[,2]),]

        if ( dim(tl)[1]==0){

        }else{

          tl <-  tl[which(tl[,1] < tl[,2]),]
          obj <- merge_effect(obj, tl)

        }
      }

      if(verbose){
        print( paste( "Discarding ",(obj$L_max- obj$L), " effects"))
        print( "Greedy search and backfitting done")
      }
      obj <- update_alpha_hist(obj,discard = TRUE)
      obj$greedy_backfit_update <- FALSE
      obj$backfit <- FALSE
      obj$greedy <- FALSE
      return(obj)
    }


    if(verbose){
      print( paste( "Adding ", temp, " extra effects"))
    }



    if( length(obj$cs)>1){
      A <- cal_cor_cs(obj, X)$cs_cor
      tl <- which(A>0.99, arr.ind = TRUE)
      tl <-  tl[- which( tl[,1]==tl[,2]),]

      if ( dim(tl)[1]==0){

      }else{

        tl <-  tl[which(tl[,1] < tl[,2]),]
        obj <- merge_effect(obj, tl, discard=FALSE)

      }
    }




    obj <- expand_susiF_obj(obj,L_extra = temp)
    return(obj)
  }

}




#' @title Initialize a susiF object using regression coefficients
#' @details  Initialize a susiF object using regression coefficients
#
#' @param L_max upper bound on the number of non zero coefficients An L-vector containing the indices of the
#   nonzero coefficients.
#
#' @param G_prior prior object defined by init_prior function
#
#' @param Y Matrix of outcomes
#
#' @param X matrix of covariatess
#
#' @param L_start number of effect to start with
#
#' @param greedy logical, if TRUE allow greedy search
#
#' @param backfit logical, if TRUE allow backfitting
#' @param tol_null_prior threshold to consider prior to be null. If the estimated weight on the point mass at zero is larger than 1-tol_null_prior
#' then set prior weight on point mass to be 1. In the mixture normal this corresponds to removing the effect. In the mixutre per scale prior this corresponds
#' to setting the prior of a given scale to at point mass at 0.
#' @param cov_lev numeric between 0 and 1, corresponding to the
#' expected level of coverage of the CS if not specified, set to 0.95
#' @param \dots Other arguments.
#
# @export
# @return A list with the following elements
# \item{fitted_wc}{ list of length L, each element contains the fitted wavelet coefficients of effect l}
# \item{fitted_wc2}{list of length L, each element contains the variance of the fitted wavelet coefficients of effect l}
# \item{alpha_hist}{ history of the fitted alpha value}
# \item{N}{ number of indidivual in the study}
# \item{sigma2}{residual variance}
# \item{n_wac}{number of wavelet coefficients}
# \item{ind_fitted_func}{fitted curves of each individual }
# \item{cs}{credible set}
# \item{pip}{Posterior inclusion probabilites}
# \item{G_prior}{a G_prior of the same class as the input G_prior, used for internal calculation}
# \item{lBF}{ log Bayes factor for the different effect}
# \item{KL}{ the KL divergence for the different effect}
# \item{ELBO}{ The evidence lower bound}
# \item{lfsr_wc}{Local fasle sign rate of the fitted wavelet coefficients}
#' @export
#
init_susiF_obj <- function(L_max,
                           G_prior,
                           Y,
                           X,
                           L_start,
                           greedy,
                           backfit,
                           tol_null_prior=0.001,
                           cov_lev=0.95,
                           ... )
{



  fitted_wc       <- list()
  fitted_wc2      <- list()
  alpha           <- list()
  alpha_hist      <- list()
  ind_fitted_func <- matrix(0, nrow = dim(Y)[1], ncol=dim(Y)[2]  )
  cs              <- list()
  cred_band       <- list()
  pip             <- rep(0, dim(X)[2])
  est_pi          <- list()
  est_sd          <- list()
  L_max           <- L_max
  L               <- L_start
  G_prior         <- G_prior
  N               <- dim(Y)[1]
  n_wac           <- dim(Y)[2]
  P               <- dim(X)[2]
  sigma2          <- mean(apply(Y,2 ,var))
  lBF             <- list()
  KL              <- rep(NA,L_start)
  ELBO            <- c()
  mean_X          <- attr(X, "scaled:center")
  csd_X           <- attr(X, "scaled:scale")
  d               <- attr(X , "d")
  n_expand        <- 0 #number of greedy expansion
  greedy          <- greedy
  backfit         <- backfit
  greedy_backfit_update <- FALSE
  lfsr_wc         <- list()
  for ( l in 1:L )
  {
    fitted_wc[[l]]        <-  matrix(0, nrow = dim(X)[2], ncol=dim(Y)[2]  )
    fitted_wc2[[l]]       <-  matrix(1, nrow = dim(X)[2], ncol=dim(Y)[2]  )
    alpha [[l]]           <-  rep(1/dim(X)[2], dim(X)[2])
    cs[[l]]               <-  list()
    est_pi [[l]]          <-  get_pi_G_prior(G_prior)
    est_sd [[l]]          <-  get_sd_G_prior(G_prior)
    lBF[[l]]              <-  rep(NA, ncol(X))
    cred_band[[l]]        <-  matrix(0, ncol = dim(Y)[2], nrow = 2)
  }
  obj <- list( fitted_wc       = fitted_wc,
               fitted_wc2      = fitted_wc2,
               lBF             = lBF,
               KL              = KL,
               cred_band       = cred_band,
               ELBO            = ELBO,
               ind_fitted_func = ind_fitted_func,
               G_prior         = G_prior,
               alpha_hist      = alpha_hist,
               N               = N,
               n_wac           = n_wac,
               sigma2          = sigma2,
               P               = P,
               alpha           = alpha,
               cs              = cs,
               pip             = pip,
               est_pi          = est_pi,
               est_sd          = est_sd,
               L               = L,
               L_max           = L_max,
               csd_X           = csd_X,
               n_expand        = n_expand,
               greedy          = greedy,
               backfit         = backfit,
               greedy_backfit_update=greedy_backfit_update,
               d               = d,
               lfsr_wc         = lfsr_wc,
               tol_null_prior  = tol_null_prior,
               cov_lev         = cov_lev)

  class(obj) <- "susiF"
  return(obj)
}



#' @title Merging effect function
#
#' @param obj a susiF object defined by init_susiF_obj function
#
#' @param tl see  \code{\link{greedy_backfit}}
#
#' @param discard logical, if set to TRUE allow discarding redundant effect
#
#
#
#' @return  a susiF object
#' @export
#' @keywords internal
merge_effect <- function( obj, tl,...)
  UseMethod("merge_effect")

#' @rdname merge_effect
#
#' @method merge_effect susiF
#
#' @export merge_effect.susiF
#
#' @export
#' @keywords internal

merge_effect.susiF <- function( obj, tl, discard=TRUE,  ...){




  if(is.vector( tl)){
    #print( tl)
    obj$fitted_wc[[tl[ 2]]] <- 0* obj$fitted_wc[[tl[  2]]]
    obj$fitted_wc[[tl[  1]]] <- obj$fitted_wc[[tl[  1]]] +   obj$fitted_wc[[tl[ 2]]]
    obj$fitted_wc2[[tl[ 1]]] <- obj$fitted_wc2[[tl[  1]]] +   obj$fitted_wc2[[tl[  2]]]
    #obj$fitted_wc[[tl[  2]]] <- 0* obj$fitted_wc[[tl[ 2]]]
    tindx <-  tl[  2]
  }else{
    tl <- tl[order(tl[,1], tl[,2], decreasing = TRUE),]
    #print( tl)
    tindx <- c(0)
    for ( o in 1:dim(tl)[1]){

      if ( tl[o, 2]%!in%tindx){
        obj$fitted_wc[[tl[o, 2]]] <- 0* obj$fitted_wc[[tl[o, 2]]]
        obj$fitted_wc[[tl[o, 1]]] <-obj$fitted_wc[[tl[o, 1]]] +   obj$fitted_wc[[tl[o, 2]]]
        obj$fitted_wc2[[tl[o, 1]]] <-obj$fitted_wc2[[tl[o, 1]]] +   obj$fitted_wc2[[tl[o, 2]]]
       # obj$fitted_wc[[tl[o, 2]]] <- 0* obj$fitted_wc[[tl[o, 2]]]
        tindx <- c(tindx, tl[o, 2])
      }

    }

    tindx <- tindx[-1]

  }
  if(discard){
   obj<-  discard_cs(obj,cs=tindx, out_prep=FALSE)
  }

  return( obj)
}



# @title Updates CS names for output
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param X matrix of size N by p

name_cs <- function(obj,X,...)
  UseMethod("name_cs")

# @rdname name_cs
#
# @method name_cs susiF
#
# @export name_cs.susiF
#
#' @export
#' @keywords internal

name_cs.susiF <- function(obj,X,...){

  if( length(colnames(X))==ncol(X)){

    for (l in 1: length(obj$cs)){
      names(obj$cs[[l]]) <- colnames(X)[obj$cs[[l]]]
    }

  }
  return(obj)
}


#' @title Preparing output of main susiF function
#
#' @param  obj a susiF object defined by init_susiF_obj function
#
#' @param Y functional phenotype, matrix of size N by size J. The underlying algorithm uses wavelets that assume that J is of the form J^2. If J is not a power of 2, susiF internally remaps the data into a grid of length 2^J
#
#' @param X matrix of size N by p
#
#' @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
#' @param filter_cs logical, if TRUE filter the credible set (removing low purity cs and cs with estimated prior equal to 0)
#
#' @param outing_grid grid use to fit fsusie
#' 
#' @param pos the original position of the Y column
#' @return susiF object
#
#' @export
#' @keywords internal
out_prep <- function( obj, Y, X, indx_lst, outing_grid,...)
  UseMethod("out_prep")

#' @rdname out_prep
#
#' @method out_prep susiF
#
#' @export out_prep.susiF
#
#' @export
#' @keywords internal

out_prep.susiF <- function(obj ,
                           Y,
                           X,
                           indx_lst,
                           outing_grid,
                           filter_cs,
                           
                           filter.number = 10,
                           family =  "DaubLeAsymm",
                           post_processing="TI",
                           tidx =NULL ,
                           names_colX =NULL,
                           pos,
                           ...)
{


  obj <-  update_cal_pip(obj)

  obj <-  name_cs(obj,X)

 # obj <-  update_lfsr_effect(obj)
  if(filter_cs)
  {
    obj  <- check_cs(obj,
                          min_purity = 0.5,
                          X          = X
                          )
  }
 
  
    obj <-  update_cal_fit_func(obj,
                                      Y             =  Y,
                                      X             = X,
                                      indx_lst      = indx_lst,
                                      post_processing = post_processing, 
                                      filter.number = filter.number,
                                      family        = family)

    if( ! (post_processing== "HMM")){
       
      obj <-  update_cal_indf(obj = obj ,
                                    Y         =  Y,
                                    X         = X,
                                    indx_lst  = indx_lst,
                                    TI        = ifelse(post_processing %in% c('TI', 'smash'), TRUE, FALSE))


    }
  #
  #


  obj             <-  rename_format_output (obj        = obj, 
                                            names_colX = names_colX,
                                            tidx       = tidx)
  obj$outing_grid   <-  outing_grid
 # obj$purity        <-  cal_purity(l_cs= obj$cs, X=X)
  obj$original_grid <- pos
  return(obj)
}





rename_format_output <- function(obj, names_colX, tidx, ...){
  
  
  if (!is.null(names_colX)){
    
  
  
  if ( length(tidx)>0){
    for ( l in 1:length(obj$cs)){
    
      
      talpha  <- rep (0, length(names_colX))
      talpha[-tidx] <- obj$alpha[[l]]
      obj$alpha[[l]] <- talpha
      names(obj$alpha[[l]]) <- names_colX
      
      names(obj$fitted_func)[l]<- paste("fitted_function_effect_", l, sep = "")
     
    }
    tpip  <- rep (0, length(names_colX))
    tpip[-tidx] <- obj$pip 
    obj$pip  <- tpip
    names(obj$pip) <- names_colX
    obj <- update_cal_cs(obj, cov_lev =obj$cov_lev)
  }else{
    for ( l in 1:length(obj$cs)){
     
      names(obj$alpha[[l]]) <- names_colX
      names(obj$fitted_func)[l]<- paste("fitted_function_effect_", l, sep = "")
       
    }
    names(obj$pip) <- names_colX
    obj <- update_cal_cs(obj, cov_lev = obj$cov_lev)
  }
  }
  
   
  return(obj)
}






#' @title Check tolerance for stopping criterion
#
#' @param obj a susiF object defined by init_susiF_obj function
#' @param check numeric, dynamic value for testing outing of th while loop
#' @param cal_obj logical, if set to TRUE compute ELBO
#
#' @param X matrix of covariates
#
#' @param D matrix of wavelet D coefficients from the original input data (Y)
#
#' @param C vector of wavelet scaling coefficient from the original input data (Y)
#
#' @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
#' @return a matrix of size N by size J of partial residuals
#' @export
#' @keywords internal
test_stop_cond <- function(obj,
                           check,
                           cal_obj,
                           Y,
                           X,
                           D,
                           C,
                           indx_lst
                           ,...)
  UseMethod("test_stop_cond")

#' @rdname test_stop_cond
#
#' @method test_stop_cond susiF
#
#' @export test_stop_cond.susiF
#
#' @export
#' @keywords internal
test_stop_cond.susiF <- function(obj, check, cal_obj, Y, X, D, C, indx_lst,...)
{

  if( obj$L==1)
  {
    obj$check <- 0
    return(obj)
  }

  if(!(obj$greedy_backfit_update)) #if not just updated check for stopping while loop
  {
    if( cal_obj){

      obj <- update_KL(obj,
                             X,
                             D= D,
                             C= C , indx_lst)

      obj <- update_ELBO(obj,
                               get_objective( obj = obj,
                                              Y         = Y,
                                              X         = X,
                                              D         = D,
                                              C         = C,
                                              indx_lst  = indx_lst
                               )
      )

      if(length(obj$ELBO)>1    )#update parameter convergence,
      {
        check <- abs(diff(obj$ELBO)[(length( obj$ELBO )-1)])
        obj$check <- check
        return(obj)
      }else{
        obj$check <- check
        return(obj)
      }
    }
    else{
      len <- length( obj$alpha_hist)
      if( len>1)#update parameter convergence, no ELBO for the moment
      {
        check <-0

        T1 <- do.call( rbind, obj$alpha_hist[[len ]])
        T1 <- T1[1:obj$L,] #might be longer than L because alpha computed before discarding effect
        T2 <- do.call( rbind, obj$alpha_hist[[(len-1) ]])


        if(!(obj$L==nrow(T2))){
          return(obj)
        }
        T2 <- T2[1:obj$L,]
        if(obj$L==1){
          T2 <- T2[1,]
        }

        if((nrow(T1)>nrow(T2))){
          obj$check <- 1
          return(obj)
        }
        if( (nrow(T2)>nrow(T1))){
          T2 <- T2[1:obj$L,]
        }


        check <- sum(abs(T1-T2))/nrow(X)
        obj$check <- check
        return(obj)
        #print(check)
      }else{
        obj$check <- check
        return(obj)
      }
    }
  }else{
    obj$check <- check
    return(obj)
  }

}





#' @title Update alpha   susiF mixture proportion of effect l
#
#' @param obj a susiF object defined by init_susiF_obj function
#
#' @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
#' @param alpha  vector of p alpha values summing up to one
#
#' @return susiF object
#
#' @export
#' @keywords internal

update_alpha  <-  function(obj, l, alpha,... )
  UseMethod("update_alpha")


#' @rdname update_alpha
#
#' @method update_alpha susiF
#
#' @export update_alpha.susiF
#
#' @export
#' @keywords internal
update_alpha.susiF <-  function(obj, l, alpha,... )
{
  obj$alpha[[l]] <- alpha
   return( obj)
}

#' @title Update alpha_hist   susiF object
#
#' @param obj a susiF object defined by init_susiF_obj function
#
#' @param  discard logical set to FALSE by default, if true remove element of history longer than L
#
#' @return susiF object
#
#' @export
#' @keywords internal

update_alpha_hist  <-  function(obj, discard,... )
  UseMethod("update_alpha_hist")


#' @rdname update_alpha_hist
#
#' @method update_alpha_hist susiF
#
#' @export update_alpha_hist.susiF
#
#' @export
#' @keywords internal
update_alpha_hist.susiF <-  function(obj , discard=FALSE,... )
{
    if(!discard){
        obj$alpha_hist[[ (length(obj$alpha_hist)+1)  ]] <- obj$alpha
    }
  if(discard){
    if((length(obj$alpha_hist[[length(obj$alpha_hist)]]) >obj$L)){

      tt <- obj$alpha_hist[[ (length(obj$alpha_hist) ) ]][1:obj$L]
      obj$alpha_hist[[ (length(obj$alpha_hist))  ]] <- tt
    }
  }

  return( obj)
}


# @title Update  susiF object using the output of EM_pi
#
# @description Update  susiF object using the output of EM_pi
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
# @param EM_pi an object of the class "EM_pi" generated by the function \code{\link{EM_pi}}
#
# @param Bhat matrix of estimated regression coefficients
#
# @param Shat  matrix of estimated standard errors
#
# @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
# @param lowc_wc list of wavelet coefficients that exhibit too little variance
#
# @param cal_wc_lsfr logical, set to TRUE if detailed output needed
#
# @return susiF object
#
# @export


update_susiF_obj  <- function(obj, l,
                              EM_pi, 
                              Bhat,
                              Shat,
                              indx_lst,
                              lowc_wc=NULL, 
                              cal_wc_lsfr=FALSE,
                              df=NULL,
                              cov_lev=0.95,
                              e=0.001,
                              ...)
      UseMethod("update_susiF_obj")

# @rdname update_susiF_obj
#
# @method update_susiF_obj susiF
#
# @export update_susiF_obj.susiF
#
#' @export
#' @keywords internal

update_susiF_obj.susiF <- function(obj,
                                   l,
                                   EM_pi,
                                   Bhat,
                                   Shat,
                                   indx_lst,
                                   lowc_wc=NULL,
                                   cal_wc_lsfr=FALSE,
                                   df=NULL,
                                   cov_lev=0.95,
                                   e=0.001,
                                   ...)
{

  if( l > length(obj$est_pi))
  {
    stop("Error trying to access mixture proportion")
  }
  if( l < 1)
  {
    stop("Error l should be larger ")
  }
  if(  "EM_pi"  %!in%  class(EM_pi)  )
  {
    stop("Error EM_pi should be of the class EM_pi")
  }
  obj         <-   update_pi(obj = obj ,
                                   l = l ,
                                   tpi =  EM_pi$tpi_k)
  obj$G_prior <-   update_prior(get_G_prior(obj) , EM_pi$tpi_k  )


  obj$fitted_wc[[l]]   <- post_mat_mean(get_G_prior(obj) ,
                                              Bhat,
                                              Shat,
                                              lBF      = EM_pi$lBF,
                                              indx_lst = indx_lst,
                                              lowc_wc  = lowc_wc,
                                              e        = e)
  obj$fitted_wc2[[l]]  <- post_mat_sd  (get_G_prior(obj) ,
                                              Bhat,
                                              Shat,
                                              lBF      = EM_pi$lBF,
                                              indx_lst = indx_lst,
                                              lowc_wc  = lowc_wc,
                                              e        = e)^2


  G_prior <- update_prior(G_prior = get_G_prior(obj),
                          tpi     = EM_pi$tpi_k )


  new_alpha <- cal_zeta(  EM_pi$lBF)

  obj  <- update_alpha (obj = obj,
                              l         = l,
                              alpha     = new_alpha)
  obj  <- update_lBF   (obj = obj,
                              l         = l,
                              lBF       = EM_pi$lBF)
  obj <-  update_cal_cs(obj=obj,  cov_lev=  cov_lev, l=l)
  # obj <-  update_lfsr  (obj=obj,
  #                             l=l ,
  #                            Bhat=Bhat,
  #                            Shat= Shat,
  #                            indx_lst=indx_lst)
  return(obj)
}

#' @title Update susiF by computing PiP
#
#' @param obj a susiF object defined by  init_susiF_obj  function
#' @return susiF object
#' @export
#' @keywords internal

update_cal_pip  <- function (obj,...)
  UseMethod("update_cal_pip")

#' @rdname update_cal_pip
#
#' @method update_cal_pip susiF
#
#' @export update_cal_pip.susiF
#
#' @export
#' @keywords internal

update_cal_pip.susiF <- function (obj,...)
{
  if(sum( is.na(unlist(obj$alpha))))
  {
    stop("Error: some alpha value not updated, please update alpha value first")
  }
  tpip <- list()
  for ( l in 1:obj$L)
  {
    tpip[[l]] <- rep(1, lengths(obj$alpha)[[l]])-obj$alpha[[l]]
  }
  obj$pip <- 1-  apply( do.call(rbind,tpip),2, prod)
  return(obj)
}

#' @title Update susiF by computing credible sets
#
#' @param obj a susiF object defined by  init_susiF_obj  function
#
#' @param cov_lev numeric between 0 and 1, corresponding to the expected level of coverage of the cs if not specified set to 0.95
#
#' @return susiF object
#
#' @export
#' @keywords internal
#'
update_cal_cs  <- function(obj, cov_lev=0.95,...)
  UseMethod("update_cal_cs")

#' @rdname update_cal_cs
#
#' @method update_cal_cs susiF
#
#' @export update_cal_cs.susiF
#
#' @export
#'
#' @keywords internal
#'
update_cal_cs.susiF <- function(obj, cov_lev=0.95, l,...)
{
  if( !missing(l)){
    if(sum( is.na(unlist(obj$alpha[[l]]))))
    {

      stop("Error: some alpha value not updated, please update alpha value first")
    }
    temp        <- obj$alpha[[l]]

    # check if temp has only 0 (i.e.  not yet updated)
    #  if(sum(temp==0)==length(temp)){
    temp_cumsum        <- cumsum( temp[order(temp, decreasing =TRUE)])
    max_indx_cs        <- min(which( temp_cumsum >cov_lev ))
    obj$cs[[l]]        <- order(temp, decreasing = TRUE)[1:max_indx_cs ]
    return(obj)
  }

  if(sum( is.na(unlist(obj$alpha))))
  {
    stop("Error: some alpha value not updated, please update alpha value first")
  }
  for ( l in 1:obj$L)
  {
    temp        <- obj$alpha[[l]]

    # check if temp has only 0 (i.e.  not yet updated)
    #  if(sum(temp==0)==length(temp)){
      temp_cumsum <- cumsum( temp[order(temp, decreasing =TRUE)])
      max_indx_cs <- min(which( temp_cumsum >cov_lev ))
      obj$cs[[l]]  <- order(temp, decreasing = TRUE)[1:max_indx_cs ]
      names(obj$cs[[l]]) <- names(obj$alpha[[l]])[obj$cs[[l]]]
  }

  return(obj)
}

#@title Update susiF by computing predicted curves
#
#@param obj a susiF object defined by init_susiF_obj function
#@param Y functional phenotype, matrix of size N by size J. The underlying algorithm uses wavelet which assume that J is of the form J^2. If J not a power of 2, susiF internally remaps the data into grid of length 2^J
#@param X matrix of size N by p
#@param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#@return susiF object
#@export
#

update_cal_indf <- function(obj, Y, X, indx_lst, TI=FALSE,...)
  UseMethod("update_cal_indf")

# @rdname update_cal_indf
#
# @method update_cal_indf susiF
#
# @export update_cal_indf.susiF
#
#
# @export
#
#' @importFrom wavethresh wr
#' @importFrom wavethresh wd
update_cal_indf.susiF <- function(obj, Y, X, indx_lst, TI=FALSE,...)
{


  if( TI){

    idx_lead_cov <- list()

    for (l in 1:length(obj$alpha)){
      idx_lead_cov[[l]]  <- which.max(obj$alpha[[l]])
    }
    
    mean_Y          <- attr(Y, "scaled:center")
    obj$ind_fitted_func <- matrix(mean_Y,
                                        byrow=TRUE,
                                        nrow=nrow(Y),
                                        ncol=ncol(Y))+Reduce("+",
                                        lapply(1:length(obj$alpha),
                                               function(l)
                                                 matrix( X[,idx_lead_cov[[l]]] , ncol=1)%*%  t(obj$fitted_func[[l]] )*(attr(X, "scaled:scale")[idx_lead_cov[[l]]])
                                               )
                                        )

    return( obj)
  }else{
    mean_Y          <- attr(Y, "scaled:center")
    if(sum( is.na(unlist(obj$alpha))))
    {
      stop("Error: some alpha value not updated, please update alpha value first")
    }
    temp <- wavethresh::wd(rep(0, obj$n_wac)) #create dummy wd object


    if(inherits(get_G_prior(obj),"mixture_normal_per_scale" ))
    {
      for ( i in 1:obj$N)
      {
        obj$ind_fitted_func[i,]  <- mean_Y#fitted_baseline future implementation
        for ( l in 1:obj$L)
        {
          #add wavelet coefficient
          temp$D                         <-    ( obj$alpha[[l]] * X [i,])%*%obj$fitted_wc[[l]][,-indx_lst[[length(indx_lst)]]]
          temp$C[length(temp$C)]         <-    ( obj$alpha[[l]] * X [i,])%*%obj$fitted_wc[[l]][,indx_lst[[length(indx_lst)]]]
          #transform back
          obj$ind_fitted_func[i,]  <-  obj$ind_fitted_func[i,]+wavethresh::wr(temp)
        }
      }
    }
    if(inherits(get_G_prior(obj),"mixture_normal" ))
    {
      for ( i in 1:obj$N)
      {
        obj$ind_fitted_func[i,]  <- mean_Y#fitted_baseline
        for ( l in 1:obj$L)
        {
          #add wavelet coefficient
          temp$D                         <-    (obj$alpha[[l]] * X [i,])%*%obj$fitted_wc[[l]][,-dim(obj$fitted_wc[[l]])[2]]
          temp$C[length(temp$C)]         <-    (obj$alpha[[l]] * X [i,]) %*%obj$fitted_wc[[l]][,dim(obj$fitted_wc[[l]])[2]]
          #transform back
          obj$ind_fitted_func[i,]  <-  obj$ind_fitted_func[i,]+wavethresh::wr(temp)
        }
      }
    }
    return( obj)
  }


}



#' @title Update susiF by computing posterior curves
#
#' @param  obj a susiF/EBmvFR object defined by init_susiF_obj function
#' @param Y  functional phenotype, matrix of size N by size J. The underlying algorithm uses wavelet which assume that J is of the form J^2. If J not a power of 2, susif internally remaps the data into grid of length 2^J
#
#' @param X matrix of size n by p in
#
#' @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#' 
#'@param filter.number see documentation of wd from wavethresh package
#'@param family see documentation of wd from wavethresh package
#' @return susiF object
#
#' @export
#' @keywords internal
update_cal_fit_func  <- function( obj, indx_lst,  ...)
  UseMethod("update_cal_fit_func")

#' @rdname update_cal_fit_func
#
#' @method update_cal_fit_func susiF
#
#' @export update_cal_fit_func.susiF
#
#' @importFrom wavethresh wr
#
#' @importFrom wavethresh wd
#
#' @export
#' @keywords internal

update_cal_fit_func.susiF <- function(obj,
                                      indx_lst,
                                      Y,
                                      X,
                                      post_processing="TI",
                                      filter.number = 10,
                                      family = "DaubLeAsymm" ,...)
{


  if(sum( is.na(unlist(obj$alpha))))
  {
    stop("Error: some alpha value not updated, please update alpha value first")
  }
  
  
  dummy_cs = which_dummy_cs(obj, min_purity=0.5 ,X)
  if( obj$L==1 &   1 %in% dummy_cs ){
    
    obj$fitted_func[[1]] = rep(0 , ncol(Y))
    return(obj)
  }
  
  
  if ( post_processing == "TI"){
    obj <- TI_regression(obj=obj,
                               Y=Y,
                               X=X,
                               filter.number = 1 ,
                               family = "DaubExPhase"

    )
  }
  if( post_processing =="HMM"){
    obj <- HMM_regression(obj=obj,
                          Y=Y,
                          X=X
    )
    obj$cred_band <- NULL
  }
  if(post_processing=="smash"){
    obj <- smash_regression(obj=obj,
                         Y=Y,
                         X=X,
                         filter.number = filter.number,
                         family = family
                         
    )
  }
  if( post_processing=="none"){
    temp <- wavethresh::wd(rep(0, obj$n_wac))
    
    if(inherits(get_G_prior(obj),"mixture_normal_per_scale" ))
    {
      for ( l in 1:obj$L)
      {
        temp$D                     <- (obj$alpha[[l]])%*%sweep( obj$fitted_wc[[l]][,-indx_lst[[length(indx_lst)]]],
                                                                1,
                                                                1/(obj$csd_X ), "*")
        temp$C[length(temp$C)]     <- (obj$alpha[[l]])%*% (obj$fitted_wc[[l]][,indx_lst[[length(indx_lst)]]]*( 1/(obj$csd_X )))
        obj$fitted_func[[l]] <-  wavethresh::wr(temp)
        
      }
    }
    if(inherits(get_G_prior(obj),"mixture_normal" ))
    {
      for ( l in 1:obj$L)
      {
        temp$D                     <- (obj$alpha[[l]])%*%sweep(obj$fitted_wc[[l]][,-dim(obj$fitted_wc[[l]])[2]],
                                                               1,
                                                               1/(obj$csd_X ), "*")
        temp$C[length(temp$C)]     <- (obj$alpha[[l]])%*% (obj$fitted_wc[[l]][,dim(obj$fitted_wc[[l]])[2]]*( 1/(obj$csd_X )) )
        obj$fitted_func[[l]] <- wr(temp)
        
      }
    }
  }
 

  return(obj)
}


# @title Update susiF by computing credible band for posterior curves
#
# @param obj a susiF object defined by init_susiF_obj function
#
# @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
# @return susiF object
#
# @export
update_cal_credible_band  <- function(obj, indx_lst,...)
  UseMethod("update_cal_credible_band")

# @rdname  update_cal_credible_band
#
# @method  update_cal_credible_band susiF
#
# @export  update_cal_credible_band.susiF
#
# @importFrom wavethresh wr
#
# @importFrom wavethresh wd
# @importFrom wavethresh GenW
#
#' @export
#' @keywords internal

update_cal_credible_band.susiF <- function(obj, indx_lst,...)
{

  if(sum( is.na(unlist(obj$alpha))))
  {
    stop("Error: some alpha value not updated, please update alpha value first")
  }
  temp <- wavethresh::wd(rep(0, obj$n_wac))


  for ( l in 1:obj$L)
  {
    Smat <-  obj$fitted_wc2[[l]]
    W1   <- ((wavethresh::GenW(n=  ncol(Smat )  , filter.number = 10, family = "DaubLeAsymm")))
    tt   <- diag( W1%*%diag(c(obj$alpha[[l]]%*%Smat ))%*% t(W1 ))

    up                       <-  obj$fitted_func[[l]]+ 3*sqrt(tt)#*sqrt(obj$N-1)
    low                      <-  obj$fitted_func[[l]]- 3*sqrt(tt)#*sqrt(obj$N-1)
    obj$cred_band[[l]] <- rbind(up, low)
  }



  return(obj)
}




#' @title Update susiF lfsr effect
#'
#' @param obj a susiF object defined by init_susiF_obj function
#'
#' @return susiF object
#'
#' @export
#'
#' @keywords internal
#'
update_lfsr_effect  <- function    (obj ,...)
  UseMethod("update_lfsr_effect")

#' @rdname update_lfsr_effect
#
#' @method update_lfsr_effect susiF
#
#' @export update_lfsr_effect.susiF
#
#' @export
#'
#' @keywords internal

update_lfsr_effect.susiF  <- function  (obj ,...){

    obj$lfsr <- lapply(1:length(obj$cs) ,
                             function(l) min(obj$lfsr_wc[[l]])
                             )
    return( obj)
}



#' @title Update susiF log Bayes factor
#
#' @param obj a susiF object defined by init_susiF_obj function
#' @param l effect to update
#' @param lBF vector of length p, containing the updated log Bayes factors
#' @return susiF object
#' @export
#' @keywords internal


update_lBF  <- function    (obj, l, lBF,...)
  UseMethod("update_lBF")

#' @rdname update_lBF
#
#' @method update_lBF susiF
#
#' @export update_lBF.susiF
#
#' @export
#' @keywords internal


update_lBF.susiF <- function    (obj,l, lBF,...)
{
  if(l> obj$L)
  {
    stop("Error: trying to update more effects than the number of specified effect")
  }

  obj$lBF[[l]] <- lBF
  return(obj)
}




#'@title Update susiF local False Sign Rate
#
#'@param obj a susiF object defined by init_susiF_obj function
#
#'@param l effect to update
#
#' @param Bhat  matrix pxJ regression coefficient, Bhat[j,t] corresponds to regression coefficient of Y[,t] on X[,j]
#
#' @param Shat matrix pxJ standard error, Shat[j,t] corresponds to standard error of the regression coefficient of Y[,t] on X[,j]
#
#'@param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#'@return susiF object
#'@export
#' @keywords internal


update_lfsr <- function    (obj, l, Bhat, Shat, indx_lst,...)
  UseMethod("update_lfsr")

#' @title Update susiF lfsr.
#'
#' @rdname update_lfsr
#
#' @method update_lfsr susiF
#
#' @export update_lfsr.susiF
#
#' @export
#
update_lfsr.susiF <- function (obj, l, Bhat, Shat, indx_lst, ...) {
  clfsr_wc <-  cal_clfsr(get_G_prior(obj),  Bhat,Shat,indx_lst )
  obj$lfsr_wc[[l]] <- cal_lfsr (clfsr_wc,obj$alpha[[l]])
  return(obj)
}

#' @title Update susiF log Bayes factor
#
#' @param obj a susiF object defined by init_susiF_obj function
#'
#' @param  ELBO new ELBO value
#'
#' @return susiF object
#'
#' @export
#'
#' @keywords internal
#'
update_ELBO  <- function    (obj,ELBO ,...)
  UseMethod("update_ELBO")

#' @rdname update_ELBO
#
#' @method update_ELBO susiF
#
#' @export update_ELBO.susiF
#
#' @export
#' @keywords internal


update_ELBO.susiF <- function    (obj,ELBO,...)
{

  obj$ELBO <- c(obj$ELBO,ELBO)
  return(obj)
}

#' @title Compute KL divergence effect l
#'  @param obj a susiF object
#' @param X matrix of covariates
#
#' @param D matrix of wavelet D coefficients from the original input data (Y)
#
#' @param C vector of wavelet scaling coefficient from the original input data (Y)
#
#' @param indx_lst list generated by gen_wavelet_indx for the given level of resolution
#
#' @return susiF object
#' @export
#' @keywords internal

update_KL <- function(obj,   X, D, C , indx_lst,...)
  UseMethod("update_KL")



#' @rdname update_KL
#
#' @method update_KL susiF
#
#' @export update_KL.susiF
#
#' @export
#' @keywords internal
#

update_KL.susiF <- function(obj,  X, D, C , indx_lst,...)
{

  obj$KL <-  do.call(c,lapply(1:obj$L,FUN=function(l) cal_KL_l(obj=obj,
                                                                           l=l,
                                                                           X=X,
                                                                           D=D,
                                                                           C=C,
                                                                           indx_lst =indx_lst )))
  return( obj)
}





# @title Update mixture proportion of susiF mixture proportions of effect l
#
# @param obj a susif object defined by init_susiF_obj function
#
# @param l integer larger or equal to 1. Corresponds to the effect to be accessed
#
# @param tpi an object of the class "pi_mixture_normal" or "pi_mixture_normal_per_scale"
#
# @return susiF object
#
# @export

update_pi <- function( obj, l, tpi,...)
  UseMethod("update_pi")

# @rdname update_pi
#
# @method update_pi susiF
#
# @export update_pi.susiF
#
#' @export
#' @keywords internal

update_pi.susiF <- function( obj, l, tpi,...)
{

  if( l > length(obj$est_pi))
  {
    stop("Error trying to access mixture proportion")
  }
  if( l < 1)
  {
    stop("Error l should be larger ")
  }
  if( class(tpi)%!in% c("pi_mixture_normal" , "pi_mixture_normal_per_scale"))
  {
    stop("Error tpi should be of one of the follwoing class:\n
          pi_mixture_normal \n pi_mixture_normal_per_scale")
  }
  obj$est_pi[[l]] <- tpi
 
   
  return(obj)
}





#' @title Update residual variance
#' @description  See title
#' @param  obj a susiF object
#' @param sigma2 the new value for residual variance
#' @export
#' @keywords internal

update_residual_variance  <- function( obj,sigma2,...)
  UseMethod("update_residual_variance")

#' @rdname update_residual_variance
#
#' @method update_residual_variance susiF
#
#' @export update_residual_variance.susiF
#
#' @export
#' @keywords internal

update_residual_variance.susiF <- function( obj,sigma2,...)
{

  obj$sigma2 <- sigma2
  return(obj)
}



#
#' @title Return which credible sets are  dummy
#
#' @param obj a susif object defined by init_susiF_obj function
#' @param min_purity minimal purity within a CS
#' @param X matrix of covariates
#' @param median_crit remove cs base on max absolute correlation instead of min absolute correlation, usefull in the
#
#' @return a list of index corresponding the the dummy effect
#
#' @export
#' @keywords internal
which_dummy_cs <- function(obj, min_purity=0.5,X,median_crit=FALSE,...)
  UseMethod("which_dummy_cs")



#' @rdname which_dummy_cs
#
#' @method which_dummy_cs susiF
#
#' @export which_dummy_cs.susiF
#' @export
#' @keywords internal
which_dummy_cs.susiF <- function(obj, min_purity=0.5,X,median_crit=FALSE,...){

  dummy.cs<- c()
 # if( obj$L==1){
 #   return(dummy.cs)
 # }

  f_crit <- function (obj, min_purity=0.5, l, median_crit=FALSE){
    if( median_crit){
      #if( length(obj$cs[[l]] )  < ncol(X)/10) {
      #  is.dummy.cs <- FALSE
      #   return(is.dummy.cs )
      #}
      if(length(obj$cs[[l]]) <5){
        is.dummy.cs <- FALSE
      }else{
        tt <-  cor( X[,obj$cs[[l]]])

        is.dummy.cs <-   median(abs( tt[lower.tri(tt, diag =FALSE)]))  <  min_purity
      }


    }else{
      is.dummy.cs <-   min(abs(cor( X[,obj$cs[[l]]]))) <  min_purity
    }

    return( is.dummy.cs)
  }





  if( inherits( obj$G_prior,"mixture_normal"))
  {
    for (l in 1:obj$L )
    {

      if (length(obj$cs[[l]])==1)
      {

        if( obj$est_pi[[l]][1]>1-obj$tol_null_prior){# check if the estimated prior is exactly 0

          dummy.cs<-  c( dummy.cs,l)
        }

      }else{

        if(   f_crit(obj = obj, min_purity=0.5, l, median_crit )){#check if the purity of cs l is lower that min_purity

          dummy.cs<-  c( dummy.cs,l)

        }else{
          if(obj$est_pi[[l]][1]==1){
            dummy.cs<-  c( dummy.cs,l)
          }

        }
      }

    }
   
    
      return(dummy.cs)
    
  }

  if(inherits(obj$G_prior,"mixture_normal_per_scale"))
  {
    for (l in 1:obj$L )
    {

      if (length(obj$cs[[l]])==1)
      {

        if(  mean(sapply(obj$est_pi[[l]],"[[",1))>1-obj$tol_null_prior){# check if the estimated prior is exactly 0

          dummy.cs<-  c( dummy.cs,l)
        }

      }else{

        if(  f_crit(obj = obj, min_purity=0.5, l, median_crit )){#check if the purity of cs l is lower that min_purity

          dummy.cs<-  c( dummy.cs,l)

        }else{
          if(  mean(sapply(obj$est_pi[[l]],"[[",1))> 1-obj$tol_null_prior){
            dummy.cs<-  c( dummy.cs,l)
          }

        }
      }

    }
   
      return(dummy.cs)
    
  }

}
stephenslab/susiF.alpha documentation built on March 1, 2025, 4:28 p.m.