R/utils_plot_survHE.R

Defines functions plot_ggplot_survHE

#' Plot the survival curves using \code{ggplot2}
#' 
#' @param exArgs list of extra options passed to \code{plot.survHE}. These 
#' include whether the KM curve should be added \code{add.km} and whether
#' the user specifies a profile of covariates (in the list \code{newdata}).
#' Other possibilities are additional (mainly graphical) options. 
#' These are: \code{xlab} = a string with the label for the
#' x-axis (default = "time") \code{ylab} = a string with the label for the
#' y-axis (default = "Survival") \code{lab.profile} = a (vector of) string(s)
#' indicating the labels associated with the strata defining the different
#' survival curves to plot. Default to the value used by the Kaplan Meier
#' estimate given in \code{fit.models}. \code{newdata} = a list (of lists) 
#' providing the values for the relevant covariates If NULL, then will use 
#' the mean values for the covariates if at least one is a continuous variable, 
#' or the combination of the categorical covariates. \code{xlim} = a vector 
#' determining the limits for the x-axis \code{colors} = a vector of characters 
#' defining the colours in which to plot the different survival curves 
#' \code{lab.profile} = a vector of characters defining the names of the models fitted 
#' \code{add.km} = TRUE (whether to also add the Kaplan Meier estimates of the data) 
#' \code{annotate} = FALSE (whether to also add text to highlight the observed vs
#' extrapolated data)
#' \code{legend.position} = a vector of proportions to place the legend. Default
#' to 'c(.75,.9)', which means 75% across the x-axis and 90% across the y-axis
#' \code{legend.title} = suitable instructions to format the title of the legend;
#' defaults to 'element_text(size=15,face="bold")' but there may be other 
#' arguments that can be added (using 'ggplot' facilities)
#' \code{legend.text} = suitable instructions to format the text of the legend;
#' defaults to 'element_text(colour="black", size=14, face="plain")' but there 
#' may be other arguments that can be added (using 'ggplot' facilities)
#' @note Something will go here
#' @author Gianluca Baio
#' @seealso Something will go here
#' @references Something will go here
#' @keywords Parametric survival models
#' @examples
#' 
#' data(bc)
#' 
#' mle = fit.models(formula=Surv(recyrs,censrec)~group,data=bc,
#'     distr="exp",method="mle")
#' plot(mle)
#' @noRd 
plot_ggplot_survHE <- function(exArgs) {
  # # First checks the class of the input
  # if(class(x)=="survHE") {
  #   # If x is a 'survHE' object, then there's only one object to deal with
  #   surv.curv=make_surv_curve_plot(x,mods,nsim=1,t,newdata,add.km)
  # }
  
  # Extracts the 'survHE' objects from the list 'exArgs. If there are none, then stop with an error message!
  w <- which(unlist(lapply(1:length(exArgs), function(i) class(exArgs[[i]])))=="survHE")
  if(length(w)==0){
    stop("Please give at least one 'survHE' object, generated by a call to 'fit.models(...)")
  } else {
    survHE_objs <- lapply(1:length(w),function(i) exArgs[[w[i]]])
  }
  names(survHE_objs) <- names(exArgs)[w]
  
  # Check some basic inputs - if not specified by the user (and stored in 'exArgs', then
  # sets up defaults to be used in the other functions)
  # t = time to plot on the x-axis (default = the times in the original data/survHE object)
  if (!exists("t",exArgs)) {
    t <- sort(unique(survHE_objs[[1]]$misc$km$time))
  } else {
    t <- exArgs$t}
  
  # newdata = possible list with a profile of covariates (default = NULL)
  if (!exists("newdata",exArgs)) {newdata <- NULL} else {newdata <- exArgs$newdata}
  
  # Do we want a single survival curve, or is this for the PSA?
  if (!exists("nsim",exArgs)) {nsim <- 1} else {nsim <- exArgs$nsim}
  
  # What model should be used from the object 'x'?
  if (!exists("mods",exArgs)) {
    mods=1:sum(unlist(lapply(survHE_objs,function(x) length(x$models))))
  } else {mods=exArgs$mods}
  
  # Should the KM curve be added to the plots?
  if (!exists("add.km",exArgs)) {add.km <- FALSE} else {add.km <- exArgs$add.km}
  
  # Should the graph be annotated with extrapolation vs observed data?
  if(exists("annotate",where=exArgs)){annotate=exArgs$annotate} else {annotate=FALSE}
  
  # What scale should the plot be done? ('survival' is default, but other options are 'hazard' and 
  # 'cumulative hazard')
  if(exists("what",where=exArgs)){what=exArgs$what} else {what="survival"}
  
  # Makes the dataframe with the data to plot
  # toplot = lapply(1:length(survHE_objs),function(i){
  #   make_data_surv(survHE_objs[[i]],
  #                  mods=1:length(survHE_objs[[i]]$models),
  #                  nsim=nsim,
  #                  t=t,
  #                  newdata=newdata,
  #                  add.km=add.km
  #   )[[1]] %>% mutate(object_name=as.factor(names(survHE_objs)[i]))
  # }) %>% bind_rows() %>%
  #   group_by(object_name,model_name) %>% mutate(mods_id=cur_group_id()) %>% ungroup() %>%
  #   filter(mods_id%in%mods)
  
  ##############################################################################################
  # Tries to only select the relevant models based on the choice indicated by the user
  # Makes a tibble with the *only* objects + the models selected in each of them
  # Initialises 'obj' and 'mod' to avoid binding issues
  obj <- mod <- NULL
  
  all_models <- tibble(
    obj = unlist(
      lapply(1:length(survHE_objs), function(x) {
        rep(names(survHE_objs)[x], length(survHE_objs[[x]]$models))
      })),
    mod = unlist(lapply(survHE_objs, function(x) 1:length(x$models)))
  ) %>%
    slice(mods) %>%
    arrange(obj)
  
  # Makes a vector with the index of *only* the objects selected
  sel_mods <- unique(match(all_models$obj,names(survHE_objs)))
  
  # Makes the dataset to plot, including *only* the objects and models selected
  toplot <- lapply(sel_mods, function(i) {
    make_data_surv(survHE_objs[[i]],
                   mods=all_models %>% filter(obj==names(survHE_objs)[i]) %>% pull(mod), 
                   nsim=nsim,
                   t=t,
                   newdata=newdata,
                   add.km=add.km )[[1]] |>
      mutate(object_name=as.factor(names(survHE_objs)[i]))
  }) |> 
    bind_rows() |> 
    group_by(object_name,model_name) |> 
    mutate(mods_id=cur_group_id()) |> 
    ungroup() 
  
  ##############################################################################################
  
  # If so, then builds the relevant data
  if (add.km==TRUE) {
    datakm <- lapply(1:length(survHE_objs), function(i) {
      make_data_surv(survHE_objs[[i]],
                     mods=1, #1:length(survHE_objs[[i]]$models), 
                     nsim=1,
                     t=t,
                     newdata=newdata,
                     add.km=add.km)[[2]] |>
        mutate(object_name=as.factor(names(survHE_objs)[i]))
    }) |> 
      bind_rows() |> 
      group_by(object_name,model_name) |> 
      mutate(mods_id=cur_group_id()) |> 
      ungroup()
  } else {
    datakm <- NULL
  }
  
  # Now makes the plot using the helper function
  surv.curv <- make_surv_curve_plot(toplot,datakm,mods,what=what)
  
  # Optional arguments
  if(exists("lab.profile",exArgs)){
    surv.curv <- surv.curv+
      scale_linetype_manual(labels=exArgs$lab.profile,values=1:length(exArgs$lab.profile))
  }
  # if both colours & labels are specified for the models chosen
  if(exists("colour",exArgs) & exists("lab.model",exArgs)) {
    surv.curv <- surv.curv+scale_color_manual(labels=exArgs$lab.model,values=exArgs$colour)
  }
  # if only the colours
  if(exists("colour",exArgs) & !exists("lab.model",exArgs)) {
    surv.curv <- surv.curv+scale_color_manual(values=exArgs$colour)
  }
  # if only the labels
  if(exists("lab.model",exArgs) & !exists("colour",exArgs)) {
    surv.curv <- surv.curv +
      scale_color_manual(values=1:length(exArgs$lab.model),labels=exArgs$lab.model)
  }
  if(exists("xlab",where=exArgs)){
    surv.curv <- surv.curv+labs(x=exArgs$xlab)
  }
  if(exists("ylab",where=exArgs)){
    surv.curv <- surv.curv+labs(y=exArgs$ylab)
  }
  if(exists("main",where=exArgs)) {
    surv.curv <- surv.curv+labs(title=exArgs$main) +
      theme(plot.title=element_text(size=18,face="bold"))
  }
  ymax <- 1
  
  if(annotate==TRUE){
    cutoff <- max(survHE_objs[[1]]$misc$km$time)
    
    surv.curv <- surv.curv + #geom_vline(xintercept=cutoff,linetype="dashed",size=1.5) +
      geom_segment(aes(x=cutoff,y=-Inf,xend=cutoff,yend=-.01),size=0.9) + 
      geom_segment(aes(x=cutoff,y=-.01,xend=cutoff*.85,yend=-.01),
                   arrow=arrow(length=unit(.25,"cm"),type="closed"),size=1.1)+
      geom_segment(aes(x=cutoff,y=-.01,xend=cutoff*1.15,yend=-.01),
                   arrow=arrow(length=unit(.25,"cm"),type="closed"),size=1.1)+
      annotate(geom="text",x=cutoff,y=-Inf,hjust=1.1,vjust=-1,label="Observed data",size=5) +
      annotate(geom="text",x=cutoff,y=-Inf,hjust=-0.1,vjust=-1,label="Extrapolation",size=5) 
    # Constrains the y-axis to [0-1] only if the required plot is the survival curve
    if(what=="survival") {surv.curv=surv.curv+ylim(-0.01,ymax)} 
    surv.curv=surv.curv+geom_rect(data=data.frame(xmin=-Inf,xmax=cutoff,ymin=-Inf,ymax=Inf),
                                  aes(xmin=xmin,xmax=xmax,ymin=ymin,ymax=ymax),fill="grey",alpha=.1)
  } else {
    surv.curv <- surv.curv +
      # Constrains the y-axis to [0-1] only if the required plot is the survival curve
      if(what=="survival") {ylim(0,ymax)} 
  }
  # If the scale is 'hazard' then adds a caption to the plot to highlight the fact that
  #  the hazard function is computed as a *numerical* derivative
  if(what=="hazard") {
    msg <- bquote(~bold(.("NB")): "The hazard function is computed as the" ~bold(.("numerical"))~"derivative of -log(Survival)")
    surv.curv <- surv.curv+labs(caption=msg)
  }
  if(exists("legend.position",exArgs)){
    surv.curv <- surv.curv+theme(legend.position=exArgs$legend.position)
  }
  if(exists("legend.title",exArgs)){
    surv.curv <- surv.curv+theme(legend.title=exArgs$legend.title)
  }
  if(exists("legend.text",exArgs)){
    surv.curv <- surv.curv+theme(legend.text=exArgs$legend.text)
  }
  # to remove the profiles legend
  #surv.curv=surv.curv+guides(linetype=FALSE)
  # to modify the profile legend
  #surv.curv=surv.curv+scale_linetype_discrete(name="XXX",label=c("XX","YY","ZZ"))
  # to remove the models legend
  #surv.curv=surv.curv+guides(colour=FALSE)
  # to modify the profile legend
  #surv.curv=surv.curv+scale_color_discrete(name="XXX",label=c("XX","YY","ZZ"))
  # +scale_linetype_manual(labels=c("Control","Treated"),values=c("dotdash","solid"))
  
  surv.curv
}


#' Make the dataset to be used by \code{ggplot2} to plot the survival curves
#' 
#' @param x The 'survHE' object
#' @param mods The models to be considered
#' @param nsim The number of simulations to generate
#' @param t The vector of times
#' @param newdata The list of "new" covariares proffiles
#' @param add.km Should the KM estimate be plotted too? Logical
#' @return \item{surv.curv}{The \code{ggplot2} object with the graph}
#' @note Something will go here
#' @author Gianluca Baio
#' @keywords Parametric survival models
#' @noRd 
make_data_surv <- function(x, mods=1:length(x$models), nsim=1,
                           t=NULL, newdata=NULL, add.km=FALSE) {
  if (is.null(t)) {
    t <- sort(unique(x$misc$km$time))
  }
  #s=lapply(1:length(x$models),function(i) {
  s <- lapply(mods,function(i) {
    make.surv(x,mod=i,t=t,nsim=nsim,newdata=newdata)
  })
  
  strata <- lapply(1:length(s), function(i) {
    lapply(1:nrow(s[[i]]$des.mat), function(x) {
      s[[i]]$des.mat %>%
        as_tibble() %>%
        select(-matches("(Intercept)",everything())) %>%
        slice(x) %>% 
        round(digits=2) %>%
        mutate(strata=paste0(names(.),"=",.,collapse=","))
    }) %>% bind_rows(.) %>%
      select(strata)
  })
  
  # toplot=lapply(1:length(mods),function(i) {
  #   lapply(1:length(s[[mods[i]]]$S),function(j) {
  #     s[[mods[i]]]$S[[j]] %>% bind_cols(strata=as.factor(strata[[mods[i]]][j,]),
  #      model_name=as.factor(names(x$models)[mods[i]]))
  #   })
  # }) %>% bind_rows(.)
  # out=list(toplot)
  toplot <- lapply(1:length(mods), function(i) {
    lapply(1:length(s[[i]]$S), function(j) {
      s[[i]]$S[[j]] %>%
        bind_cols(strata = as.factor(as.character(strata[[i]][j,])),
                  model_name = as.factor(names(x$models)[mods[i]]))
    })
  }) %>% bind_rows(.)
  
  out <- list(toplot)
  
  # Add the data for the KM curve?
  if (add.km==TRUE) {
    # If the number of strata in the KM computed in 'fit.models' is not the same as the 
    # number of rows in the design matrix from 'make.surv', then re-do a KM with no covariates
    if(length(x$misc$km$strata)!=nrow(s[[1]]$des.mat)){
      
      x$misc$km <- rms::npsurv(update(x$misc$formula,~1), data=x$misc$data)
      
      x$misc$km$call$formula <- as.formula(deparse(update(x$misc$formula,~1)))
    }
    # Now uses info in the KM table in the survHE object to create a dataset to plot
    datakm <- bind_cols(time = x$misc$km$time,
                        n.risk = x$misc$km$n.risk,
                        n.event = x$misc$km$n.event,
                        n.censor = x$misc$km$n.censor,
                        S = x$misc$km$surv,
                        lower = x$misc$km$lower,
                        upper = x$misc$km$upper) %>%
      mutate(model_name = "Kaplan Meier")
    
    # If 'strata' is not in the KM object (will happen if there's only 1)
    if (is.null(x$misc$km$strata)) {
      datakm$strata=as.factor("all")
    } else {
      datakm$strata <- as.factor(rep(1:length(x$misc$km$strata), x$misc$km$strata))
    }
    out$datakm <- datakm
  }
  # Returns the output as a list with the dataset(s) to plot
  return(out)
}


#' Make the actual \code{ggplot2} plot with the survival curves
#' 
#' @param toplot The dataset with the relevant data
#' @param dataKM The dataset with the (optional) data for the KM estimate
#' @param mods The models to be plotted (a vector of numbers)
#' @return \item{out}{A list with the dataset to be plotted including the survival curves}
#' @author Gianluca Baio
#' @keywords Parametric survival models
#' @noRd 
make_surv_curve_plot <- function(toplot, datakm=NULL, mods, what="survival") {
  # Does the model have covariates?
  if (all(toplot$strata=="=")) {
    # In this case not (intercept only), so remove the linetype as not needed
    linetype <- NULL
  } else {
    # If it does have covariates then use 'strata' to plot a curve per profile
    linetype <- toplot$strata
  }
  
  ylab <- "Survival"
  
  # Change the scale from the survival to the (approximated) hazard function
  # (computed as the numerical derivative of the cumulative hazard)
  if (what=="hazard") {
    toplot <- toplot %>%
      group_by(model_name,strata) %>%
      mutate(S = (-log(S)-lag(-log(S))) / (t-lag(t)) ) %>%
      ungroup()
    
    # If 'low' is a column of 'toplot' (=nsim>1) then also rescale the lower
    # and upper end to plot the ribbon
    if ("low" %in% names(toplot)) {
      toplot <- toplot %>%
        group_by(model_name,strata) %>%
        mutate(low = lag(-log(low))/lag(t),
               upp = lag(-log(upp))/lag(t)) %>%
        ungroup()
    }
    ylab <- "Hazard"
  }
  # Change the scale from the survival to the cumulative hazard
  if (what=="cumhazard") {
    toplot <- toplot %>%
      group_by(model_name,strata) %>%
      mutate(S = -log(S)) %>%
      ungroup()
    
    # If 'low' is a column of 'toplot' (=nsim>1) then also rescale the lower
    # and upper end to plot the ribbon
    if ("low" %in% names(toplot)) {
      toplot <- toplot %>%
        group_by(model_name,strata) %>%
        mutate(low = -log(low),
               upp = -log(upp)) %>%
        ungroup()
    }
    ylab <- "Cumulative hazard"
  }
  
  surv.curv <- ggplot() 
  # Am I plotting a single 'survHE' object?
  if (length(levels(toplot$object_name))==1) {
    surv.curv <- surv.curv+
      geom_line(data = toplot,
                aes(x=time, y=S, group=model_name:strata,col=model_name, linetype=linetype),
                size = 0.9) 
  } else {
    surv.curv <- surv.curv +
      geom_line(data = toplot,
                aes(x = time, y=S, group=model_name:strata:object_name,
                    col=object_name:model_name, linetype=linetype),
                size = 0.9)   
  }
  surv.curv <- surv.curv +
    theme_bw() + 
    theme(axis.text.x = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
          axis.text.y = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
          axis.title.x = element_text(color="black",size=14,angle=0,hjust=.5,vjust=.5),
          axis.title.y = element_text(color="black",size=14,angle=90,hjust=.5,vjust=.5)) +
    theme(axis.line = element_line(colour = "black"),
          #panel.grid.major = element_blank(),
          #panel.grid.minor = element_blank(),
          #panel.border = element_blank(),
          panel.background = element_blank(),
          panel.border = element_blank(),
          plot.title = element_text(size=18, face="bold")) +
    theme(legend.position=c(.75,.78),
          legend.title=element_text(size=15,face="bold"),
          #legend.title = element_blank(),
          legend.text = element_text(colour="black", size=14, face="plain"),
          legend.background=element_blank()) +
    labs(y=ylab,x="Time",title=NULL,
         color=ifelse(length(mods)==1,"Model","Models"),
         linetype="Profile") + 
    # This ensures that the model legend is always before the profile legend
    guides(color=guide_legend(order=1),
           linetype=guide_legend(order=2))
  # If uses more than 1 simulation from distribution of survival curves, then add ribbon
  if(any(grepl("low",names(toplot)))) {
    surv.curv <- surv.curv +
      geom_ribbon(data = toplot,
                  aes(x=time, y=S, ymin=low, ymax=upp, group=model_name:strata),
                  alpha = 0.2)
  }
  
  # Add KM plot? 
  if(!is.null(datakm)) {
    surv.curv <- surv.curv +
      geom_step(data = datakm, aes(x = time, y = S, group=as.factor(strata)),
                color="darkgrey") + 
      geom_ribbon(data = datakm,
                  aes(x = time, y = S, ymin=lower, ymax=upper, group=as.factor(strata)),
                  alpha = 0.2) 
  }
  surv.curv
}


# #' Make the KM curve plot using \code{survminer::ggsurvplot} 
# #' 
# #' @param x The KM fit object
# #' @param conf.int Should the confidence intervals be plotted too?
# #' @param risk.table Should the table with the number at risk plotted too?
# #' @param risk.table.col The name of the variable indexing the number at risk 
# #' @param palette The colouring scheme
# #' @param legend.labs The vector of legend labels
# #' @return \item{kmplot}{The \code{ggplot2} plot}
# #' @note Something will go here
# #' @author Gianluca Baio
# #' @keywords Parametric survival models
# #' @noRd 
#make_KM_plot <- function(x,conf.int=TRUE,risk.table=TRUE,risk.table.col="strata",
#                         palette=NULL,legend.labs=NULL) {
#  # Helper function to plot the KM curve 
#  if(is.null(palette)) {
#    if (length(x$misc$km$strata)>0) {
#      greyscale=colorRampPalette(c("grey20","grey80"))
#      palette=greyscale(length(x$misc$km$strata))
#    } else {
#      palette="black"
#    }
#  }
#  kmplot <- survminer::ggsurvplot(fit=x$misc$km,data=x$misc$data,
#                                  conf.int=TRUE,
#                                  risk.table = risk.table,        
#                                  ##risk.table.col = "strata",
#                                  palette = palette,
#                                  legend.labs = legend.labs,
#                                  risk.table.height = 0.25, # Useful to change when you have multiple groups
#                                  ggtheme = ggplot2::theme_bw()
#  )
#  kmplot
#}

#' Plot survival curves for the models fitted using \code{fit.models}
#' 
#' Plots the results of model fit.
#' 
#' 
#' @param ...  Must include at least one result object saved as the call to the
#' \code{fit.models} function.  Other possibilities are additional (mainly
#' graphical) options. These are: \code{xlab} = a string with the label for the
#' x-axis (default = "time") \code{ylab} = a string with the label for the
#' y-axis (default = "Survival") \code{lab.profile} = a (vector of) string(s)
#' indicating the labels associated with the strata defining the different
#' survival curves to plot. Default to the value used by the Kaplan Meier
#' estimate given in \code{fit.models} \code{lab.model} = a vector of string(s) 
#' indicating the labels associated with the models shown in the plots. 
#' \code{cex.trt} = factor by which the size of the font used to write the 
#' strata is resized (default = 0.8) \code{n.risk} = logical. If TRUE (defaults) 
#' writes the number at risk at different time points (as determined by the 
#' Kaplan Meier estimate) \code{newdata} = a list (of lists) providing the 
#' values for the relevant covariates If NULL, then will use the mean values 
#' for the covariates if at least one is a continuous variable, or the combination 
#' of the categorical covariates. \code{xlim} = a vector determining the limits 
#' for the x-axis \code{colors} = a vector of characters defining the colours in 
#' which to plot the different survival curves \code{labs} = a vector of characters 
#' defining the names of the models fitted \code{add.km} = TRUE (whether to also 
#' add the Kaplan Meier estimates of the data) \code{legend} = TRUE (whether to also 
#' add the legend to the graph)
#' @note Something will go here
#' @author Gianluca Baio
#' @seealso Something will go here
#' @references Something will go here
#' @keywords Parametric survival models
#' @examples
#' 
#' data(bc)
#' 
#' mle = fit.models(formula=Surv(recyrs,censrec)~group,data=bc,
#'     distr="exp",method="mle")
#' plot(mle)
#' @noRd 
plot_base_survHE <- function(x, exArgs) {
  ## Plots the KM + the results of the model fitted by fit.models()
  ## Uses different commands, depending on which method has been used to fit the models
  #
  # x = the result of the call to the fit.model function. Can be x,y,z,... (each survHE objects)
  #
  # mod = a numeric vector --- selects the models to plot (so mod=c(1,3) only selects the 1st and 3rd arguments)
  # xlab
  # ylab
  # lab.profile
  # lab.model
  # cex.trt
  # n.risk
  # xlim
  # colors
  # labs
  # add.km = TRUE (whether to also add the Kaplan Meier estimates of the data)
  # newdata = a list (of lists), specifiying the values of the covariates at which the computation is performed. For example
  #           'list(list(arm=0),list(arm=1))' will create two survival curves, one obtained by setting the covariate 'arm'
  #           to the value 0 and the other by setting it to the value 1. In line with 'flexsurv' notation, the user needs
  #           to either specify the value for *all* the covariates or for none (in which case, 'newdata=NULL', which is the
  #           default). If some value is specified and at least one of the covariates is continuous, then a single survival
  #           curve will be computed in correspondence of the average values of all the covariates (including the factors, 
  #           which in this case are expanded into indicators). 
  
  
  #### THIS NEEDS TO BE MODIFIED TO ACCOUNT FOR THE FACT THAT NOW THERE'S A PRIMARY INPUT 'x' WHICH IS EITHER A 'survHE'
  #### OBJECT OR A LIST!
  nexArgs <- length(exArgs)
  classes <- unlist(lapply(1:nexArgs,function(i) class(exArgs[[i]])))
  w <- which(classes=="survHE")
  original.method <- unlist(lapply(w,function(i) exArgs[[i]]$method))
  if(length(w)==0) {
    stop("You need to input at least one 'survHE' object to run this function!")
  }
  if(length(w)==1) {
    totmodels <- unlist(lapply(w,function(i) length(exArgs[[i]]$models)))
    mods <- exArgs[[w]]$models
    method <- rep(exArgs[[w]]$method,totmodels) 
    aic <- unlist(exArgs[[w]]$model.fitting$aic)
    bic <- unlist(exArgs[[w]]$model.fitting$bic)
    dic <- unlist(exArgs[[w]]$model.fitting$dic)
    if(totmodels>1){
      if (!is.null(exArgs$mod)) {which.model <- exArgs$mod} else {which.model <- 1:length(mods)}
      mods <- lapply(which.model,function(i) mods[[i]])
      method <- method[which.model]
      aic <- aic[which.model]
      bic <- bic[which.model]
      dic <- dic[which.model]
    } 
  }
  if (length(w)>1) {
    mods <- unlist(lapply(w,function(i) exArgs[[i]]$models),recursive = FALSE)
    totmodels <- unlist(lapply(w,function(i) length(exArgs[[i]]$models)))
    method <- unlist(lapply(w,function(i) rep(exArgs[[i]]$method,totmodels[i])))
    aic <- unlist(lapply(w,function(i) exArgs[[i]]$model.fitting$aic))
    bic <- unlist(lapply(w,function(i) exArgs[[i]]$model.fitting$bic))
    dic <- unlist(lapply(w,function(i) exArgs[[i]]$model.fitting$dic))
    if (!is.null(exArgs$mod)) {which.model <- exArgs$mod} else {which.model <- 1:length(mods)}
    mods <- lapply(which.model,function(i) mods[[i]])
    method <- method[which.model]
    aic <- aic[which.model]
    bic <- bic[which.model]
    dic <- dic[which.model]
  }
  model.fitting <- list(aic=aic,bic=bic,dic=dic)
  x <- list()
  x$models <- mods
  nmodels <- length(x$models)  # Number of models fitted by fit.models()
  class(x) <- "survHE"
  x$model.fitting <- model.fitting
  ## Needs to include in the misc object the element vars (which is used for HMC models)
  if (any(method=="hmc")) {
    x$misc <- exArgs[[min(which(original.method=="hmc"))]]$misc
    x$misc$data.stan=x$misc$data.stan[[1]]
    if (exists("X",x$misc$data.stan)) {
      x$misc$data.stan$X_obs <- x$misc$data.stan$X
    } else {
      x$misc$data.stan$X <- x$misc$data.stan$X_obs
    }
    x$misc$data.stan <- rep(list(x$misc$data.stan),nmodels)
  } else {
    # If none of the survHE objects are HMC, then just use the first
    x$misc <- exArgs[[1]]$misc
  }
  
  # Checks that extra options are specified
  if (is.null(exArgs$t)) {t <- sort(unique(x$misc$km$time))} else {t <- exArgs$t}
  if (is.null(exArgs$xlab)) {xl <- "time"} else {xl <- exArgs$xlab}
  if (is.null(exArgs$ylab)) {yl <- "Survival"} else {yl <- exArgs$ylab}
  if (is.null(exArgs$lab.profile)) {lab.profile <- names(x$misc$km$strata)} else {lab.profile<-exArgs$lab.profile}
  if (is.null(exArgs$cex.trt)) {cex.trt <- 0.8} else {cex.trt <- exArgs$cex.trt}
  if (is.null(exArgs$n.risk)) {nrisk <- FALSE} else {nrisk <- exArgs$n.risk}
  if (is.null(exArgs$main)) {main <- ""} else {main <- exArgs$main}
  if (is.null(exArgs$newdata)) {newdata <- NULL} else {newdata <- exArgs$newdata}
  if (is.null(exArgs$cex.lab)) {cex.lab <- 0.8} else {cex.lab <- exArgs$cex.lab}
  if (is.null(exArgs$legend)) {legend=TRUE} else (legend=FALSE)
  
  if (is.null(exArgs$xlim) & is.null(exArgs$t)) {
    xlm <- range(pretty(x$misc$km$time))
  } 
  if (is.null(exArgs$xlim) & !is.null(exArgs$t)) {
    xlm <- range(pretty(t))
  }
  if (!is.null(exArgs$xlim) & is.null(exArgs$t)) {
    xlm <- exArgs$xlim
  }
  if (!is.null(exArgs$xlim) & !is.null(exArgs$t)) {
    xlm <- exArgs$xlim
  }
  
  if (is.null(exArgs$colors)) {
    if (nmodels>1) {colors <- (2:(nmodels+1))} else {colors <- 2}
  } else {colors <- exArgs$colors}
  if(is.null(exArgs$axes)){axes <- TRUE} else {axes <- exArgs$axes}
  if (is.null(exArgs$labs)) {
    labs <- unlist(lapply(1:length(x$models),function(i) {
      if(inherits(x$models[[i]],"stanfit")) {tolower(x$models[[i]]@model_name)} else {x$models[[i]]$dlist$name}
    }))
    labs[labs %in% c("weibull.quiet","weibull","weibullaf","weibullph")] <- "Weibull"
    labs[labs %in% c("exp","exponential")] <- "Exponential"
    labs[labs %in% "gamma"] <- "Gamma"
    labs[labs %in% c("lnorm","lognormal")] <- "log-Normal"
    labs[labs %in% c("llogis","loglogistic","loglogis")] <- "log-Logistic"
    labs[labs %in% "gengamma"] <- "Gen. Gamma"
    labs[labs %in% "genf"] <- "Gen. F"
    labs[labs %in% "gompertz"] <- "Gompertz"
    labs[labs %in% c("survspline","rp")] <- "Royston & Parmar splines"
  } else {labs <- exArgs$labs}
  labs <- c("Kaplan Meier",labs)
  if (is.null(exArgs$add.km)) {add.km <- TRUE} else {add.km <- exArgs$add.km}
  
  # Now plots the KM curve using "rms" if add.km is set to TRUE
  if (add.km==TRUE & is.null(newdata)) {
    rms::survplot(x$misc$km,                                     # Specialised plot from "rms" 
                  xlab=xl,ylab=yl,		                           # x- and y- labels
                  label.curves=list(labels=lab.profile,cex=cex.trt), # specifies curve labels
                  n.risk=nrisk,   	                             # tells R to show number at risk 
                  lwd=2,xlim=xlm  	                             # defines the size of the lines (2 pts)
    )
    col <- c("black",colors)
    title(main)
  } else {
    labs <- labs[-1]
    if(!inherits(colors,"character")) {colors <- colors-1}
    plot(0,0,col="white",xlab=xl,ylab=yl,axes=FALSE,xlim=xlm,ylim=c(0,1),main=main)
    if(axes==TRUE) {
      axis(1)
      axis(2)}
    col <- colors
  }
  res <- lapply(1:nmodels,function(i) {
    x$method <- method[i]
    make.surv(x,nsim=1,t=t,mod=i,newdata=newdata)
  })
  
  if (!is.null(newdata)) {
    # Needs to distinguish between mle and non-mle because of how make.surv saves the S list
    options(digits=5,nsmall=2)
    pts <- list()
    for (i in 1:nmodels) {
      if (method[i]=="mle") {
        pts[[i]] <- lapply(1:length(newdata),function(j) {
          tmp <- matrix(unlist(res[[i]]$S[[j]]),ncol=4)
          cbind(tmp[,1],tmp[,2])
        })
      } else {
        pts[[i]] <- lapply(1:length(newdata),function(j) {
          res[[i]]$S[[1]][[j]]
        })
      }
    }
    colors <- 1:nmodels
    leg.txt <- character()
    for (i in 1:nmodels) {
      for (j in 1:length(newdata)) {
        points(pts[[i]][[j]],t="l",col=colors[i],lty=j)
        leg.txt[j] <- paste0(names(newdata[[j]]),"=",prettyNum(newdata[[j]],format="fg"),collapse=", ")
      }
    }
    if (legend) {legend("topright",legend=leg.txt,bty="n",lty=1:length(newdata),cex=cex.lab)}
  }
  if(is.null(newdata)) {
    # With no newdata this works!
    for (i in 1:nmodels) {
      #####pts <- lapply(res[[i]]$S[[1]],function(m) cbind(m[,1],m[,2]))
      #####lapply(1:length(pts), function(x) points(pts[[x]],t="l",col=colors[i],lty=x))
      pts <- lapply(res[[i]]$S[[1]],function(m) m %>% as.matrix())
      lapply(1:length(pts), function(x) points(pts[[x]]$t,pts[[x]]$S,t="l",col=colors[i],lty=x))
    }
    if(legend) {legend(x="topright",legend=labs,lwd=2,bty="n",col=col,cex=cex.lab)}
  }
}
giabaio/survHE documentation built on Sept. 9, 2023, 2:47 a.m.