# Pairwise binomial mixed models
# Jan van den Brand, PhD
# jan.vandenbrand@kuleuven.be
# janvandenbrand@gmail.com
# Project: DKF19OK003

# The work here is based on The work is based on Fieuws S. Mixed Models for Longitudinal Data 2006

packages <- c("optimx", "ggplot2","lattice","MASS", "Matrix", "gridExtra", "tidyverse",
              "GLMMadaptive", "stringr", "flexsurv", "future.apply", "progressr", "viridis")
lapply(packages, library, character.only = TRUE)

#' Bar charts for categorical variables
#' Make bar charts categorical variables in a data.frame
#' @param data A data.frame
#' @param var A character vector with the names of the categorical variables in data.frame.
#' @param by A character value giving the name of the variable to group by (e.g. event status)
#' @import ggplot2
#' @import gridExtra
#' @export
#' @examples 
#' make_bar_plot(data = cars, var = mpg, by = cil)
make_bar_plot <- function(data, var, by = NULL) {
  if (is.null(by)) {
    bar_plot <- ggplot(data = data) +
      geom_bar(aes(x = get(var), fill = get(var))) +
      theme_minimal() +
      labs(x = "", y = "") +
      guides(fill = guide_legend(title = paste(var)))
  } else {
    stopifnot("by should be a character or NULL"=is.character(by))
    bar_plot <- ggplot(data = data) +
      geom_bar(aes(x = get(var), fill = get(var))) +
      theme_minimal() +
      labs(x = "", y = "") +
      guides(fill = guide_legend(title = paste(var))) +
      facet_grid(cols = vars(get(by)))

#' Histograms for continuous variables
#' Make histograms for continuous variable in your data.frame
#' @param data A data.frame
#' @param var a character vector with the names of the numeric variables in data.frame.
#' @import ggplot2
#' @import gridExtra
#' @export
make_histogram <- function(data, var) {
  plot <- ggplot() +
    geom_histogram(data = data, aes(x = get(var)), bins = 20) +
    theme_classic() +
    labs(x = paste("histogram of", var))

#' Density plots for continuous variables
#' Make density plots for continuous variable in your data.frame with option
#' to show data by outcome
#' @param data A data.frame
#' @param var a character vector with the names of the numeric variables in data.frame.
#' @param by A character value giving the name of the variable to group by (e.g. event status)
#' @import ggplot2
#' @import gridExtra
#' @export
make_density <- function(data, var, by = NULL) {
  if (is.null(by)) {
    plot <- ggplot() +
      ggplot() +
      geom_density(data = data, aes(x = var), alpha = 0.5) +
      theme_classic() +
      labs(x = paste("density plot of", var))
  } else {
    stopifnot("by should be a character or NULL"=is.character(by))
    plot <- ggplot() +
      geom_density(data = data, aes(x = get(var), fill = get(by)), alpha = 0.5) +
      theme_classic() +
      labs(x = paste("density plot of", var))

#' Spaghetti plots for longitudinal variables
#' Make spaghetti plots for longintudinal variables in a data.frame.
#' @param data A data.frame in long format.
#' @param time A character value giving the name of the time variable.
#' @param var A character value giving the name of the longitudinal variable.
#' @param id A character value giving the name of the variable with the subject identifier.
#' @param by A character value giving the name of the variable to group the trajectories by (e.g. event status)
#' @param breaks A numeric value with the major tick mark for the x-axis.
#' @import dplyr
#' @import ggplot2
#' @import gridExtra
make_spaghetti_plots <- function(data, ftime, var, id, by = NULL, breaks = 1) {
  if (is.null(by) ) {
    plot <- ggplot(data = data, 
                   aes(x = get(ftime), y = get(var), group = get(id))) +
      geom_jitter(size = 0.2, alpha = 0.3, width = 0.01, height = 0.01) +
      geom_line(size = 0.1, alpha = 0.2) +
      geom_smooth(data = data, aes(x = get(ftime), y = get(var)), method = 'loess') +
      scale_x_continuous(breaks = scales::breaks_width(breaks)) +
      theme_classic() +
      labs(x = "Follow-up time", y = paste(var)) 
  } else {
    plot <- ggplot(data = data, 
                   aes(x = get(ftime), y = get(var), group = get(id), color = as.factor(get(by)))) +
      geom_jitter(size = 0.2, alpha = 0.3, width = 0.01, height = 0.01) +
      geom_line(size = 0.1, alpha = 0.2) +
      geom_smooth(data = data, 
                  aes(x = get(ftime), y = get(var), color = as.factor(get(by))), method = 'loess') +
      scale_x_continuous(breaks = scales::breaks_width(breaks)) +
      theme_classic() +
      labs(x = "Follow-up time", 
           y = paste(var),
           color = by) + 

#' Make pairs
#' Take a character with the names of the longitudinal outcome variables and make all unique pairs
#' @param outcomes a character vector with the names of longitudinal outcomes
#' @returns a matrix of length(outomes)*(length(outcomes)-1)/2 rows and 2 columns with all unique pairs.  
make_pairs <- function(outcomes) {
  N <- length(outcomes)
  Npair <- N * (N-1)/2
  pairs <- matrix(data = 0, nrow = Npair, ncol = 2)
  k <- 1
  for (i in 1:(N-1)) {
    for (j in (i+1):N) {
      pairs[k,1] <- outcomes[i]
      pairs[k,2] <- outcomes[j]
      k <- k + 1

#' Test input data types and return model families and indicator for binary outcomes
#' Takes a data.frame with longitudinal data in long format and checks if the longitudinal data are binary or numeric.
#' @param data A data.frame in long format
#' @param pairs A character matrix constructed by make_pairs function.
#' @return A list of length nrow(pairs) that includes 
#' \describe{
#'   \item{families}{character vector with the custom model family.}
#'   \item{indicator}{a character vector with the indicator for the binary outcome if family = binary.normal}
#' }
test_input_datatypes <- function (data, pairs) {
  # initialize return
  .families <- vector("character", length = nrow(pairs))
  .indicators <- vector("list", length = nrow(pairs))
  # Function body
  for (i in 1:nrow(pairs)) {
    # Assert that y1 is dichotomous
    if (all(sort(unlist(unique(data[paste(pairs[i, 1])]))) == c(0, 1), na.rm = TRUE)) {
      # Assert that y2 is  dichotomous
      if (all(sort(unlist(unique(data[paste(pairs[i, 2])]))) == c(0, 1), na.rm = TRUE)) {
        .families[i] <- "binomial"
        # Assert that y2 is numeric
      } else if (class(unlist(data[paste(pairs[i, 2])])) == "numeric") {
        .families[i] <- "binary.normal"
        .indicators[[i]] <- c(1, 0)
      #  Assert that y1 is numeric
    } else if (class(unlist(data[paste(pairs[i, 1])])) == "numeric") {
      # Assert that y2 is dichotomous
      if (all(sort(unlist(unique(data[paste(pairs[i, 2])]))) == c(0, 1), na.rm = TRUE)) {
        .families[i] <- "binary.normal"
        .indicators[[i]] <- c(0, 1)
        ## Assert that y2 is numeric
      } else if (class(unlist(data[paste(pairs[i, 2])])) == "numeric")  {
        .families[i] <- "normal"
    } else {
      stop("The outcomes variables do not seem to be binary or continuous")
  list("families" = .families, "indicators" = .indicators)

#' Return a list with stacked data sets for each of the outcome pairs
#' Takes a data.frame in with longitudinal data to be used in the multivariate mixed model function.
#' Uses the pairwise combination created with the make_pairs() function to unroll the outcome matrix into a vector.
#' The longitudinal predictor and time invariant covariates are processed accordingly.
#' @param data A data.frame with the longitudinal data in long format.
#' @param pairs A character matrix with the pairs, returned from the make_pairs function
#' @param id A character value with the variable name of the subject identifier.
#' @param covars (optional) a character vector with the names of the covariate vectors.
#' @import Matrix
#' @import dplyr
#' @export
#' @return a list of data.frames of length nrow(pairs).

stack_data <- function (data, id, pairs, covars = NULL) {
  stopifnot("Argument 'id' is missing. Use it to set subject level id." = is.character(id))
  stopifnot("'id' should refer to a character vector" = is.character(unlist(data[id])))
  if (!is.null(covars)) {
    stopifnot("'covars' should be a character vector." = is.character(covars))
  if (!is.null(covars)) {
    # covars should be numeric
    if (all(sapply(data %>% dplyr::select(all_of(covars)), class) == "numeric") == FALSE) {
      stop("stack_data can only handle numeric data: '", 
           paste(covars[which(sapply(data %>% dplyr::select(all_of(covars)), class) != "numeric")]), 
           "' is not numeric.")
  ids <- rep(unlist(data[id]), each = 2)
  stacked_data <- lapply(1:nrow(pairs), function(i, ...) {
  # i <- 1
  # Unroll Y matrices into vectors
    Y <- data %>%
      dplyr::select(y1 = paste(pairs[i, 1]), y2 = paste(pairs[i, 2]))
    Y <- as.vector(t(as.matrix(Y)))
    Y <- data.frame(id = ids,
                    Y = Y)
    # Create design matrices
    ## The ordering of x1 and x2 is intentional!
    X1 <- data %>%
      dplyr::select(x1 = paste(pairs[i,2]), x2 = paste(pairs[i,1]))
    X1 <- as.vector(t(as.matrix(X1)))
    ## add fixed covariates and time to the design matrices
    X2 <- data %>%
    X2 <- as.matrix(X2)
    ## add dummies for X
    intercepts <- kronecker(rep(1, times = nrow(data)), diag(1, nrow = 2))
    X2k <- kronecker(X2, diag(1, nrow = 2))
    X2k <- cbind(intercepts, X2k)
    ## join the data
    .stacked_data <- cbind(Y, X1, X2k)
    ## set variable names
    var_names <- c(id, "Y", "X", "intercept_Y1", "intercept_Y2")
    if (!is.null(covars)) {
      .covariate_names <- sapply(covars, function(j) {
        c(paste0(j,"_Y1"), paste0(j,"_Y2"))
      var_names <- c(var_names, paste(.covariate_names))
    names(.stacked_data) <- var_names

#' Test to compare stacked outcome data to original outcome data
#' @param data data.frame with original data in long format
#' @param stacked_datathe list of data.frames returned by stack_data()
#' @param pairs A character matrix with the pairs, returned from the make_pairs function
#' @import dplyr
#' @return Prints a list with assertions. All should return TRUE. Otherwise something went wrong.
test_compare_stacked_to_original_data <- function (data, stacked_data, pairs) {
  for (i in 1:nrow(pairs)) {
    current <- unlist(data[paste(pairs[i,1])])
    target <- stacked_data[[i]] %>%
      dplyr::filter(intercept_Y1 == 1) %>%
    check <- cbind(current, target)
    print(paste('Assert: outcome vector matches the original data for', pairs[i, 1], '=',
                all.equal(check[,1], check[,2], na.rm = TRUE, check.attributes = FALSE, use.names = FALSE))
    current <- unlist(data[paste(pairs[i,2])])
    target <- stacked_data[[i]] %>%
      dplyr::filter(intercept_Y1 == 1) %>%
    check <- cbind(current, target)
    print(paste('Assert: outcome vector matches the original data for', pairs[i, 2], '=',
                all.equal(check[,1], check[,2], na.rm = TRUE, check.attributes = FALSE, use.names = FALSE))

binary.normal <- function (indicator = stop("'indicator must be specified'")) {
  .indicator <- indicator
  env <- new.env(parent = .GlobalEnv)
  assign(".indicator", indicator, envir = env)
  stats <- make.link("identity")
  log_dens <- function (y, eta, mu_fun, phis, eta_zi) {
    ind <- rep(.indicator, times = length(y)/2)
    eta <- as.matrix(eta)
    out <- eta
    # linear mixed model
    sigma <- exp(phis)
    out[ind == 0, ] <- dnorm(y[ind == 0], eta[ind == 0, ], sigma, log = TRUE)
    # mixed logistic model
    probs <- exp(eta[ind == 1, ]) / (1 + exp(eta[ind == 1, ]))
    out[ind == 1,] <- dbinom(y[ind == 1], 1, probs, log = TRUE)
  structure(list(family = "user binary normal", link = stats$name, linkfun = stats$linkfun,
                 linkinv = stats$linkinv, log_dens = log_dens),
            class = "family")

normal <- function () {
  env <- new.env(parent = .GlobalEnv)
  stats <- make.link("identity")
  log_dens <- function (y, eta, mu_fun, phis, eta_zi) {
    eta <- as.matrix(eta)
    out <- eta
    # linear mixed models
    sigma <- exp(phis)
    out <- dnorm(y , eta, sigma, log = TRUE)
  # score_eta_fun <- function (y, mu, phis, eta_zi) {
  #   # the derivative of the log density w.r.t. mu
  #   sigma2 <- exp(phis)^2
  #   y_mu <- y - mu
  #   y_mu / sigma2
  # }
  # score_phis_fun <- function (y, mu, phis, eta_zi) {
  #   sigma <- exp(phis)
  #   y_mu <- y - mu
  #   sigma^{-2} / (1 + y_mu / sigma^2) - 1
  # }
  # environment(log_dens) <- environment(score_eta_fun) <- environment(score_phis_fun) <- env
  structure(list(family = "user defined normal" , link = stats$name, linkfun = stats$linkfun,
                 linkinv = stats$linkinv, log_dens = log_dens),
            class = "family")

#' Generate fixed formula
#' Takes a list of covariates. Each element of the list are the covariates for a pairwise model
#' @param covars a list of character vectors with the covariates names
#' @export
#' @return a list with the fixed formulas for the pairwise model fitting stage.
make_fixed_formula <- function(covars = NULL, ...) {
    if (!is.null(covars)) {
      covars <- c(sapply(covars, function(x) paste(c(x, "Y1"), collapse = "_")),
                  sapply(covars, function(x) paste(c(x, "Y2"), collapse = "_")))
      intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
      x <- sapply(intercepts, function(x) paste(c(x,"X"), collapse = ":"))
      paste("Y ~ -1 +", paste(c(intercepts, x, covars), collapse = " + "))
    } else {
      intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
      x <- sapply(intercepts, function(x) paste(c(x,"X"), collapse = ":"))
      paste("Y ~ -1 +", paste(c(intercepts, x), collapse = " + "))

#' Generate random effects formula
#' Takes a list of covariates for the random slopes. 
#' Each element of the list are the random slopes for a pairwise model
#' @param covars a list of character vectors with the covariates names
#' @export
#' @return a list with the fixed formulas for the pairwise model fitting stage.
make_random_formula <- function(id, covars = NULL, ...) {
  if (!is.null(covars)) {
    covars <- c(sapply(covars, function(x) paste(c(x, "Y1"), collapse = "_")),
                sapply(covars, function(x) paste(c(x, "Y2"), collapse = "_")))
    intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
    paste(paste("~ -1 +", paste(c(intercepts, covars), collapse = " + ")), "|", id)
  } else {
    intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
    paste(paste("~ -1 +", paste(c(intercepts), collapse = " + ")), "|", id)

#' Create random start values
#' Take a model info object created by test_input_datatypes and add random start values
#' @param model_families a list with model inf returned by test_input_datatypes
#' @param n_fixed the number of fixed parameters
#' @param n_random the number of random parameters
#' @import Matrix
#' @return a list with model information including:
#' \describe{
#' \item{The model family.}
#' \item{An indicator for the binary outomce in the binary.normal family.}
#' \item{Starting values for the model fitting stage.}
#' }
#' @export
get_random_start <- function(model_families, n_fixed, n_random) {
  betas <- rnorm(n_fixed, mean = 0, sd = 0.33)
  D <- as.matrix(
        matrix(runif(n_random^2), ncol = n_random, nrow = n_random)
  phis <- rnorm(2, mean = 0, sd = 0.33)
  model_families$start_values <- lapply(1:length(model_families$families), function(i) {
    if (model_families$families[[i]] == "binomial") {
      list("betas" = betas, 
           "D" = D
    } else if (model_families$families[[i]] == "binary.normal") {
      list("betas" = betas,
           "D" = D,
           "phis" = phis[1]
    } else if (model_families $families[[i]] == "normal") {
      list("betas" = betas,
           "D" = D,
           "phis" = phis

# Fit a mixed_model for all the frames in stacked data and store the starting values
# optimParallel returns error:
# Error in FGgenerator(par = par, fn = fn, gr = gr, ..., lower = lower,  : is.vector(par) is not TRUE
# Therefore use optim.
# The model fitting within a pair is sequential, since optimParallel throws an error. 
# This is a limitation when the number of pairs is less than the number of available workers.
# FIX: add a check for the formula inputs. If a '+' sign is forgotten a opaque error message is returned.

#' Fit pairwise mixed models to get start values for the multivariate mixed model
#' Takes the stacked data and the user defined formulas for the fixed and random effects. Leverages the future.apply package to parallelize along pairs.
#' Using a different number of workers than pairs may result in problems during the optimization process. 
#' @param stacked_data a list with the stacked data returned by the stack_data() function.
#' @param fixed character string with the fixed model part. This is passed to as.formula().
#' @param random character string with the random model part. This is passed to as.formula().
#' @param pairs a character matrix with the pairs returned by the make_pairs() function.
#' @param model_families a list with family names and indicators returned by the test_input_datatypes() function.
#' @import future.apply
#' @import GLMMAdaptive
#' @return a list with starting values to determine the subject level contributions to the derivates.

get_mmm_start_values <- function (stacked_data, 
                                  user_initial_values = NULL,
                                  nAGQ = 11, 
                                  iter_EM = 30, 
                                  iter_qN_outer = 10, 
                                  ...) {
  prog <- progressr::progressor(along = pairs)
  out <- future_lapply(1:nrow(pairs), function(i, ...) {
    prog(sprintf("Fitting pairwise model number %g of %g.", i, nrow(pairs)), 
         class = "sticky",
         amount = 0)
    if (model_families$families[i] == "binary.normal") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = binary.normal(indicator = model_families$indicators[[i]]),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         n_phis = 1,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
      starting_values <- list(betas = fixef(fit), phis = fit$phis, D = fit$D)
    } else if (model_families$families[i] == "normal") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = normal(),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         n_phis = 2,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
      starting_values <- list(betas = fixef(fit), phis = fit$phis, D = fit$D)
    } else if (model_families$families[i] == "binomial") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = binomial(),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
      starting_values <- list(betas = fixef(fit), D = fit$D)
    } else {
      stop("Model family not found. Has a list with family for every model been provided?") # Fix: this should be evaluated before fitting is started.
    list("fit" = fit, "starting_values" = starting_values)  

#' Determine the subject level derivatives.
#' Run mixed models without iterating and return subject level derivatives
#' This function uses the start values determined in by the function get_start_values()  
#' to determine the subject level contributions to the derivatives. In turn these are used 
#' to average the parameters and standard errors for the full multivariate mixed model.
#' @param id The name of the column with subject ids in stacked_data
#' @param stacked_data A list of data.frames returned by the stack_data() function
#' @param fixed A formula (as character string) for the fixed part of the mixed model. Passed to as.formula().
#' @param random A formula (as character string) for the random part of the mixed model. Passed to as.formula().
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param model_families A list with the model families returned by the test_input_datatypes() function.
#' @param start_values A list (length == length(stacked_data)) with start values returned by the get_start_values() function.
#' @import future.apply
#' @import GLMMAdaptive
#' @import dplyr
#' @return A list of length nrow(pairs) each with two elements:
#' \describe{
#' \item{df_hessians}{a list with data.frames storing the subject level Hessians}
#' \item{df_gradients}{a list with data.frames storing the subject level gradients}
#' }

get_mmm_derivatives <- function (stacked_data, id, fixed, random, pairs, model_families, start_values, nAGQ = 11) {
  stopifnot("id should be a character" = is.character(id))
  prog <- progressr::progressor(along = pairs)
  derivatives <- lapply(1:nrow(pairs), function(i, ...) {
    # message("Getting derivatives for pairwise model: ",i, "out of ", nrow(pairs))
    prog(sprintf("Getting derivatives for pairwise model %g out of %g", i, nrow(pairs)),
         class = "sticky",
         amount = 0)
    subject_out <- future_lapply(1:nrow(unique(stacked_data[[i]][id])), function(j, ...) {
      subject_data <- stacked_data[[i]][stacked_data[[i]][id] == unique(stacked_data[[i]][id])[j,1],]
      # Data copy to work around optim requiring n > 1 subjects
      subject_data_copy <- subject_data
      subject_data_copy[,id] <- paste0(j,"a")
      subject_data <- rbind(subject_data, subject_data_copy)
      if (model_families$families[i] == "binary.normal") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = binary.normal(indicator = model_families$indicators[[i]]),
                           n_phis = 1,
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        # Store Hessian (undoing the work around)
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
      } else if (model_families$families[i] == "normal") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = normal(),
                           n_phis = 2,
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
      } else if (model_families$families[i] == "binomial") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = binomial(),
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
      names(g) <- colnames(h)
      list("hessians" = h,
           "gradients" = g)
      "df_hessians" = do.call(rbind, lapply(subject_out, function(l) l$hessians)),
      "df_gradients" = do.call(rbind, lapply(subject_out, function(l) l$gradients))

#' Take the parameter estimates, add labels and stack them into a data.frame
#' Extract the parameter estimates from the start_values and label them back to the names of the original data.
#' @param fixed A formula (as character string) for the fixed part of the mixed model. Passed to as.formula().
#' @param random A formula (as character string) for the random part of the mixed model. Passed to as.formula().
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param start_values The list of start_values (i.e. parameter estimates) returned by the get_start_values() function.
#' @return A data.frame with the labelled parameter estimates.

stack_parameters <- function (start_values, pairs, model_families) {
  parameters <- lapply(1:nrow(pairs), function(i) {
    names_fixed <- names(start_values[[i]]$starting_values$betas)
    .parameter_label_y1 <- names_fixed[str_detect(names_fixed, "Y1")]
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "intercept", "")
    if (any(str_detect(.parameter_label_y1, ":X")) == TRUE) {
      .parameter_label_y1[str_detect(.parameter_label_y1, "X")] <- 
        paste(rev(unlist(str_split(.parameter_label_y1[str_detect(.parameter_label_y1, "X")], ":"))), collapse = "") 
    } else {
    .parameter_label_y1[str_detect(.parameter_label_y1, "X")] <- 
      paste(unlist(str_split(.parameter_label_y1[str_detect(.parameter_label_y1, "X")], ":")), collapse = "") 
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "X", pairs[i, 2])
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "Y1", pairs[i, 1])
    .parameter_label_y1 <- sapply(1:length(.parameter_label_y1), function(x) paste0("b", x - 1, "_", .parameter_label_y1[x]))
    .parameter_value_y1 <- start_values[[i]]$starting_values$betas[str_detect(names_fixed, "Y1")]
    .parameter_label_y2 <- names_fixed[str_detect(names_fixed, "Y2")]
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "intercept", "")
    if (any(str_detect(.parameter_label_y2, pattern = ":X")) == TRUE) {
    .parameter_label_y2[str_detect(.parameter_label_y2, "X")] <-
      paste(rev(unlist(str_split(.parameter_label_y2[str_detect(.parameter_label_y2, "X")], ":"))), collapse = "")
    } else {
    .parameter_label_y2[str_detect(.parameter_label_y2, "X")] <-
      paste(unlist(str_split(.parameter_label_y2[str_detect(.parameter_label_y2, "X")], ":")), collapse = "")
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "X", pairs[i, 1])
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "Y2", pairs[i, 2])
    .parameter_label_y2 <- sapply(1:length(.parameter_label_y2), function(x) paste0("b", x - 1, "_", .parameter_label_y2[x]))
    .parameter_value_y2 <-  start_values[[i]]$starting_values$betas[str_detect(names_fixed, "Y2")]
    fixed <- data.frame(
      parameter_label = c(.parameter_label_y1, .parameter_label_y2),
      parameter_estimate = c(.parameter_value_y1, .parameter_value_y2), 
      row.names = NULL
    # Set names and labels for and get values for random effects
    labels_random <- levels(interaction(
      sep = "_x_"))
    labels_random <- t(matrix(labels_random, 
                              ncol = ncol(start_values[[i]]$starting_values$D), 
                              byrow = FALSE))
    diag(labels_random) <- paste0("var_",diag(labels_random))
    labels_random[lower.tri(labels_random)] <- paste0("cov_", labels_random[lower.tri(labels_random)])
    labels_random <- labels_random[lower.tri(labels_random, diag = TRUE)]
    labels_random <- str_replace_all(labels_random, "Y1", pairs[i,1])
    labels_random <- str_replace_all(labels_random, "Y2", pairs[i,2])
    values_random <- start_values[[i]]$starting_values$D[lower.tri(start_values[[i]]$starting_values$D, diag = TRUE)]
    random <- data.frame(
      parameter_label = labels_random,
      parameter_estimate = values_random,
      row.names = NULL
    # Get residual errors
    resid_labels <- c()
    if (model_families$families[i] == "binary") {
      resid_labels <- NULL
    } else if (model_families$families[i] == "binary.normal") {
      if (all(model_families$indicators[[i]] == c(0, 1 ))) {
        resid_labels <- paste0("resid_", pairs[i, 1])
      } else {
        resid_labels <- paste0("resid_", pairs[i, 2])
    } else if (model_families$families[i] == "normal") {
      resid_labels <- c(paste0("resid_", pairs[i, 1]),
                        paste0("resid_", pairs[i, 2]))
    resid <- data.frame(
      parameter_label = resid_labels,
      parameter_estimate = start_values[[i]]$fit$phis,
      row.names = NULL
    out <- bind_rows(fixed, random, resid)
    rownames(out) <- NULL
  do.call(rbind, parameters)

#' Stack gradients into a single data.frame
#' Take the gradients returned by the get_mmm_derivatives function and stack these into a single data.frame for further processing
#' @param derivatives list with derivatives returned by the get_mmm_derivatives() function.
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param parameter_labels The parameter labels returned by the stack_parameters() function.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @return A data.frame with subject level gradients for every longitudinal outcome

stack_gradients <- function (derivatives, data, id, pairs) {
  gradients <- lapply(1:nrow(pairs), function(i) {
    g <- lapply(1:nrow(unique(data[id])), function(j, ...) {
          "id" = derivatives[[i]]$df_gradients[j, 1],
          "gradient" = derivatives[[i]]$df_gradients[j, -1],
          row.names = NULL
    do.call(rbind, g)
  gradients <- do.call(rbind, gradients)

#' Create the helper 'A'  matrix used to average the parameter estimates.
#' Take parameter labels from MMM object and return a p*q matrix where p is unique(parameters) and q is the total number of parameters (including duplicates).
#' @param parameters A data frame with parameter labels returned by the stack_parameters() function.
#' @Return A p*q matrix with ones if a parameter was estimated and zeros if it was not estimated in the pairwise step

create_A_matrix <- function (parameter_labels) {
  A <- matrix(0, nrow = length(unique(parameter_labels)), ncol = length(parameter_labels))
  dimnames(A) <- list(
  for (i in 1:nrow(A)) {
    for (j in 1:ncol(A)) {
      if (rownames(A)[i] == colnames(A)[j]) {
        A[i,j] <- 1

#' Get cross products of the gradients
#' Get the cross product of the subject specific gradient vector, sum and take the average
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param parameter_labels The parameter labels returned by the stack_parameters() function.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param gradients A data frame with gradients returned by the stack_gradients() function.
#' @import Matrix
#' @return A matrix with the average gradient for all parameters over all subjects.

get_crossproducts <- function (pairs, parameter_labels,  gradients, data, id) {
  cpgrad <- matrix(0, nrow = length(parameter_labels), ncol = length(parameter_labels),
                   dimnames = list(
  # The loop sums all the subject level contributions to the gradient.
  for (i in 1:nrow(unique(data[id]))){
    .id <- unlist(unique(data[paste(id)])[i,])
    gsub <- gradients %>% 
      dplyr::filter(id == .id) %>% 
    cpgsub <- tcrossprod(as.matrix(gsub)) 
    cpgrad <- cpgrad + cpgsub 
  # Average the gradients
  cpgrad * (1/nrow(unique(data[paste(id)])))

#' Average the Hessians
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param hessians a data frame with the individual contribution to the hessians returned by the get_mmm_derivatives() function
#' @param A the A helper matrix returned by the create_A_matrix function.
#' @import Matrix 
#' @import dplyr
#' @return a Sparse matrix of the MMM Hessian

average_hessians <- function(derivatives, pairs, data, id, A) {
  # Outer loop: create block Sparse matrix of the multivariate mixed model Hessian
  for (j in 1:nrow(pairs)) {
    hess <- as.data.frame(derivatives[[j]]$df_hessians)
    hess_pair <- matrix(0, nrow = ncol(derivatives[[j]]$df_hessians)-1, 
                        ncol = ncol(derivatives[[j]]$df_hessians)-1)
    # Inner loop: subject level contributions to the pairwise Hessian
    for (i in 1:nrow(unique(data[id]))) {
      .id <- unlist(unique(data[paste(id)])[i,])
      hsub <- hess %>% dplyr::filter(id == .id)
      hsub <- as.matrix(hsub)[,-1]
      hess_pair <- hess_pair + hsub
    if (j == 1) {
      Hessian <- hess_pair
    } else {
      Hessian <- as.matrix(bdiag(Hessian, hess_pair))
  # Average the Hessian
  Hessian <- Hessian * (1/nrow(unique(data[paste(id)])))
  dimnames(Hessian) <- list(

#' Calculate the covariance matrix for the MMM and return it and the robust standard error 
#' Use a sandwich estimator to return the standard errors for the full multivariate mixed model.
#' @param hessian The sparse Hessian matrix for the MMM returned by the average_hessians function().
#' @param cpgradient The cross-product and average of the subject level gradients returned by the get_crossproducts() function.
#' @param parameters The parameter estimates returned by the stack_parameters() function.
#' @param A the A helper matrix returned by the create_A_matrix function().
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @import Matrix
#' @return A data.frame with parameter estimates and standard errors for the multivariate mixed model.

# Covariance of sqrt(n) * (theta - theta_0) Fieuws 2006 q(4) p453
get_covariance <- function(hessian, cpgradient, parameters, A, data, id) {
  JKJ <- ginv(hessian) %*% cpgradient %*% ginv(hessian)
  dimnames(JKJ) <- list(
  # Covariance of (theta - theta_0)
  cov <- JKJ * (1/nrow(unique(data[paste(id)])))
  # Average of the parameters
  A <- A * (1/rowSums(A))
  avg_par <- A %*% parameters$parameter_estimate
  covrobust <- A %*% cov %*% t(A)
  serobust <- sqrt(diag(covrobust))
  out <- data.frame( 
    "parameter_name" = rownames(A),
    "parameter_estimate" = avg_par,
    "parameter_std_err" = serobust,
    row.names = NULL

#' Create the full mixed model output
#' Take the random effect estimates and return vcov and correlation matrices
#' @param estimates A data.frame with estimates returned by the get_covariance() function.
#' @import dplyr
#' @return A list of 3 elements including:
#' \describe{
#'   \item{estimates}{A data.frame with the parameter estimates and standard errors}
#'   \item{vcov}{The Variance-Covariance matrix for the random effects}
#'   \item{corr}{A correlation matrix for the random effects}
get_model_output <- function(estimates) {
  # Sort the fixed effects
  fixef <- estimates %>% dplyr::filter(grepl("^b", parameter_name) == TRUE)
  # fixef <- fixef[order(fixef$parameter),]
  # Sort the random effects
  varest <- estimates %>% dplyr::filter(grepl("^var", parameter_name) == TRUE)
  covest <- estimates %>% dplyr::filter(grepl("^cov", parameter_name) == TRUE)
  # Sort the residuals
  resid <- estimates %>% dplyr::filter(grepl("resid", parameter_name) == TRUE)
  resid$parameter_estimate <- exp(resid$parameter_estimate)
  resid$parameter_std_err <- exp(resid$parameter_std_err)
  # Covariance matrix
  est_D <- matrix(0, nrow = nrow(varest), ncol = nrow(varest))
  dimnames(est_D) <- list(gsub("^.*_x_","", unique(varest$parameter_name)), 
                          gsub("^.*_x_","", unique(varest$parameter_name)))
  diag(est_D) <- varest$parameter_estimate
  est_D[lower.tri(est_D)] <- covest$parameter_estimate
  est_D[upper.tri(est_D)] <- t(est_D)[upper.tri(est_D)]
  # check if est_D is positive definite
  if (all(eigen(est_D)$values > 0) == FALSE) {
    # Compute the nearest positive definite matrix
    est_D <- nearPD(est_D)$mat
    # Replace the variance and covariance estimates with the nearest PD values
    varest$parameter_estimate <- diag(est_D)
    covest$parameter_estimate <- est_D[lower.tri(est_D)]
    # Send a warning to the user.
    message("The Variance-Covariance matrix was not a positive definine. 
            Projecting to the closest positive definine. 
            Note that your results may be unexpected.")
  est_Dcorr <- cov2cor(est_D)
  # outcome tables
  out <- list("estimates" = rbind(fixef, varest, covest, resid),
              "vcov" = est_D,
              "corr" = est_Dcorr)

#' Fit a multivative mixed model using the pairwise fitting approach
#' Takes the stacked data and the user defined formulas for the fixed and random effects. Leverages the future.apply package to parallelize along pairs.
#' Using a different number of workers than pairs may result in problems during the optimization process. The output of the pairwise model fitting is combined
#' and any parameters that have been estimated multiple times are averaged. The procedure takes the individual level contributions to the Maximum Likelihood in
#' order to estimate the standard errors for the parameter estimates with a robust sandwich estimator. 
#' @param stacked_data a list with the stacked data returned by the stack_data() function.
#' @param fixed character string with the fixed model part. This is passed to as.formula().
#' @param random character string with the random model part. This is passed to as.formula().
#' @param pairs a character matrix with the pairs returned by the make_pairs() function.
#' @param model_families a list with family names and indicators returned by the test_input_datatypes() function.
#' @param id The name of the column with subject ids in stacked_data.
#' @param iter_EM (default = 100) EM iterations passed to GLMMAdaptive::mixed_model().
#' @param iter_qN_outer (default = 10) outer Quasi Newton iterations passed to GLMMAdaptive::mixed_model(). 
#' @param parallel_plan (default = "sequential") The parallel plan to be passed to future.apply(). Currently available: "sequential" and "multisession."
#' @param ncores number of workers for multisession parallel plan.
#' @param (...) additional arguments passed to GLMMAdaptive::mixed_model().
#' @import tidyverse
#' @import GLMMAdaptive
#' @import Matrix
#' @import future.apply
#' @return A list of 3 elements including:
#' \describe{
#'   \item{estimates}{A data.frame with the parameter estimates and standard errors}
#'   \item{vcov}{The Variance-Covariance matrix for the random effects}
#'   \item{corr}{A correlation matrix for the random effects}
mmm_model <- function (fixed, random, id, data, stacked_data, pairs, model_families, iter_EM = 100,
                       iter_qN_outer = 10, parallel_plan = "sequential", ncores = NULL,
                       ...) {
  if (!id %in% colnames(data)) {
    stop("Could not find `id` in colnames(`data`). Have you specified the correct subject id?")
  if (parallel_plan == "sequential") {
    message("Fitting pairwise model using ", parallel_plan, ".")
  } else if (parallel_plan == "multisession") {
    if (is.null(ncores)) {
      ncores <- min(parallel::detectCores() - 1, nrow(pairs))
    plan(multisession, workers = ncores)
    message("Fitting pairwise model using ", parallel_plan, " on ", ncores, " cores.")
  start_values <- get_mmm_start_values(stacked_data = stacked_data,
                                       fixed = as.formula(fixed),
                                       random = as.formula(random),
                                       pairs = pairs,
                                       model_families = model_families,
                                       iter_EM = iter_EM, 
                                       iter_qN_outer = iter_qN_outer,
  if (parallel_plan == "sequential") {
    message("Retrieving derivatives using ", parallel_plan, ".")
  } else if (parallel_plan == "multisession") {
    ncores <- parallel::detectCores() - 1
    plan(multisession, workers = ncores)
    message("Retrieving derivatives using ", parallel_plan, " on ", ncores, " cores.")
  derivatives <- get_mmm_derivatives(stacked_data = stacked_data,
                                     id = id,
                                     fixed = as.formula(fixed),
                                     random = as.formula(random),
                                     pairs = pairs,
                                     model_families = model_families,
                                     start_values = start_values)
  message("Compiling model output.")
  parameters <- stack_parameters(start_values = start_values,
                                 pairs = pairs,
                                 model_families = model_families)
  gradients <- stack_gradients(derivatives = derivatives,
                               data = data,
                               id = id,
                               pairs = pairs)
  A <- create_A_matrix(parameter_labels = parameters$parameter_label)
  cpgradient <- get_crossproducts( pairs = pairs,
                                   parameter_labels = parameters$parameter_label,
                                   gradients = gradients,
                                   data = data,
                                   id = id)
  hessian <- average_hessians(derivatives = derivatives,
                              pairs = pairs,
                              data = data,
                              id = id,
                              A = A)
  mmm_estimates <- get_covariance(hessian = hessian,
                                  cpgradient = cpgradient,
                                  parameters = parameters,
                                  A = A,
                                  data = data,
                                  id = id)
  model <- get_model_output(estimates = mmm_estimates)
  # reset future plan

#'  Return a data.frame with the outcome label and type
#' @param data original data
#' @param outcomes a character vector with the names of the longitudinal outcomes

get_outcome_type <- function(data, outcomes) {
  out <- lapply(outcomes, function(i) {
      if (all(sort(unlist(unique(data[paste(i)]))) == c(0, 1), na.rm = TRUE)) {
        outcome_type <- "binary"
      } else if (class(unlist(data[paste(i)])) == "numeric") {
        outcome_type <- "continuous"
      } else {
        stop(paste("The outcome", i, "does not appear to be binary or continuous"))
    data.frame(outcome = i,
             outcome_type = outcome_type)
  do.call(rbind, out)

# FIX: grepl for variable names results in poor abstraction and generalization. Parts of variable names may be similar or hold symbols used in regex. 
# FIX: models should be different for different longitudinal outcomes 
# FIX: models should be different for different longitudinal outcomes

#'  Get Likelihood No Fail
#' Get the parameter estimates from a MMM and combine with data to return contributions to the likelihood at observed time points for non failures
#' @param outcomes a character vector with the labels for the longitudinal outcomes.
#' @param data data.frame with original data in long format
#' @param fixed_formula_nofail a character string with lme4 style formula for the fixed effects
#' @param random_formula_nofail a character string with lme4 style formula for the random effects
#' @param time a character string with the label for the follow-up time variable
#' @param parameters_nofail a data.frame with the parameter labels and estimates returned by mmm_model()
#' @param random_effects_nofail The random effects sampled from a multivariate normal distribution (see MASS::mvrnorm)
#' @param id a character string with subject identifier
#' @param horizon the prediction horizon 
#' @import future.apply
#' @import dplyr 
#' @import Matrix
#' @return A list with the likelihood contributions for each subject (rows) at every time point (cols).

get_likelihood_nofail <- function (data,
                                   fixed_formula_nofail = NULL,
                                   random_formula_nofail = NULL,
                                   ...) {
  # cut the fixed and random formulas into parts
  ff <- str_remove(fixed_formula_nofail, "~ ")
  ff <- str_split(ff, " \\+ ") %>% unlist()
  rf <- str_remove(random_formula_nofail, "~ ")
  rf <- str_remove(rf, paste(" \\|", id))
  rf <- str_split(rf, " \\+ ") %>% unlist()
  # Select the parameter estimates for fixed effects
  params <- parameters_nofail %>%
    dplyr::filter(grepl("^b", parameter_name)) %>%
    dplyr::select(parameter_name, parameter_estimate)
  # Initialize lists to store the log-likelihoods by outcome for each individual i, at every time point
  # Loop over the patterns. Note that non-failures only have a single pattern as they all survive past the max horizon.
  .s <- horizon
    # Calculate the log-likelihoods for every outcome j
  likelihood <- future_lapply( 1:nrow(unique(data[id])), function(i, ...) {
  # i <- 1
  # Calculate the log-likelihoods for every outcome j
    ll <- lapply(1:nrow(outcomes), function (j) {
    # j <- 1
      # Get the fixed effect parameters
      betas <- params %>%
        dplyr::filter(grepl(paste0("_", outcomes[j, 1], "$"), parameter_name)) %>% 
        dplyr::select(parameter_estimate, parameter_name)
      # adjust the order of the parameter estimates to match the fixed_formula argument
      betas <- betas %>% 
        mutate(parameter_name = str_replace(parameter_name, "b0_", "intercept"))
      betas <- betas %>% 
        mutate(parameter_name = str_remove(parameter_name, paste0("_", outcomes[j, 1], "$")))
      betas <- betas %>% 
        mutate(parameter_name = str_remove(parameter_name, paste0("^(..|...)_")))
      betas <- betas[match(c("intercept", ff[-j]), betas$parameter_name), ] %>%
      betas <- as.matrix(betas)
      if(!assertthat::noNA(betas)) {
        stop("Not all parameters have matching estimates. Have you correctly specified the the fixed_formula_nofail argument?")
      # Get the residual
      resid <- parameters_nofail %>%
        dplyr::filter(grepl(paste0("resid", paste0("_", outcomes[j, 1],"$")), parameter_name)) %>%
      resid <- as.numeric(resid)
      # get the log likelihood for the i-th subject
      # log_likelihood <- vector("list" , length = nrow(unique(data[id])))
      .id <- unlist(unique(data[id]))[i]
      # Get the outcome vector for subject i
      y <- data %>%
        dplyr::filter(get(id) == .id & time <= .s) %>%
        dplyr::select(all_of(outcomes[j, 1]))
      y <- as.matrix(y)
      # Using bX assumes similar models for all longitudinal outcomes. 
      .ff <- str_remove(ff, paste0("^", outcomes[j, 1], "$"))
      .ff <- .ff[.ff != ""]
      # Get the X matrix for subject i including all data prior to the horizon.
      X <- data %>%
        dplyr::filter(get(id) == .id & time <= .s)
      X <- model.matrix(as.formula(paste("~ ", paste(.ff, collapse = " + "))), X)
      # Get the linear predictor for the fixed effect
      xb <- X %*% betas
      # Get the Z matrix for subject i
      .rf <- str_remove(rf, paste0("^",outcomes[j, 1], "$"))
      .rf <- .rf[.rf != ""]
      # Select all data prior to the prediction horizon.
      Z <- data %>%
        dplyr::filter(get(id) == .id & time <= .s)
      Z <- model.matrix(as.formula(paste("~", paste(.rf, collapse = " + "))), Z)
      # log_lik is the log likelihood of the j-th outcome for the i-th subject at observed time points
      # for the k-th draw from the random effects.
      if (outcomes$outcome_type[j] == "continuous") {
        # for (k in 1:nrow(random_effect)) {
        log_lik <- lapply(1:nrow(random_effects_nofail), function(k) {  
          # get the linear predictor = X*beta + Z*bu
          zb <- rowSums(Z * random_effects_nofail[k, j])
          lp <- xb + zb
          # calculate the log-likelihood
          -log(resid) -0.5 * log(2 * pi) - ( (y - lp)^2 / (2 * resid^2) )
      } else if (outcomes$outcome_type[j] == "binary") {
        log_lik <- lapply(1:nrow(random_effects_nofail), function(k) {
          # lp = X*beta + Z*bu
          zb <- rowSums(Z * random_effects_nofail[k, j])
          lp <- xb + zb
          p <- exp(lp) / (1 + exp(lp))
          lik <- ifelse(y == 0, 1 - p, p)
          lik <- ifelse(lik > 1e-8, lik, 1e-10)
      # log likelihood for j-th outcome at each observed time point (columns) and all RE draws (rows) for i-th subject  
      list("ll" = t(do.call(cbind, log_lik)),
           ".id" = .id,
           "X" = X)
    # return the sum over the outcomes at each observed time point (columns) and all RE  draws (rows) for i-the subject
    sum_ll <- ll[[1]]$ll
    for (j in 2:length(ll)) {
      sum_ll <- sum_ll + ll[[j]]$ll
    # take the exp to obtain likelihood per draw (rows) and per observed time t (columns) for the i-th subject
    l <- exp(sum_ll)
    # take the mean of the likelihoods of each subject over the draws per observed time t
    l <- matrix(colMeans(l)) # this transposes the times to the rows
      id = ll[[1]]$.id,
      time = ll[[1]]$X[, time],
      likelihoods = l,
      row.names = NULL

# FIX: de likelihoods veranderen niet afhankelijk van de failure time...

#' Get Likelihood Fail
#' Get the parameter estimates from a MMM and combine with data to return contributions to the likelihood at observed time points for non failures
#' @param outcomes a character vector with the labels for the longitudinal outcomes.
#' @param data data.frame with original data in long format
#' @param fixed_formula_fail a character string with lme4 style formula for the fixed effects
#' @param random_formula_fail a character string with lme4 style formula for the random effects
#' @param time a character string with the label for the follow-up time variable
#' @param failure_time a character string with the label for the failure time variable
#' @param parameters_fail a data.frame with the parameter labels and estimates returned by mmm_model()
#' @param random_effects_fail The random effects sampled from a multivariate normal distribution (see MASS::mvrnorm)
#' @param id a character string with subject identifier
#' @param horizon the prediction horizon 
#' @param interval the intervals at which predictions are to be made
#' @param landmark the landmark time until which data is available.
#' @import future.apply
#' @import dplyr 
#' @import Matrix
#' @return A list with the likelihood contributions for each subject (rows) at every time point (cols).
get_likelihood_fail <- function (data,
                                 fixed_formula_fail = NULL,
                                 random_formula_fail = NULL,
                                 ...) {
  # Check if there are observations before the landmark for all subjects
  lm <- seq(from = landmark, to = horizon, by = interval)
  # cut the fixed and random formulas into parts
  ff <- str_remove(fixed_formula_fail, "~ ")
  ff <- str_split(ff, " \\+ ") %>% unlist()
  rf <- str_remove(random_formula_fail, "~ ")
  rf <- str_remove(rf, paste(" \\|", id))
  rf <- str_split(rf, " \\+ ") %>% unlist()
  # Select the parameter estimates for fixed effects
  params <- parameters_fail %>%
    dplyr::filter(grepl("^b", parameter_name)) %>%
    dplyr::select(parameter_name, parameter_estimate)
  likelihood_pattern <- lapply(lm, function(s, ...) {
  # s <- 1
  # use midpoint rule.
    .s <- (s + 0.5) 
    likelihood <- future_lapply(1:nrow(unique(data[id])), function(i, ...) {
    # i <- 1
    # Calculate the log-likelihoods for every outcome j
      ll <- lapply(1:nrow(outcomes), function (j) {
      # j <- 5
        # Get the fixed effect parameters
        betas <- params %>%
          dplyr::filter(grepl(paste0("_", outcomes[j, 1], "$"), parameter_name)) %>% 
          dplyr::select(parameter_estimate, parameter_name)
        # adjust the order of the parameter estimates to match the fixed_formula argument
        betas <- betas %>% 
          mutate(parameter_name = str_replace(parameter_name, "b0_", "intercept"))
        betas <- betas %>% 
          mutate(parameter_name = str_remove(parameter_name, paste0("_", outcomes[j, 1], "$")))
        betas <- betas %>% 
          mutate(parameter_name = str_remove(parameter_name, paste0("^(..|...)_")))
        betas <- betas[match(c("intercept", ff[-j]), betas$parameter_name), ] %>%
        betas <- as.matrix(betas)
        if(!assertthat::noNA(betas)) {
          stop("Not all parameters have matching estimates. Have you correctly specified the the fixed_formula_nofail argument?")
        # Get the residual
        resid <- parameters_fail %>%
          dplyr::filter(grepl(paste0("resid", paste0("_", outcomes[j, 1],"$")), parameter_name)) %>%
        resid <- as.numeric(resid)
        # get the log likelihood for the i-th subject
        # log_likelihood <- vector("list" , length = nrow(unique(data[id])))
        .id <- unlist(unique(data[id]))[i]
        # Get the outcome vector for subject i
        y <- data %>%
          dplyr::filter(get(id) == .id & time <= .s) %>%
          dplyr::select(all_of(outcomes[j, 1]))
        y <- as.matrix(y)
        # bX assumes similar models for all longitudinal outcomes. 
        .ff <- str_remove(ff, paste0("^", outcomes[j, 1], "$"))
        .ff <- .ff[.ff != ""]
        # Get the X matrix for subject i, select only data prior to the lm time 
        X <- data %>%
          dplyr::filter(get(id) == .id & time <= .s)
        X <- model.matrix(as.formula(paste("~ ", paste(.ff, collapse = " + "))), X)
        # replace actual failure time with failure time for the prediction
        if (failure_time %in% colnames(X)) {
          X[ , failure_time] <- .s
        # Get the linear predictor for the fixed effect
        xb <- X %*% betas
        # Get the Z matrix for subject i
        .rf <- str_remove(rf, paste0("^",outcomes[j, 1], "$"))
        .rf <- .rf[.rf != ""]
        # Select only data prior to lm time 
        Z <- data %>%
          dplyr::filter(get(id) == .id & time <= .s)
        Z <- model.matrix(as.formula(paste("~", paste(.rf, collapse = " + "))), Z)
        # replace actual failure time with failure time for the prediction
        if (failure_time %in% colnames(Z)) {
          Z[ , failure_time] <- .s  
        # log_lik is the log likelihood of the j-th outcome for the i-th subject at observed time points
        # for the k-th draw from the random effects.
        if (outcomes$outcome_type[j] == "continuous") {
          # for (k in 1:nrow(random_effect)) {
          log_lik <- lapply(1:nrow(random_effects_fail), function(k) {  
            # get the linear predictor = X*beta + Z*bu
            zb <- rowSums(Z * random_effects_fail[k, j])
            lp <- xb + zb
            # calculate the log-likelihood
            -log(resid) -0.5 * log(2 * pi) - ( (y - lp)^2 / (2 * resid^2) )
        } else if (outcomes$outcome_type[j] == "binary") {
          log_lik <- lapply(1:nrow(random_effects_fail), function(k) {
            # lp = X*beta + Z*bu
            zb <- rowSums(Z * random_effects_fail[k, j])
            lp <- xb + zb
            p <- exp(lp) / (1 + exp(lp))
            lik <- ifelse(y == 0, 1 - p, p)
            lik <- ifelse(lik > 1e-8, lik, 1e-10)
        # log likelihood for j-th outcome at each observed time point (columns) and all RE draws (rows) for i-th subject  
        list("ll" = t(do.call(cbind, log_lik)),
             ".id" = .id,
             "X" = X)
      # return the sum over the outcomes at each observed time point (columns) and all RE  draws (rows) for i-the subject
      sum_ll <- ll[[1]]$ll
      for (j in 2:length(ll)) {
        sum_ll <- sum_ll + ll[[j]]$ll
      # take the exp to obtain likelihood per draw (rows) and per observed time t (columns) for the i-th subject
      l <- exp(sum_ll)
      # take the mean of the likelihoods of each over the draws per observed time t
      l <- matrix(colMeans(l))
        id = ll[[1]]$.id,
        time = ll[[1]]$X[, time],
        likelihoods = l,
        pattern = s,
        row.names = NULL

#' Get likelihood profiles for non failures
#' Generate a step function of the likelihood at interval times.  
#' @param data data.frame with the data over which to predict.
#' @param likelihood_nofail a list of likelihoods for each subject at the observed time points returned by get_likelihood_nofail().
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 if monthly intervals if horizon is in years)
#' @import future.apply
#' @import dplyr
#' @return A list of length nrow(unique(id)) with likelihood profiles
get_likelihood_profiles_nofail <- function(likelihood_nofail,
                                           interval) {
  lm <- seq(from = landmark, to = horizon, by = interval)
  obs_l <- future_lapply(likelihood_nofail, function(x) {
    x %>% mutate(likelihoods = exp(cumsum(log(likelihoods))))
  # select the last row and append to build a data.frame of cumulative likelihoods per interval
  Xnew <- future_lapply(obs_l, function(l) {
    temp <- lapply(lm, function(t) {
      temp2 <- l %>% dplyr::filter(time <= t)
        if (nrow(temp2) > 0) {
          temp2 <- temp2 %>% 
            dplyr::filter(time == max(time)) %>% 
            dplyr::mutate(cutpoint1 = t)
      # else {
      #     temp2 <- data.frame(
      #       id = l$id[1],
      #       time = 0,
      #       likelihoods = 1, 
      #       cutpoint1 = t)
      #   }
    do.call(rbind, temp) %>% 
      arrange(id) %>% 
      dplyr::select(id, cutpoint1, likelihoods)

#' Get likelihood profiles for failure 
#'Generate a step function of the likelihood at interval times t.  
#' @param likelihood_fail a list of likelihoods for each subject at the observed time points
#' @param landmark the prediction landmark time.
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' @import future.apply
#' @import dplyr
#' @return A list of length (horizon - landmark)/interval of lists with nrow(unique(id)) with likelihood profiles 
get_likelihood_profiles_fail <- function(likelihood_fail,
                                         interval) {  
  lm <- seq(from = landmark, to = horizon, by = interval)
  # likelihood_fail is a list of lists with length (horizon - landmark + 1)/interval
  # Each of the lists contains a data.frame per patient that needs a cumsum over the likelihood taken.
  list_obs_l <- lapply(likelihood_fail, function(l) {
    future_lapply(seq_along(l), function (x) {
      l[[x]] %>%
        mutate(likelihoods = exp(cumsum(log(likelihoods))))
  # Combine the likelihood patterns for each patient
  obs_l <- lapply(list_obs_l, data.table::rbindlist) %>%
    do.call(rbind, .) %>%
    dplyr::filter(time < pattern) %>%
    split(., by = "id")
  # We now have the patterns per observed time. We need the patterns per possible landmark time (cutpoint 1).
  # dplyr::select the last row and append to build a data.frame of cumulative likelihoods per interval
  Xnew <- future_lapply(obs_l, function(l, ...) {
    temp <- lapply(lm, function(t) {
      temp2 <- l %>% 
          time <= t,
          pattern >= t
      if (nrow(temp2) > 0) {
        temp2 <- temp2 %>% 
          dplyr::filter(time == max(time)) %>% 
          mutate(cutpoint1 = t)
    do.call(rbind, temp) %>% 
      arrange(id) %>% 
      dplyr::select(id, cutpoint1, likelihoods, cutpoint2 = pattern)

#' Get priors
#' Call flexsurvreg to fit an unconditional survival model and return obtain priors 
#' @param data a data.frame with the data for which to predict in the long format (input should be the same as likelihood_profiles).
#' @param time_failure a character with the failure time variable name.
#' @param failure a character with the failure indicator variable name.
#' @param horizon a 1L int with the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' @param dist distribution parameter for the survreg()
#' @import flexsurv
#' @import dplyr
#' @return a priori survival and failure probabilities (1-survival) for every interval between min landmark and max horizon
get_priors <- function (data, time_failure, failure, horizon, interval, dist = "weibull") {  
  surv <- paste0("Surv(",time_failure, ",", failure, ")  ~ 1")
  sfit_wb <- flexsurvreg(as.formula(surv), data = data, dist = dist)
  # get the prior failure probabilities from the weibull model
  pct <- seq(0, 1, by = 0.001)
  pred_wb <- predict(sfit_wb,  type = "quantile", p = pct)$.pred[[1]]
  lm <- (0:(horizon*interval^-1+1))*interval
  prior <- sapply(lm, function (i) {
      pred_wb %>% 
        dplyr::filter(.pred <= i) %>% 
        summarise(prior_fail = max(.quantile))
  prior <- do.call(rbind, prior)
  prior <- data.frame(
    lm = lm,
    prior_fail = prior,
    prior_survival = 1 - prior,
    row.names = NULL
  prior <- prior %>% 
    mutate(prior_survival_interval = lead(prior_survival) / prior_survival,
           prior_fail_interval = 1 - prior_survival_interval) %>%
    dplyr::filter(lm <= horizon) %>%
    dplyr::select(lm, prior_survival_interval, prior_fail_interval)

#' Get predictions for non failure
#' Take a cumsum of the posteriors and return the cumulative incidence of the event between interval and horizon
#' @param likelihood_profiles_fail a list with the likelihood profiles for the failure model for each subject
#' @param prior a data frame with priors returned by get_priors()
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' @import dyplr
#' @return a list of dataframes with the predictions for each subject at every interval up to the horizon
get_predictions_nofail <- function(likelihood_profiles_nofail) {
  negloglik_nofail <- lapply(likelihood_profiles_nofail, setNames, c("id", "cutpoint1", "f_nofail"))

#' Get Predictions for Failure 
#' Take a cumsum of the posteriors and return the cumulative incidence of the event between interval and horizon
#'Implements equation 7.2 (p155) from Steffen's thesis.
#' @param likelihood_profiles_fail a list with the likelihood profiles for the failure model for each subject
#' @param prior a data frame with priors returned by get_priors()
#' @import dplyr
#' @return a list of dataframes with the predictions for each subject up to the horizon.
get_predictions_fail <- function (likelihood_profiles_fail, prior) {
  # Join the likelihood and priors data, i.e. 
  # f_i(y_i^<=k | F_i = k + 0.5) * P(k <= F_i <= k +1)
  predict_fail <- future_lapply(likelihood_profiles_fail, left_join, prior, by =  c("cutpoint1" = "lm"))
  predict_fail <- future_lapply(predict_fail, function(x) {
    x %>% 
      mutate(f_fail_interval = likelihoods * (1-prior_survival_interval)) %>% 
    arrange(id, cutpoint1, cutpoint2)})
  future_lapply(predict_fail, function(x) { 
    x %>% 
      group_by(cutpoint1) %>%
      mutate(sumf = cumsum(f_fail_interval),
           cumfprior = cumprod(prior_survival_interval),
           # Sigma_t=k^h [ f_i(y_i^<=k | F_i = k + 0.5) * P(k <= F_i <= k + 1)] / P(lm <= F_i <= h)
           f_fail = cumsum(f_fail_interval) / (1 - cumprod(prior_survival_interval))) %>%
      ungroup() %>%
      dplyr::select(id, cutpoint1, cutpoint2, f_fail)
#' Apply Bayes Rule to get posterior predictions
#' Use the posteriors for failures and non-failure to get predictions with Bayes Rule
#' @param f_fail density for the failures at landmarks upto horizon.
#' @param f_nofail density for the non-failures up to horizon.
#' @param landmark vector of landmark timepoints at which to make a prediction
#' @param horizon the time until predictions are made.
#' @import purrr
#' @import dplyr
#' @import future.apply
#' @returns: a list with posterior predictions at every landmark
get_posteriors <- function (f_fail, f_nofail, prior) {
  # Combine f_fail and f_nofail
  post <- map2(f_fail, f_nofail, right_join, by = c("id", "cutpoint1"))
  # get P(F>hor|F>=lm) from prior
  post <- future_lapply(post, function(x) {
    x %>% left_join(prior, by = c("cutpoint1" = "lm")) %>%
      group_by(cutpoint1) %>%
      mutate(prior_surv = cumprod(prior_survival_interval),
             prior_fail = 1 - prior_surv,
             post_fail = (f_fail * prior_fail) / (f_fail * prior_fail + f_nofail * prior_surv)) %>%
      ungroup() %>%
      dplyr::select(id, landmark = cutpoint1, horizon=cutpoint2, post_fail)

#' Predictions from the multivariate mixed model using a disciminant analysis framework.
#' Calculate the posterior predictions from the likelihood profiles and priors. Implements equation 7.1 (p154) from Steffen's thesis.
#' @param data a data.frame with the data for which to predict in the long format (input should be the same as likelihood_profiles)
#' @param outcomes character vector with the outcomes
#' @param id a character string with the subject id
#' @param fixed_formula_nofail formula for the fixed effects of the nofail model  
#' @param random_formula_nofail formula for the random effects of the nofail model,
#' @param random_effects_nofail random effect estimates for the no fail model (see MASS::mvnorm)
#' @param parameters_nofail parameter estimates from the non-failures model (see mmm_model)
#' @param fixed_formula_fail formula for the fixed effects of the fail model,
#' @param random_formula_fail formula for the random effects of the fail model,
#' @param random_effects_fail random effect estimates for the fail model (see MASS::mvnorm),
#' @param parameters_fail parameter estimates from the non-failures model (see mmm_model)
#' @param time a character string with the time variable name in data
#' @param failure a character string with the failure variable name in data
#' @param failure_time a character string with the failure time variable in data
#' @param landmark a numeric value indicating the prediction landmark until which data is available for prediction.
#' @param horizon the prediction horizon (maximum follow-up time)
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' @import future.apply
#' @import dplyr
#' @import flexsurv
#' @export
#' @return a data.frame with the posterior predictions for each subject at every landmark up to the horizon

mmm_predictions <- function( data,
                             ...) {
  n1 <- unique(data[id])
  data <- data %>% 
    dplyr::filter(all_of(time) <= landmark)
  n2 <- distinct(data[id])
  if (nrow(n1) != nrow(n2)) {
    dropped <- setdiff(n1, n2) %>% unlist
    message(paste("Dropped", length(dropped), "subjects without data prior to landmark"))
  if (nrow(n2) == 0) {
    stop("There are no observations prior to the landmark. Have you set the argmument `landmark=` correcty")
  likelihood_nofail <- get_likelihood_nofail(data = data,
                                             outcomes = outcomes,
                                             fixed_formula_nofail = fixed_formula_nofail,
                                             random_formula_nofail = random_formula_nofail,
                                             time = time,
                                             parameters_nofail = parameters_nofail,
                                             random_effects_nofail = random_effects_nofail,
                                             id = id,
                                             horizon = horizon)
  likelihood_fail <- get_likelihood_fail(data = data,
                                         outcomes = outcomes,
                                         fixed_formula_fail = fixed_formula_fail,
                                         random_formula_fail = random_formula_fail,
                                         time = time,
                                         failure_time = failure_time,
                                         parameters_fail = parameters_fail,
                                         random_effects_fail = random_effects_fail,
                                         id = id,
                                         landmark = landmark,
                                         interval = interval,
                                         horizon = horizon)
  likelihood_profiles_nofail <- get_likelihood_profiles_nofail(likelihood_nofail = likelihood_nofail,
                                                               landmark = landmark,
                                                               horizon = horizon, 
                                                               interval = interval) 
  likelihood_profiles_fail <- get_likelihood_profiles_fail(likelihood_fail = likelihood_fail,
                                                           landmark = landmark,
                                                           horizon = horizon,
                                                           interval = interval)
  f_nofail <- get_predictions_nofail(likelihood_profiles_nofail = likelihood_profiles_nofail)
  f_fail <- get_predictions_fail(likelihood_profiles_fail = likelihood_profiles_fail,
                                 prior = prior)
  predictions <- get_posteriors(prior = prior,
                                f_fail = f_fail,
                                f_nofail = f_nofail)
  predictions <- lapply(predictions, nest, data = -id)
  predictions <- lapply(predictions, setNames, c(id, "prediction"))
  predictions <- do.call(rbind, predictions)
  data_trajectory <- data %>%
    group_by(all_of(id)) %>%
    dplyr::filter(all_of(time) <= landmark) %>%
    ungroup() %>%
    dplyr::select(all_of(c(id, time, failure, failure_time, outcomes$outcome))) %>%
    nest(data = -id)
  predictions <- data_trajectory %>% left_join(predictions, by = id)
  if (!is.list(predictions)) { 
    predictions = list(predictions)

#' Calculate the Area under the Receiver Operating Characteristics Curve
#' Use the timeROC package to calculate the ROC-AUC based on predictions
#' from the mmm_predictions function.
#' @param predictions predictions data from the mmm_predictions function.
#' @param prediction_landmark a numeric vector with the landmark times to select predictions
#' @param prediction_horizon the prediction horizon *at* which to calculate the ROC-AUC.
#' @param id a character value for the name of the vector of subject identifiers in the predictions data
#' @param failure_time a character value for the name of the vector with failure times in the predictions data
#' @param predictions the predictions returned by the mmm_predictions function
#' @import timeROC
#' @import dplyr
#' @details Uses the timeROC package to determine the area under the receiver operating characteristics curve. 

get_auc <- function (
  ) {
    data <- predictions %>%  # This assumes that predictions are always stored in a list of data tables
      unnest(data) %>% 
      group_by(get(id)) %>% 
      dplyr::select(all_of(c(failure_time, failure))) %>%
      summarize(across(everything(),first)) %>%
    predict <- predictions %>%
      unnest(prediction) %>%
      group_by(get(id)) %>%
      dplyr::select(-data) %>%
      dplyr::filter(landmark == prediction_landmark,
             horizon == prediction_horizon) %>%
    time <- data %>% dplyr::select(all_of(failure_time)) %>% unlist()
    delta <- data %>% dplyr::select(all_of(failure)) %>% unlist()
    marker <- predict$post_fail
      msg = paste("The length of the vectors", failure_time, failure,    
                  "and predictions do not match.")
    roc <- timeROC(T = time,
                   delta = delta,
                   marker = marker,
                   cause = 1, 
                   weighting = "marginal",
                   times = prediction_horizon-0.001, # substract a small amount of time to get result
                   iid = TRUE) 
    roc <- cbind(roc$Stats, roc$AUC, roc$inference$vect_sd_1, roc$CumulativeIncidence)
    colnames(roc) <- c("failures", "survivors", "censored", "auc", "auc_se", "cuminc")
    roc <- as.data.frame(roc, row.names = FALSE) %>% mutate(landmark = prediction_landmark)
  roc %>% 
    dplyr::filter(!is.na(auc)) %>%
      mutate(horizon = prediction_horizon) %>%
    relocate(landmark, horizon, everything())

#' #' Plot calibration curve
#' #' 
#' #' Use riskRegression's Score.list() to determine calibration.
#' #' 
#' #' @param predictions Predictions from the MMM calculated with the mmm_predictions() function.
#' #' @param landmark The prediction landmark time.
#' #' @param horizon The prediction horizon time.
#' #' @param id The subject identifier in the predictions data.
#' #' 
#' #' @import riskRegression
#' #' @export
#' plot_calibration <- function(predictions,
#'                              landmark, 
#'                              horizon,
#'                              id) {
#'   data <- predictions[[landmark]] %>% 
#'     unnest(data) %>% 
#'     group_by(get(id)) %>% 
#'     dplyr::select(all_of(c(failure_time, failure))) %>%
#'     summarize_all(first)
#'   predict <- predictions[[landmark]] %>%
#'     unnest(prediction) %>%
#'     dplyr::select(-data) %>%
#'     dplyr::filter(time == horizon)
#'   score <- riskRegression::Score(list( "p_fail"=predict$post_fail ),
#'                                  formula=Surv(stime, failure) ~ 1,
#'                                  data=data, 
#'                                  conf.inf=FALSE,
#'                                  times=horizon,
#'                                  metrics=NULL, 
#'                                  plots="Calibration")
#'   plt_calibrate <- riskRegression::plotCalibration(score,
#'                                                    cens.method="local", 
#'                                                    round=TRUE,
#'                                                    rug=TRUE,
#'                                                    plot=FALSE)
#'   gplt_calibrate <- ggplot(plt_calibrate$plotFrames$p_fail, aes(x=Pred, y=Obs)) + 
#'     stat_smooth(method="loess", 
#'                 span=0.3,
#'                 color="black",
#'                 alpha=0.2) + 
#'     geom_abline(intercept=0, slope=1, alpha=0.5) +
#'     geom_rug(aes(x=Pred)) + 
#'     scale_y_continuous(limits=c(0, 1),
#'                        breaks=seq(0, 1, 0.2)) + 
#'     scale_x_continuous(limits=c(0, 1),
#'                        breaks=seq(0, 1, 0.2)) +
#'     labs(title="", 
#'          caption=paste("Predictions generated at", landmark[t],"years follow-up"),
#'          y="",
#'          x="") +
#'     theme_tufte(ticks=FALSE, 
#'                 base_family="WorkSans")
#' }

#' Plot predictions from the multivariate mixed model
#' For a given landmark time, plot the probability of failure conditional on survival 
#' and longitudinal biomarker evolutions.
#' @param predictions the predictions generated from mmm_predictions()
#' @param outcomes a character vector with the names of the longitudinal outcomes
#' @param id a character value with the name of the subject identifier
#' @param subject a character value with id for subject of interest
#' @param landmark a scalar value for the landmark at which the prediction is plotted
#' @import ggplot2
#' @import gridExtra
#' @import dplyr
#' @import scales
#' @return plots for survival conditional on the longitudinal biomarker trajectory.
plot_predictions <- function (predictions,
                              time) {
  # FIX: Return marker data to the orginal scale
  # check input
  if((length(subject) == 1) == FALSE) {
    stop("Set `subject` to select one subjects")  
  # detect binary outcomes
  is_binary <- sapply(outcomes, function(x, ...) {
      predictions %>% 
        unnest(data) %>% 
        distinct(get(x)) %>%
        unlist() %>%
        sort() == c(0, 1)
  # Apply filter at subject level
  # if (is.null(subject)) { 
    # # Select marker data
    # markers <- predictions %>% 
    #   dplyr::select(all_of(id), data) %>%
    #   unnest(data) 
    # # Select prediction data
    # pred <- predictions %>% 
    #   dplyr::select(all_of(id), prediction) %>%
    #   unnest(prediction)
  # } else { 
    # Select marker data
    markers <- predictions %>%
      dplyr::select(all_of(id), data) %>%
      filter(get(id) == subject) %>%
    # Select prediction data
    pred <- predictions %>%
      dplyr::select(all_of(id), prediction) %>%
      filter(get(id) == subject) %>%
  # }
  # set-up time axis
  if (nrow(markers) == 0) {
    max_x <- 0
  } else {
    max_x <- markers %>%
      dplyr::select(all_of(time)) %>%
  # Pivot the data for plotting
  markers_trajectory <- markers %>%
    pivot_longer(cols = outcomes, names_to ="outcome")
  markers_trajectory <- markers_trajectory %>% 
    mutate(is_binary = is_binary[markers_trajectory$outcome])
  # Filter out a single subject
  # subjects <- markers_trajectory %>% 
  #   dplyr::select(all_of(id)) %>% 
  #   unique() %>% 
  #   unlist()
  # If binary == TRUE
  binary_plot <- markers_trajectory %>% 
        is_binary == TRUE
    ) %>% 
    ggplot(aes(x=time, y=value, col=outcome)) +
    geom_point() + 
    geom_line(alpha=0.8) +
                       limits=c(0, ceiling(max_x)),
                       breaks=scales::breaks_width(0.5)) +
      name = "longitudinal outcome") +
    theme_bw() +  # FIX: Tufte theme later
    theme(legend.position = "none") +
  # If binary == FALSE
  continuous_plot <- markers_trajectory %>% 
        is_binary == FALSE
    ) %>%  
    ggplot(aes(x=time, y=value, col=outcome)) + 
    geom_point() + 
    geom_line(alpha = 0.8) + 
    scale_x_continuous(name = "",
                       limits = c(0, ceiling(max_x)),
                       breaks = scales::breaks_width(0.5)
    ) + 
    scale_y_continuous(name = "longitudinal outcome") +
    theme_bw() +  # FIX: Tufte theme later
    theme(legend.position = "none") +
  # Plot the predictions
  surv_plot <- pred %>% 
    filter(landmark==min(landmark)) %>% 
    ggplot(aes(x=horizon,y=post_fail)) + 
    geom_line(alpha = 0.8)  +
    scale_x_continuous(name = "",
                       breaks = scales::breaks_width(1)) +
    scale_y_continuous(name = "Probability of graft failure",
                       limits = c(0, 1),
                       breaks = scales::breaks_width(0.2),
                       position = "right") +
    theme_bw() # FIX: theme Tufte
  # Combine plots and return
    grobs=list(binary_plot, continuous_plot, surv_plot),

#' Plot the calibration for MMM
#' Create a grobs with calibration plot for a given landmark and horizon.
#' @param predictions predictions data from the mmm_predictions function.
#' @param prediction_landmark a numeric vector with the landmark times to select predictions
#' @param prediction_horizon the prediction horizon *at* which to calculate the ROC-AUC.
#' @param id a character value for the name of the vector of subject identifiers in the predictions data
#' @param failure_time a character value for the name of the vector with failure times in the predictions data
#' @param predictions the predictions returned by the mmm_predictions function
#' @import ggplot2
#' @import ggthemes
#' @import gridExtra
#' @import dplyr
#' @import scales
#' @import riskRegression
#' @return a list of grobs
#' @export
plot_calibration <- function(
  ) {
  data <- predictions %>%  # This assumes that predictions are always stored in a list of data tables
    unnest(data) %>% 
    group_by(get(id)) %>% 
    dplyr::select(all_of(c(failure_time, failure))) %>%
    summarize(across(everything(),first)) %>%
  predict <- predictions %>%
    unnest(prediction) %>%
    group_by(get(id)) %>%
    dplyr::select(-data) %>%
    dplyr::filter(landmark == prediction_landmark,
                  horizon == prediction_horizon) %>%
  score <- riskRegression::Score(
    list( "p_fail"=predict$post_fail ),
    formula=Surv(stime, failure) ~ 1,
  plt_calibrate <- riskRegression::plotCalibration(
  plt_calibrate$plotFrames$p_fail %>% 
