R/utils.R

Defines functions automate_folds dist_ci get_covar get_time get_status format_digits sw bound01 bound prodlag foldsids get_fits

Documented in get_fits

#' Return Nuisance Parameter Model Fits
#'
#' @param metadata An object of class "Survival" generated by a call to \code{metadata()}.
#'
#' @seealso \code{\link{survrct}} for creating \code{metadata}.
#'
#' @return A list containing the following components:
#'
#' \item{Hazard}{The fit for the outcome model.}
#' \item{Censoring}{The fit for the censoring model.}
#' \item{Treatment}{The fit for the treatment model.}
#'
#' @export
#'
#' @examples
#' surv <- survrct(Surv(days, event) ~ A + age + sex + dyspnea + bmi,
#'                 A ~ 1, data = c19.tte)
#' get_fits(surv)
get_fits <- function(metadata) {
  list(Hazard = metadata$nuisance$hzrd_fit,
       Censoring = metadata$nuisance$cens_fit,
       Treatment = metadata$nuisance$trt_fit)
}

foldsids <- function(n, id, V) {
  out <- rep(0, n)
  folds <- origami::make_folds(n, cluster_ids = id, V = V)
  for (i in 1:length(folds)) {
    out[folds[[i]]$validation_set] <- i
  }
  out
}

prodlag <- function(x) {
  cumprod(c(1, x[-length(x)]))
}

bound <- function(x, lower = 0.001){
  x[x < lower] <- lower
  return(as.numeric(x))
}

bound01 <- function(x, bound = 1e-10){
  x[x < bound] <- bound
  x[x > 1 - bound] <- 1 - bound
  return(as.numeric(x))
}

sw <- function(x) {
  suppressWarnings(x)
}

format_digits <- function(x, n) {
  format(round(as.double(x), digits = n), nsmall = n, trim = TRUE)
}

get_status <- function(formula) {
  all.vars(formula[[2]])[[2]]
}

get_time <- function(formula) {
  all.vars(formula[[2]])[[1]]
}

get_covar <- function(formula, target) {
  setdiff(all.vars(formula[[3]]), target)
}

dist_ci <- function(estim, std.error, cv = qnorm(1 - 0.05 / 2)) {
  out <- matrix(ncol = 2, nrow = length(std.error))
  for (i in 1:length(std.error)) {
    out[i, ] <- estim[i] + std.error[i]*c(-cv, cv)
  }
  pmin(pmax(out, 0), 1)
}

automate_folds <- function(n) {
  if (n <= 100) {
    return(20)
  }
  if (n > 100 && n <= 500) {
    return(10)
  }
  if (n > 500 && n <= 1000) {
    return(5)
  }
  return(3)
}
nt-williams/survrct documentation built on July 29, 2021, 7:46 a.m.