R/bxmodel.R

Defines functions make_dich

# Doc header --------------------------------------------------------------

# author: "Jan van den Brand, PhD"
# email: jan.vandenbrand@kuleuven.be
# project: NSN19OK003
# funding: Dutch Kidney Foundation
# Topic: Data exploration

# Preliminaries  ------------------------------------------------------

reqlib <- c("foreign", "ggplot2", "lattice", "gridExtra", "tidyverse", "caret",
            "MASS",  "haven", "lubridate",
            "mice", "miceadds", "tableone", "cmprsk",  
            "corrplot", "timeROC", "ggthemes", "survminer", "ggrepel")
lapply(reqlib, library, character.only=TRUE)
rm(reqlib)
devtools::install_github("JanvandenBrand/jmmm")
library(jmmm)
# Set handlers for progress bars
progressr::handlers(global=TRUE)
progressr::handlers("progress")
# set seed for reproducibility
set.seed(20201013)

source("R/edit.R")
source("R/dataviz.R")

# Global variables

horizon <- 10
lm <- seq(from=1, to=horizon, by=0.25)

# cleaning ----------------------------------------------------------------

# Select biopsy data only
d_bx <- d_long %>% filter(!is.na(biopsy_id)) 
# Make numeric out of survival time (otherwise imo gives issues)
d_bx <- d_bx %>% mutate(stime=as.numeric(stime))
# make dummies out of primary_kd --> this results in a very big stacked data set with mostly zeros...
pkd <- stats::model.matrix(~d_bx$primary_kd + 0)
colnames(pkd) <- levels(d_bx$primary_kd)
d_bx <- cbind(d_bx, pkd)

# Missing data analysis
write_csv2(as.data.frame(md.pattern(d_bx)), file="output/mdpattern_dbx.csv")
# impute missing proteinuria
init_mice <- mice(d_bx, maxit=0, predictorMatrix=quickpred(d_bx, mincor=0.2), print=FALSE)
predmat <- init_mice$predictorMatrix
# Exclude predictors
predmat[,c("transnr", "eadnr", "txdate", "date")] <- 0 
# Set 2-level predictive mean matching as the method for longitudinal variables
imp_method <- init_mice$method
imp_method[names(imp_method) %in% c("gfr", "nf_protu", "nf_procr")] <- "2l.pmm"
# set cluster variable for 2 level imputations
predmat[, "transnr"] <- -2
imp <- mice(d_bx, method=imp_method, predictorMatrix=predmat, maxit=10, m=10)
# imputation diagnostics
pdf(file="plots/bx_model_imputation_dx.pdf")
plot(imp)
stripplot(imp)
dev.off()

# Use single imputation to impute missing values for proteinuria
d_imp <- complete(imp, action=1, include=FALSE) %>% 
  dplyr::select(transnr, date, nf_procr)
d_bx <- d_bx %>% mutate(
  nf_procr = coalesce(nf_procr, d_imp$nf_procr)
)

# Select only outcomes that have prevalance >0.1
clin_out <- c("gfr", "nf_procr")
banff_acute <- c("t", "i", "v", "g", "ptcitis", "c4d_ptc") 
banff_chron <- c("ci", "ct", "cv", "cg")
outcomes <- c(clin_out, banff_acute, banff_chron)
covars <- c("time", "stime", "repeat_tx", 
            "donor_age", "donor_sex_m1", 
            "rec_sex_m1", "rec_bmi_d0", "cit", 
            "abdr_antigen_mismatches", # "pretx_hla_abs", "overall_pretx_dsa", => collinar with hla mm
            "txyear", 
            "congenital", "cystic", "diabetes", "gn", "other",
            "systemic", "tin") # vascular is the reference
# Drop all observations without full longitudinal outcome data.
cc <- d_bx %>% dplyr::select(all_of(outcomes)) %>% complete.cases()
d_bx_cc <- d_bx[cc,] 

# Make numeric out of binary factors
d_bx_cc <- d_bx_cc %>% mutate(transnr=as.character(transnr),
                              repeat_tx=as.numeric(repeat_tx)-1,
                              donor_sex_m1=as.numeric(donor_sex_m1)-1,
                              rec_sex_m1=as.numeric(rec_sex_m1)-1,
                              pretx_hla_abs=as.numeric(pretx_hla_abs)-1,
                              overall_pretx_dsa=as.numeric(overall_pretx_dsa)-1,
                              induction=as.numeric(induction)-1)
d_bx_cc <- d_bx_cc %>% dplyr::select("transnr", "event", all_of(outcomes), all_of(covars))
d_bx_cc <- d_bx_cc %>% mutate(event=factor(event,
                                             labels=c("censored", "graft failure", "death")))

# Data exploration  ---- 

## Overall survival (select only the first observation for plotting the KM curve)
cuminc_data <- d_bx_cc %>% 
  group_by(transnr) %>%
  summarize(across(.cols=c(stime, event), .fns=first)) %>%
  ungroup() 
cuminc_fit <- with(cuminc_data, cuminc(stime/365.25, event))
ggcompetingrisks(cuminc_fit, 
           multilpe_panels=FALSE,
           conf.int=TRUE,
           ylim=c(0,1),
           break.time.by=1
           ) 

pairs <- make_pairs(outcomes=outcomes)
# Spaghettiplots
plotlist <- lapply(outcomes, make_spaghetti_plots, data=d_bx_cc,
                   ftime="time", id="transnr", breaks=12, by="event")
do.call("grid.arrange", c(plotlist, 
                          ncol=3, 
                          bottom="Red: Censored, Green: Graft failure, Blue: Death")
)
# Split the data by outcome for correlation plots
df_fail <- d_bx_cc %>% 
  filter(event == "graft failure" & stime/365.25 < horizon)
df_nofail <- d_bx_cc %>% 
  filter(stime/365.25 >= horizon) 
df_death <- d_bx_cc %>% 
  filter(event == "death" & stime/365.25 < horizon)
# Plots
d_factors <- d_bx_cc %>% dplyr::select(all_of(c(outcomes[3:length(outcomes)], "donor_sex_m1",
                            "rec_sex_m1")))
d_factors <- d_factors %>% mutate(across(all_of(colnames(d_factors)), as.factor))
d_factors$event <- d_bx_cc$event
d_covars <- d_bx_cc %>% dplyr::select(all_of(c(covars, "event")))

# bar plots
lapply(colnames(d_factors), make_bar_plot, data=d_factors, by="event")
# density plots
lapply(colnames(d_covars), make_density, data=d_covars, by="event")

# Correlations
cor_nofail <- df_nofail %>%
  dplyr::select(all_of(outcomes), all_of(covars)) %>%
  cor()
cor_fail <- df_fail %>%
  dplyr::select(all_of(outcomes), all_of(covars)) %>%
  cor()
cor_death <- df_death %>% 
  dplyr::select(all_of(outcomes), all_of(covars)) %>% 
  cor()
png("plots/correlation_nofail.png", width=16, height=12, units="cm", res=300)
corrplot(cor_nofail,
        tl.pos="td",
        tl.cex=0.5, 
        tl.col="black",
        type="upper",
        method="ellipse",
        diag=TRUE)
# title(main="Correlations in non-failures",
#       cex.main=1)
dev.off()
png("plots/correlation_fail.png", width=16, height=12, units="cm", res=300)
corrplot(cor_fail,
         tl.pos="td",
         tl.cex=0.5, 
         tl.col="black",
         type="upper",
         method="ellipse",
         diag=TRUE)
# title(main="Correlations in failures",
#       cex.main=1)
dev.off()
png("plots/correlation_death.png", width=16, height=12, units="cm", res=300)
corrplot(cor_death,
         tl.pos="td",
         tl.cex=0.5, 
         tl.col="black",
         type="upper",
         method="ellipse",
         diag=TRUE)
# title(main="Correlations in deaths",
#       cex.main=1)
dev.off()

# Data transformation ----
# Make dichotomous variables out of the biopsy parameters
make_dich <- function(x) (ifelse(x > 0, 1, 0))
d_bx_cc <- d_bx_cc %>% 
  mutate_at(outcomes[3:length(outcomes)], make_dich)
# Store the means and sds of continuous covariates for post processing
d_bx_cc %>% 
  summarize(across(all_of(outcomes), list(mean=mean, sd=sd))) %>% 
  pivot_longer(cols=everything(),
               names_to=c("outcome" , "stat"),
               names_pattern="^(.*)_(.*)",
               values_to="value") %>%
  write_csv(file="output/summary_stats.csv")
# scale covariates
d_bx_cc <- d_bx_cc %>%
  mutate(stime=stime/365.25,
         time=time/12,
         txyear=as.numeric(scale(txyear - 2008)),
         abdr_antigen_mismatches=abdr_antigen_mismatches - 3,
         gfr=as.numeric(scale(gfr)),
         cit=as.numeric(scale(cit)),
         rec_bmi_d0=as.numeric(scale(rec_bmi_d0)),
         donor_age=as.numeric(scale(donor_age)),
         nf_procr=as.numeric(scale(log(nf_procr + 1)))
  )
# Restrict analyses to include stable transplant >0.5 and <5.5 yrs
ggplot(data=d_bx_cc, aes(x=time)) + 
  geom_histogram(binwidth=0.12) + 
  scale_x_continuous(
                     breaks=seq(0,15,1),
                     element_text(size=10)) +
  scale_y_continuous() + 
  theme_minimal() +
  xlab("Follow-up time")
d_bx_cc <- d_bx_cc %>%
  filter(time > 0.5 & time <= 5.5) 

# Initalize the models ----------------------------------------------------

# Select the outcomes to be included in the models
# v score is too uncommon overall
# ptcitis is questionable
# c4d is questionable
# ct=0 is too uncommon overall
# cg is too uncommon overall.
# ci <> ct are higly correlated, 
# i <> t are highly correlated
# hla abs and dsa are highly correlated
outcomes <- c("gfr", "nf_procr", "ci", "cv", "i")
pairs <- make_pairs(outcomes=outcomes)
d_model <- d_bx_cc %>% dplyr::select("transnr", "event", all_of(outcomes), all_of(covars))
model_info <- test_input_datatypes(data=d_model, pairs=pairs)

# Filter the data to select events prior to the horizon (5 years)
# Split the data by outcome and create train and testing set
df_fail <- d_model %>% 
  filter(event == "graft failure" & stime <= horizon)
idx <- sample(seq_len(nrow(df_fail)), size=floor(0.8 * nrow(df_fail)))
df_fail_train <- df_fail[idx, ]
df_fail_test <- df_fail[-idx, ]
# Non failures should not be 'at risk' for failure when censored.
# Therefore only select censored subjects who survived the entire period to the horizon
df_nofail <- d_model %>% 
  filter(stime >= horizon) %>%
  mutate(stime = horizon)
idx <- sample(seq_len(nrow(df_nofail)), size=floor(0.8 * nrow(df_nofail)))
df_nofail_train <- df_nofail[idx, ]
df_nofail_test <- df_nofail[-idx, ]
df_death <- d_model %>% 
  filter(event == "death" & stime <= horizon)
idx <- sample(seq_len(nrow(df_death)), size=floor(0.8 * nrow(df_death)))
df_death_train <- df_death[idx, ]
df_death_test <- df_death[-idx, ]

# Describe the training and test datasets
tableone_vars <- c(outcomes, covars, "event")
tableone_factors <- c(outcomes[-2:-1], "congenital", "cystic", "diabetes", "gn", "other", "systemic", "tin",
                      "repeat_tx", "donor_sex_m1", "rec_sex_m1", "abdr_antigen_mismatches", "pretx_hla_abs",
                      "overall_pretx_dsa")
tableone_nonnorm <- c("gfr", "nf_procr", "stime", "donor_age", "rec_bmi_d0", "cit")
tableone_fail_train <- CreateTableOne(data=df_fail_train, 
                                      vars=tableone_vars, 
                                      factorVars=tableone_factors)
tableone_fail_train <- data.frame(
  variable = rownames(print(tableone_fail_train, nonnormal=tableone_nonnorm)),
  value = print(tableone_fail_train, nonnormal=tableone_nonnorm))
tableone_fail_test <- CreateTableOne(data=df_fail_test, 
                                     vars=tableone_vars, 
                                     factorVars=tableone_factors)
tableone_fail_test <- data.frame(
  variable = rownames(print(tableone_fail_test, nonnormal=tableone_nonnorm)),
  value = print(tableone_fail_test, nonnormal=tableone_nonnorm))
tableone_nofail_train <- CreateTableOne(data=df_nofail_train, 
                                      vars=tableone_vars, 
                                      factorVars=tableone_factors)
tableone_nofail_train <- data.frame(
  variable = rownames(print(tableone_nofail_train, nonnormal=tableone_nonnorm)),
  value = print(tableone_nofail_train, nonnormal=tableone_nonnorm))
tableone_nofail_test <- CreateTableOne(data=df_nofail_test, 
                                     vars=tableone_vars, 
                                     factorVars=tableone_factors)
tableone_nofail_test <- data.frame(
  variable = rownames(print(tableone_nofail_test, nonnormal=tableone_nonnorm)),
  value = print(tableone_nofail_test, nonnormal=tableone_nonnorm))
tableone_death_train <- CreateTableOne(data=df_death_train, 
                                      vars=tableone_vars, 
                                      factorVars=tableone_factors)
tableone_death_train <- data.frame(
  variable = rownames(print(tableone_death_train, nonnormal=tableone_nonnorm)),
  value = print(tableone_death_train, nonnormal=tableone_nonnorm))
tableone_death_test <- CreateTableOne(data=df_death_test, 
                                     vars=tableone_vars, 
                                     factorVars=tableone_factors)
tableone_death_test <- data.frame(
  variable = rownames(print(tableone_death_test, nonnormal=tableone_nonnorm)),
  value = print(tableone_death_test, nonnormal=tableone_nonnorm))
tableone_fail <- tableone_fail_train %>% left_join(tableone_fail_test, by="variable")
tableone_nofail <- tableone_nofail_train %>% left_join(tableone_nofail_test, by="variable")
tableone_death <- tableone_death_train %>% left_join(tableone_death_test, by="variable")
write.csv(tableone_fail, file="output/tableone_fail.csv")
write.csv(tableone_nofail, file="output/tableone_nofail.csv")
write.csv(tableone_death, file="output/tableone_death.csv")

# Clean-up
# rm(list=setdiff(ls(), c("df_fail", "df_fail_train", "df_fail_test", 
#                         "df_nofail", "df_nofail_train", "df_nofail_test",
#                         "df_death", "df_death_train", "df_death_test",
#                         "df_fail_stacked", "df_nofail_stacked", "df_death_stacked",
#                         "d_model", "model_info", "pairs", 
#                         "outcomes", "covars", "clin_out", "banff_acute", "banff_chron",
#                         "d_gf_test", "d_gf_train", "model_fail", "model_nofail",
#                         "re_samples_fail", "re_samples_nofail", "fixed_covars",
#                         "predictions_train",
#                         "fixed_formula_fail", "fixed_formula_nofail", "horizon")))
# source("R/mmm_functions.R")
# gc()

# Model fitting ----
# make the model formulas -- exclude primary diagnosis as some are rare (<5%)
fixed_covars <- c( "repeat_tx", 
                   "donor_age", "donor_sex_m1", 
                   "rec_sex_m1", "rec_bmi_d0", "cit", 
                   "abdr_antigen_mismatches", "pretx_hla_abs", 
                   "overall_pretx_dsa", "txyear")
fixed_fail <- make_fixed_formula(covars=c("time", "stime",fixed_covars))
random_fail <- make_random_formula(id="transnr", covars=c("time"))
df_fail_stacked <- stack_data(data=df_fail_train, 
                              id="transnr",
                              pairs=pairs, 
                              covars=c("time", "stime", fixed_covars))
test_compare_stacked_to_original_data(data=df_fail_train, stacked_data=df_fail_stacked, pairs=pairs)
model_fail <- mmm_model(fixed=fixed_fail,
                        random=random_fail,
                        id="transnr",
                        data=df_fail_train,
                        stacked_data=df_fail_stacked,
                        pairs=pairs,
                        model_families=model_info,
                        iter_EM=300,
                        iter_qN_outer=300,
                        nAGQ=7, 
                        parallel_plan="multisession",
                        ncores=2, 
                        penalized=TRUE)

fixed_nofail <- make_fixed_formula(covars=c("time", fixed_covars))
random_nofail <- make_random_formula(id="transnr",covars=c("time"))
df_nofail_stacked <- stack_data(data=df_nofail_train,
                                id="transnr",
                                pairs=pairs, 
                                covars=c("time", fixed_covars))
test_compare_stacked_to_original_data(data=df_nofail_train, stacked_data=df_nofail_stacked, pairs=pairs)
model_nofail <- mmm_model(fixed=fixed_nofail,
                          random=random_nofail,
                          id="transnr",
                          data=df_nofail_train,
                          stacked_data=df_nofail_stacked,
                          pairs=pairs,
                          model_families=model_info,
                          iter_EM=300,
                          iter_qN_outer=300,
                          nAGQ=7,
                          parallel_plan="multisession",
                          ncores=3,
                          tol3=1e-7, 
                          penalized=TRUE)

fixed_death <- make_fixed_formula(covars=c("time", fixed_covars))
random_death <- make_random_formula(id="transnr",covars=c("time"))
# pairs_death <- make_pairs(outcomes[-7])
df_death_stacked <- stack_data(data=df_death_train,
                               id="transnr",
                               pairs=pairs,
                               covars=c("time", fixed_covars))
test_compare_stacked_to_original_data(data=df_death_train, stacked_data=df_death_stacked, pairs=pairs)
model_death <- mmm_model(fixed=fixed_death,
                         random=random_death,
                         id="transnr",
                         data=df_death_train,
                         stacked_data=df_death_stacked,
                         pairs=pairs,
                         model_families=model_info,
                         iter_EM=300,
                         iter_qN_outer=100,
                         nAGQ=7,
                         parallel_plan="multisession",
                         ncores=3,
                         penalized=TRUE)
# Report the models
write_csv(model_nofail$estimates, file="output/model_nofail.csv")
write_csv(model_fail$estimates, file="output/model_fail.csv")
# write_csv(model_death$estimates, file="output/model_death.csv")


# Model reporting ---------------------------------------------------------

estimates <- model_nofail$estimates %>% 
  mutate(model = "Survivors") %>% 
  bind_rows(
    model_fail$estimates %>%
    mutate(model = "Failures")
  )
estimates <- estimates %>% 
  mutate(
    p = case_when(
      model == "Survivors" ~  2 * (1 - pt(abs(parameter_estimate/parameter_std_err), df=176)),
      model == "Failures" ~  2 * (1 - pt(abs(parameter_estimate/parameter_std_err), df=32))
  )
)
estimates <- estimates %>% 
  mutate(
    p = ifelse(p == 0, .Machine$double.eps, p)
  )
estimates <- estimates %>% filter(
  str_starts(parameter_name, "b")
)
# construct volcano plot
volcano_plot <- ggplot(data=estimates, 
                       aes(
                         x=parameter_estimate, 
                         y=-log10(p), 
                         color=factor(model)
                        )
                       ) +
  geom_point() +
  geom_label_repel(aes(label=ifelse(
    -log10(p)>10, 
    str_remove(parameter_name, "^b(..|...)_"), 
    "")),
    max.overlaps=30) +
  geom_hline(yintercept=-log10(0.05/165), lty=2, color="grey50") + 
  theme_classic() + 
  scale_x_continuous(limits=c(-9, 9),
                     breaks=seq(-9, 9, 3)) +
  scale_y_continuous(limits=c(0, 20),
                     breaks=seq(0, 20, 5)) +
  labs(title=paste("Volcano plot of parameter estimates in the MMM"),
       y="-log10(p)",
       x="Parameter estimate") +
  guides(color=guide_legend(title="Submodel"))
ggsave(volcano_plot, 
       filename="plots/volcano-plot-mmm.png", 
       device =png,
       width=16, 
       height=12, 
       units="cm")
# correlation matrix
png("plots/vcov_nofail.png", width=16, height=12, units="cm", res=300)
corrplot(as.matrix(model_nofail$corr),
         tl.pos="td",
         tl.cex=1, 
         tl.col="black",
         type="upper",
         method="ellipse",
         diag=TRUE)
dev.off()

png("plots/vcov_fail.png", width=16, height=12, units="cm", res=300)
corrplot(as.matrix(model_fail$corr),
         tl.pos="td",
         tl.cex=1, 
         tl.col="black",
         type="upper",
         method="ellipse",
         diag=TRUE)
dev.off()


# Predictions -------------------------------------------------------------

mu <- setNames(rep(0, length(colnames(model_nofail$vcov))), colnames(model_nofail$vcov))
re_samples_nofail <- mvrnorm(1e3,mu=mu, Sigma=model_nofail$vcov)
mu <- setNames(rep(0, length(colnames(model_fail$vcov))), colnames(model_fail$vcov))
re_samples_fail <- mvrnorm(1e3,mu=mu, Sigma=model_fail$vcov)
# re_samples_death <- mvrnorm(1e3, mu=mu, Sigma=model_death$vcov)

# Predictions on training data
d_gf_train <- df_fail_train %>% 
  bind_rows(df_nofail_train) %>% 
  mutate(
    failure=as.numeric(
      case_when(
        event == "graft failure" & stime <= horizon ~ 1,
        TRUE ~ 0
      )
    ),
    stime=pmin(stime, horizon)
  ) %>%
  filter(time <= stime) %>%
  arrange(transnr, time)
# sanity check
xtabs(~ event + failure, d_gf_train)
d_gf_train %>% dplyr::select(event, stime, time) %>% summary()

prior <- get_priors(data=d_gf_train,
                    time_failure="stime",
                    failure="failure",
                    horizon=horizon,
                    interval=1/4)
.outcomes <- get_outcome_type(data=d_gf_train,
                              outcomes=outcomes)
plan(multisession, workers=6)
fixed_formula_fail <- paste("~", paste(outcomes, collapse=" + "), "+ time + stime +", paste(fixed_covars, collapse=" + "))
fixed_formula_nofail <- paste("~", paste(outcomes, collapse=" + "), "+ time +", paste(fixed_covars, collapse=" + "))

predictions_train <- lapply(lm, function(l) {
  mmm_predictions(data=d_gf_train,
                  outcomes=.outcomes,
                  fixed_formula_nofail=fixed_formula_nofail,
                  random_formula_nofail="~ time | transnr",
                  random_effects_nofail=re_samples_nofail,
                  parameters_nofail=model_nofail$estimates,
                  fixed_formula_fail=fixed_formula_fail,
                  random_formula_fail="~ time | transnr",
                  random_effects_fail=re_samples_fail,
                  parameters_fail=model_fail$estimates,
                  time="time",
                  failure="failure",
                  failure_time="stime",
                  prior=prior,
                  id="transnr",
                  landmark=l,
                  horizon=horizon,
                  interval=1/4)
})

# Predictions on test data
d_gf_test <- df_fail_test %>% 
  bind_rows(df_nofail_test) %>% 
  mutate(
    failure=as.numeric(
      case_when(
        event == "graft failure" & stime <= horizon ~ 1,
        TRUE ~ 0
      )
    ),
    stime=pmin(stime, horizon)
  ) %>%
  filter(time <= stime) %>%
  arrange(transnr, time)
# sanity check
xtabs(~ event + failure, d_gf_test)
d_gf_test %>% dplyr::select(event, stime, time) %>% summary()
    
prior <- get_priors(data=d_gf_test,
                    time_failure="stime",
                    failure="failure",
                    horizon=horizon,
                    interval=1/4)
.outcomes <- get_outcome_type(data=d_gf_test,
                              outcomes=outcomes)

# FIX: the order of the random effects formula should be similar to the random effects sampling. Incorporate the sampling into the function call.
predictions_test <- lapply(lm, function (l) {
  mmm_predictions(data=d_gf_test, 
  outcomes=.outcomes,
  fixed_formula_nofail=fixed_formula_nofail,
  random_formula_nofail="~ time | transnr",
  random_effects_nofail=re_samples_nofail,
  parameters_nofail=model_nofail$estimates,
  fixed_formula_fail=fixed_formula_fail,
  random_formula_fail="~ time | transnr",
  random_effects_fail=re_samples_fail,
  parameters_fail=model_fail$estimates,
  time="time",
  failure="failure",
  failure_time="stime",
  prior=prior,
  id="transnr",
  landmark=l,
  horizon=horizon, 
  interval=1/4)
})

# Example dynamic predictions
predictions_test_plots <- lapply(lm, function(x) {
  plot_predictions(predictions=predictions_test[[x]],
                   outcomes=outcomes,
                   id="transnr",
                   subject="2774", 
                   time="time"
  )
})
predictions_test_plots <- marrangeGrob(predictions_test_plots,
                                       ncol=1,
                                       nrow=2,
                                       top=paste("Predicted probability of graft failure before", horizon, "years follow-up"),
                                       bottom="Follow-up time [years]")
pdf(paste0("plots/predicted-risk-of-gf-at-", horizon,".pdf"),
    width=15, 
    height=12)
predictions_test_plots
dev.off()

# Model performance in training data -----------------------------------------------

# Discrimination in training data
roc_auc_train <- lapply(seq_along(lm), function(l) {
                        get_auc(
                          predictions=predictions_train[[l]],
                          prediction_landmark=lm[l],
                          prediction_horizon=horizon,
                          id="transnr",
                          failure_time="stime",
                          failure="failure")
  })
roc_auc_train <- do.call(rbind, roc_auc_train)
write.csv(roc_auc_train, 
          file=paste0("output/rocauc-gf-train-at-", horizon, ".csv"))
# Calibration in training data
plot_calibration_train <- lapply(which(lm %in% c(1:3)), function(l) {
                                 plot_calibration(
                                   predictions=predictions_train[[l]],
                                   prediction_landmark=lm[l],
                                   prediction_horizon=horizon,
                                   id="transnr",
                                   failure_time="stime",
                                   failure="failure")
                                 })
plot_calibration_train <- do.call(rbind, plot_calibration_train)
plot_calibration_train <- ggplot(
  plot_calibration_train, 
  aes(x=Pred, y=Obs)) + 
  geom_smooth(
    aes(color=factor(landmark),
        fill=factor(landmark)),
    method="loess", 
    span=0.3,
    alpha=0.2) + 
  geom_abline(intercept=0, slope=1, alpha=0.5) +
  geom_rug(aes(x=Pred)) + 
  scale_y_continuous(breaks=seq(0, 1, 0.2)) + 
  scale_x_continuous(breaks=seq(0, 1, 0.2)) +
  coord_cartesian(xlim=c(0, 1),
                  ylim=c(0, 1)) +
  labs(title="", 
       caption=paste("Calibration plots for kidney graft failure at", horizon,"years follow-up in training data"),,
       y="Observed risks",
       x="Predicted risks",
       fill="Landmark") +
  guides(color="none") + 
  ggthemes::theme_tufte(ticks=TRUE)
  
  
pdf(paste0("plots/calibration-gf-train-at", horizon, ".pdf"),
    width=7, 
    height=5.6)
plot_calibration_train
dev.off()

# Model performance in test data ------------------------------------------

# Discrimination in test data
roc_auc_test<- lapply(seq_along(lm), function(l) {
  get_auc(
    predictions=predictions_test[[l]],
    prediction_landmark=lm[l],
    prediction_horizon=horizon,
    id="transnr",
    failure_time="stime",
    failure="failure")
})
roc_auc_test <- do.call(rbind, roc_auc_test)
write.csv(roc_auc_test, 
          file=paste0("output/rocauc-gf-test-at-", horizon, ".csv"))

# Calibration in test data
plot_calibration_test <- lapply(which(lm %in% c(1:3)), function(l) {
  plot_calibration(
    predictions=predictions_test[[l]],
    prediction_landmark=lm[l],
    prediction_horizon=horizon,
    id="transnr",
    failure_time="stime",
    failure="failure")
})
plot_calibration_test <- do.call(rbind, plot_calibration_test)
plot_calibration_test <- ggplot(
  plot_calibration_test, 
  aes(x=Pred, y=Obs)) + 
  stat_smooth(
    aes(color=factor(landmark),
        fill=factor(landmark)),
    method="loess", 
    span=0.3,
    alpha=0.2) + 
  geom_abline(intercept=0, slope=1, alpha=0.5) +
  geom_rug(aes(x=Pred)) + 
  scale_y_continuous(breaks=seq(0, 1, 0.2)) + 
  scale_x_continuous(breaks=seq(0, 1, 0.2)) +
  coord_cartesian(xlim=c(0, 1),
                  ylim=c(0, 1)) +
  labs(title="",
       caption=paste("Calibration plots for kidney graft failure at", horizon ,"years follow-up in test data"),
       y="Observed risks",
       x="Predicted risks",
       fill="Landmark") +
  guides(color="none") +
  ggthemes::theme_tufte(ticks=TRUE) 

pdf(paste0("plots/calibration-gf-test-at-", horizon, ".pdf"),
    width=7, 
    height=5.6)
plot_calibration_test
dev.off()

# Plot AUCs
roc_auc <- roc_auc_train %>% 
  mutate(label="train") %>% 
  bind_rows(
    roc_auc_test %>%
      mutate(label="test")
  )
plot_roc_auc <- ggplot(roc_auc, aes(x=landmark, y=auc, color=factor(label), fill=factor(label))) + 
  geom_ribbon(aes(ymin=auc - qnorm(0.975) * auc_se,
                  ymax=auc + qnorm(0.975) * auc_se), alpha=0.2) +
  geom_line() +
  geom_hline(aes(yintercept=0.5), linetype="dotted") +
  scale_y_continuous(breaks=seq(0, 1, 0.2)) +
  coord_cartesian(ylim=c(0, 1)) +
  theme_few() + 
  guides(color="none") +
  labs(title=paste("Discrimination performance at", horizon,"years follow-up"),
       y="Area under the ROC curve",
       x="Landmark time [years]",
       fill="Dataset") 
ggsave(plot_roc_auc, 
       filename=paste0("plots/rocauc-at-", horizon, ".png"), 
       device =png,
       width=15, 
       height=12, 
       units="cm", 
       res=300)
pdf(paste0("plots/rocauc-at-", horizon, ".pdf"),
    width=15, 
    height=12)
plot_roc_auc
dev.off()
JanvandenBrand/highdimjm documentation built on Dec. 18, 2021, 12:32 a.m.