R/model.fit.plot.R

Defines functions model.fit.plot

Documented in model.fit.plot

#' Graphical representation of the measures of model fitting based on
#' Information Criteria
#' 
#' Plots a summary of the model fit for all the models fitted
#' 
#' Something will go here
#' 
#' @param ...  Optional inputs. Must include at least one \code{survHE} object.
#' @param type should the AIC, the BIC or the DIC plotted? (values = \code{"aic"},
#' \code{"bic"} or \code{"dic"})
#' @param scale If \code{scale='absolute'} (default), then plot the absolute value 
#' of the *IC. If \code{scale='relative'} then plot a rescaled version taking
#' the percentage increase in the *IC in comparison with the best-fitting model
#' @param stacked Should the bars be stacked and grouped by survHE object? (default=F)
#' @return A plot with the relevant model fitting statistics
#' @author Gianluca Baio
#' @seealso \code{fit.models}
#' @template refs
#' @keywords Model fitting Parametric survival models
#' @examples
#' \dontrun{ 
#' data(bc)
#' 
#' mle = fit.models(formula=Surv(recyrs,censrec)~group,data=bc,
#'     distr=c("exp","wei","lno"),method="mle")
#' model.fit.plot(mle)
#' }
#' 
#' @export model.fit.plot
model.fit.plot <- function(...,type="aic",scale="absolute",stacked=FALSE) {
  ## Plots a summary of the model fit for all the models 
  ## Can also combine several survHE objects each containing the fit for one model
  
  exArgs=list(...)
  
  # Finds out whether there are objects with no name (if so, they will be 'survHE' objects!)
  # If there are any, then needs to rename them to make the rest of the function work
  if(length(names(exArgs))==0) {
    # This is the case where the only argument(s) is/are unnamed 'survHE' object(s)
    names(exArgs)=paste0("Object",1:length(exArgs))
  }
  if(length(which(names(exArgs)==""))>0){
    names(exArgs)[which(names(exArgs)=="")] = paste0("Object",1:length(which(names(exArgs)=="")))
  }
  # Extracts the 'survHE' objects from the list 'exArgs'. If there are none, then stop with an error message!
  w <- which(unlist(lapply(1:length(exArgs),function(i) class(exArgs[[i]])))=="survHE")
  if(length(w)==0){
    stop("Please give at least one 'survHE' object, generated by a call to 'fit.models(...)")
  } else {
    survHE_objs=lapply(1:length(w),function(i) exArgs[[w[i]]])
  }
  names(survHE_objs)=names(exArgs)[w]
  
  # What model should be used from the 'survHE' objects?
  if (!exists("mods",exArgs)) {
    mods=1:sum(unlist(lapply(survHE_objs,function(x) length(x$models))))
  } else {mods=exArgs$mods}
  
  # Maps the arguments for 'type' to relevant strings
  if(type %in% c("aic","AIC","a","A")){type="AIC"}
  if(type %in% c("bic","BIC","b","B")){type="BIC"}
  if(type %in% c("dic","DIC","d","D")){type="DIC"}
  
  # Creates the dataset with the model fitting statistics for all the selected models
  toplot=lapply(1:length(survHE_objs),function(x) survHE_objs[[x]]$model.fitting %>% bind_rows %>%
                  mutate(object_name=as.factor(names(survHE_objs)[x]),
                         model_name=names(survHE_objs[[x]]$models))) %>% 
    bind_rows %>% mutate(lab=paste0(model_name,":",object_name)) %>% select(object_name,model_name,lab,everything()) %>% 
    slice(mods)
  
  # Can make the bars stacked (looks actually nice...)
  if(stacked==TRUE) {
    if(type=="AIC") {
      if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$aic)))}
      mfp=ggplot(data=toplot,aes(x=model_name,y=aic,fill=object_name)) +
        geom_bar(stat="identity",position=position_dodge()) +
        geom_text(aes(x=model_name,y=aic,label=aic %>% round(digits=1.5)), hjust=1.05, 
                  color="white", size=5.5,position = position_dodge(0.9)) + coord_flip(ylim=yl)
    }
    if(type=="BIC") {
      if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$bic)))}
      mfp=ggplot(data=toplot,aes(x=model_name,y=bic,fill=object_name)) +
        geom_bar(stat="identity",position=position_dodge()) +
        geom_text(aes(x=model_name,y=bic,label=bic %>% round(digits=1.5)), hjust=1.05, 
                  color="white", size=5.5,position = position_dodge(0.9)) + coord_flip(ylim=yl)
    }
    if(type=="DIC") {
      if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$dic,na.rm=TRUE)))}
      mfp=ggplot(data=toplot,aes(x=model_name,y=dic,fill=object_name)) +
        geom_bar(stat="identity",position=position_dodge()) +
        geom_text(aes(x=model_name,y=dic,label=dic %>% round(digits=1.5)), hjust=1.05, 
                  color="white", size=5.5,position = position_dodge(0.9)) + coord_flip(ylim=yl)
    }
    mfp=mfp+
      theme_bw() + 
      theme(axis.text.x = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
            axis.text.y = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
            axis.title.x = element_text(color="black",size=14,angle=0,hjust=.5,vjust=.5),
            axis.title.y = element_text(color="black",size=14,angle=90,hjust=.5,vjust=.5)) +
      theme(axis.line = element_line(colour = "black"),
            panel.background = element_blank(),
            panel.border = element_blank(),
            plot.title = element_text(size=18, face="bold")) +
      labs(y=toupper(type),x="",title=paste0("Model comparison based on ",toupper(type)),
           fill="survHE object") + 
      scale_fill_brewer(palette="Paired") + 
      theme(legend.position="bottom")
    
    # Optional arguments
    # Manual colours should be a vector with lenght equal to the number of objects
    if(exists("col",exArgs)){
      mfp=mfp+scale_fill_manual(values=exArgs$col)
    }
    if(exists("colour",exArgs)){
      mfp=mfp+scale_fill_manual(values=exArgs$colour)
    }
    if(exists("color",exArgs)){
      mfp=mfp+scale_fill_manual(values=exArgs$color)
    }
    # Can modify the title of the legend
    if(exists("name_legend",exArgs)){
      mfp=mfp+labs(fill=exArgs$name_legend)
    }
    # Can change the palette too
    # plot + scale_fill_brewer(palette=...) see 'help(scale_fill_brewer)' for possible options
  }
  
  if(stacked==FALSE) {
    # Can choose the colour with which to plot the bars. If nothing specified, then select default ('steelblue').
    # Otherwise, the user can specify a vector of colour with as many as there are bars. Can use interchangeably
    # the strings 'colour', 'color' or 'col'
    if(exists("colour",exArgs)){col=exArgs$colour} else {col="steelblue"}
    if(exists("color",exArgs)){col=exArgs$color} else {col="steelblue"}
    if(exists("col",exArgs)){col=exArgs$col} else {col="steelblue"}
    
    # Finally plots the bar-chart
    if(nlevels(toplot$object_name)==1){x=toplot$model_name} else {x=toplot$lab}
    mfp=ggplot(data=toplot)
    if(type=="AIC") {
      if(scale=="absolute" | scale=="abs") {
        if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$aic)))}
        mfp=mfp+geom_bar(mapping=aes(x=x,y=aic),stat="identity",fill=col) +
          geom_text(aes(x=x,y=aic,label=aic %>% round(digits=1.5)), hjust=1.05, color="white", size=5.5) +
          labs(y=toupper(type),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models") 
          ) + coord_flip(ylim=yl)
      }
      if (scale=="rel" | scale=="relative") {
        mfp=mfp+geom_bar(mapping=aes(x=x,y=100*(aic-min(aic))/min(aic)),stat="identity",fill=col) +
          geom_text(aes(x=x,y=100*(aic-min(aic))/min(aic),label=(100*(aic-min(aic))/min(aic)) %>% round(digits=1.5)), 
                    hjust=-.05, color="black", size=5.5) +
          labs(y=paste0("Percentage increase in ",toupper(type)),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models")
          ) + coord_flip()
      }
    } 
    if(type=="BIC") {
      if(scale=="absolute" | scale=="abs") {
        if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$bic)))}
        mfp=mfp+geom_bar(mapping=aes(x=x,y=bic),stat="identity",fill=col) +
          geom_text(aes(x=x,y=bic,label=bic %>% round(digits=1.5)), hjust=1.05, color="white", size=5.5) +
          labs(y=toupper(type),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models")
          ) + coord_flip(ylim=yl)
      } 
      if(scale=="rel" | scale=="relative") {
        mfp=mfp+geom_bar(mapping=aes(x=x,y=100*(bic-min(bic))/min(bic)),stat="identity",fill=col) +
          geom_text(aes(x=x,y=100*(bic-min(bic))/min(bic),label=(100*(bic-min(bic))/min(bic)) %>% round(digits=1.5)), 
                    hjust=-.05, color="black", size=5.5) +
          labs(y=paste0("Percentage increase in ",toupper(type)),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models")
          ) + coord_flip()
      }
    }
    if(type=="DIC") {
      if(scale=="absolute" | scale=="abs") {
        if(exists("xlim",exArgs)){yl=exArgs$xlim} else {yl=range(pretty(range(toplot$dic,na.rm=TRUE)))}
        mfp=mfp+geom_bar(mapping=aes(x=x,y=dic),stat="identity",fill=col) +
          geom_text(aes(x=x,y=dic,label=dic %>% round(digits=1.5)), hjust=1.05, color="white", size=5.5) +
          labs(y=toupper(type),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models")
          ) + coord_flip(ylim=yl)
      } 
      if(scale=="rel" | scale=="relative") {
        mfp=mfp+geom_bar(mapping=aes(x=x,y=100*(dic-min(dic))/min(dic)),stat="identity",fill=col) +
          geom_text(aes(x=x,y=100*(dic-min(dic))/min(dic),label=(100*(dic-min(dic))/min(dic)) %>% round(digits=1.5)), 
                    hjust=-.05, color="black", size=5.5) +
          labs(y=paste0("Percentage increase in ",toupper(type)),x="",title=paste0("Model comparison based on ",toupper(type)),
               color=ifelse(length(mods)==1,"Model","Models")
          ) + coord_flip()
      }
    }
    mfp=mfp +
      theme_bw() + 
      theme(axis.text.x = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
            axis.text.y = element_text(color="black",size=12,angle=0,hjust=.5,vjust=.5),
            axis.title.x = element_text(color="black",size=14,angle=0,hjust=.5,vjust=.5),
            axis.title.y = element_text(color="black",size=14,angle=90,hjust=.5,vjust=.5)) +
      theme(axis.line = element_line(colour = "black"),
            panel.background = element_blank(),
            panel.border = element_blank(),
            plot.title = element_text(size=18, face="bold")) 
    
    # Optional arguments
    if(exists("main",exArgs)){
      mfp=mfp+labs(title=exArgs$main)
    }
    if(exists("models",exArgs)){
      mfp=mfp+scale_x_discrete(labels=exArgs$models)
    }
  }
  
  # Renders the graph
  mfp
}
giabaio/survHE documentation built on Sept. 9, 2023, 2:47 a.m.