R/predict.R

Defines functions calculate_predictions calculate_survival calculate_cum_inc check_newdata predict.TE_msm

Documented in calculate_predictions check_newdata predict.TE_msm

#' Predict marginal cumulative incidences with confidence intervals for a target trial population
#'
#' This function predicts the marginal cumulative incidences when a target trial population receives either the
#' treatment or non-treatment at baseline (for an intention-to-treat analysis) or either sustained treatment or
#' sustained non-treatment (for a per-protocol analysis). The difference between these cumulative incidences is the
#' estimated causal effect of treatment. Currently, the `predict` function only provides marginal intention-to-treat and
#' per-protocol effects, therefore it is only valid when `estimand_type = "ITT"` or `estimand_type = "PP"`.
#'
#' @param object Object from [trial_msm()] or [initiators()].
#' @param newdata Baseline trial data that characterise the target trial population that marginal cumulative incidences
#'   or survival probabilities are predicted for.  `newdata` must have the same columns and formats of variables as in
#'   the fitted marginal structural model specified in [trial_msm()] or [initiators()]. If `newdata` contains rows with
#'   `followup_time > 0` these will be removed.
#' @param type Specify cumulative incidences or survival probabilities to be predicted. Either cumulative incidence
#'   (`"cum_inc"`) or survival probability (`"survival"`).
#' @param predict_times Specify the follow-up visits/times where the marginal cumulative incidences or survival
#'   probabilities are predicted.
#' @param conf_int Construct the point-wise 95-percent confidence intervals of cumulative incidences for the target
#'   trial population under treatment and non-treatment and their differences by simulating the parameters in the
#'   marginal structural model from a multivariate normal distribution with the mean equal to the marginal structural
#'   model parameter estimates and the variance equal to the estimated robust covariance matrix.
#' @param samples Number of samples used to construct the simulation-based confidence intervals.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return A list of three data frames containing the cumulative incidences for each of the assigned treatment options
#'   (treatment and non-treatment) and the difference between them.
#' @export
#' @importFrom stats .checkMFClasses coef delete.response model.frame model.matrix terms setNames
#' @examples
#' # If necessary set the number of `data.table` threads
#' data.table::setDTthreads(2)
#'
#' data("te_model_ex")
#' predicted_ci <- predict(te_model_ex, predict_times = 0:30, samples = 10)
#'
#' # Plot the cumulative incidence curves under treatment and non-treatment
#' plot(predicted_ci[[1]]$followup_time, predicted_ci[[1]]$cum_inc,
#'   type = "l",
#'   xlab = "Follow-up Time", ylab = "Cumulative Incidence",
#'   ylim = c(0, 0.7)
#' )
#' lines(predicted_ci[[1]]$followup_time, predicted_ci[[1]]$`2.5%`, lty = 2)
#' lines(predicted_ci[[1]]$followup_time, predicted_ci[[1]]$`97.5%`, lty = 2)
#'
#' lines(predicted_ci[[2]]$followup_time, predicted_ci[[2]]$cum_inc, type = "l", col = 2)
#' lines(predicted_ci[[2]]$followup_time, predicted_ci[[2]]$`2.5%`, lty = 2, col = 2)
#' lines(predicted_ci[[2]]$followup_time, predicted_ci[[2]]$`97.5%`, lty = 2, col = 2)
#' legend("topleft", title = "Assigned Treatment", legend = c("0", "1"), col = 1:2, lty = 1)
#'
#' # Plot the difference in cumulative incidence over follow up
#' plot(predicted_ci[[3]]$followup_time, predicted_ci[[3]]$cum_inc_diff,
#'   type = "l",
#'   xlab = "Follow-up Time", ylab = "Difference in Cumulative Incidence",
#'   ylim = c(0.0, 0.5)
#' )
#' lines(predicted_ci[[3]]$followup_time, predicted_ci[[3]]$`2.5%`, lty = 2)
#' lines(predicted_ci[[3]]$followup_time, predicted_ci[[3]]$`97.5%`, lty = 2)
#'
predict.TE_msm <- function(object,
                           newdata,
                           predict_times,
                           conf_int = TRUE,
                           samples = 100,
                           type = c("cum_inc", "survival"),
                           ...) {
  if (object$args$estimand_type == "As-Treated") {
    warning("As-Treated estimands are not currently supported by this predict method. Results may be unexpected.")
  }

  assert_class(object$model, "glm")
  model <- object$model
  type <- match.arg(type)
  assert_integerish(predict_times, lower = 0, min.len = 1)
  assert_flag(conf_int)
  assert_int(samples, lower = 1)

  coefs_mat <- matrix(coef(model), nrow = 1)
  if (conf_int) {
    assert_matrix(object$robust$matrix, nrows = ncol(coefs_mat), ncols = ncol(coefs_mat))
    coefs_mat <- rbind(coefs_mat, mvtnorm::rmvnorm(n = samples, mean = coef(model), sigma = object$robust$matrix))
  }

  newdata <- check_newdata(newdata, model, predict_times)


  pred_fun <- if (type == "survival") {
    calculate_survival
  } else if (type == "cum_inc") {
    calculate_cum_inc
  }

  pred_list <- calculate_predictions(
    newdata = newdata,
    model = model,
    treatment_values = c(assigned_treatment_0 = 0, assigned_treatment_1 = 1),
    pred_fun = pred_fun,
    coefs_mat = coefs_mat,
    matrix_n_col = length(predict_times)
  )

  pred_list$difference <- pred_list$assigned_treatment_1 - pred_list$assigned_treatment_0

  mapply(
    pred_matrix = pred_list,
    col_names = paste0(type, c("", "", "_diff")),
    SIMPLIFY = FALSE,
    FUN = function(pred_matrix, col_names) {
      if (conf_int) {
        quantiles <- apply(pred_matrix, 1, quantile, probs = c(0.025, 0.975))
        setNames(
          data.frame(predict_times, pred_matrix[, 1], quantiles[1, ], quantiles[2, ]),
          c("followup_time", col_names, "2.5%", "97.5%")
        )
      } else {
        setNames(data.frame(predict_times, pred_matrix[, 1]), nm = c("followup_time", col_names))
      }
    }
  )
}



#' Check Data used for Prediction
#'
#' @param newdata new data to predict, or missing.
#' @param model glm model object.
#' @param predict_times times to predict to add to resulting newdata.
#' @keywords internal
#' @return A `newdata` data.frame
check_newdata <- function(newdata, model, predict_times) {
  required_vars <- setdiff(all.vars(model$formula), "outcome")
  if (missing(newdata)) {
    newdata <- model$data[, required_vars, with = FALSE]
    newdata <- newdata[newdata$followup_time == 0, ]
  } else {
    assert_data_frame(newdata, min.rows = 1)
    assert_names(colnames(newdata), must.include = required_vars)
    newdata <- data.table(newdata)[, required_vars, with = FALSE]
    newdata <- newdata[newdata$followup_time == 0, ]
    col_attr_model <- lapply(model$data[, required_vars, with = FALSE], attributes)
    col_attr_newdata <- lapply(newdata, attributes)
    if (!isTRUE(all_eq <- all.equal(col_attr_model, col_attr_newdata))) {
      warning("Attributes of newdata do not match data used for fitting. Attempting to fix.")
      newdata <- rbind(model$data[0, required_vars, with = FALSE], newdata)
      fixed <- all.equal(
        lapply(model$data[, required_vars, with = FALSE], attributes),
        lapply(newdata, attributes)
      )
      if (!fixed) {
        print(fixed)
        stop("Attributes do not match.")
      }
    }
  }

  n_baseline <- nrow(newdata)
  newdata <- newdata[rep(seq_len(n_baseline), times = length(predict_times)), ]
  newdata$followup_time <- rep(predict_times, each = n_baseline)

  newdata
}


#' Calculate Cumulative Incidence and Survival
#'
#' @param p_mat Probability matrix with rows for each subject and follow-up time as the columns.
#'
#' @return A vector containing the cumulative incidence or survival values.
#' @noRd
#' @examples
#' surv_prob <- matrix(
#'   c(
#'     0.1, 0.1, 0.1,
#'     0.5, 0.2, 0.1
#'   ),
#'   nrow = 2,
#'   byrow = TRUE
#' )
#' TrialEmulation:::calculate_cum_inc(surv_prob)
calculate_cum_inc <- function(p_mat) {
  assert_matrix(p_mat, mode = "numeric")
  result <- 1 - calculate_survival(p_mat)
  assert_monotonic(result)
  result
}

#' @rdname calculate_cum_inc
#' @noRd
#' @keywords internal
#' TrialEmulation:::calculate_survival(surv_prob)
calculate_survival <- function(p_mat) {
  assert_matrix(p_mat, mode = "numeric")
  result <- rowMeans(apply(1 - p_mat, 1, cumprod))
  assert_monotonic(result, increasing = FALSE)
  result
}



#' Calculate and transform predictions
#'
#' @param model GLM object
#' @param newdata New data to predict outcome
#' @param treatment_values Named vector of value to insert into `assigned_treatment` column
#' @param pred_fun Function to transform prediction matrix
#' @param coefs_mat Matrix of coefficients corresponding to `model_matrix`.
#' @param matrix_n_col Expected number of column after prediction.
#'
#' @return A matrix with transformed predicted values. Number of columns corresponds
#'  to the number of rows of `coefs_mat`
#' @keywords internal
calculate_predictions <- function(newdata, model, treatment_values, pred_fun, coefs_mat, matrix_n_col) {
  model_terms <- delete.response(terms(model))
  model_frame <- model.frame(model_terms, newdata, xlev = model$xlevels)
  if (!is.null(data_classes <- attr(model_terms, "dataClasses"))) .checkMFClasses(data_classes, model_frame)

  linkinv <- model$family$linkinv

  lapply(treatment_values, function(treatment_value) {
    model_frame$assigned_treatment <- treatment_value
    model_matrix <- model.matrix(model_terms, model_frame, contrasts.arg = model$contrasts)
    pred_list <- lapply(seq_len(nrow(coefs_mat)), function(coef_i) {
      pred_fun(matrix(linkinv(model_matrix %*% t(coefs_mat[coef_i, , drop = FALSE])), ncol = matrix_n_col))
    })
    matrix(unlist(pred_list), ncol = length(pred_list))
  })
}

Try the TrialEmulation package in your browser

Any scripts or data that you put into this service are public.

TrialEmulation documentation built on Sept. 11, 2024, 9:06 p.m.