R/mmm_functions.R

Defines functions get_likelihood_profiles_fail get_likelihood_profiles_nofail get_outcome_type get_model_output get_covariance average_hessians get_random_start make_random_formula make_fixed_formula make_pairs make_spaghetti_plots make_density make_histogram make_bar_plot

# Doc header --------------------------------------------------------------

# Pairwise binomial mixed models
# Jan van den Brand, PhD
# jan.vandenbrand@kuleuven.be
# janvandenbrand@gmail.com
# Project: DKF19OK003

# The work here is based on The work is based on Fieuws S. Mixed Models for Longitudinal Data 2006

# Preliminaries  ------------------------------------------------------

packages <- c("optimx", "ggplot2","lattice","MASS", "Matrix", "gridExtra", "tidyverse",
              "GLMMadaptive", "stringr", "flexsurv", "future.apply", "progressr", "viridis")
lapply(packages, library, character.only = TRUE)


# Data visualization ------------------------------------------------------
#' Bar charts for categorical variables
#' 
#' Make bar charts categorical variables in a data.frame
#' 
#' @param data A data.frame
#' @param var A character vector with the names of the categorical variables in data.frame.
#' @param by A character value giving the name of the variable to group by (e.g. event status)
#' 
#' @import ggplot2
#' @import gridExtra
#' 
#' @export
#' @examples 
#' make_bar_plot(data = cars, var = mpg, by = cil)
#' 
make_bar_plot <- function(data, var, by = NULL) {
  if (is.null(by)) {
    bar_plot <- ggplot(data = data) +
      geom_bar(aes(x = get(var), fill = get(var))) +
      theme_minimal() +
      labs(x = "", y = "") +
      guides(fill = guide_legend(title = paste(var)))
    bar_plot
  } else {
    stopifnot("by should be a character or NULL"=is.character(by))
    bar_plot <- ggplot(data = data) +
      geom_bar(aes(x = get(var), fill = get(var))) +
      theme_minimal() +
      labs(x = "", y = "") +
      guides(fill = guide_legend(title = paste(var))) +
      facet_grid(cols = vars(get(by)))
    bar_plot
  }
}

#' Histograms for continuous variables
#' 
#' Make histograms for continuous variable in your data.frame
#' 
#' @param data A data.frame
#' @param var a character vector with the names of the numeric variables in data.frame.
#' 
#' @import ggplot2
#' @import gridExtra
#' 
#' @export
make_histogram <- function(data, var) {
  plot <- ggplot() +
    geom_histogram(data = data, aes(x = get(var)), bins = 20) +
    theme_classic() +
    labs(x = paste("histogram of", var))
  plot
}

#' Density plots for continuous variables
#' 
#' Make density plots for continuous variable in your data.frame with option
#' to show data by outcome
#' 
#' @param data A data.frame
#' @param var a character vector with the names of the numeric variables in data.frame.
#' @param by A character value giving the name of the variable to group by (e.g. event status)
#' 
#' @import ggplot2
#' @import gridExtra
#' 
#' @export
make_density <- function(data, var, by = NULL) {
  if (is.null(by)) {
    plot <- ggplot() +
      ggplot() +
      geom_density(data = data, aes(x = var), alpha = 0.5) +
      theme_classic() +
      labs(x = paste("density plot of", var))
  } else {
    stopifnot("by should be a character or NULL"=is.character(by))
    plot <- ggplot() +
      geom_density(data = data, aes(x = get(var), fill = get(by)), alpha = 0.5) +
      theme_classic() +
      labs(x = paste("density plot of", var))
  }
}


#' Spaghetti plots for longitudinal variables
#' 
#' Make spaghetti plots for longintudinal variables in a data.frame.
#' 
#' @param data A data.frame in long format.
#' @param time A character value giving the name of the time variable.
#' @param var A character value giving the name of the longitudinal variable.
#' @param id A character value giving the name of the variable with the subject identifier.
#' @param by A character value giving the name of the variable to group the trajectories by (e.g. event status)
#' @param breaks A numeric value with the major tick mark for the x-axis.
#' 
#' @import dplyr
#' @import ggplot2
#' @import gridExtra
make_spaghetti_plots <- function(data, ftime, var, id, by = NULL, breaks = 1) {
  if (is.null(by) ) {
    plot <- ggplot(data = data, 
                   aes(x = get(ftime), y = get(var), group = get(id))) +
      geom_jitter(size = 0.2, alpha = 0.3, width = 0.01, height = 0.01) +
      geom_line(size = 0.1, alpha = 0.2) +
      geom_smooth(data = data, aes(x = get(ftime), y = get(var)), method = 'loess') +
      scale_x_continuous(breaks = scales::breaks_width(breaks)) +
      theme_classic() +
      labs(x = "Follow-up time", y = paste(var)) 
  } else {
    plot <- ggplot(data = data, 
                   aes(x = get(ftime), y = get(var), group = get(id), color = as.factor(get(by)))) +
      geom_jitter(size = 0.2, alpha = 0.3, width = 0.01, height = 0.01) +
      geom_line(size = 0.1, alpha = 0.2) +
      geom_smooth(data = data, 
                  aes(x = get(ftime), y = get(var), color = as.factor(get(by))), method = 'loess') +
      scale_x_continuous(breaks = scales::breaks_width(breaks)) +
      theme_classic() +
      labs(x = "Follow-up time", 
           y = paste(var),
           color = by) + 
      theme(legend.position="none")
  }
  plot
}

# Data checking and wrangling ------------------------------------------------

#' Make pairs
#' 
#' Take a character with the names of the longitudinal outcome variables and make all unique pairs
#' 
#' @param outcomes a character vector with the names of longitudinal outcomes
#' 
#' @returns a matrix of length(outomes)*(length(outcomes)-1)/2 rows and 2 columns with all unique pairs.  
make_pairs <- function(outcomes) {
  N <- length(outcomes)
  Npair <- N * (N-1)/2
  pairs <- matrix(data = 0, nrow = Npair, ncol = 2)
  k <- 1
  for (i in 1:(N-1)) {
    for (j in (i+1):N) {
      pairs[k,1] <- outcomes[i]
      pairs[k,2] <- outcomes[j]
      k <- k + 1
    }
  }
  pairs
}


#' Test input data types and return model families and indicator for binary outcomes
#' 
#' Takes a data.frame with longitudinal data in long format and checks if the longitudinal data are binary or numeric.
#' 
#' @param data A data.frame in long format
#' @param pairs A character matrix constructed by make_pairs function.
#'
#' @return A list of length nrow(pairs) that includes 
#' \describe{
#'   \item{families}{character vector with the custom model family.}
#'   \item{indicator}{a character vector with the indicator for the binary outcome if family = binary.normal}
#' }
test_input_datatypes <- function (data, pairs) {
  
  # initialize return
  .families <- vector("character", length = nrow(pairs))
  .indicators <- vector("list", length = nrow(pairs))
  # Function body
  for (i in 1:nrow(pairs)) {
    # Assert that y1 is dichotomous
    if (all(sort(unlist(unique(data[paste(pairs[i, 1])]))) == c(0, 1), na.rm = TRUE)) {
      # Assert that y2 is  dichotomous
      if (all(sort(unlist(unique(data[paste(pairs[i, 2])]))) == c(0, 1), na.rm = TRUE)) {
        .families[i] <- "binomial"
        # Assert that y2 is numeric
      } else if (class(unlist(data[paste(pairs[i, 2])])) == "numeric") {
        .families[i] <- "binary.normal"
        .indicators[[i]] <- c(1, 0)
      }
      #  Assert that y1 is numeric
    } else if (class(unlist(data[paste(pairs[i, 1])])) == "numeric") {
      # Assert that y2 is dichotomous
      if (all(sort(unlist(unique(data[paste(pairs[i, 2])]))) == c(0, 1), na.rm = TRUE)) {
        .families[i] <- "binary.normal"
        .indicators[[i]] <- c(0, 1)
        ## Assert that y2 is numeric
      } else if (class(unlist(data[paste(pairs[i, 2])])) == "numeric")  {
        .families[i] <- "normal"
      }
    } else {
      stop("The outcomes variables do not seem to be binary or continuous")
    }
  }
  list("families" = .families, "indicators" = .indicators)
}

#' Return a list with stacked data sets for each of the outcome pairs
#' 
#' Takes a data.frame in with longitudinal data to be used in the multivariate mixed model function.
#' Uses the pairwise combination created with the make_pairs() function to unroll the outcome matrix into a vector.
#' The longitudinal predictor and time invariant covariates are processed accordingly.
#' 
#' @param data A data.frame with the longitudinal data in long format.
#' @param pairs A character matrix with the pairs, returned from the make_pairs function
#' @param id A character value with the variable name of the subject identifier.
#' @param covars (optional) a character vector with the names of the covariate vectors.
#' 
#' @import Matrix
#' @import dplyr
#' 
#' @export
#' 
#' @return a list of data.frames of length nrow(pairs).

stack_data <- function (data, id, pairs, covars = NULL) {
  stopifnot("Argument 'id' is missing. Use it to set subject level id." = is.character(id))
  stopifnot("'id' should refer to a character vector" = is.character(unlist(data[id])))
  if (!is.null(covars)) {
    stopifnot("'covars' should be a character vector." = is.character(covars))
  }
  if (!is.null(covars)) {
    # covars should be numeric
    if (all(sapply(data %>% dplyr::select(all_of(covars)), class) == "numeric") == FALSE) {
      stop("stack_data can only handle numeric data: '", 
           paste(covars[which(sapply(data %>% dplyr::select(all_of(covars)), class) != "numeric")]), 
           "' is not numeric.")
    }
  }
  ids <- rep(unlist(data[id]), each = 2)
  stacked_data <- lapply(1:nrow(pairs), function(i, ...) {
  # i <- 1
  # Unroll Y matrices into vectors
    Y <- data %>%
      dplyr::select(y1 = paste(pairs[i, 1]), y2 = paste(pairs[i, 2]))
    Y <- as.vector(t(as.matrix(Y)))
    Y <- data.frame(id = ids,
                    Y = Y)
    # Create design matrices
    ## The ordering of x1 and x2 is intentional!
    X1 <- data %>%
      dplyr::select(x1 = paste(pairs[i,2]), x2 = paste(pairs[i,1]))
    X1 <- as.vector(t(as.matrix(X1)))
    ## add fixed covariates and time to the design matrices
    X2 <- data %>%
      dplyr::select(all_of(covars))
    X2 <- as.matrix(X2)
    ## add dummies for X
    intercepts <- kronecker(rep(1, times = nrow(data)), diag(1, nrow = 2))
    X2k <- kronecker(X2, diag(1, nrow = 2))
    X2k <- cbind(intercepts, X2k)
    ## join the data
    .stacked_data <- cbind(Y, X1, X2k)
    ## set variable names
    var_names <- c(id, "Y", "X", "intercept_Y1", "intercept_Y2")
    if (!is.null(covars)) {
      .covariate_names <- sapply(covars, function(j) {
        c(paste0(j,"_Y1"), paste0(j,"_Y2"))
      })
      var_names <- c(var_names, paste(.covariate_names))
    } 
    names(.stacked_data) <- var_names
    .stacked_data
  })
  stacked_data
}

#' Test to compare stacked outcome data to original outcome data
#' 
#' 
#' 
#' @param data data.frame with original data in long format
#' @param stacked_datathe list of data.frames returned by stack_data()
#' @param pairs A character matrix with the pairs, returned from the make_pairs function
#' 
#' @import dplyr
#' 
#' @return Prints a list with assertions. All should return TRUE. Otherwise something went wrong.
test_compare_stacked_to_original_data <- function (data, stacked_data, pairs) {
  for (i in 1:nrow(pairs)) {
    current <- unlist(data[paste(pairs[i,1])])
    target <- stacked_data[[i]] %>%
      dplyr::filter(intercept_Y1 == 1) %>%
      dplyr::select(Y)
    check <- cbind(current, target)
    print(paste('Assert: outcome vector matches the original data for', pairs[i, 1], '=',
                all.equal(check[,1], check[,2], na.rm = TRUE, check.attributes = FALSE, use.names = FALSE))
    )
    current <- unlist(data[paste(pairs[i,2])])
    target <- stacked_data[[i]] %>%
      dplyr::filter(intercept_Y1 == 1) %>%
      dplyr::select(X)
    check <- cbind(current, target)
    print(paste('Assert: outcome vector matches the original data for', pairs[i, 2], '=',
                all.equal(check[,1], check[,2], na.rm = TRUE, check.attributes = FALSE, use.names = FALSE))
    )
  }
}

# Fitting the MMM ------------------------------------------

binary.normal <- function (indicator = stop("'indicator must be specified'")) {
  .indicator <- indicator
  env <- new.env(parent = .GlobalEnv)
  assign(".indicator", indicator, envir = env)
  stats <- make.link("identity")
  log_dens <- function (y, eta, mu_fun, phis, eta_zi) {
    ind <- rep(.indicator, times = length(y)/2)
    eta <- as.matrix(eta)
    out <- eta
    # linear mixed model
    sigma <- exp(phis)
    out[ind == 0, ] <- dnorm(y[ind == 0], eta[ind == 0, ], sigma, log = TRUE)
    # mixed logistic model
    probs <- exp(eta[ind == 1, ]) / (1 + exp(eta[ind == 1, ]))
    out[ind == 1,] <- dbinom(y[ind == 1], 1, probs, log = TRUE)
    out
  }
  structure(list(family = "user binary normal", link = stats$name, linkfun = stats$linkfun,
                 linkinv = stats$linkinv, log_dens = log_dens),
            class = "family")
}

normal <- function () {
  env <- new.env(parent = .GlobalEnv)
  stats <- make.link("identity")
  log_dens <- function (y, eta, mu_fun, phis, eta_zi) {
    eta <- as.matrix(eta)
    out <- eta
    # linear mixed models
    sigma <- exp(phis)
    out <- dnorm(y , eta, sigma, log = TRUE)
    out
  }
  # score_eta_fun <- function (y, mu, phis, eta_zi) {
  #   # the derivative of the log density w.r.t. mu
  #   sigma2 <- exp(phis)^2
  #   y_mu <- y - mu
  #   y_mu / sigma2
  # }
  # score_phis_fun <- function (y, mu, phis, eta_zi) {
  #   sigma <- exp(phis)
  #   y_mu <- y - mu
  #   sigma^{-2} / (1 + y_mu / sigma^2) - 1
  # }
  # environment(log_dens) <- environment(score_eta_fun) <- environment(score_phis_fun) <- env
  structure(list(family = "user defined normal" , link = stats$name, linkfun = stats$linkfun,
                 linkinv = stats$linkinv, log_dens = log_dens),
            class = "family")
}

#' Generate fixed formula
#' 
#' Takes a list of covariates. Each element of the list are the covariates for a pairwise model
#' 
#' @param covars a list of character vectors with the covariates names
#' 
#' @export
#' 
#' @return a list with the fixed formulas for the pairwise model fitting stage.
make_fixed_formula <- function(covars = NULL, ...) {
    if (!is.null(covars)) {
      covars <- c(sapply(covars, function(x) paste(c(x, "Y1"), collapse = "_")),
                  sapply(covars, function(x) paste(c(x, "Y2"), collapse = "_")))
      intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
      x <- sapply(intercepts, function(x) paste(c(x,"X"), collapse = ":"))
      paste("Y ~ -1 +", paste(c(intercepts, x, covars), collapse = " + "))
    } else {
      intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
      x <- sapply(intercepts, function(x) paste(c(x,"X"), collapse = ":"))
      paste("Y ~ -1 +", paste(c(intercepts, x), collapse = " + "))
    }
}

#' Generate random effects formula
#' 
#' Takes a list of covariates for the random slopes. 
#' Each element of the list are the random slopes for a pairwise model
#' 
#' @param covars a list of character vectors with the covariates names
#' 
#' @export
#' 
#' @return a list with the fixed formulas for the pairwise model fitting stage.
make_random_formula <- function(id, covars = NULL, ...) {
  if (!is.null(covars)) {
    covars <- c(sapply(covars, function(x) paste(c(x, "Y1"), collapse = "_")),
                sapply(covars, function(x) paste(c(x, "Y2"), collapse = "_")))
    intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
    paste(paste("~ -1 +", paste(c(intercepts, covars), collapse = " + ")), "|", id)
    
  } else {
    intercepts <- sapply(c("Y1", "Y2"), function(x) paste(c("intercept",x), collapse = "_"))
    paste(paste("~ -1 +", paste(c(intercepts), collapse = " + ")), "|", id)
  }
}


#' Create random start values
#' 
#' Take a model info object created by test_input_datatypes and add random start values
#' 
#' @param model_families a list with model inf returned by test_input_datatypes
#' @param n_fixed the number of fixed parameters
#' @param n_random the number of random parameters
#' 
#' @import Matrix
#' 
#' @return a list with model information including:
#' \describe{
#' \item{The model family.}
#' \item{An indicator for the binary outomce in the binary.normal family.}
#' \item{Starting values for the model fitting stage.}
#' }
#' @export
get_random_start <- function(model_families, n_fixed, n_random) {
  betas <- rnorm(n_fixed, mean = 0, sd = 0.33)
  D <- as.matrix(
     nearPD(
        matrix(runif(n_random^2), ncol = n_random, nrow = n_random)
      )$mat
    )
  phis <- rnorm(2, mean = 0, sd = 0.33)
  model_families$start_values <- lapply(1:length(model_families$families), function(i) {
    if (model_families$families[[i]] == "binomial") {
      list("betas" = betas, 
           "D" = D
      )
    } else if (model_families$families[[i]] == "binary.normal") {
      list("betas" = betas,
           "D" = D,
           "phis" = phis[1]
      )
    } else if (model_families $families[[i]] == "normal") {
      list("betas" = betas,
           "D" = D,
           "phis" = phis
      )
    }
  })
  model_families
}


# Fit a mixed_model for all the frames in stacked data and store the starting values
# optimParallel returns error:
# Error in FGgenerator(par = par, fn = fn, gr = gr, ..., lower = lower,  : is.vector(par) is not TRUE
# Therefore use optim.
# The model fitting within a pair is sequential, since optimParallel throws an error. 
# This is a limitation when the number of pairs is less than the number of available workers.
# FIX: add a check for the formula inputs. If a '+' sign is forgotten a opaque error message is returned.

#' Fit pairwise mixed models to get start values for the multivariate mixed model
#' 
#' Takes the stacked data and the user defined formulas for the fixed and random effects. Leverages the future.apply package to parallelize along pairs.
#' Using a different number of workers than pairs may result in problems during the optimization process. 
#' 
#' @param stacked_data a list with the stacked data returned by the stack_data() function.
#' @param fixed character string with the fixed model part. This is passed to as.formula().
#' @param random character string with the random model part. This is passed to as.formula().
#' @param pairs a character matrix with the pairs returned by the make_pairs() function.
#' @param model_families a list with family names and indicators returned by the test_input_datatypes() function.
#'
#' @import future.apply
#' @import GLMMAdaptive
#'
#' @return a list with starting values to determine the subject level contributions to the derivates.

get_mmm_start_values <- function (stacked_data, 
                                  fixed, 
                                  random, 
                                  pairs, 
                                  model_families,
                                  user_initial_values = NULL,
                                  nAGQ = 11, 
                                  iter_EM = 30, 
                                  iter_qN_outer = 10, 
                                  tol1=1e-3,
                                  tol2=1e-4,
                                  tol3=1e-8,
                                  ...) {
  prog <- progressr::progressor(along = pairs)
  out <- future_lapply(1:nrow(pairs), function(i, ...) {
    prog(sprintf("Fitting pairwise model number %g of %g.", i, nrow(pairs)), 
         class = "sticky",
         amount = 0)
    if (model_families$families[i] == "binary.normal") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = binary.normal(indicator = model_families$indicators[[i]]),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         n_phis = 1,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
                                        tol1=tol1,
                                        tol2=tol2,
                                        tol3=tol3)
      )
      starting_values <- list(betas = fixef(fit), phis = fit$phis, D = fit$D)
    } else if (model_families$families[i] == "normal") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = normal(),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         n_phis = 2,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
                                        tol1=tol1,
                                        tol2=tol2,
                                        tol3=tol3)
      )
      starting_values <- list(betas = fixef(fit), phis = fit$phis, D = fit$D)
    } else if (model_families$families[i] == "binomial") {
      fit <- mixed_model(fixed = as.formula(fixed),
                         random = as.formula(random),
                         data = stacked_data[[i]],
                         family = binomial(),
                         initial_values = user_initial_values[[i]], 
                         nAGQ = nAGQ,
                         control = list(optimizer = "optim",
                                        optim_method = "BFGS",
                                        numeric_deriv = "fd",
                                        iter_EM = iter_EM,
                                        iter_qN_outer = iter_qN_outer,
                                        tol1=tol1,
                                        tol2=tol2,
                                        tol3=tol3)
      )
      starting_values <- list(betas = fixef(fit), D = fit$D)
    } else {
      stop("Model family not found. Has a list with family for every model been provided?") # Fix: this should be evaluated before fitting is started.
    }
    list("fit" = fit, "starting_values" = starting_values)  
  })
  out
}

#' Determine the subject level derivatives.
#' 
#' Run mixed models without iterating and return subject level derivatives
#'
#' This function uses the start values determined in by the function get_start_values()  
#' to determine the subject level contributions to the derivatives. In turn these are used 
#' to average the parameters and standard errors for the full multivariate mixed model.
#'
#' @param id The name of the column with subject ids in stacked_data
#' @param stacked_data A list of data.frames returned by the stack_data() function
#' @param fixed A formula (as character string) for the fixed part of the mixed model. Passed to as.formula().
#' @param random A formula (as character string) for the random part of the mixed model. Passed to as.formula().
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param model_families A list with the model families returned by the test_input_datatypes() function.
#' @param start_values A list (length == length(stacked_data)) with start values returned by the get_start_values() function.
#'
#' @import future.apply
#' @import GLMMAdaptive
#' @import dplyr
#'
#' @return A list of length nrow(pairs) each with two elements:
#' \describe{
#' \item{df_hessians}{a list with data.frames storing the subject level Hessians}
#' \item{df_gradients}{a list with data.frames storing the subject level gradients}
#' }

get_mmm_derivatives <- function (stacked_data, id, fixed, random, pairs, model_families, start_values, nAGQ = 11) {
  stopifnot("id should be a character" = is.character(id))
  prog <- progressr::progressor(along = pairs)
  derivatives <- lapply(1:nrow(pairs), function(i, ...) {
    # message("Getting derivatives for pairwise model: ",i, "out of ", nrow(pairs))
    prog(sprintf("Getting derivatives for pairwise model %g out of %g", i, nrow(pairs)),
         class = "sticky",
         amount = 0)
    subject_out <- future_lapply(1:nrow(unique(stacked_data[[i]][id])), function(j, ...) {
      subject_data <- stacked_data[[i]][stacked_data[[i]][id] == unique(stacked_data[[i]][id])[j,1],]
      # Data copy to work around optim requiring n > 1 subjects
      subject_data_copy <- subject_data
      subject_data_copy[,id] <- paste0(j,"a")
      subject_data <- rbind(subject_data, subject_data_copy)
      if (model_families$families[i] == "binary.normal") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = binary.normal(indicator = model_families$indicators[[i]]),
                           n_phis = 1,
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        )
        # Store Hessian (undoing the work around)
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
                   fit$Hessian/2)
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
               colSums(fit$score_vect_contributions$score.betas/2),
               colSums(fit$score_vect_contributions$score.D/2),
               colSums(fit$score_vect_contributions$score.phis/2))
      } else if (model_families$families[i] == "normal") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = normal(),
                           n_phis = 2,
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        )
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
                   fit$Hessian/2)
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
               colSums(fit$score_vect_contributions$score.betas/2),
               colSums(fit$score_vect_contributions$score.D/2),
               colSums(fit$score_vect_contributions$score.phis/2))
      } else if (model_families$families[i] == "binomial") {
        fit <- mixed_model(fixed = as.formula(fixed),
                           random = as.formula(random),
                           data = subject_data,
                           family = binomial(),
                           nAGQ = nAGQ,
                           initial_values = start_values[[i]]$starting_values,
                           control = list(optimizer = "optim",
                                          optim_method = "BFGS",
                                          numeric_deriv = "fd",
                                          iter_EM = 0,
                                          iter_qN_outer = 0)
        )
        h <- cbind("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
                   fit$Hessian/2)
        g <- c("id" = as.numeric(unique(stacked_data[[i]][id])[j, 1]),
               colSums(fit$score_vect_contributions$score.betas/2),
               colSums(fit$score_vect_contributions$score.D/2))
      }
      names(g) <- colnames(h)
      list("hessians" = h,
           "gradients" = g)
    })
    list(
      "df_hessians" = do.call(rbind, lapply(subject_out, function(l) l$hessians)),
      "df_gradients" = do.call(rbind, lapply(subject_out, function(l) l$gradients))
    )
  })
  derivatives
}

#' Take the parameter estimates, add labels and stack them into a data.frame
#' 
#' Extract the parameter estimates from the start_values and label them back to the names of the original data.
#' 
#' @param fixed A formula (as character string) for the fixed part of the mixed model. Passed to as.formula().
#' @param random A formula (as character string) for the random part of the mixed model. Passed to as.formula().
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param start_values The list of start_values (i.e. parameter estimates) returned by the get_start_values() function.
#' 
#' @return A data.frame with the labelled parameter estimates.

stack_parameters <- function (start_values, pairs, model_families) {
  parameters <- lapply(1:nrow(pairs), function(i) {
    names_fixed <- names(start_values[[i]]$starting_values$betas)
    .parameter_label_y1 <- names_fixed[str_detect(names_fixed, "Y1")]
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "intercept", "")
    if (any(str_detect(.parameter_label_y1, ":X")) == TRUE) {
      .parameter_label_y1[str_detect(.parameter_label_y1, "X")] <- 
        paste(rev(unlist(str_split(.parameter_label_y1[str_detect(.parameter_label_y1, "X")], ":"))), collapse = "") 
    } else {
    .parameter_label_y1[str_detect(.parameter_label_y1, "X")] <- 
      paste(unlist(str_split(.parameter_label_y1[str_detect(.parameter_label_y1, "X")], ":")), collapse = "") 
    }
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "X", pairs[i, 2])
    .parameter_label_y1 <- str_replace(.parameter_label_y1, "Y1", pairs[i, 1])
    .parameter_label_y1 <- sapply(1:length(.parameter_label_y1), function(x) paste0("b", x - 1, "_", .parameter_label_y1[x]))
    .parameter_value_y1 <- start_values[[i]]$starting_values$betas[str_detect(names_fixed, "Y1")]
    .parameter_label_y2 <- names_fixed[str_detect(names_fixed, "Y2")]
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "intercept", "")
    if (any(str_detect(.parameter_label_y2, pattern = ":X")) == TRUE) {
    .parameter_label_y2[str_detect(.parameter_label_y2, "X")] <-
      paste(rev(unlist(str_split(.parameter_label_y2[str_detect(.parameter_label_y2, "X")], ":"))), collapse = "")
    } else {
    .parameter_label_y2[str_detect(.parameter_label_y2, "X")] <-
      paste(unlist(str_split(.parameter_label_y2[str_detect(.parameter_label_y2, "X")], ":")), collapse = "")
    }
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "X", pairs[i, 1])
    .parameter_label_y2 <- str_replace(.parameter_label_y2, "Y2", pairs[i, 2])
    .parameter_label_y2 <- sapply(1:length(.parameter_label_y2), function(x) paste0("b", x - 1, "_", .parameter_label_y2[x]))
    .parameter_value_y2 <-  start_values[[i]]$starting_values$betas[str_detect(names_fixed, "Y2")]
    fixed <- data.frame(
      parameter_label = c(.parameter_label_y1, .parameter_label_y2),
      parameter_estimate = c(.parameter_value_y1, .parameter_value_y2), 
      row.names = NULL
    )
    # Set names and labels for and get values for random effects
    labels_random <- levels(interaction(
      rownames(start_values[[i]]$starting_values$D),
      colnames(start_values[[i]]$starting_values$D),
      sep = "_x_"))
    labels_random <- t(matrix(labels_random, 
                              ncol = ncol(start_values[[i]]$starting_values$D), 
                              byrow = FALSE))
    diag(labels_random) <- paste0("var_",diag(labels_random))
    labels_random[lower.tri(labels_random)] <- paste0("cov_", labels_random[lower.tri(labels_random)])
    labels_random <- labels_random[lower.tri(labels_random, diag = TRUE)]
    labels_random <- str_replace_all(labels_random, "Y1", pairs[i,1])
    labels_random <- str_replace_all(labels_random, "Y2", pairs[i,2])
    values_random <- start_values[[i]]$starting_values$D[lower.tri(start_values[[i]]$starting_values$D, diag = TRUE)]
    random <- data.frame(
      parameter_label = labels_random,
      parameter_estimate = values_random,
      row.names = NULL
    )
    # Get residual errors
    resid_labels <- c()
    if (model_families$families[i] == "binary") {
      resid_labels <- NULL
    } else if (model_families$families[i] == "binary.normal") {
      if (all(model_families$indicators[[i]] == c(0, 1 ))) {
        resid_labels <- paste0("resid_", pairs[i, 1])
      } else {
        resid_labels <- paste0("resid_", pairs[i, 2])
      }
    } else if (model_families$families[i] == "normal") {
      resid_labels <- c(paste0("resid_", pairs[i, 1]),
                        paste0("resid_", pairs[i, 2]))
    }
    resid <- data.frame(
      parameter_label = resid_labels,
      parameter_estimate = start_values[[i]]$fit$phis,
      row.names = NULL
    )
    out <- bind_rows(fixed, random, resid)
    rownames(out) <- NULL
    out
  })
  do.call(rbind, parameters)
}

#' Stack gradients into a single data.frame
#'
#' Take the gradients returned by the get_mmm_derivatives function and stack these into a single data.frame for further processing
#'
#' @param derivatives list with derivatives returned by the get_mmm_derivatives() function.
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param parameter_labels The parameter labels returned by the stack_parameters() function.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' 
#' @return A data.frame with subject level gradients for every longitudinal outcome

stack_gradients <- function (derivatives, data, id, pairs) {
  gradients <- lapply(1:nrow(pairs), function(i) {
    g <- lapply(1:nrow(unique(data[id])), function(j, ...) {
        data.frame(
          "id" = derivatives[[i]]$df_gradients[j, 1],
          "gradient" = derivatives[[i]]$df_gradients[j, -1],
          row.names = NULL
        )
      })
    do.call(rbind, g)
  })
  gradients <- do.call(rbind, gradients)
  gradients[order(gradients$id),]
}

#' Create the helper 'A'  matrix used to average the parameter estimates.
#'
#' Take parameter labels from MMM object and return a p*q matrix where p is unique(parameters) and q is the total number of parameters (including duplicates).
#'
#' @param parameters A data frame with parameter labels returned by the stack_parameters() function.
#'
#' @Return A p*q matrix with ones if a parameter was estimated and zeros if it was not estimated in the pairwise step

create_A_matrix <- function (parameter_labels) {
  A <- matrix(0, nrow = length(unique(parameter_labels)), ncol = length(parameter_labels))
  dimnames(A) <- list(
    unique(parameter_labels),
    parameter_labels
  )
  for (i in 1:nrow(A)) {
    for (j in 1:ncol(A)) {
      if (rownames(A)[i] == colnames(A)[j]) {
        A[i,j] <- 1
      }
    }
  }
  A
}

#' Get cross products of the gradients
#' 
#' Get the cross product of the subject specific gradient vector, sum and take the average
#' 
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param parameter_labels The parameter labels returned by the stack_parameters() function.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param gradients A data frame with gradients returned by the stack_gradients() function.
#' 
#' @import Matrix
#' 
#' @return A matrix with the average gradient for all parameters over all subjects.

get_crossproducts <- function (pairs, parameter_labels,  gradients, data, id) {
  cpgrad <- matrix(0, nrow = length(parameter_labels), ncol = length(parameter_labels),
                   dimnames = list(
                     parameter_labels,
                     parameter_labels
                   )
  )
  # The loop sums all the subject level contributions to the gradient.
  for (i in 1:nrow(unique(data[id]))){
    .id <- unlist(unique(data[paste(id)])[i,])
    gsub <- gradients %>% 
      dplyr::filter(id == .id) %>% 
      dplyr::select(gradient)
    cpgsub <- tcrossprod(as.matrix(gsub)) 
    cpgrad <- cpgrad + cpgsub 
  }
  # Average the gradients
  cpgrad * (1/nrow(unique(data[paste(id)])))
}

#' Average the Hessians
#'
#' 
#' 
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' @param pairs A character matrix with pairs returned by the make_pairs function.
#' @param hessians a data frame with the individual contribution to the hessians returned by the get_mmm_derivatives() function
#' @param A the A helper matrix returned by the create_A_matrix function.
#'
#' @import Matrix 
#' @import dplyr
#' 
#' @return a Sparse matrix of the MMM Hessian

average_hessians <- function(derivatives, pairs, data, id, A) {
  # Outer loop: create block Sparse matrix of the multivariate mixed model Hessian
  for (j in 1:nrow(pairs)) {
    hess <- as.data.frame(derivatives[[j]]$df_hessians)
    hess_pair <- matrix(0, nrow = ncol(derivatives[[j]]$df_hessians)-1, 
                        ncol = ncol(derivatives[[j]]$df_hessians)-1)
    # Inner loop: subject level contributions to the pairwise Hessian
    for (i in 1:nrow(unique(data[id]))) {
      .id <- unlist(unique(data[paste(id)])[i,])
      hsub <- hess %>% dplyr::filter(id == .id)
      hsub <- as.matrix(hsub)[,-1]
      hess_pair <- hess_pair + hsub
    }
    if (j == 1) {
      Hessian <- hess_pair
    } else {
      Hessian <- as.matrix(bdiag(Hessian, hess_pair))
    }
  }
  # Average the Hessian
  Hessian <- Hessian * (1/nrow(unique(data[paste(id)])))
  dimnames(Hessian) <- list(
    colnames(A),
    colnames(A)
  )
  Hessian
}

#' Calculate the covariance matrix for the MMM and return it and the robust standard error 
#'
#' Use a sandwich estimator to return the standard errors for the full multivariate mixed model.
#'
#' @param hessian The sparse Hessian matrix for the MMM returned by the average_hessians function().
#' @param cpgradient The cross-product and average of the subject level gradients returned by the get_crossproducts() function.
#' @param parameters The parameter estimates returned by the stack_parameters() function.
#' @param A the A helper matrix returned by the create_A_matrix function().
#' @param data The data.frame with original data in long format.
#' @param id A character value with the subject identifier in the orginal data.
#' 
#' @import Matrix
#' 
#' @return A data.frame with parameter estimates and standard errors for the multivariate mixed model.

# Covariance of sqrt(n) * (theta - theta_0) Fieuws 2006 q(4) p453
get_covariance <- function(hessian, cpgradient, parameters, A, data, id) {
  JKJ <- ginv(hessian) %*% cpgradient %*% ginv(hessian)
  dimnames(JKJ) <- list(
    colnames(A),
    colnames(A)
  )
  # Covariance of (theta - theta_0)
  cov <- JKJ * (1/nrow(unique(data[paste(id)])))
  # Average of the parameters
  A <- A * (1/rowSums(A))
  avg_par <- A %*% parameters$parameter_estimate
  covrobust <- A %*% cov %*% t(A)
  serobust <- sqrt(diag(covrobust))
  out <- data.frame( 
    "parameter_name" = rownames(A),
    "parameter_estimate" = avg_par,
    "parameter_std_err" = serobust,
    row.names = NULL
    )
  out
}

#' Create the full mixed model output
#' 
#' Take the random effect estimates and return vcov and correlation matrices
#' 
#' @param estimates A data.frame with estimates returned by the get_covariance() function.
#'
#' @import dplyr
#'
#' @return A list of 3 elements including:
#' \describe{
#'   \item{estimates}{A data.frame with the parameter estimates and standard errors}
#'   \item{vcov}{The Variance-Covariance matrix for the random effects}
#'   \item{corr}{A correlation matrix for the random effects}
#'}
get_model_output <- function(estimates) {
  # Sort the fixed effects
  fixef <- estimates %>% dplyr::filter(grepl("^b", parameter_name) == TRUE)
  # fixef <- fixef[order(fixef$parameter),]
  # Sort the random effects
  varest <- estimates %>% dplyr::filter(grepl("^var", parameter_name) == TRUE)
  covest <- estimates %>% dplyr::filter(grepl("^cov", parameter_name) == TRUE)
  # Sort the residuals
  resid <- estimates %>% dplyr::filter(grepl("resid", parameter_name) == TRUE)
  resid$parameter_estimate <- exp(resid$parameter_estimate)
  resid$parameter_std_err <- exp(resid$parameter_std_err)
  # Covariance matrix
  est_D <- matrix(0, nrow = nrow(varest), ncol = nrow(varest))
  dimnames(est_D) <- list(gsub("^.*_x_","", unique(varest$parameter_name)), 
                          gsub("^.*_x_","", unique(varest$parameter_name)))
  diag(est_D) <- varest$parameter_estimate
  est_D[lower.tri(est_D)] <- covest$parameter_estimate
  est_D[upper.tri(est_D)] <- t(est_D)[upper.tri(est_D)]
  # check if est_D is positive definite
  if (all(eigen(est_D)$values > 0) == FALSE) {
    # Compute the nearest positive definite matrix
    est_D <- nearPD(est_D)$mat
    # Replace the variance and covariance estimates with the nearest PD values
    varest$parameter_estimate <- diag(est_D)
    covest$parameter_estimate <- est_D[lower.tri(est_D)]
    # Send a warning to the user.
    message("The Variance-Covariance matrix was not a positive definine. 
            Projecting to the closest positive definine. 
            Note that your results may be unexpected.")
  }
  est_Dcorr <- cov2cor(est_D)
  # outcome tables
  out <- list("estimates" = rbind(fixef, varest, covest, resid),
              "vcov" = est_D,
              "corr" = est_Dcorr)
  out
}

#' Fit a multivative mixed model using the pairwise fitting approach
#' 
#' Takes the stacked data and the user defined formulas for the fixed and random effects. Leverages the future.apply package to parallelize along pairs.
#' Using a different number of workers than pairs may result in problems during the optimization process. The output of the pairwise model fitting is combined
#' and any parameters that have been estimated multiple times are averaged. The procedure takes the individual level contributions to the Maximum Likelihood in
#' order to estimate the standard errors for the parameter estimates with a robust sandwich estimator. 
#' 
#' @param stacked_data a list with the stacked data returned by the stack_data() function.
#' @param fixed character string with the fixed model part. This is passed to as.formula().
#' @param random character string with the random model part. This is passed to as.formula().
#' @param pairs a character matrix with the pairs returned by the make_pairs() function.
#' @param model_families a list with family names and indicators returned by the test_input_datatypes() function.
#' @param id The name of the column with subject ids in stacked_data.
#' @param iter_EM (default = 100) EM iterations passed to GLMMAdaptive::mixed_model().
#' @param iter_qN_outer (default = 10) outer Quasi Newton iterations passed to GLMMAdaptive::mixed_model(). 
#' @param parallel_plan (default = "sequential") The parallel plan to be passed to future.apply(). Currently available: "sequential" and "multisession."
#' @param ncores number of workers for multisession parallel plan.
#' @param (...) additional arguments passed to GLMMAdaptive::mixed_model().
#' 
#' @import tidyverse
#' @import GLMMAdaptive
#' @import Matrix
#' @import future.apply
#' 
#' @return A list of 3 elements including:
#' \describe{
#'   \item{estimates}{A data.frame with the parameter estimates and standard errors}
#'   \item{vcov}{The Variance-Covariance matrix for the random effects}
#'   \item{corr}{A correlation matrix for the random effects}
#'}
mmm_model <- function (fixed, random, id, data, stacked_data, pairs, model_families, iter_EM = 100,
                       iter_qN_outer = 10, parallel_plan = "sequential", ncores = NULL,
                       ...) {
  if (!id %in% colnames(data)) {
    stop("Could not find `id` in colnames(`data`). Have you specified the correct subject id?")
  }
  if (parallel_plan == "sequential") {
    plan(sequential)
    message("Fitting pairwise model using ", parallel_plan, ".")
  } else if (parallel_plan == "multisession") {
    if (is.null(ncores)) {
      ncores <- min(parallel::detectCores() - 1, nrow(pairs))
    }
    plan(multisession, workers = ncores)
    message("Fitting pairwise model using ", parallel_plan, " on ", ncores, " cores.")
  }
  start_values <- get_mmm_start_values(stacked_data = stacked_data,
                                       fixed = as.formula(fixed),
                                       random = as.formula(random),
                                       pairs = pairs,
                                       model_families = model_families,
                                       iter_EM = iter_EM, 
                                       iter_qN_outer = iter_qN_outer,
                                       ...
                                       )
  if (parallel_plan == "sequential") {
    plan(sequential)
    message("Retrieving derivatives using ", parallel_plan, ".")
  } else if (parallel_plan == "multisession") {
    ncores <- parallel::detectCores() - 1
    plan(multisession, workers = ncores)
    message("Retrieving derivatives using ", parallel_plan, " on ", ncores, " cores.")
  }
  derivatives <- get_mmm_derivatives(stacked_data = stacked_data,
                                     id = id,
                                     fixed = as.formula(fixed),
                                     random = as.formula(random),
                                     pairs = pairs,
                                     model_families = model_families,
                                     start_values = start_values)
  message("Compiling model output.")
  parameters <- stack_parameters(start_values = start_values,
                                 pairs = pairs,
                                 model_families = model_families)
  gradients <- stack_gradients(derivatives = derivatives,
                               data = data,
                               id = id,
                               pairs = pairs)
  A <- create_A_matrix(parameter_labels = parameters$parameter_label)
  cpgradient <- get_crossproducts( pairs = pairs,
                                   parameter_labels = parameters$parameter_label,
                                   gradients = gradients,
                                   data = data,
                                   id = id)
  hessian <- average_hessians(derivatives = derivatives,
                              pairs = pairs,
                              data = data,
                              id = id,
                              A = A)
  mmm_estimates <- get_covariance(hessian = hessian,
                                  cpgradient = cpgradient,
                                  parameters = parameters,
                                  A = A,
                                  data = data,
                                  id = id)
  model <- get_model_output(estimates = mmm_estimates)
  # reset future plan
  plan(sequential)
  model
}


# Get predictions from the mmm ---------------------------------------------------

#'  Return a data.frame with the outcome label and type
#'
#' @param data original data
#' @param outcomes a character vector with the names of the longitudinal outcomes

get_outcome_type <- function(data, outcomes) {
  out <- lapply(outcomes, function(i) {
    suppressWarnings(
      if (all(sort(unlist(unique(data[paste(i)]))) == c(0, 1), na.rm = TRUE)) {
        outcome_type <- "binary"
      } else if (class(unlist(data[paste(i)])) == "numeric") {
        outcome_type <- "continuous"
      } else {
        stop(paste("The outcome", i, "does not appear to be binary or continuous"))
      }
    )
    data.frame(outcome = i,
             outcome_type = outcome_type)
  })
  do.call(rbind, out)
}

# FIX: grepl for variable names results in poor abstraction and generalization. Parts of variable names may be similar or hold symbols used in regex. 
# FIX: models should be different for different longitudinal outcomes

#'  Get Likelihood No Fail
#' 
#' Get the parameter estimates from a MMM and combine with data to return contributions to the likelihood at observed time points for non failures
#' 
#' @param outcomes a character vector with the labels for the longitudinal outcomes.
#' @param data data.frame with original data in long format
#' @param fixed_formula_nofail a character string with lme4 style formula for the fixed effects
#' @param random_formula_nofail a character string with lme4 style formula for the random effects
#' @param time a character string with the label for the follow-up time variable
#' @param parameters_nofail a data.frame with the parameter labels and estimates returned by mmm_model()
#' @param random_effects_nofail The random effects sampled from a multivariate normal distribution (see MASS::mvrnorm)
#' @param id a character string with subject identifier
#' @param horizon the prediction horizon 
#' 
#' @import future.apply
#' @import dplyr 
#' @import Matrix
#' 
#' @return A list with the likelihood contributions for each subject (rows) at every time point (cols).

get_likelihood_nofail <- function (data,
                                   outcomes,
                                   fixed_formula_nofail = NULL,
                                   random_formula_nofail = NULL,
                                   time,
                                   parameters_nofail,
                                   random_effects_nofail,
                                   id,
                                   horizon,
                                   ...) {
  # cut the fixed and random formulas into parts
  ff <- str_remove(fixed_formula_nofail, "~ ")
  ff <- str_split(ff, " \\+ ") %>% unlist()
  rf <- str_remove(random_formula_nofail, "~ ")
  rf <- str_remove(rf, paste(" \\|", id))
  rf <- str_split(rf, " \\+ ") %>% unlist()
  # Select the parameter estimates for fixed effects
  params <- parameters_nofail %>%
    dplyr::filter(grepl("^b", parameter_name)) %>%
    dplyr::select(parameter_name, parameter_estimate)
  # Initialize lists to store the log-likelihoods by outcome for each individual i, at every time point
  # Loop over the patterns. Note that non-failures only have a single pattern as they all survive past the max horizon.
  .s <- horizon
    # Calculate the log-likelihoods for every outcome j
  likelihood <- future_lapply( 1:nrow(unique(data[id])), function(i, ...) {
  # i <- 1
  # Calculate the log-likelihoods for every outcome j
    ll <- lapply(1:nrow(outcomes), function (j) {
    # j <- 1
      # Get the fixed effect parameters
      betas <- params %>%
        dplyr::filter(grepl(paste0("_", outcomes[j, 1], "$"), parameter_name)) %>% 
        dplyr::select(parameter_estimate, parameter_name)
      # adjust the order of the parameter estimates to match the fixed_formula argument
      betas <- betas %>% 
        mutate(parameter_name = str_replace(parameter_name, "b0_", "intercept"))
      betas <- betas %>% 
        mutate(parameter_name = str_remove(parameter_name, paste0("_", outcomes[j, 1], "$")))
      betas <- betas %>% 
        mutate(parameter_name = str_remove(parameter_name, paste0("^(..|...)_")))
      betas <- betas[match(c("intercept", ff[-j]), betas$parameter_name), ] %>%
        dplyr::select(parameter_estimate)
      betas <- as.matrix(betas)
      if(!assertthat::noNA(betas)) {
        stop("Not all parameters have matching estimates. Have you correctly specified the the fixed_formula_nofail argument?")
      }
      # Get the residual
      resid <- parameters_nofail %>%
        dplyr::filter(grepl(paste0("resid", paste0("_", outcomes[j, 1],"$")), parameter_name)) %>%
        dplyr::select(parameter_estimate)
      resid <- as.numeric(resid)
      # get the log likelihood for the i-th subject
      # log_likelihood <- vector("list" , length = nrow(unique(data[id])))
      .id <- unlist(unique(data[id]))[i]
      # Get the outcome vector for subject i
      y <- data %>%
        dplyr::filter(get(id) == .id & time <= .s) %>%
        dplyr::select(all_of(outcomes[j, 1]))
      y <- as.matrix(y)
      # Using bX assumes similar models for all longitudinal outcomes. 
      .ff <- str_remove(ff, paste0("^", outcomes[j, 1], "$"))
      .ff <- .ff[.ff != ""]
      # Get the X matrix for subject i including all data prior to the horizon.
      X <- data %>%
        dplyr::filter(get(id) == .id & time <= .s)
      X <- model.matrix(as.formula(paste("~ ", paste(.ff, collapse = " + "))), X)
      # Get the linear predictor for the fixed effect
      xb <- X %*% betas
      # Get the Z matrix for subject i
      .rf <- str_remove(rf, paste0("^",outcomes[j, 1], "$"))
      .rf <- .rf[.rf != ""]
      # Select all data prior to the prediction horizon.
      Z <- data %>%
        dplyr::filter(get(id) == .id & time <= .s)
      Z <- model.matrix(as.formula(paste("~", paste(.rf, collapse = " + "))), Z)
      # log_lik is the log likelihood of the j-th outcome for the i-th subject at observed time points
      # for the k-th draw from the random effects.
      if (outcomes$outcome_type[j] == "continuous") {
        # for (k in 1:nrow(random_effect)) {
        log_lik <- lapply(1:nrow(random_effects_nofail), function(k) {  
          # get the linear predictor = X*beta + Z*bu
          zb <- rowSums(Z * random_effects_nofail[k, j])
          lp <- xb + zb
          # calculate the log-likelihood
          -log(resid) -0.5 * log(2 * pi) - ( (y - lp)^2 / (2 * resid^2) )
        })
      } else if (outcomes$outcome_type[j] == "binary") {
        log_lik <- lapply(1:nrow(random_effects_nofail), function(k) {
          # lp = X*beta + Z*bu
          zb <- rowSums(Z * random_effects_nofail[k, j])
          lp <- xb + zb
          p <- exp(lp) / (1 + exp(lp))
          lik <- ifelse(y == 0, 1 - p, p)
          lik <- ifelse(lik > 1e-8, lik, 1e-10)
          log(lik)
        })
      }
      # log likelihood for j-th outcome at each observed time point (columns) and all RE draws (rows) for i-th subject  
      list("ll" = t(do.call(cbind, log_lik)),
           ".id" = .id,
           "X" = X)
    })
    # return the sum over the outcomes at each observed time point (columns) and all RE  draws (rows) for i-the subject
    sum_ll <- ll[[1]]$ll
    for (j in 2:length(ll)) {
      sum_ll <- sum_ll + ll[[j]]$ll
    }
    # take the exp to obtain likelihood per draw (rows) and per observed time t (columns) for the i-th subject
    l <- exp(sum_ll)
    # take the mean of the likelihoods of each subject over the draws per observed time t
    l <- matrix(colMeans(l)) # this transposes the times to the rows
    data.frame(
      id = ll[[1]]$.id,
      time = ll[[1]]$X[, time],
      likelihoods = l,
      row.names = NULL
    )
  })
  likelihood
}

# FIX: de likelihoods veranderen niet afhankelijk van de failure time...

#' Get Likelihood Fail
#' 
#' Get the parameter estimates from a MMM and combine with data to return contributions to the likelihood at observed time points for non failures
#'
#' @param outcomes a character vector with the labels for the longitudinal outcomes.
#' @param data data.frame with original data in long format
#' @param fixed_formula_fail a character string with lme4 style formula for the fixed effects
#' @param random_formula_fail a character string with lme4 style formula for the random effects
#' @param time a character string with the label for the follow-up time variable
#' @param failure_time a character string with the label for the failure time variable
#' @param parameters_fail a data.frame with the parameter labels and estimates returned by mmm_model()
#' @param random_effects_fail The random effects sampled from a multivariate normal distribution (see MASS::mvrnorm)
#' @param id a character string with subject identifier
#' @param horizon the prediction horizon 
#' @param interval the intervals at which predictions are to be made
#' @param landmark the landmark time until which data is available.
#' 
#' @import future.apply
#' @import dplyr 
#' @import Matrix
#' 
#' @return A list with the likelihood contributions for each subject (rows) at every time point (cols).
get_likelihood_fail <- function (data,
                                 outcomes,
                                 fixed_formula_fail = NULL,
                                 random_formula_fail = NULL,
                                 time,
                                 failure_time,
                                 parameters_fail,
                                 random_effects_fail,
                                 id,
                                 landmark,
                                 horizon,
                                 interval,
                                 ...) {
  # Check if there are observations before the landmark for all subjects
  lm <- seq(from = landmark, to = horizon, by = interval)
  # cut the fixed and random formulas into parts
  ff <- str_remove(fixed_formula_fail, "~ ")
  ff <- str_split(ff, " \\+ ") %>% unlist()
  rf <- str_remove(random_formula_fail, "~ ")
  rf <- str_remove(rf, paste(" \\|", id))
  rf <- str_split(rf, " \\+ ") %>% unlist()
  # Select the parameter estimates for fixed effects
  params <- parameters_fail %>%
    dplyr::filter(grepl("^b", parameter_name)) %>%
    dplyr::select(parameter_name, parameter_estimate)
  likelihood_pattern <- lapply(lm, function(s, ...) {
  # s <- 1
  # use midpoint rule.
    .s <- (s + 0.5) 
    likelihood <- future_lapply(1:nrow(unique(data[id])), function(i, ...) {
    # i <- 1
    # Calculate the log-likelihoods for every outcome j
      ll <- lapply(1:nrow(outcomes), function (j) {
      # j <- 5
        # Get the fixed effect parameters
        betas <- params %>%
          dplyr::filter(grepl(paste0("_", outcomes[j, 1], "$"), parameter_name)) %>% 
          dplyr::select(parameter_estimate, parameter_name)
        # adjust the order of the parameter estimates to match the fixed_formula argument
        betas <- betas %>% 
          mutate(parameter_name = str_replace(parameter_name, "b0_", "intercept"))
        betas <- betas %>% 
          mutate(parameter_name = str_remove(parameter_name, paste0("_", outcomes[j, 1], "$")))
        betas <- betas %>% 
          mutate(parameter_name = str_remove(parameter_name, paste0("^(..|...)_")))
        betas <- betas[match(c("intercept", ff[-j]), betas$parameter_name), ] %>%
          dplyr::select(parameter_estimate)
        betas <- as.matrix(betas)
        if(!assertthat::noNA(betas)) {
          stop("Not all parameters have matching estimates. Have you correctly specified the the fixed_formula_nofail argument?")
        }
        # Get the residual
        resid <- parameters_fail %>%
          dplyr::filter(grepl(paste0("resid", paste0("_", outcomes[j, 1],"$")), parameter_name)) %>%
          dplyr::select(parameter_estimate)
        resid <- as.numeric(resid)
        # get the log likelihood for the i-th subject
        # log_likelihood <- vector("list" , length = nrow(unique(data[id])))
        .id <- unlist(unique(data[id]))[i]
        # Get the outcome vector for subject i
        y <- data %>%
          dplyr::filter(get(id) == .id & time <= .s) %>%
          dplyr::select(all_of(outcomes[j, 1]))
        y <- as.matrix(y)
        # bX assumes similar models for all longitudinal outcomes. 
        .ff <- str_remove(ff, paste0("^", outcomes[j, 1], "$"))
        .ff <- .ff[.ff != ""]
        # Get the X matrix for subject i, select only data prior to the lm time 
        X <- data %>%
          dplyr::filter(get(id) == .id & time <= .s)
        X <- model.matrix(as.formula(paste("~ ", paste(.ff, collapse = " + "))), X)
        # replace actual failure time with failure time for the prediction
        if (failure_time %in% colnames(X)) {
          X[ , failure_time] <- .s
        }
        # Get the linear predictor for the fixed effect
        xb <- X %*% betas
        # Get the Z matrix for subject i
        .rf <- str_remove(rf, paste0("^",outcomes[j, 1], "$"))
        .rf <- .rf[.rf != ""]
        # Select only data prior to lm time 
        Z <- data %>%
          dplyr::filter(get(id) == .id & time <= .s)
        Z <- model.matrix(as.formula(paste("~", paste(.rf, collapse = " + "))), Z)
        # replace actual failure time with failure time for the prediction
        if (failure_time %in% colnames(Z)) {
          Z[ , failure_time] <- .s  
        }
        # log_lik is the log likelihood of the j-th outcome for the i-th subject at observed time points
        # for the k-th draw from the random effects.
        if (outcomes$outcome_type[j] == "continuous") {
          # for (k in 1:nrow(random_effect)) {
          log_lik <- lapply(1:nrow(random_effects_fail), function(k) {  
            # get the linear predictor = X*beta + Z*bu
            zb <- rowSums(Z * random_effects_fail[k, j])
            lp <- xb + zb
            # calculate the log-likelihood
            -log(resid) -0.5 * log(2 * pi) - ( (y - lp)^2 / (2 * resid^2) )
          })
        } else if (outcomes$outcome_type[j] == "binary") {
          log_lik <- lapply(1:nrow(random_effects_fail), function(k) {
            # lp = X*beta + Z*bu
            zb <- rowSums(Z * random_effects_fail[k, j])
            lp <- xb + zb
            p <- exp(lp) / (1 + exp(lp))
            lik <- ifelse(y == 0, 1 - p, p)
            lik <- ifelse(lik > 1e-8, lik, 1e-10)
            log(lik)
          })
        }
        # log likelihood for j-th outcome at each observed time point (columns) and all RE draws (rows) for i-th subject  
        list("ll" = t(do.call(cbind, log_lik)),
             ".id" = .id,
             "X" = X)
      })
      # return the sum over the outcomes at each observed time point (columns) and all RE  draws (rows) for i-the subject
      sum_ll <- ll[[1]]$ll
      for (j in 2:length(ll)) {
        sum_ll <- sum_ll + ll[[j]]$ll
      }
      # take the exp to obtain likelihood per draw (rows) and per observed time t (columns) for the i-th subject
      l <- exp(sum_ll)
      # take the mean of the likelihoods of each over the draws per observed time t
      l <- matrix(colMeans(l))
      data.frame(
        id = ll[[1]]$.id,
        time = ll[[1]]$X[, time],
        likelihoods = l,
        pattern = s,
        row.names = NULL
      )
    })
    likelihood
  })
  likelihood_pattern
}

#' Get likelihood profiles for non failures
#'
#' Generate a step function of the likelihood at interval times.  
#' 
#' @param data data.frame with the data over which to predict.
#' @param likelihood_nofail a list of likelihoods for each subject at the observed time points returned by get_likelihood_nofail().
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 if monthly intervals if horizon is in years)
#' 
#' @import future.apply
#' @import dplyr
#' 
#' @return A list of length nrow(unique(id)) with likelihood profiles
get_likelihood_profiles_nofail <- function(likelihood_nofail,
                                           landmark,
                                           horizon, 
                                           interval) {
  lm <- seq(from = landmark, to = horizon, by = interval)
  obs_l <- future_lapply(likelihood_nofail, function(x) {
    x %>% mutate(likelihoods = exp(cumsum(log(likelihoods))))
  })
  # select the last row and append to build a data.frame of cumulative likelihoods per interval
  Xnew <- future_lapply(obs_l, function(l) {
    temp <- lapply(lm, function(t) {
      temp2 <- l %>% dplyr::filter(time <= t)
        if (nrow(temp2) > 0) {
          temp2 <- temp2 %>% 
            dplyr::filter(time == max(time)) %>% 
            dplyr::mutate(cutpoint1 = t)
        } 
      # else {
      #     temp2 <- data.frame(
      #       id = l$id[1],
      #       time = 0,
      #       likelihoods = 1, 
      #       cutpoint1 = t)
      #   }
      })
    do.call(rbind, temp) %>% 
      arrange(id) %>% 
      dplyr::select(id, cutpoint1, likelihoods)
  })
  Xnew
}


#' Get likelihood profiles for failure 
#' 
#'Generate a step function of the likelihood at interval times t.  
#'
#' @param likelihood_fail a list of likelihoods for each subject at the observed time points
#' @param landmark the prediction landmark time.
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#'
#' @import future.apply
#' @import dplyr
#'   
#' @return A list of length (horizon - landmark)/interval of lists with nrow(unique(id)) with likelihood profiles 
get_likelihood_profiles_fail <- function(likelihood_fail,
                                         landmark,
                                         horizon,
                                         interval) {  
  lm <- seq(from = landmark, to = horizon, by = interval)
  # likelihood_fail is a list of lists with length (horizon - landmark + 1)/interval
  # Each of the lists contains a data.frame per patient that needs a cumsum over the likelihood taken.
  list_obs_l <- lapply(likelihood_fail, function(l) {
    future_lapply(seq_along(l), function (x) {
      l[[x]] %>%
        mutate(likelihoods = exp(cumsum(log(likelihoods))))
    })
  })
  # Combine the likelihood patterns for each patient
  obs_l <- lapply(list_obs_l, data.table::rbindlist) %>%
    do.call(rbind, .) %>%
    dplyr::filter(time < pattern) %>%
    split(., by = "id")
  # We now have the patterns per observed time. We need the patterns per possible landmark time (cutpoint 1).
  # dplyr::select the last row and append to build a data.frame of cumulative likelihoods per interval
  Xnew <- future_lapply(obs_l, function(l, ...) {
    temp <- lapply(lm, function(t) {
      temp2 <- l %>% 
        dplyr::filter(
          time <= t,
          pattern >= t
          )
      if (nrow(temp2) > 0) {
        temp2 <- temp2 %>% 
          dplyr::filter(time == max(time)) %>% 
          mutate(cutpoint1 = t)
      } 
    })
    do.call(rbind, temp) %>% 
      arrange(id) %>% 
      dplyr::select(id, cutpoint1, likelihoods, cutpoint2 = pattern)
  })
  Xnew
}

#' Get priors
#' 
#' Call flexsurvreg to fit an unconditional survival model and return obtain priors 
#'
#' @param data a data.frame with the data for which to predict in the long format (input should be the same as likelihood_profiles).
#' @param time_failure a character with the failure time variable name.
#' @param failure a character with the failure indicator variable name.
#' @param horizon a 1L int with the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' @param dist distribution parameter for the survreg()
#' 
#' @import flexsurv
#' @import dplyr
#' 
#' @return a priori survival and failure probabilities (1-survival) for every interval between min landmark and max horizon
get_priors <- function (data, time_failure, failure, horizon, interval, dist = "weibull") {  
  surv <- paste0("Surv(",time_failure, ",", failure, ")  ~ 1")
  sfit_wb <- flexsurvreg(as.formula(surv), data = data, dist = dist)
  # get the prior failure probabilities from the weibull model
  pct <- seq(0, 1, by = 0.001)
  pred_wb <- predict(sfit_wb,  type = "quantile", p = pct)$.pred[[1]]
  lm <- (0:(horizon*interval^-1+1))*interval
  prior <- sapply(lm, function (i) {
      pred_wb %>% 
        dplyr::filter(.pred <= i) %>% 
        summarise(prior_fail = max(.quantile))
  })
  prior <- do.call(rbind, prior)
  prior <- data.frame(
    lm = lm,
    prior_fail = prior,
    prior_survival = 1 - prior,
    row.names = NULL
  )
  prior <- prior %>% 
    mutate(prior_survival_interval = lead(prior_survival) / prior_survival,
           prior_fail_interval = 1 - prior_survival_interval) %>%
    dplyr::filter(lm <= horizon) %>%
    dplyr::select(lm, prior_survival_interval, prior_fail_interval)
  prior
}

#' Get predictions for non failure
#' 
#' Take a cumsum of the posteriors and return the cumulative incidence of the event between interval and horizon
#'  
#' @param likelihood_profiles_fail a list with the likelihood profiles for the failure model for each subject
#' @param prior a data frame with priors returned by get_priors()
#' @param horizon the prediction horizon.
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' 
#' @import dyplr
#' 
#' @return a list of dataframes with the predictions for each subject at every interval up to the horizon
get_predictions_nofail <- function(likelihood_profiles_nofail) {
  negloglik_nofail <- lapply(likelihood_profiles_nofail, setNames, c("id", "cutpoint1", "f_nofail"))
}

#' Get Predictions for Failure 
#'
#' Take a cumsum of the posteriors and return the cumulative incidence of the event between interval and horizon
#'Implements equation 7.2 (p155) from Steffen's thesis.
#'
#' @param likelihood_profiles_fail a list with the likelihood profiles for the failure model for each subject
#' @param prior a data frame with priors returned by get_priors()
#'
#' @import dplyr
#'  
#' @return a list of dataframes with the predictions for each subject up to the horizon.
get_predictions_fail <- function (likelihood_profiles_fail, prior) {
  # Join the likelihood and priors data, i.e. 
  # f_i(y_i^<=k | F_i = k + 0.5) * P(k <= F_i <= k +1)
  predict_fail <- future_lapply(likelihood_profiles_fail, left_join, prior, by =  c("cutpoint1" = "lm"))
  predict_fail <- future_lapply(predict_fail, function(x) {
    x %>% 
      mutate(f_fail_interval = likelihoods * (1-prior_survival_interval)) %>% 
    arrange(id, cutpoint1, cutpoint2)})
  future_lapply(predict_fail, function(x) { 
    x %>% 
      group_by(cutpoint1) %>%
      mutate(sumf = cumsum(f_fail_interval),
           cumfprior = cumprod(prior_survival_interval),
           # Sigma_t=k^h [ f_i(y_i^<=k | F_i = k + 0.5) * P(k <= F_i <= k + 1)] / P(lm <= F_i <= h)
           f_fail = cumsum(f_fail_interval) / (1 - cumprod(prior_survival_interval))) %>%
      ungroup() %>%
      dplyr::select(id, cutpoint1, cutpoint2, f_fail)
  })
}
  
#' Apply Bayes Rule to get posterior predictions
#' 
#' Use the posteriors for failures and non-failure to get predictions with Bayes Rule
#' 
#' @param f_fail density for the failures at landmarks upto horizon.
#' @param f_nofail density for the non-failures up to horizon.
#' @param landmark vector of landmark timepoints at which to make a prediction
#' @param horizon the time until predictions are made.
#' 
#' @import purrr
#' @import dplyr
#' @import future.apply
#' 
#' @returns: a list with posterior predictions at every landmark
get_posteriors <- function (f_fail, f_nofail, prior) {
  # Combine f_fail and f_nofail
  post <- map2(f_fail, f_nofail, right_join, by = c("id", "cutpoint1"))
  # get P(F>hor|F>=lm) from prior
  post <- future_lapply(post, function(x) {
    x %>% left_join(prior, by = c("cutpoint1" = "lm")) %>%
      group_by(cutpoint1) %>%
      mutate(prior_surv = cumprod(prior_survival_interval),
             prior_fail = 1 - prior_surv,
             post_fail = (f_fail * prior_fail) / (f_fail * prior_fail + f_nofail * prior_surv)) %>%
      ungroup() %>%
      dplyr::select(id, landmark = cutpoint1, horizon=cutpoint2, post_fail)
  })
  post
}

#' Predictions from the multivariate mixed model using a disciminant analysis framework.
#' 
#' Calculate the posterior predictions from the likelihood profiles and priors. Implements equation 7.1 (p154) from Steffen's thesis.
#'
#' @param data a data.frame with the data for which to predict in the long format (input should be the same as likelihood_profiles)
#' @param outcomes character vector with the outcomes
#' @param id a character string with the subject id
#' @param fixed_formula_nofail formula for the fixed effects of the nofail model  
#' @param random_formula_nofail formula for the random effects of the nofail model,
#' @param random_effects_nofail random effect estimates for the no fail model (see MASS::mvnorm)
#' @param parameters_nofail parameter estimates from the non-failures model (see mmm_model)
#' @param fixed_formula_fail formula for the fixed effects of the fail model,
#' @param random_formula_fail formula for the random effects of the fail model,
#' @param random_effects_fail random effect estimates for the fail model (see MASS::mvnorm),
#' @param parameters_fail parameter estimates from the non-failures model (see mmm_model)
#' @param time a character string with the time variable name in data
#' @param failure a character string with the failure variable name in data
#' @param failure_time a character string with the failure time variable in data
#' @param landmark a numeric value indicating the prediction landmark until which data is available for prediction.
#' @param horizon the prediction horizon (maximum follow-up time)
#' @param interval the intervals relative to the horizon (e.g. 1/12 for monthly intervals and horizon is in years)
#' 
#' @import future.apply
#' @import dplyr
#' @import flexsurv
#' 
#' @export
#' 
#' @return a data.frame with the posterior predictions for each subject at every landmark up to the horizon

mmm_predictions <- function( data,
                             outcomes,
                             fixed_formula_nofail,
                             random_formula_nofail,
                             random_effects_nofail,
                             parameters_nofail,
                             fixed_formula_fail,
                             random_formula_fail,
                             random_effects_fail,
                             parameters_fail,
                             time,
                             failure,
                             failure_time,
                             prior,
                             id,
                             landmark=0,
                             horizon, 
                             interval,
                             ...) {
  n1 <- unique(data[id])
  data <- data %>% 
    dplyr::filter(all_of(time) <= landmark)
  n2 <- distinct(data[id])
  if (nrow(n1) != nrow(n2)) {
    dropped <- setdiff(n1, n2) %>% unlist
    message(paste("Dropped", length(dropped), "subjects without data prior to landmark"))
  }
  if (nrow(n2) == 0) {
    stop("There are no observations prior to the landmark. Have you set the argmument `landmark=` correcty")
  }
  likelihood_nofail <- get_likelihood_nofail(data = data,
                                             outcomes = outcomes,
                                             fixed_formula_nofail = fixed_formula_nofail,
                                             random_formula_nofail = random_formula_nofail,
                                             time = time,
                                             parameters_nofail = parameters_nofail,
                                             random_effects_nofail = random_effects_nofail,
                                             id = id,
                                             horizon = horizon)
  likelihood_fail <- get_likelihood_fail(data = data,
                                         outcomes = outcomes,
                                         fixed_formula_fail = fixed_formula_fail,
                                         random_formula_fail = random_formula_fail,
                                         time = time,
                                         failure_time = failure_time,
                                         parameters_fail = parameters_fail,
                                         random_effects_fail = random_effects_fail,
                                         id = id,
                                         landmark = landmark,
                                         interval = interval,
                                         horizon = horizon)
  likelihood_profiles_nofail <- get_likelihood_profiles_nofail(likelihood_nofail = likelihood_nofail,
                                                               landmark = landmark,
                                                               horizon = horizon, 
                                                               interval = interval) 
  likelihood_profiles_fail <- get_likelihood_profiles_fail(likelihood_fail = likelihood_fail,
                                                           landmark = landmark,
                                                           horizon = horizon,
                                                           interval = interval)
  f_nofail <- get_predictions_nofail(likelihood_profiles_nofail = likelihood_profiles_nofail)
  f_fail <- get_predictions_fail(likelihood_profiles_fail = likelihood_profiles_fail,
                                 prior = prior)
  predictions <- get_posteriors(prior = prior,
                                f_fail = f_fail,
                                f_nofail = f_nofail)
  predictions <- lapply(predictions, nest, data = -id)
  predictions <- lapply(predictions, setNames, c(id, "prediction"))
  predictions <- do.call(rbind, predictions)
  data_trajectory <- data %>%
    group_by(all_of(id)) %>%
    dplyr::filter(all_of(time) <= landmark) %>%
    ungroup() %>%
    dplyr::select(all_of(c(id, time, failure, failure_time, outcomes$outcome))) %>%
    nest(data = -id)
  predictions <- data_trajectory %>% left_join(predictions, by = id)
  if (!is.list(predictions)) { 
    predictions = list(predictions)
  }
  predictions
}


# Model performance -------------------------------------------------------


#' Calculate the Area under the Receiver Operating Characteristics Curve
#' 
#' Use the timeROC package to calculate the ROC-AUC based on predictions
#' from the mmm_predictions function.
#' 
#' @param predictions predictions data from the mmm_predictions function.
#' @param prediction_landmark a numeric vector with the landmark times to select predictions
#' @param prediction_horizon the prediction horizon *at* which to calculate the ROC-AUC.
#' @param id a character value for the name of the vector of subject identifiers in the predictions data
#' @param failure_time a character value for the name of the vector with failure times in the predictions data
#' @param predictions the predictions returned by the mmm_predictions function
#' 
#' @import timeROC
#' @import dplyr
#' 
#' @details Uses the timeROC package to determine the area under the receiver operating characteristics curve. 

get_auc <- function (
  predictions,
  prediction_landmark,
  prediction_horizon,  
  id,
  failure_time,
  failure
  ) {
    data <- predictions %>%  # This assumes that predictions are always stored in a list of data tables
      unnest(data) %>% 
      group_by(get(id)) %>% 
      dplyr::select(all_of(c(failure_time, failure))) %>%
      summarize(across(everything(),first)) %>%
      ungroup()
    predict <- predictions %>%
      unnest(prediction) %>%
      group_by(get(id)) %>%
      dplyr::select(-data) %>%
      dplyr::filter(landmark == prediction_landmark,
             horizon == prediction_horizon) %>%
      ungroup()
    time <- data %>% dplyr::select(all_of(failure_time)) %>% unlist()
    delta <- data %>% dplyr::select(all_of(failure)) %>% unlist()
    marker <- predict$post_fail
    assertthat::assert_that(
      all(
      length(time)==length(delta),
      length(time)==length(marker)
      ), 
      msg = paste("The length of the vectors", failure_time, failure,    
                  "and predictions do not match.")
    )
    roc <- timeROC(T = time,
                   delta = delta,
                   marker = marker,
                   cause = 1, 
                   weighting = "marginal",
                   times = prediction_horizon-0.001, # substract a small amount of time to get result
                   iid = TRUE) 
    roc <- cbind(roc$Stats, roc$AUC, roc$inference$vect_sd_1, roc$CumulativeIncidence)
    colnames(roc) <- c("failures", "survivors", "censored", "auc", "auc_se", "cuminc")
    roc <- as.data.frame(roc, row.names = FALSE) %>% mutate(landmark = prediction_landmark)
  roc %>% 
    dplyr::filter(!is.na(auc)) %>%
      mutate(horizon = prediction_horizon) %>%
    relocate(landmark, horizon, everything())
}

#' #' Plot calibration curve
#' #' 
#' #' Use riskRegression's Score.list() to determine calibration.
#' #' 
#' #' @param predictions Predictions from the MMM calculated with the mmm_predictions() function.
#' #' @param landmark The prediction landmark time.
#' #' @param horizon The prediction horizon time.
#' #' @param id The subject identifier in the predictions data.
#' #' 
#' #' @import riskRegression
#' #' @export
#' plot_calibration <- function(predictions,
#'                              landmark, 
#'                              horizon,
#'                              id) {
#'   data <- predictions[[landmark]] %>% 
#'     unnest(data) %>% 
#'     group_by(get(id)) %>% 
#'     dplyr::select(all_of(c(failure_time, failure))) %>%
#'     summarize_all(first)
#'   predict <- predictions[[landmark]] %>%
#'     unnest(prediction) %>%
#'     dplyr::select(-data) %>%
#'     dplyr::filter(time == horizon)
#'   score <- riskRegression::Score(list( "p_fail"=predict$post_fail ),
#'                                  formula=Surv(stime, failure) ~ 1,
#'                                  data=data, 
#'                                  conf.inf=FALSE,
#'                                  times=horizon,
#'                                  metrics=NULL, 
#'                                  plots="Calibration")
#'   plt_calibrate <- riskRegression::plotCalibration(score,
#'                                                    cens.method="local", 
#'                                                    round=TRUE,
#'                                                    rug=TRUE,
#'                                                    plot=FALSE)
#'   gplt_calibrate <- ggplot(plt_calibrate$plotFrames$p_fail, aes(x=Pred, y=Obs)) + 
#'     stat_smooth(method="loess", 
#'                 span=0.3,
#'                 color="black",
#'                 alpha=0.2) + 
#'     geom_abline(intercept=0, slope=1, alpha=0.5) +
#'     geom_rug(aes(x=Pred)) + 
#'     scale_y_continuous(limits=c(0, 1),
#'                        breaks=seq(0, 1, 0.2)) + 
#'     scale_x_continuous(limits=c(0, 1),
#'                        breaks=seq(0, 1, 0.2)) +
#'     labs(title="", 
#'          caption=paste("Predictions generated at", landmark[t],"years follow-up"),
#'          y="",
#'          x="") +
#'     theme_tufte(ticks=FALSE, 
#'                 base_family="WorkSans")
#' }

# Visualizations ----------------------------------------------------------


#' Plot predictions from the multivariate mixed model
#' 
#' For a given landmark time, plot the probability of failure conditional on survival 
#' and longitudinal biomarker evolutions.
#' 
#' @param predictions the predictions generated from mmm_predictions()
#' @param outcomes a character vector with the names of the longitudinal outcomes
#' @param id a character value with the name of the subject identifier
#' @param subject a character value with id for subject of interest
#' @param landmark a scalar value for the landmark at which the prediction is plotted
#' 
#' @import ggplot2
#' @import gridExtra
#' @import dplyr
#' @import scales
#' 
#' @return plots for survival conditional on the longitudinal biomarker trajectory.
plot_predictions <- function (predictions,
                              outcomes,
                              id,
                              subject,
                              time) {
  # FIX: Return marker data to the orginal scale
  # check input
  if((length(subject) == 1) == FALSE) {
    stop("Set `subject` to select one subjects")  
  }
  
  # detect binary outcomes
  is_binary <- sapply(outcomes, function(x, ...) {
    all(
      predictions %>% 
        unnest(data) %>% 
        distinct(get(x)) %>%
        unlist() %>%
        sort() == c(0, 1)
      )
  })
  # Apply filter at subject level
  # if (is.null(subject)) { 
    # # Select marker data
    # markers <- predictions %>% 
    #   dplyr::select(all_of(id), data) %>%
    #   unnest(data) 
    # # Select prediction data
    # pred <- predictions %>% 
    #   dplyr::select(all_of(id), prediction) %>%
    #   unnest(prediction)
  # } else { 
    # Select marker data
    markers <- predictions %>%
      dplyr::select(all_of(id), data) %>%
      filter(get(id) == subject) %>%
      unnest(data)
    # Select prediction data
    pred <- predictions %>%
      dplyr::select(all_of(id), prediction) %>%
      filter(get(id) == subject) %>%
      unnest(prediction) 
  # }
  # set-up time axis
  if (nrow(markers) == 0) {
    max_x <- 0
  } else {
    max_x <- markers %>%
      dplyr::select(all_of(time)) %>%
      max()
  }
  # Pivot the data for plotting
  markers_trajectory <- markers %>%
    pivot_longer(cols = outcomes, names_to ="outcome")
  markers_trajectory <- markers_trajectory %>% 
    mutate(is_binary = is_binary[markers_trajectory$outcome])
  # Filter out a single subject
  # subjects <- markers_trajectory %>% 
  #   dplyr::select(all_of(id)) %>% 
  #   unique() %>% 
  #   unlist()
  # If binary == TRUE
  binary_plot <- markers_trajectory %>% 
    filter(
        is_binary == TRUE
    ) %>% 
    ggplot(aes(x=time, y=value, col=outcome)) +
    geom_point() + 
    geom_line(alpha=0.8) +
    scale_x_continuous(name="",
                       limits=c(0, ceiling(max_x)),
                       breaks=scales::breaks_width(0.5)) +
    scale_y_continuous(
      limits=c(0,1),
      name = "longitudinal outcome") +
    theme_bw() +  # FIX: Tufte theme later
    theme(legend.position = "none") +
    facet_wrap(vars(outcome))
  # If binary == FALSE
  continuous_plot <- markers_trajectory %>% 
    filter(
        is_binary == FALSE
    ) %>%  
    ggplot(aes(x=time, y=value, col=outcome)) + 
    geom_point() + 
    geom_line(alpha = 0.8) + 
    scale_x_continuous(name = "",
                       limits = c(0, ceiling(max_x)),
                       breaks = scales::breaks_width(0.5)
    ) + 
    scale_y_continuous(name = "longitudinal outcome") +
    theme_bw() +  # FIX: Tufte theme later
    theme(legend.position = "none") +
    facet_wrap(vars(outcome),
               scales="free")
  # Plot the predictions
  surv_plot <- pred %>% 
    filter(landmark==min(landmark)) %>% 
    ggplot(aes(x=horizon,y=post_fail)) + 
    geom_line(alpha = 0.8)  +
    scale_x_continuous(name = "",
                       breaks = scales::breaks_width(1)) +
    scale_y_continuous(name = "Probability of graft failure",
                       limits = c(0, 1),
                       breaks = scales::breaks_width(0.2),
                       position = "right") +
    theme_bw() # FIX: theme Tufte
  # Combine plots and return
  grid.arrange(
    grobs=list(binary_plot, continuous_plot, surv_plot),
    widths=c(1,1),
    layout_matrix=rbind(
      c(1,3),
      c(2,3)
    )
  )
}

#' Plot the calibration for MMM
#' 
#' Create a grobs with calibration plot for a given landmark and horizon.
#' 
#' @param predictions predictions data from the mmm_predictions function.
#' @param prediction_landmark a numeric vector with the landmark times to select predictions
#' @param prediction_horizon the prediction horizon *at* which to calculate the ROC-AUC.
#' @param id a character value for the name of the vector of subject identifiers in the predictions data
#' @param failure_time a character value for the name of the vector with failure times in the predictions data
#' @param predictions the predictions returned by the mmm_predictions function
#' 
#' @import ggplot2
#' @import ggthemes
#' @import gridExtra
#' @import dplyr
#' @import scales
#' @import riskRegression
#' 
#' @return a list of grobs
#' 
#' @export
plot_calibration <- function(
  predictions,
  prediction_landmark,
  prediction_horizon,
  id,
  failure_time,
  failure
  ) {
  data <- predictions %>%  # This assumes that predictions are always stored in a list of data tables
    unnest(data) %>% 
    group_by(get(id)) %>% 
    dplyr::select(all_of(c(failure_time, failure))) %>%
    summarize(across(everything(),first)) %>%
    ungroup()
  predict <- predictions %>%
    unnest(prediction) %>%
    group_by(get(id)) %>%
    dplyr::select(-data) %>%
    dplyr::filter(landmark == prediction_landmark,
                  horizon == prediction_horizon) %>%
    ungroup()
  score <- riskRegression::Score(
    list( "p_fail"=predict$post_fail ),
    formula=Surv(stime, failure) ~ 1,
    data=data, 
    conf.inf=FALSE,
    times=prediction_horizon,
    metrics=NULL, 
    plots="Calibration")
  plt_calibrate <- riskRegression::plotCalibration(
    score,
    cens.method="local", 
    round=TRUE,
    rug=TRUE,
    plot=FALSE)
  plt_calibrate$plotFrames$p_fail %>% 
    dplyr::mutate(landmark=prediction_landmark)
}
JanvandenBrand/highdimjm documentation built on Dec. 18, 2021, 12:32 a.m.