R/utils_plot_survHE.R

Defines functions make_data_surv

#' Plot the survival curves using \code{ggplot2}
#' 
#' @param exArgs list of extra options passed to \code{plot.survHE}. These 
#' include whether the KM curve should be added \code{add.km} and whether
#' the user specifies a profile of covariates (in the list \code{newdata}).
#' Other possibilities are additional (mainly graphical) options. 
#' These are: \code{xlab} = a string with the label for the
#' x-axis (default = "time") \code{ylab} = a string with the label for the
#' y-axis (default = "Survival") \code{lab.profile} = a (vector of) string(s)
#' indicating the labels associated with the strata defining the different
#' survival curves to plot. Default to the value used by the Kaplan Meier
#' estimate given in \code{fit.models}. \code{newdata} = a list (of lists) 
#' providing the values for the relevant covariates If NULL, then will use 
#' the mean values for the covariates if at least one is a continuous variable, 
#' or the combination of the categorical covariates. \code{xlim} = a vector 
#' determining the limits for the x-axis \code{colors} = a vector of characters 
#' defining the colours in which to plot the different survival curves 
#' \code{lab.profile} = a vector of characters defining the names of the models fitted 
#' \code{add.km} = TRUE (whether to also add the Kaplan Meier estimates of the data) 
#' \code{annotate} = FALSE (whether to also add text to highlight the observed vs
#' extrapolated data)
#' \code{legend.position} = a vector of proportions to place the legend. Default
#' to 'c(.75,.9)', which means 75% across the x-axis and 90% across the y-axis
#' \code{legend.title} = suitable instructions to format the title of the legend;
#' defaults to 'element_text(size=15,face="bold")' but there may be other 
#' arguments that can be added (using 'ggplot' facilities)
#' \code{legend.text} = suitable instructions to format the text of the legend;
#' defaults to 'element_text(colour="black", size=14, face="plain")' but there 
#' may be other arguments that can be added (using 'ggplot' facilities)
#' @note Something will go here
#' @author Gianluca Baio
#' @seealso Something will go here
#' @references Something will go here
#' @keywords Parametric survival models
#' @import ggplot2
#' @examples
#' #' 
#' data(bc)
#' 
#' mle = fit.models(formula=Surv(recyrs,censrec)~group,data=bc,
#'     distr="exp",method="mle")
#' plot(mle)
#' @noRd 
plot_ggplot_expertsurv <- function (exArgs, scale_expert_plot = 0.4) {
  w <- which(unlist(lapply(1:length(exArgs), function(i) class(exArgs[[i]]))) == "expertsurv")
  if (length(w) == 0) {
    stop("Please give at least one 'survHE' object, generated by a call to 'fit.models(...)")
  } else {
    survHE_objs = lapply(1:length(w), function(i) exArgs[[w[i]]])
  }
  names(survHE_objs) = names(exArgs)[w]

  t <- if (!exists("t", exArgs)) sort(unique(survHE_objs[[1]]$misc$km$time)) else exArgs$t
  newdata <- if (!exists("newdata", exArgs)) NULL else exArgs$newdata
  nsim <- if (!exists("nsim", exArgs)) 1 else exArgs$nsim
  mods <- if (!exists("mods", exArgs)) 1:sum(unlist(lapply(survHE_objs, function(x) length(x$models)))) else exArgs$mods
  add.km <- if (!exists("add.km", exArgs)) FALSE else exArgs$add.km
  annotate <- if (exists("annotate", where = exArgs)) exArgs$annotate else FALSE

  obj <- mod <- NULL
  all_models <- tibble(obj = unlist(lapply(1:length(survHE_objs), function(x) rep(names(survHE_objs)[x], length(survHE_objs[[x]]$models)))),
                       mod = unlist(lapply(survHE_objs, function(x) 1:length(x$models)))) %>%
               slice(mods) %>% arrange(obj)

  sel_mods <- unique(match(all_models$obj, names(survHE_objs)))
  toplot <- lapply(sel_mods, function(i) {
    make_data_surv(survHE_objs[[i]], mods = all_models %>% filter(obj == names(survHE_objs)[i]) %>% pull(mod),
                   nsim = nsim, t = t, newdata = newdata, add.km = add.km)[[1]] %>%
      mutate(object_name = as.factor(names(survHE_objs)[i]))
  }) %>% bind_rows() %>% group_by(object_name, model_name) %>% mutate(mods_id = cur_group_id()) %>% ungroup()

  if (add.km) {
    datakm <- lapply(1:length(survHE_objs), function(i) {
      make_data_surv(survHE_objs[[i]], mods = 1, nsim = 1, t = t, newdata = newdata, add.km = add.km)[[2]] %>%
        mutate(object_name = as.factor(names(survHE_objs)[i]))
    }) %>% bind_rows() %>% group_by(object_name, model_name) %>% mutate(mods_id = cur_group_id()) %>% ungroup()
  } else {
    datakm <- NULL
  }
  
  
  if(exists("plot_ci", exArgs)) {
    plot_ci <- exArgs$plot_ci
  }else{
    plot_ci <- FALSE
  }
  
  if(exists("ci_plot_ribbon", exArgs)) {
    ci_plot_ribbon <- exArgs$ci_plot_ribbon
  }else{
    ci_plot_ribbon <- FALSE
  }

  surv.curv <- make_surv_curve_plot(toplot, datakm, mods,plot_ci = plot_ci,ci_plot_ribbon = ci_plot_ribbon)

  if (exists("lab.profile", exArgs)) {
    surv.curv <- surv.curv + scale_linetype_manual(labels = exArgs$lab.profile, values = 1:length(exArgs$lab.profile))
  }
  if (exists("colour", exArgs) & exists("lab.model", exArgs)) {
    surv.curv <- surv.curv + scale_color_manual(labels = exArgs$lab.model, values = exArgs$colour)
  } else if (exists("colour", exArgs) & !exists("lab.model", exArgs)) {
    surv.curv <- surv.curv + scale_color_manual(values = exArgs$colour)
  } else if (exists("lab.model", exArgs) & !exists("colour", exArgs)) {
    surv.curv <- surv.curv + scale_color_manual(values = 1:length(exArgs$lab.model), labels = exArgs$lab.model)
  }

  if (exists("xlab", where = exArgs)) {
    surv.curv <- surv.curv + labs(x = exArgs$xlab)
  }
  if (exists("ylab", where = exArgs)) {
    surv.curv <- surv.curv + labs(y = exArgs$ylab)
  }
  if (exists("main", where = exArgs)) {
    surv.curv <- surv.curv + labs(title = exArgs$main) + theme(plot.title = element_text(size = 18, face = "bold"))
  }

  if (annotate) {
    cutoff <- max(survHE_objs[[1]]$misc$km$time)
    surv.curv <- surv.curv +
      geom_segment(aes(x = cutoff, y = -Inf, xend = cutoff, yend = -0.01), size = 0.9) +
      geom_segment(aes(x = cutoff, y = -0.01, xend = cutoff * 0.85, yend = -0.01), arrow = arrow(length = unit(0.25, "cm"), type = "closed"), size = 1.1) +
      geom_segment(aes(x = cutoff, y = -0.01, xend = cutoff * 1.15, yend = -0.01), arrow = arrow(length = unit(0.25, "cm"), type = "closed"), size = 1.1) +
      annotate(geom = "text", x = cutoff, y = -Inf, hjust = 1.1, vjust = -1, label = "Observed data", size = 5) +
      annotate(geom = "text", x = cutoff, y = -Inf, hjust = -0.1, vjust = -1, label = "Extrapolation", size = 5) +
      ylim(-0.01, 1) +
      geom_rect(data = data.frame(xmin = -Inf, xmax = cutoff, ymin = -Inf, ymax = Inf), aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax), fill = "grey", alpha = 0.1)
  } else {
    surv.curv <- surv.curv + ylim(0, 1)
  }

  if (exists("legend.position", exArgs)) {
    surv.curv <- surv.curv + theme(legend.position = exArgs$legend.position)
  }
  if (exists("legend.title", exArgs)) {
    surv.curv <- surv.curv + theme(legend.title = exArgs$legend.title)
  }
  if (exists("legend.text", exArgs)) {
    surv.curv <- surv.curv + theme(legend.text = exArgs$legend.text)
  }

  input_args <- exArgs$Object1$misc$input_args
  
  if (exists("plot_opinion", where = exArgs)) {
    plot_opinion <- exArgs$plot_opinion
  }else{
	plot_opinion <- FALSE
  }
  
  if (input_args$opinion_type == "survival"& plot_opinion) {
    list_expert_plot <- list()
    list_expert_data <- list()
    df_data_bind <- NULL

    if (all(sapply(input_args$param_expert, nrow) == 1)) {
      for (i in 1:length(input_args$param_expert)) {
        dens_df_temp <- input_args$param_expert[[i]]
        x_seq <- seq(0.01, 0.98, by = 0.01)
        dens_temp <- get_density(dist = dens_df_temp$dist, dens_df_temp$param1, dens_df_temp$param2, dens_df_temp$param3, x = x_seq, St_indic = 1)
        df_temp <- data.frame(x = x_seq, fx = dens_temp, expert = input_args$pool_type, ftype = input_args$pool_type, times_expert = input_args$times_expert[i])
        df_data_bind <- rbind(df_data_bind, df_temp)
      }
    } else {
      for (i in 1:length(input_args$param_expert)) {
        list_expert_plot[[i]] <- plot_expert_opinion(input_args$param_expert[[i]])
        list_expert_data[[i]] <- list_expert_plot[[i]]$data
        df_data_bind <- rbind(df_data_bind, cbind(list_expert_data[[i]], times_expert = input_args$times_expert[i]))
      }
    }


	if (exists("scale_expert_plot", exArgs)) {
    scale_expert_plot<- exArgs$scale_expert_plot
	}else{
	scale_expert_plot <- 0.4
	}

    xlim_vec <- ggplot_build(surv.curv)$layout$panel_params[[1]]$x.range
    max_time_x <- max(xlim_vec)
    min_diff_x <- min(diff(c(input_args$times_expert, max_time_x)))
    df_data_final <- dplyr::filter(df_data_bind, ftype == input_args$pool_type)
    df_scaling <- df_data_final %>% group_by(times_expert) %>% summarize(scaling_fac = min_diff_x * scale_expert_plot / max(fx))
    df_data_final2 <- df_data_final %>% left_join(df_scaling) %>% mutate(fx_final = fx * scaling_fac + times_expert)

    surv.curv <- surv.curv +
      geom_ribbon(data = df_data_final2, aes(x = x, y = x, xmin = fx_final, xmax = times_expert, group = times_expert), fill = "sky blue", alpha = 0.5, colour = "grey")
  }

  surv.curv
}

#' Make the dataset to be used by \code{ggplot2} to plot the survival curves
#' 
#' @param x The 'survHE' object
#' @param mods The models to be considered
#' @param nsim The number of simulations to generate
#' @param t The vector of times
#' @param newdata The list of "new" covariares proffiles
#' @param add.km Should the KM estimate be plotted too?
#' @return \item{surv.curv}{The \code{ggplot2} object with the graph}
#' @note Something will go here
#' @author Gianluca Baio
#' @keywords Parametric survival models
#' @noRd 
make_data_surv <- function(x,mods=1:length(x$models),nsim=1,t=NULL,newdata=NULL,add.km=FALSE) {
  if(is.null(t)) {
    t <- sort(unique(x$misc$km$time))
  }
  #s=lapply(1:length(x$models),function(i) {
  s=lapply(mods,function(i) {
    make.surv(x,mod=i,t=t,nsim=nsim,newdata=newdata)
  })
  strata=lapply(1:length(s),function(i) {
    lapply(1:nrow(s[[i]]$des.mat),function(x){
      s[[i]]$des.mat %>% as_tibble() %>% select(-matches("(Intercept)",everything())) %>% slice(x) %>% 
        round(digits=2) %>% mutate(strata=paste0(names(.),"=",.,collapse=","))
    }) %>% bind_rows(.) %>% select(strata)
  })

  # toplot=lapply(1:length(mods),function(i) {
  #   lapply(1:length(s[[mods[i]]]$S),function(j) {
  #     s[[mods[i]]]$S[[j]] %>% bind_cols(strata=as.factor(strata[[mods[i]]][j,]),model_name=as.factor(names(x$models)[mods[i]]))
  #   })
  # }) %>% bind_rows(.)
  # out=list(toplot)
  toplot=lapply(1:length(mods),function(i) {
    lapply(1:length(s[[i]]$S),function(j) {
      s[[i]]$S[[j]] %>% bind_cols(strata=as.factor(as.character(strata[[i]][j,])),model_name=as.factor(names(x$models)[mods[i]]))
    })
  }) %>% bind_rows(.)
  out=list(toplot)
  
  # Add the data for the KM curve?
  if(add.km==TRUE) {
    # If the number of strata in the KM computed in 'fit.models' is not the same as the 
    # number of rows in the design matrix from 'make.surv', then re-do a KM with no covariates
    if(length(x$misc$km$strata)!=nrow(s[[1]]$des.mat)){
      x$misc$km=rms::npsurv(update(x$misc$formula,~1),data=x$misc$data)
      x$misc$km$call$formula=as.formula(deparse(update(x$misc$formula,~1)))
    }
    # Now uses info in the KM table in the survHE object to create a dataset to plot
    datakm=bind_cols(t=x$misc$km$time,n.risk=x$misc$km$n.risk,n.event=x$misc$km$n.event,
                     n.censor=x$misc$km$n.censor,S=x$misc$km$surv,lower=x$misc$km$lower,
                     upper=x$misc$km$upper) %>% mutate(model_name="Kaplan Meier")
    # If 'strata' is not in the KM object (will happen if there's only 1)
    if(is.null(x$misc$km$strata)) {
      datakm$strata=as.factor("all")
    } else {
      datakm$strata=as.factor(rep(1:length(x$misc$km$strata),x$misc$km$strata))
    }
    out$datakm=datakm
  }
  # Returns the output as a list with the dataset(s) to plot
  return(out)
}


#' Make the actual \code{ggplot2} plot with the survival curves
#' 
#' @param toplot The dataset with the relevant data
#' @param dataKM The dataset with the (optional) data for the KM estimate
#' @param mods The models to be plotted (a vector of numbers)
#' @param plot_ci Plot the statistical uncertainty? If FALSE and n_sim > 1, it will plot the average of the survival estimates, rather than the survival at the estimated parameter values. For MLE approach this should be very similar (due to asymptotic normality), however, for the Bayesian approach the survival at the posterior mean parameters may be different to the average survival from the posterior (although ususually not the case). 
#' @param ci_plot_ribbon Plot the statistical uncertainty as a ribbon (TRUE) or dashed line (FALSE)
#' @return \item{out}{A list with the dataset to be plotted including the survival curves}
#' @import ggplot2
#' @note Something will go here
#' @author Gianluca Baio
#' @keywords Parametric survival models
#' @noRd 
make_surv_curve_plot <- function(toplot,datakm=NULL,mods, plot_ci = TRUE,ci_plot_ribbon = FALSE) {
  # Does the model have covariates?
  if (all(toplot$strata=="=")) {
    # In this case not (intercept only), so remove the linetype as not needed
    linetype=NULL
  } else {
    # If it does have covariates then use 'strata' to plot a curve per profile
    linetype=toplot$strata
  }
  surv.curv=ggplot() 
  # Am I plotting a single 'survHE' object?
  if(length(levels(toplot$object_name))==1) {
    surv.curv=surv.curv+
      geom_line(data=toplot,aes(x=t,y=S,group=model_name:strata,col=model_name,linetype=linetype),size=.9) 
  } else {
    surv.curv=surv.curv+
      geom_line(data=toplot,aes(x=t,y=S,group=model_name:strata:object_name,col=object_name:model_name,linetype=linetype),size=.9)   
  }
  surv.curv=surv.curv +
    theme_bw() + 
    theme(axis.text.x = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
          axis.text.y = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
          axis.title.x = element_text(color="black",size=14,angle=0,hjust=.5,vjust=.5),
          axis.title.y = element_text(color="black",size=14,angle=90,hjust=.5,vjust=.5)) +
    theme(axis.line = element_line(colour = "black"),
          #panel.grid.major = element_blank(),
          #panel.grid.minor = element_blank(),
          #panel.border = element_blank(),
          panel.background = element_blank(),
          panel.border = element_blank(),
          plot.title = element_text(size=18, face="bold")) +
    theme(legend.position=c(.75,.78),
          legend.title=element_text(size=15,face="bold"),
          #legend.title = element_blank(),
          legend.text = element_text(colour="black", size=14, face="plain"),
          legend.background=element_blank()) +
    labs(y="Survival",x="Time",title=NULL,
         color=ifelse(length(mods)==1,"Model","Models"),
         linetype="Profile") + 
    # This ensures that the model legend is always before the profile legend
    guides(color=guide_legend(order=1),
           linetype=guide_legend(order=2))
  

  # If uses more than 1 simulation from distribution of survival curves, then add ribbon
  if(any(grepl("low",names(toplot))) & plot_ci) {
    
    if(ci_plot_ribbon){
      surv.curv=surv.curv+geom_ribbon(data=toplot,aes(x=t,y=S,ymin=low,ymax=upp,group=model_name:strata),alpha=.2)
    }else{
  
      if(length(levels(toplot$object_name))==1) {
        surv.curv=surv.curv+ 
          geom_line(data=toplot,aes(x=t,y=low,group=model_name:strata,col=model_name),size=.9, linetype = "dashed")+ 
          geom_line(data=toplot,aes(x=t,y=upp,group=model_name:strata,col=model_name),size=.9, linetype = "dashed")
      } else {
        surv.curv=surv.curv+
          geom_line(data=toplot,aes(x=t,y=low,group=model_name:strata:object_name,col=object_name:model_name),size=.9,linetype="dashed")+
          geom_line(data=toplot,aes(x=t,y=upp,group=model_name:strata:object_name,col=object_name:model_name),size=.9,linetype="dashed")  
      }
      
      
    }
  }
  
  # Add KM plot? 
  if(!is.null(datakm)) {
    surv.curv=surv.curv+geom_step(data=datakm,aes(t,S,group=as.factor(strata)),color="darkgrey") + 
      geom_ribbon(data=datakm,aes(x=t,y=S,ymin=lower,ymax=upper,group=as.factor(strata)),alpha=.2) 
  }
  surv.curv
}

Try the expertsurv package in your browser

Any scripts or data that you put into this service are public.

expertsurv documentation built on April 3, 2025, 10:37 p.m.