R/sl_fit_survival.R

Defines functions initial_sl_fit

Documented in initial_sl_fit

# hack away the tidyr syntax for R CHECK
utils::globalVariables(c("Q1Haz", "G_dC"))


#' super learner fit for failure and censoring event
#'
#' using survtmle package
#'
#' @param T_tilde vector of last follow up time
#' @param Delta vector of censoring indicator
#' @param A vector of treatment
#' @param W data.frame of baseline covariates
#' @param t_max the maximum time to estimate the survival probabilities
#' @param sl_failure SuperLearner library for failure event hazard
#' @param sl_censoring SuperLearner library for censoring event hazard
#' @param sl_treatment SuperLearner library for propensity score
#' @param gtol treshold for the fitted propensity scores
#'
#' @importFrom SuperLearner SuperLearner
#' @importFrom survtmle estimateTreatment makeDataList estimateCensoring estimateHazards
#' @export
initial_sl_fit <- function(
                           T_tilde,
                           Delta,
                           A,
                           W,
                           t_max,
                           sl_failure = c("SL.glm"),
                           sl_censoring = c("SL.glm"),
                           sl_treatment = c("SL.glm"),
                           gtol = 1e-3
                           ) {
  # convert dictionary of variable names
  ftime <- T_tilde
  ftype <- Delta
  trt <- A
  adjustVars <- W
  t_0 <- t_max
  trtOfInterest <- 0:1
  SL.ftime <- sl_failure
  SL.ctime <- sl_censoring
  SL.trt <- sl_treatment

  adjustVars <- data.frame(adjustVars)
  ftypeOfInterest <- unique(ftype)
  n <- length(ftime)
  id <- seq_len(n)
  dat <- data.frame(id = id, ftime = ftime, ftype = ftype, trt = trt)
  if (!is.null(adjustVars)) dat <- cbind(dat, adjustVars)

  nJ <- length(ftypeOfInterest)
  allJ <- sort(unique(ftype[ftype != 0]))
  ofInterestJ <- sort(ftypeOfInterest)

  # calculate number of groups
  ntrt <- length(trtOfInterest)
  uniqtrt <- sort(trtOfInterest)

  # estimate trt probabilities
  trtOut <- survtmle::estimateTreatment(
    dat = dat,
    ntrt = ntrt,
    uniqtrt = uniqtrt,
    adjustVars = adjustVars,
    SL.trt = SL.trt,
    returnModels = TRUE,
    gtol = gtol
  )
  dat <- trtOut$dat
  trtMod <- trtOut$trtMod

  # make long version of data sets needed for estimation and prediction
  dataList <- survtmle::makeDataList(
    dat = dat, J = allJ, ntrt = ntrt, uniqtrt = uniqtrt, t0 = t_0, bounds = NULL
  )
  # estimate censoring
  # when there is almost no censoring, the classification will fail;
  # we manually input the conditional survival for the censoring
  censOut <- tryCatch({
    survtmle::estimateCensoring(
      dataList = dataList,
      ntrt = ntrt,
      uniqtrt = uniqtrt,
      t0 = t_0,
      verbose = FALSE,
      adjustVars = adjustVars,
      SL.ctime = SL.ctime,
      glm.family = "binomial",
      returnModels = TRUE,
      gtol = gtol
    )
  },
  error = function(cond) {
    message("censoring sl error")
    NULL
  }
  )
  if (is.null(censOut)) {
    censOut <- list()
    censOut$dataList <- dataList
    censOut$dataList$obs[, "G_dC"] <- 1
    censOut$dataList$'0'[, "G_dC"] <- 1
    censOut$dataList$'1'[, "G_dC"] <- 1
    is_sl_censoring_converge <- FALSE
    dataList <- censOut$dataList
  } else {
    dataList <- censOut$dataList
    ctimeMod <- censOut$ctimeMod
    is_sl_censoring_converge <- TRUE
  }

  # estimate cause specific hazards
  estOut <- survtmle::estimateHazards(
    dataList = dataList,
    J = allJ,
    verbose = FALSE,
    bounds = NULL,
    adjustVars = adjustVars,
    SL.ftime = SL.ftime,
    glm.family = "binomial",
    returnModels = TRUE
  )
  dataList <- estOut$dataList
  ftimeMod <- estOut$ftimeMod
  # check for convergence
  suppressWarnings(
    if (all(dataList[[1]] == "convergence failure")) {
      return("estimation convergence failure")
    }
  )

  # extract g
  g_1 <- dat$g_1
  g_0 <- dat$g_0

  # extract hazard
  d1 <- dataList$`1`
  d0 <- dataList$`0`

  haz1 <- d1[, c("id", "t", "Q1Haz")]
  haz1 <- tidyr::spread(haz1, t, Q1Haz)
  haz1$id <- NULL # remove the id column

  haz0 <- d0[, c("id", "t", "Q1Haz")]
  haz0 <- tidyr::spread(haz0, t, Q1Haz)
  haz0$id <- NULL # remove the id column

  # extract S_{Ac}
  S_Ac_1 <- d1[, c("id", "t", "G_dC")]
  S_Ac_1 <- tidyr::spread(S_Ac_1, t, G_dC)
  S_Ac_1 <- S_Ac_1[, -1] # remove the id column

  S_Ac_0 <- d0[, c("id", "t", "G_dC")]
  S_Ac_0 <- tidyr::spread(S_Ac_0, t, G_dC)
  S_Ac_0 <- S_Ac_0[, -1] # remove the id column

  density_failure_1 <- survival_curve$new(
    t = seq(range(ftime)[1], range(ftime)[2]), hazard = haz1
  )
  density_failure_0 <- survival_curve$new(
    t = seq(range(ftime)[1], range(ftime)[2]), hazard = haz0
  )
  density_censor_1 <- survival_curve$new(
    t = seq(range(ftime)[1], range(ftime)[2]), survival = S_Ac_1
  )
  density_censor_0 <- survival_curve$new(
    t = seq(range(ftime)[1], range(ftime)[2]), survival = S_Ac_0
  )
  return(list(
    density_failure_1 = density_failure_1,
    density_failure_0 = density_failure_0,
    density_censor_1 = density_censor_1,
    density_censor_0 = density_censor_0,
    g1W = g_1[, 1]
  ))
}
wilsoncai1992/MOSS documentation built on June 1, 2020, 2:26 p.m.