R/compute_baseline.R

Defines functions compute_baseline_for_sample_size compute_baseline

Documented in compute_baseline compute_baseline_for_sample_size

#' Compute baseline processes.
#'
#' Compute parameters to build baseline processes.
#'
#' @param alpha ARL parameter in (0,1)
#' @param delta_lower Lower bound of target Delta. It must be positive and smaller than or equal to \code{delta_upper}.
#' @param delta_upper Upper bound of target Delta. It must be positive and larger than or equal to \code{delta_lower}.
#' @param psi_fn_list A list of R functions that computes psi and psi_star functions. Can be generated by \code{generate_sub_G_fn()} or counterparts for sub_B and sub_E.
#' @param v_min A lower bound of v function in the baseline process. Default is \code{1}.
#' @param k_max Positive integer to determine the maximum number of baselines. Default is \code{200}.
#' @param tol Tolerance of root-finding, positive numeric. Default is 1e-10.
#'
#' @return A list of 1. Parameters of baseline processes, 2. Mixing weights, 3. Auxiliary values for computation.
#' @export
#'
compute_baseline <- function(alpha,
                             delta_lower,
                             delta_upper,
                             psi_fn_list = generate_sub_G_fn(),
                             v_min = 1,
                             k_max = 200,
                             tol = 1e-10) {
  # Type checks
  if (!(alpha > 0 |
        alpha < 1))
    stop("alpha must be a number in (0,1).")
  if (!(delta_lower > 0 & delta_upper >=  delta_lower)) {
    stop("delta_lower and delta_upper must be positive with delta_lower <= delta_upper.")
  }
  
  psi_star <- psi_fn_list$psi_star
  psi_star_div <- psi_fn_list$psi_star_div
  psi_star_inv <- psi_fn_list$psi_star_inv
  
  if (abs(psi_star(0)) > 1e-8)
    stop("psi_star must be zero at x = 0.")
  if (abs(psi_star_div(0)) > 1e-8)
    stop("psi_star_div must be zero at x = 0.")
  if (!(v_min >= 0))
    stop("v_min must be non-negative.")
  k_max_raw <- k_max
  k_max <- as.integer(k_max_raw)
  if (k_max != k_max_raw)
    warning("k_max is coverted to an integer.")
  if (k_max < 1)
    stop("k_max must be larger than or equal to 1.")
  
  # Compute constants
  log_one_over_alpha <- log(1 / alpha)
  d_l <- psi_star(delta_lower)
  d_u <- psi_star(delta_upper)
  ratio <- d_u / d_l
  lambda_l <- psi_star_div(delta_lower)
  lambda_u <- psi_star_div(delta_upper)
  
  # If delta_lower is small enough or equal to delta_upper
  # Return trivial single baseline
  if (log_one_over_alpha <= v_min * d_l ||
      delta_lower == delta_upper) {
    baseline_list <- list(
      alpha = alpha,
      delta_lower = delta_lower,
      delta_upper = delta_upper,
      lambda = lambda_l,
      omega = 1,
      g_alpha = log_one_over_alpha,
      k_alpha = 0,
      eta_alpha = 1,
      w = alpha,
      psi_fn_list = psi_fn_list
    )
    return(baseline_list)
  }
  
  # Compute the threshold g_alpha
  log_f <- function(g) {
    k_vec <- 1:k_max
    log_f_val <-
      sapply(k_vec, function(k)
        log(k) - g * ratio ^ (-1 / k))
    return(min(log_f_val))
  }
  
  log_f_with_exp <- function(g) {
    exp_vec <- c(-g, log_f(g))
    return(logSumExpTrick(exp_vec))
  }
  
  log_f_u <- log_f(v_min * d_u)
  if (log_f_u <= -log_one_over_alpha) {
    root_out <-
      stats::uniroot(function(g) {
        log_f(g) + log_one_over_alpha
      },
      c(log_one_over_alpha, v_min * d_u), tol = tol)
  } else {
    root_out <-
      stats::uniroot(function(g) {
        log_f_with_exp(g) + log_one_over_alpha
      },
      c(v_min * d_u, ratio * log(2 / alpha)), tol = tol)
  }
  g_alpha <- root_out$root
  
  # Compute the number of non-trival baselines k_alpha
  k_vec <- 1:k_max
  log_f_val <-
    sapply(k_vec, function(k)
      log(k) - g_alpha * ratio ^ (-1 / k))
  k_alpha <- which.min(log_f_val)
  
  # Compute the spacing parameter eta_alpha
  eta_alpha <- ratio ^ (1 / k_alpha)
  
  # Compute lambdas and mixing weights
  if (g_alpha > v_min * d_u) {
    # Compute lambdas
    lambda_vec <- numeric(k_alpha + 1)
    lambda_vec[1] <- lambda_u
    lambda_vec[k_alpha +  1] <- lambda_l
    if (k_alpha >= 2) {
      k_candidate <- seq(1, k_alpha - 1)
      delta_vec <-
        sapply(k_candidate, function(k)
          psi_star_inv(d_u / eta_alpha ^ k))
      lambda_vec[-c(1, k_alpha + 1)] <-
        sapply(delta_vec, psi_star_div)
    }
    # Compute weights
    omega_vec <-
      c(exp(-g_alpha), rep(exp(-g_alpha / eta_alpha), k_alpha))
  } else {
    # Compute lambdas
    lambda_vec <- numeric(k_alpha)
    lambda_vec[k_alpha] <- lambda_l
    if (k_alpha >= 2) {
      k_candidate <- seq(1, k_alpha - 1)
      delta_vec <-
        sapply(k_candidate, function(k)
          psi_star_inv(d_u / eta_alpha ^ k))
      lambda_vec[-k_alpha] <- sapply(delta_vec, psi_star_div)
    }
    # Compute weights
    omega_vec <- rep(exp(-g_alpha / eta_alpha), k_alpha)
  }
  
  # Normalize weights
  w <- sum(omega_vec)
  omega_normal_vec <- omega_vec / w
  
  # Collect all computed parameters
  baseline_list <- list(
    alpha = alpha,
    delta_lower = delta_lower,
    delta_upper = delta_upper,
    lambda = lambda_vec,
    omega = omega_normal_vec,
    g_alpha = g_alpha,
    k_alpha = k_alpha,
    eta_alpha = eta_alpha,
    w = w,
    psi_fn_list = psi_fn_list
  )
  return(baseline_list)
}

#' Compute baseline parameters given target variance process bounds.
#'
#' Given target variance process bounds for confidence sequences, compute baseline parameters.
#'
#' @param v_upper Upper bound of the target variance process bound
#' @param v_lower Lower bound of the target variance process bound.
#' @param skip_g_alpha If true, we do not compute g_alpha and use log(1/alpha) instead.
#' @inheritParams compute_baseline
#'
#' @return A list of 1. Parameters of baseline processes, 2. Mixing weights, 3. Auxiliary values for computation.
#' @export
#'
compute_baseline_for_sample_size <- function(alpha,
                                             v_upper,
                                             v_lower,
                                             psi_fn_list = generate_sub_G_fn(),
                                             skip_g_alpha = TRUE,
                                             v_min = 1,
                                             k_max = 200,
                                             tol = 1e-10) {
  if (!(v_lower > 0 && v_upper >=  v_lower)) {
    stop("v_lower and v_upper must be positive with v_lower <= v_upper.")
  }
  
  if (v_lower < v_min) {
    warning("v_lower is lower than v_min. v_min will be used intead of v_lower.")
    v_lower <- v_min
  }
  
  if (skip_g_alpha) {
    g_alpha <- log(1 / alpha)
  } else if (v_lower == v_upper) {
    # Trivial case
    g_alpha <- log(1 / alpha)
  } else {
    delta_init_upper <-
      psi_fn_list$psi_star_inv(log(1 / alpha)  / v_lower)
    delta_init_lower <-
      psi_fn_list$psi_star_inv(log(1 / alpha)  / v_upper)
    baseline_init <- compute_baseline(alpha,
                                      delta_init_lower,
                                      delta_init_upper,
                                      psi_fn_list,
                                      v_min,
                                      k_max)
    g_alpha <- baseline_init$g_alpha
  }
  
  # Compute delta bound
  delta_lower <-
    psi_fn_list$psi_star_inv(g_alpha / v_upper)
  delta_upper <-
    psi_fn_list$psi_star_inv(g_alpha  / v_lower)
  
  # Compute baseline parameters
  baseline_param <- compute_baseline(alpha,
                                     delta_lower,
                                     delta_upper,
                                     psi_fn_list,
                                     v_min,
                                     k_max)
  return(baseline_param)
}

Try the stcpR6 package in your browser

Any scripts or data that you put into this service are public.

stcpR6 documentation built on Oct. 8, 2024, 9:07 a.m.