R/surv_bart.R

Defines functions deconstructed_surv_bart surv_bart_from_input_list assign_time predict_times_from_pbart predict_times_from_pbart_b

#' Fit a survival BART using probit regression on discretized time
#' @param times Times at which censoring or exit event occurs
#' @param delta Event-is-observed indicator
#' @param x_train Design matrix (excludes times)
#' @param mc.cores Number of cores to use
#' @param ... Other parameters to pass to \code{\link{mc.gbart}}
#' @return pbart object
#' @export
deconstructed_surv_bart <- function(times, delta, x_train, mc.cores = 4, ...) {
  pre <- surv.pre.bart(times   = times,
                       delta   = delta,
                       x.train = x_train)

  res <- mc.gbart(y.train  = pre$y.train,
                  x.train  = pre$tx.train,
                  type     = "pbart",
                  mc.cores = mc.cores,
                  ...)

  return(res)
}



#' Wrapper for \code{\link{deconstructed_surv_bart}} to run a survival BART
#' but leave in probit form for easier event time prediction
#'
#' @param input_list Named list of inputs (see \code{\link{deconstructed_surv_bart}})
#' @param mc.cores Number of cores on which to run
#' @param ... Parameters to pass to \code{\link{bart::mc.gbart}}
#' @return Probit BART result
#' @export
surv_bart_from_input_list <- function(input_list, mc.cores = 4, ...) {
  stopifnot(c("times", "delta", "x_train") %in% names(input_list))
  deconstructed_surv_bart(times = input_list[["times"]],
                          delta = input_list[["delta"]],
                          x_train = input_list[["x_train"]],
                          mc.cores = mc.cores,
                          ...)
}



#' Use binary conditional probability coin flip results to assign an actual
#' event time to each person (or \code{Inf} if censored)
#'
#' @param id Scalar ID of person coin flip corresponds to
#' @param id_rows Length-R label for which person the coin flips correspond to,
#' where R is the number of unique person x risk interval combinations
#' @param coin_flips B x R matrix of TRUE/FALSE generated by the predicted
#' (conditional) discrete hazard of that person having an event in that interval
#' @param time_labels Length-R vector of times which each flip column
#' corresponds to
#' @return Length-B vector of event times for person with ID \code{id},
#' as determined by the first interval during which a person's flip returns a 1
#' (i.e., an event) according to the \code{b}th ensemble
#' @export
assign_time <- function(id, id_rows, coin_flips, time_labels) {
  time_seq <- sort(unique(time_labels))
  this_id <- which(id_rows == id)
  zero_one_string <- coin_flips[this_id]
  this_id_time_seq <- time_labels[this_id]
  if (sum(zero_one_string) == 0) {
    event_time <- Inf
  } else {
    event_time <- time_seq[which.max(zero_one_string)]
  }
  return(event_time)
}



#' Use a BART probit fit of a survival model to impute event times
#'
#' @param pbart_fit BART probit fit
#' @param time_seq Sequence of times at which the survival data was split
#' @param x_new Design matrix for data to be predicted (should not contain
#' times except as an added baseline covariate)
#' @param min_times Time at which a person was last known to have survived (only
#' used to force congeniality with observed data)
#' @param mc.cores Number of cores to use for prediction
#' @return N x B matrix of imputed event times (or \code{Inf}, if censored)
#' @export
predict_times_from_pbart <- function(pbart_fit, time_seq, x_new,
                                     min_times = rep(0, NROW(x_new)),
                                     mc.cores = 4) {
  Nt <- length(time_seq)
  B <- pbart_fit$ndpost
  tx_test <- surv.pre.bart(times = time_seq,
                           delta = rep(1, Nt),
                           x.train = x_new[rep(1, Nt), ], # x.train not used
                           x.test  = x_new)$tx.test
  unique_ids <- 1:NROW(x_new)
  id_rows <- rep(unique_ids, each = Nt)
  to_keep <- which(tx_test[ , "t"] > rep(min_times, each = Nt))
  tx_test <- tx_test[to_keep, ]
  id_rows <- id_rows[to_keep]

  pred <- predict(object = pbart_fit, newdata = tx_test, mc.cores = mc.cores)
  coin_flips <- apply(pred$prob.test,
                      MARGIN = c(1, 2),
                      FUN = rbinom,
                      size = 1, n = 1)
  event_times <- matrix(NA,
                        nrow = length(unique_ids),
                        ncol = B)
  for (b in 1:B) {
    event_times[, b] <- sapply(unique_ids,
                               FUN = assign_time,
                               id_rows = id_rows,
                               coin_flips = coin_flips[b, ],
                               time_labels = tx_test[ , "t"])
  }
  return(event_times)
}


#' Use a BART probit fit of a survival model to impute event times
#'
#' @param pbart_fit BART probit fit
#' @param time_seq Sequence of times at which the survival data was split
#' @param x_new Design matrix for data to be predicted (should not contain
#' times except as an added baseline covariate)
#' @param min_times Time at which a person was last known to have survived (only
#' used to force congeniality with observed data)
#' @return N x B matrix of imputed event times (or \code{Inf}, if censored)
#' @export
predict_times_from_pbart_b <- function(b, pbart_fit,
                                       time_seq, x_new = NULL,
                                       tx_new = NULL,
                                       min_times) {
  Nt <- length(time_seq)
  N <- NROW(x_new)
  B <- pbart_fit$ndpost
  stopifnot(1 <= b, b <= B)
  if (is.null(tx_new)) {
    tx_new <- surv.pre.bart(times = time_seq,
                            delta = rep(1, Nt),
                            x.train = x_new[rep(1, Nt), ], # x.train not used
                            x.test  = x_new)$tx.test
  }

  unique_ids <- 1:N
  id_rows <- rep(unique_ids, each = Nt)
  to_keep <- which(tx_new[ , "t"] > rep(min_times, each = Nt))
  # browser()
  tx_new  <- tx_new[to_keep, ]
  id_rows <- id_rows[to_keep]
  N_stochastic <- length(unique(id_rows))
  N_determined <- N - N_stochastic
  which_stoch <- which(unique_ids %in% unique(id_rows))
  which_deter <- which(!(unique_ids %in% unique(id_rows)))

  pred <- predict_prob_b(b = b, object = pbart_fit, newdata = tx_new)
  coin_flips <- rbinom(n = length(to_keep), size = 1, pred)
  event_times <- sapply(which_stoch,
                        FUN = assign_time,
                        id_rows = id_rows,
                        coin_flips = coin_flips,
                        time_labels = tx_new[ , "t"])
  res <- rep(NA, N)
  if (length(which_stoch) != length(event_times)) {
    browser()
  }
  res[which_stoch] <- event_times
  res[which_deter] <- Inf
  return(res)
}
lcomm/mstatebart documentation built on May 7, 2019, 8:22 a.m.