R/imputation-helpers.R

Defines functions pool_diffm_preds pool_diffm setup_mstate set_mi_methods get_predictor_mats

##********************************##
## Helpers for Imputation methods ##
##********************************##


# Set-up matrices/methods -------------------------------------------------


#' Prepare prediction matrices MICE/smcfcs
#' 
#' (Helper specific to simulation study)
#' 
#' @param dat Dataset as generated by generate_dat()
#' 
#' @return List of prediction matrices
#' 
#' @noRd
get_predictor_mats <- function(dat) {
  
  mpred_ch1 <-  matrix(
    data = 0, 
    nrow = ncol(dat), 
    ncol = ncol(dat), 
    dimnames = list(names(dat), names(dat))
  )
  
  mpred_ch12 <- mpred_ch12_int <- mpred_smcfcs <- mpred_ch1 
  
  ## Ch1 model: MI with Z, eps (as factor) and H1(t) as covars
  mpred_ch1["X", c("Z", "ev1", "H1")] <- 1
  
  ## Ch12 model: MI with Z, eps, H1(t) & H2(t) as covars
  mpred_ch12["X", c("Z", "eps", "H1", "H2")] <- 1
  
  # Ch12 model with addition of interactions H x Z
  mpred_ch12_int["X", c("Z", "eps", "H1", "H2", "H1_Z", "H2_Z")] <- 1
  
  # smcfcs - this is the same as defaults anyway
  mpred_smcfcs["X", "Z"] <- 1
  
  # Store in a list
  mats <- list(
    "CH1" = mpred_ch1, 
    "CH12" = mpred_ch12,
    "CH12_int" = mpred_ch12_int,
    "smcfcs" = mpred_smcfcs
  )
  
  return(mats)
}


#' Set methods vector for mice
#' 
#' @param dat A dataframe to be imputed
#' @param var_names_miss Vector of characters names for variables with missing data
#' @param imp_type Imputation type, either "mice" or "smcfcs"
#' @param cont_method Method use to impute continuous covariates, default is 
#' "norm" but can for example change to "pmm"
#' 
#' @noRd
set_mi_methods <- function(dat,
                           var_names_miss,
                           imp_type = "mice",
                           cont_method = "norm") {
  
  # Convert to data.table if not
  if (!("data.table" %in% class(dat))) dat <- data.table::data.table(dat)
  
  # Set up methods vector
  n_vars_miss <- length(var_names_miss)
  meths_miss <- stats::setNames(character(n_vars_miss), var_names_miss)
  
  # Get indicators:
  ordered_ind <- sapply(dat[, .SD, .SDcols = var_names_miss], is.ordered)
  contin_ind <- sapply(dat[, .SD, .SDcols = var_names_miss], is.numeric)
  
  unordered_ind <- sapply(
    dat[, .SD, .SDcols = var_names_miss], 
    function(col) length(levels(col)) > 2 & !is.ordered(col)
  )
  
  binary_ind <- sapply(
    dat[, .SD, .SDcols = var_names_miss], 
    function(col) length(levels(col)) == 2 & !is.ordered(col)
  )
  
  # Set up methods 
  meths_miss[ordered_ind] <- "polr"
  meths_miss[contin_ind] <- cont_method
  meths_miss[unordered_ind] <- "polyreg"
  meths_miss[binary_ind] <- "logreg"
  
  # Adjust if for smcfcs
  if (imp_type == "smcfcs") {
    meths_miss[which(meths_miss == "polyreg")] <- "mlogit"
    meths_miss[which(meths_miss == "polr")] <- "podds"
  }
  
  # Make global vec
  meths <- stats::setNames(character(ncol(dat)), names(dat))
  meths[var_names_miss] <- meths_miss
  
  return(meths)
}


# Analysis method ---------------------------------------------------------


#' Set-up mstate model (pre-predicting) for simulated datasets
#' 
#' Helper function for simulation study - it sets-up and run the 
#' cause-specific Cox models using mstate.
#' 
#' @inheritParams get_predictor_mats
#' 
#' @return Long cox model
#' 
#' @noRd
setup_mstate <- function(dat) {
  
  # Set up transition matrix 
  tmat <- mstate::trans.comprisk(K = 2, names = c("Rel", "NRM"))
  covs <- c("X", "Z")
  
  # Long format
  dat_msprepped <- mstate::msprep(
    time = c(NA, "t", "t"),
    status = c(NA, "ev1", "ev2"), 
    data = dat,
    trans = tmat,
    keep = covs
  ) 
  
  # Use longnames if ordcat - irrelevant now
  if (length(levels(dat$X)) > 2) {
    
    # Expand covariates
    dat_expanded <- mstate::expand.covs(
      dat_msprepped, covs, append = TRUE, longnames = T
    )
    
    # Fit long cox model (both transitions)
    cox_long <- survival::coxph(
      Surv(time, status) ~ 
        Xinterm.1 + Xhigh.1 + Z.1 + # Trans == 1
        Xinterm.2 + Xhigh.2 + Z.2 + # Trans == 2
        strata(trans), # Separate baseline hazards
      data = dat_expanded
    )
    
  } else {
    
    # Expand covariates
    dat_expanded <- mstate::expand.covs(
      dat_msprepped, covs, append = TRUE, longnames = F
    )
    
    # Fit long cox model (both transitions)
    cox_long <- survival::coxph(
      Surv(time, status) ~ 
        X.1 + Z.1 + # Trans == 1
        X.2 + Z.2 + # Trans == 2
        strata(trans), # Separate baseline hazards
      data = dat_expanded
    )
    
  }
  
  return(cox_long)
}


# Pooling -----------------------------------------------------------------


#' Pool for multiple m - regression coefficients
#' 
#' Pools across first \code{n_imp} (vector) imputed datasets.
#' 
#' @param mods_impdats List of imputated datasets
#' @param n_imp Vector of m imputations to extract
#' @param analy String label to attach
#' 
#' @noRd
pool_diffm <- function(mods_impdats,
                       n_imp,
                       analy) {
  
  . <- NULL

  purrr::map_dfr(n_imp, function(m) {
    estims <- mods_impdats[1:m] %$%
      summary(mice::pool(., dfcom = 999999), conf.int = TRUE) %>% 
      dplyr::rename(var = "term") %>% 
      dplyr::mutate(var = as.character(.data$var)) %>% 
      dplyr::mutate(m = m, 
                    analy = analy) %>% 
      dplyr::select(-.data$statistic, -.data$df)
  })
}


#' Pool for multiple m - for probabilities
#' 
#' @inheritParams pool_diffm
#' @param preds_impdats List of imputed datasets
#' 
#' @noRd
pool_diffm_preds <- function(preds_impdats,
                             n_imp,
                             analy) {
  
  
  
  pooled_preds <- purrr::map_dfr(n_imp, function(m) {
    
    # Take first m imputed datasets
    dplyr::bind_rows(preds_impdats[1:m], .id = "imp_num") %>% 
      
      # Make long format (pstate and true) to make pooling easier
      tidyr::pivot_longer(
        cols = .data$pstate1:.data$pstate3, 
        names_to = "state_est", 
        values_to = "prob"
      ) %>% 
      tidyr::pivot_longer(
        cols = .data$true_pstate2:.data$true_pstate1, 
        names_to = "state_true", 
        values_to = "true"
      ) %>% 
      
      # Make single state variable
      tidyr::unite(
        "state", 
        .data$state_est, 
        .data$state_true
      ) %>% 
      dplyr::mutate(
        state = dplyr::case_when(
          stringr::str_detect(.data$state, "pstate1_true_pstate1") ~ "1",
          stringr::str_detect(.data$state, "pstate2_true_pstate2") ~ "2",
          stringr::str_detect(.data$state, "pstate3_true_pstate3") ~ "3"
        )
      ) %>% 
      dplyr::filter(!is.na(.data$state)) %>% 
      tidyr::unite("combo-X_Z", .data$X, .data$Z, sep = "_X-Z_") %>% 
      
      # Rubins rules - only point estimate
      dplyr::group_by(
        .data$state, 
        .data$`combo-X_Z`, 
        .data$times, 
        .data$true
      ) %>% 
      
      # No transformation before pooling
      dplyr::summarise(m = dplyr::n(), p_pool = mean(.data$prob)) %>% 
      dplyr::mutate(analy = analy) %>% 
      dplyr::ungroup() %>% 
      
      # Add Squared error 
      dplyr::mutate(sq_err = (.data$p_pool - .data$true)^2)
  })
  
  return(pooled_preds)
}
survival-lumc/CauseSpecCovarMI documentation built on June 16, 2022, 9:51 a.m.