R/conf_int_funcs.R

Defines functions compute_CI TNSurv

Documented in compute_CI TNSurv

#' Survival function of truncated normal distribution.
#'
#' @keywords internal
#' @details
#' Let \eqn{X} be a normal random variable with mean \code{mu} and standard deviation \code{sig*nu_norm}.
#' This function returns \eqn{P(X \ge v^{T}y | X \in \code{truncation)}}, where \code{truncation} is a subset of the real line.
#' Log-sum-exp operations are used to avoid underflows in the upper tail probability of the truncated normal distribution.
#'
#' Input:
#' @param truncation A data frame of truncation with 3 columns: min_mean, max_mean, and contained.
#'  Each row corresponds to a truncation interval with lower and upper limits specified by min_mean and max_mean,
#'  respectively. In addition, contained is binary-valued and indicates whether this interval is included
#'  in the final truncation set.
#' @param vTy Numeric; the value to evaluate the survival function on
#' @param nu_norm Numeric; part 1 of standard deviation
#' @param sig Numeric; part 2 of standard deviation
#' @param mu Numeric; mean of untruncated distribution, default to 0
#' @examples
#' truncation <- data.frame(matrix(c(-2,1,0,2),byrow = T,ncol=2))
#' colnames(truncation) <- c("min_mean", "max_mean")
#' truncation$contained <- 1
#' TNSurv(truncation,vTy=0, nu_norm=1, sig=1)
#' @export
TNSurv <- function(truncation, vTy, nu_norm, sig, mu = 0){
  n_intervals <- dim(truncation)[[1]]
  n1 = -Inf;
  d1 = -Inf;
  for (i in c(1:n_intervals)) {
    cur_interval <- truncation[i, ]
    if (cur_interval$contained == 1) {
      a = pnorm((cur_interval$max_mean - mu) / sqrt(nu_norm * sig), log.p = TRUE);
      b = pnorm((cur_interval$min_mean - mu) / sqrt(nu_norm * sig), log.p = TRUE);
      arg2 = log_subtract(a, b);
      d1 = log_sum_exp(d1, arg2);
      # one-sided p-value
        # two-sided p-values
      if (cur_interval$max_mean >= (vTy)) {
        arg2 = log_subtract(pnorm((cur_interval$max_mean - mu) / sqrt(nu_norm * sig), log.p = TRUE ),
                            pnorm((max(cur_interval$min_mean, vTy) - mu) / sqrt(nu_norm * sig), log.p = TRUE));
        n1 = log_sum_exp(n1, arg2);
      }
    }
  }
  if(is.nan(exp(n1 - d1))){
    p = 0
  }else{
    p = exp(n1 - d1)
  }
  return (p);
}

# ----- computing the confidence intervals -----

# return selective ci for v^T mu (i.e. a single parameter)
#' compute selective confidence intervals
#'
#' This function computes an equal-tailed (1-alpha) selective confidence intervals.
#'
#' @keywords internal
#'
#' @param truncation, the truncation set (object: intervals)
#' Computes a confidence interval for the mean of a truncated normal distribution.
#' @param vTv v: the contrast vector that defines the parameter of interest
#' @param vTy y: the observed response vector
#' @param sigma The known noise standard deviation.
#' If unknown, we recommend a conservative estimate. If it
#' is left blank, we use the sample variance as a conservative estimate.
#' @param alpha, the significance level. Default to 0.05
#' @param steps_lim, the maximum steps bisection method will take to initialize the LCB and UCB, default
#' to 50.
#' @return This function returns a vector of lower and upper confidence limits.
#'
#' @export

compute_CI <- function(vTy, vTv, sigma, truncation, alpha=0.05, steps_lim=50) {
  ### Conservative guess
  scale <- sigma * sqrt(vTv)
  q <- sum(vTy) / scale

  # transform for calculating p value with`calc_p_value_safer`
  truncation <- data.frame(truncation)
  colnames(truncation) <- c("min_mean","max_mean")
  truncation$contained <- 1

  fun <- function(x) {
    # survival function
    return(TNSurv(truncation, vTy, vTv, sigma, mu = x))
  }

  # L: fun.L(L) = 0
  fun.L <- function(x) {
    return(fun(x) - (alpha/2))
  }
  # U: fun.U(U) = 0
  fun.U <- function(x) {
    return(fun(x) - (1-alpha/2))
  }

  # find the starting point (x1, x2) such that
  # fun.L(x1), fun.U(x1) <= 0 AND fun.L(x2), fun.U(x2) >= 0.
  # i.e. fun(x1) <= alpha/2 AND fun(x2) >= 1-alpha/2.

  # find x1 s.t. fun(x1) <= alpha/2
  # what we know:
  # fun is monotonically increasing;
  # fun(x) = NaN if x too small;
  # fun(x) > alpha/2 if x too big.
  # so we can do a modified bisection search to find x1.
  step <- 0
  x1.up <- q * scale + scale
  x1 <- q * scale - 3 * scale
  f1 <- fun(x1)
  while(step <= steps_lim) {
    if (is.na(f1)) { # x1 is too small
      x1 <- (x1 + x1.up) / 2
      f1 <- fun(x1)
    }
    else if (f1 > alpha/2) { # x1 is too big
      x1.up <- x1
      x1 <- x1 - 3 * 1.4^step
      f1 <- fun(x1)
      step <- step + 1
    }
    else { # fun(x1) <= alpha/2, excited!
      break
    }
  }

  # find x2 s.t. fun(x2) <= 1 - alpha/2
  # what we know:
  # fun is monotonically increasing;
  # fun(x) = NaN if x too big;
  # fun(x) < 1 - alpha/2 if x too small.
  # again can do a modified bisection search to find x2.
  step <- 0
  x2 = q * scale + 3 * scale
  x2.lo = q * scale - scale
  f2 = fun(x2)
  while(step <= steps_lim) {
    if (is.na(f2)) { # x2 is too big
      x2 <- (x2 + x2.lo) / 2
      f2 <- fun(x2)
    }
    else if (f2 < 1 - alpha/2) { # x2 is too small
      x2.lo <- x2
      x2 <- x2 + 3 * 1.4^step
      f2 <- fun(x2)
      step <- step + 1
    }
    else { # fun(x2) >= 1 - alpha/2, excited!
      break
    }
  }

  # if the above search does not work, set up a grid search
  # for starting points
  if (is.na(f1)||(f1 > alpha/2)||is.na(f2)||(f2 < 1-alpha/2)) {
    grid <- seq(from = q * scale - 1000*scale, to = q*scale + 1000*scale)
    value <- sapply(grid, fun)
    # want max x1: fun(x1) <= alpha/2
    ind1 <- rev(which(value <= alpha/2))[1]
    x1 <- grid[ind1]
    # want min x2: fun(x2) >= 1-alpha/2
    ind2 <- which(value >= 1 - alpha/2)[1]
    x2 <- grid[ind2]
  }

  # if the above fails, then either x1, x2 = NA, so uniroot() will throw error,
  # in which case we set (-Inf, Inf) as the CI

  # we know the functions are increasing

  L <- tryCatch({
    stats::uniroot(fun.L, c(x1, x2), extendInt = "upX", tol = 1e-4)$root
  }, error = function(e) {
    -Inf
  })


  U <- tryCatch({
    stats::uniroot(fun.U, c(x1, x2), extendInt = "upX", tol = 1e-4)$root
  }, error = function(e) {
    Inf
  })


  return(c(L, U))
}
yiqunchen/PGInference documentation built on March 20, 2022, 11:51 p.m.