#' Fit a survival BART using probit regression on discretized time
#' @param times Times at which censoring or exit event occurs
#' @param delta Event-is-observed indicator
#' @param x_train Design matrix (excludes times)
#' @param mc.cores Number of cores to use
#' @param ... Other parameters to pass to \code{\link{mc.gbart}}
#' @return pbart object
#' @export
deconstructed_surv_bart <- function(times, delta, x_train, mc.cores = 4, ...) {
pre <- surv.pre.bart(times = times,
delta = delta,
x.train = x_train)
res <- mc.gbart(y.train = pre$y.train,
x.train = pre$tx.train,
type = "pbart",
mc.cores = mc.cores,
...)
return(res)
}
#' Wrapper for \code{\link{deconstructed_surv_bart}} to run a survival BART
#' but leave in probit form for easier event time prediction
#'
#' @param input_list Named list of inputs (see \code{\link{deconstructed_surv_bart}})
#' @param mc.cores Number of cores on which to run
#' @param ... Parameters to pass to \code{\link{bart::mc.gbart}}
#' @return Probit BART result
#' @export
surv_bart_from_input_list <- function(input_list, mc.cores = 4, ...) {
stopifnot(c("times", "delta", "x_train") %in% names(input_list))
deconstructed_surv_bart(times = input_list[["times"]],
delta = input_list[["delta"]],
x_train = input_list[["x_train"]],
mc.cores = mc.cores,
...)
}
#' Use binary conditional probability coin flip results to assign an actual
#' event time to each person (or \code{Inf} if censored)
#'
#' @param id Scalar ID of person coin flip corresponds to
#' @param id_rows Length-R label for which person the coin flips correspond to,
#' where R is the number of unique person x risk interval combinations
#' @param coin_flips B x R matrix of TRUE/FALSE generated by the predicted
#' (conditional) discrete hazard of that person having an event in that interval
#' @param time_labels Length-R vector of times which each flip column
#' corresponds to
#' @return Length-B vector of event times for person with ID \code{id},
#' as determined by the first interval during which a person's flip returns a 1
#' (i.e., an event) according to the \code{b}th ensemble
#' @export
assign_time <- function(id, id_rows, coin_flips, time_labels) {
time_seq <- sort(unique(time_labels))
this_id <- which(id_rows == id)
zero_one_string <- coin_flips[this_id]
this_id_time_seq <- time_labels[this_id]
if (sum(zero_one_string) == 0) {
event_time <- Inf
} else {
event_time <- time_seq[which.max(zero_one_string)]
}
return(event_time)
}
#' Use a BART probit fit of a survival model to impute event times
#'
#' @param pbart_fit BART probit fit
#' @param time_seq Sequence of times at which the survival data was split
#' @param x_new Design matrix for data to be predicted (should not contain
#' times except as an added baseline covariate)
#' @param min_times Time at which a person was last known to have survived (only
#' used to force congeniality with observed data)
#' @param mc.cores Number of cores to use for prediction
#' @return N x B matrix of imputed event times (or \code{Inf}, if censored)
#' @export
predict_times_from_pbart <- function(pbart_fit, time_seq, x_new,
min_times = rep(0, NROW(x_new)),
mc.cores = 4) {
Nt <- length(time_seq)
B <- pbart_fit$ndpost
tx_test <- surv.pre.bart(times = time_seq,
delta = rep(1, Nt),
x.train = x_new[rep(1, Nt), ], # x.train not used
x.test = x_new)$tx.test
unique_ids <- 1:NROW(x_new)
id_rows <- rep(unique_ids, each = Nt)
to_keep <- which(tx_test[ , "t"] > rep(min_times, each = Nt))
tx_test <- tx_test[to_keep, ]
id_rows <- id_rows[to_keep]
pred <- predict(object = pbart_fit, newdata = tx_test, mc.cores = mc.cores)
coin_flips <- apply(pred$prob.test,
MARGIN = c(1, 2),
FUN = rbinom,
size = 1, n = 1)
event_times <- matrix(NA,
nrow = length(unique_ids),
ncol = B)
for (b in 1:B) {
event_times[, b] <- sapply(unique_ids,
FUN = assign_time,
id_rows = id_rows,
coin_flips = coin_flips[b, ],
time_labels = tx_test[ , "t"])
}
return(event_times)
}
#' Use a BART probit fit of a survival model to impute event times
#'
#' @param pbart_fit BART probit fit
#' @param time_seq Sequence of times at which the survival data was split
#' @param x_new Design matrix for data to be predicted (should not contain
#' times except as an added baseline covariate)
#' @param min_times Time at which a person was last known to have survived (only
#' used to force congeniality with observed data)
#' @return N x B matrix of imputed event times (or \code{Inf}, if censored)
#' @export
predict_times_from_pbart_b <- function(b, pbart_fit,
time_seq, x_new = NULL,
tx_new = NULL,
min_times) {
Nt <- length(time_seq)
N <- NROW(x_new)
B <- pbart_fit$ndpost
stopifnot(1 <= b, b <= B)
if (is.null(tx_new)) {
tx_new <- surv.pre.bart(times = time_seq,
delta = rep(1, Nt),
x.train = x_new[rep(1, Nt), ], # x.train not used
x.test = x_new)$tx.test
}
unique_ids <- 1:N
id_rows <- rep(unique_ids, each = Nt)
to_keep <- which(tx_new[ , "t"] > rep(min_times, each = Nt))
# browser()
tx_new <- tx_new[to_keep, ]
id_rows <- id_rows[to_keep]
N_stochastic <- length(unique(id_rows))
N_determined <- N - N_stochastic
which_stoch <- which(unique_ids %in% unique(id_rows))
which_deter <- which(!(unique_ids %in% unique(id_rows)))
pred <- predict_prob_b(b = b, object = pbart_fit, newdata = tx_new)
coin_flips <- rbinom(n = length(to_keep), size = 1, pred)
event_times <- sapply(which_stoch,
FUN = assign_time,
id_rows = id_rows,
coin_flips = coin_flips,
time_labels = tx_new[ , "t"])
res <- rep(NA, N)
if (length(which_stoch) != length(event_times)) {
browser()
}
res[which_stoch] <- event_times
res[which_deter] <- Inf
return(res)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.