R/rwlsSolution.R

Defines functions .rwlsSolution

#' Inner Function: `.rwlsSolution()`
#' 
#' Generalized linear model regression with Lasso and Weighted Fusion penalties for 
#' a binary outcome. Specifically, the coefficient at a given lambda_1 (controls Lasso penalty)
#' and lambda_2 (controls Weighted Fusion penalties) is iteratively updated with reweighted 
#' least square. The final result is determined by convergence thresholds. 
#' 

#' @noRd
#' @param data An object of class "WTsth.data" as generated by prep() and
#'   post-processed by .expandData() considering the type of "weight" selected.
#' @param X.app A matrix object. The A weight matrix scaled by lambda2.
#' @param Y.app A numeric vector. 
#' @param lambda1 A number. A Lambda_1 value to be considered. Provided
#'   values will be transformed to 2^(lambda1).
#' @param iter.control A list object. Allows user to control iterative
#'   update procedure. Allowed elements are "max.iter", the maximum number
#'   of iterations; "tol.beta", the difference between consecutive beta
#'   updates below which the procedure is deemed converged; and "tol.loss",
#'   the absolute difference in consecutive loss updates below which the procedure
#'   is deemed converged.
#'
#' @returns A numeric vector. The estimated coefficients.
#'
#' @include ctnsSolution.R helpful_tests.R linearPred.R loss.R probPred.R
#' @import glmnet
#' 
#' @keywords internal
.rwlsSolution <- function(data, X.app, Y.app, lambda1, iter.control) {
  
  stopifnot(
    "`data` must be a 'WTsth.data' object expanded with weight" = !missing(data) && 
      inherits(data, "WTsmth.data") && !is.null(data$XZ),
    "`lambda1 must be a scalar numeric" =
      !missing(lambda1) && .isNumericVector(lambda1, 1L),
    "`iter.control` must be a list; allowed elements are max.iter, tol.beta, and tol.loss" = 
      .isNamedList(iter.control, c("max.iter", "tol.beta", "tol.loss"))
  )
  
  
  ##initial beta values, default type.measure="deviance"
  fit_yb_init <- tryCatch(glmnet::glmnet(x = data$XZ, y = data$Y, family = "binomial"),
                          error = function(e){
                            stop("glmnet encountered errors\n\t", call. = FALSE, e$message)
                          })
  #initial coef at candidate lambda 1 values
  #iteration 0_initial
  iter <- 0L
  beta_cur <- glmnet::coef.glmnet(fit_yb_init, s = 2^lambda1) |> drop()
  
  #iteratively update the coef
  
  #stop criteria
  #   1.coef change: set initial change of beta values as max(abs(beta))
  maxdif <- max(abs(beta_cur))
  
  #2. loss change: Negtive log likelihood,set initial as current loss (in original data set)
  loss <- .loss(X = data$XZ, Y = drop(data$Y), beta = beta_cur, family = "binomial")
  loss_dif <- loss
  
  while ((iter < iter.control$max.iter) && 
         (maxdif > iter.control$tol.beta) && 
         (loss_dif > iter.control$tol.loss)) {
    
    # prepare linear prediction and probability prediction for iteration update
    linear_pred <- .linearPred(data$XZ, beta_cur)
    prob_pred <- .probPred(linear_pred)
    
    # update u*, v*
    iter <- iter + 1L
    u <- linear_pred + {data$Y - prob_pred} / {prob_pred * {1.0 - prob_pred}}
    
    sqrt_v <- sqrt(prob_pred * {1.0 - prob_pred})
    
    
    #update all predictors include the intercept
    data$XZ_update <- cbind(1.0, data$XZ) * sqrt_v
    #augmentation in .ctnssolution
    #X_aug <- rbind(X_update, X.app)
    
     ## prepare to update Y_update
    data$Y_update <- u * sqrt_v |> data.matrix()
    rownames(data$Y_update) = rownames(data$XZ_update)
    
    data$XZ_update <- data$XZ_update[!is.na(data$Y_update), ]
    data$Y_update <- data$Y_update[!is.na(data$Y_update), ]
    #augmentation in .ctnssolution
    #Y_aug <- c(y_update, Y.app)
    
    beta_next <- .ctnsSolution(data = data, X.app = X.app, Y.app = Y.app, lambda1 = lambda1)
    
    # stop criteria  beta
    beta_dif <- abs(beta_next - beta_cur)
    maxdif <- max(abs(beta_dif))
    
    # stop criteria loss function
    loss_next <- .loss(X = data$XZ, Y = drop(data$Y), beta = beta_next, family = "binomial")
    loss_dif <- abs(loss - loss_next)
    
    if (is.infinite(loss)) {
      loss_dif <- 0.0
      iter <- iter - 1L
      break
    }
    
    loss <- loss_next
    #prepare for next iteration
    beta_cur <- beta_next
  }
  if (iter == iter.control$max.iter) {
    warning("maximum iterations reached", call. = FALSE)
  }
  beta_cur
}

Try the CNVreg package in your browser

Any scripts or data that you put into this service are public.

CNVreg documentation built on April 4, 2025, 12:41 a.m.