### autoplot.predictCox.R ---
##----------------------------------------------------------------------
## author: Brice Ozenne
## created: feb 17 2017 (10:06)
## Version:
## last-updated: Jan 29 2019 (10:49)
## By: Thomas Alexander Gerds
## Update #: 482
##----------------------------------------------------------------------
##
### Commentary:
##
### Change Log:
##----------------------------------------------------------------------
##
### Code:
# {{{ autoplot.predictCox
## * autoplot.predictCox (documentation)
#' @title Plot Predictions From a Cox Model
#' @description Plot predictions from a Cox model.
#' @name autoplot.predictCox
#'
#' @param object Object obtained with the function \code{predictCox}.
#' @param type [character] The type of predicted value to display.
#' Choices are:
#' \code{"hazard"} the hazard function,
#' \code{"cumhazard"} the cumulative hazard function,
#' or \code{"survival"} the survival function.
#' @param ci [logical] If \code{TRUE} display the confidence intervals for the predictions.
#' @param band [logical] If \code{TRUE} display the confidence bands for the predictions.
#' @param group.by [character] The grouping factor used to color the prediction curves. Can be \code{"row"}, \code{"strata"}, or \code{"covariates"}.
#' @param reduce.data [logical] If \code{TRUE} only the covariates that does take indentical values for all observations are displayed.
#' @param plot [logical] Should the graphic be plotted.
#' @param digits [integer] Number of decimal places.
#' @param ylab [character] Label for the y axis.
#' @param alpha [numeric, 0-1] Transparency of the confidence bands. Argument passed to \code{ggplot2::geom_ribbon}.
#' @param ... Not used. Only for compatibility with the plot method.
## * autoplot.predictCox (examples)
#' @rdname autoplot.predictCox
#' @examples
#' library(survival)
#' library(ggplot2)
#'
#' #### simulate data ####
#' set.seed(10)
#' d <- sampleData(1e2, outcome = "survival")
#'
#' #### Cox model ####
#' m.cox <- coxph(Surv(time,event)~ X1 + X2 + X3,
#' data = d, x = TRUE, y = TRUE)
#'
#' ## display baseline hazard
#' e.basehaz <- predictCox(m.cox)
#'
#' autoplot(e.basehaz, type = "cumhazard")
#'
#' ## display predicted survival
#' pred.cox <- predictCox(m.cox, newdata = d[1:4,],
#' times = 1:5, type = "survival", keep.newdata = TRUE)
#' autoplot(pred.cox)
#' autoplot(pred.cox, group.by = "covariates")
#' autoplot(pred.cox, group.by = "covariates", reduce.data = TRUE)
#'
#' ## predictions with confidence interval/bands
#' pred.cox <- predictCox(m.cox, newdata = d[1,,drop=FALSE],
#' times = 1:5, type = "survival", band = TRUE, se = TRUE, keep.newdata = TRUE)
#' autoplot(pred.cox, ci = TRUE, band = TRUE)
#' autoplot(pred.cox, ci = TRUE, band = TRUE, alpha = 0.1)
#'
#' #### Stratified Cox model ####
#' m.cox.strata <- coxph(Surv(time,event)~ strata(X1) + strata(X2) + X3 + X6,
#' data = d, x = TRUE, y = TRUE)
#'
#' pred.cox.strata <- predictCox(m.cox.strata, newdata = d[1:5,,drop=FALSE],
#' time = 1:5, keep.newdata = TRUE)
#'
#' ## display
#' res <- autoplot(pred.cox.strata, type = "survival", group.by = "strata")
#'
#' ## customize display
#' res$plot + facet_wrap(~strata, labeller = label_both)
#' res$plot %+% res$data[strata == "0, 1"]
## * autoplot.predictCox (code)
#' @rdname autoplot.predictCox
#' @method autoplot predictCox
#' @export
autoplot.predictCox <- function(object,
type = NULL,
ci = FALSE,
band = FALSE,
group.by = "row",
reduce.data = FALSE,
plot = TRUE,
ylab = NULL,
digits = 2, alpha = NA, ...){
## initialize and check
possibleType <- c("cumhazard","survival")
possibleType <- possibleType[possibleType %in% names(object)]
if(is.null(type)){
if(length(possibleType) == 1){
type <- possibleType
}else{
stop("argument \'type\' must be specified to choose between ",paste(possibleType, collapse = " "),"\n")
}
}else{
type <- match.arg(type, possibleType)
}
if(is.null(ylab)){
ylab <- switch(type,
"cumhazard" = "cumulative hazard",
"survival" = "survival")
}
group.by <- match.arg(group.by, c("row","covariates","strata"))
if(group.by[[1]] == "covariates" && (("newdata" %in% names(object)) == FALSE)){
stop("argument \'group.by\' cannot be \"covariates\" when newdata is missing in the object \n",
"set argment \'keep.newdata\' to TRUE when calling predictCox \n")
}
if(group.by[[1]] == "strata" && ("strata" %in% names(object) == FALSE)){
stop("argument \'group.by\' cannot be \"strata\" when strata is missing in the object \n",
"set argment \'keep.strata\' to TRUE when calling predictCox \n")
}
if(ci[[1]]==TRUE && (object$se[[1]]==FALSE || is.null(object$conf.level))){
stop("argument \'ci\' cannot be TRUE when no standard error have been computed \n",
"set arguments \'se\' and \'confint\' to TRUE when calling predictCox \n")
}
if(band[[1]] && (object$band[[1]]==FALSE || is.null(object$conf.level))){
stop("argument \'band\' cannot be TRUE when the quantiles for the confidence bands have not been computed \n",
"set arguments \'band\' and \'confint\' to TRUE when calling predictCox \n")
}
dots <- list(...)
if(length(dots)>0){
txt <- names(dots)
txt.s <- if(length(txt)>1){"s"}else{""}
stop("unknown argument",txt.s,": \"",paste0(txt,collapse="\" \""),"\" \n")
}
## reshape data
if(!is.matrix(object[[type]])){
if(is.null(object[["strata"]])){
object[[type]] <- rbind(object[[type]])
}else{
strata <- unique(object[["strata"]])
n.strata <- length(strata)
time <- unique(sort(object[["times"]]))
n.time <- length(time)
type.tempo <- matrix(NA, nrow = n.strata, ncol = n.time)
init <- switch(type,
"cumhazard" = 0,
"survival" = 1)
for(iStrata in 1:n.strata){ ## iStrata <- 1
index.strata <- which(object[["strata"]]==strata[iStrata])
type.tempo[iStrata,] <- stats::approx(x = object[["times"]][index.strata],
y = object[[type]][index.strata],
yleft = init,
yright = NA,
xout = time,
method = "constant")$y
}
object[[type]] <- type.tempo
object[["strata"]] <- strata
object[["times"]] <- time
group.by <- "strata"
}
newdata <- NULL
}else{
newdata <- data.table::copy(object$newdata) ## can be NULL
if(!is.null(newdata) && reduce.data[[1]]==TRUE){
test <- unlist(newdata[,lapply(.SD, function(col){length(unique(col))==1})])
if(any(test)){
newdata[, (names(test)[test]):=NULL]
}
}
}
dataL <- predict2melt(outcome = object[[type]], ci = ci, band = band,
outcome.lower = if(ci){object[[paste0(type,".lower")]]}else{NULL},
outcome.upper = if(ci){object[[paste0(type,".upper")]]}else{NULL},
outcome.lowerBand = if(band){object[[paste0(type,".lowerBand")]]}else{NULL},
outcome.upperBand = if(band){object[[paste0(type,".upperBand")]]}else{NULL},
newdata = newdata,
strata = object$strata,
times = object$times,
name.outcome = type,
group.by = group.by,
digits = digits
)
## display
gg.res <- predict2plot(dataL = dataL,
name.outcome = type,
ci = ci,
band = band,
group.by = group.by,
conf.level = object$conf.level,
alpha = alpha,
ylab = ylab
)
if(plot){
print(gg.res$plot)
}
return(invisible(gg.res))
}
# }}}
# {{{ predict2melt
## * predict2melt
predict2melt <- function(outcome, name.outcome,
ci, outcome.lower, outcome.upper,
band, outcome.lowerBand, outcome.upperBand,
newdata, strata, times, group.by, digits){
patterns <- NULL ## [:CRANtest:] data.table
n.time <- NCOL(outcome)
if(!is.null(time)){
time.names <- times
}else{
time.names <- 1:n.time
}
colnames(outcome) <- paste0(name.outcome,"_",time.names)
keep.cols <- unique(c("time",name.outcome,"row",group.by))
## add initial values ####
first.dt <- switch(name.outcome,
"cumhazard" = data.table(time = 0, cumhazard = 0),
"survival" = data.table(time = 0, survival = 1),
"absRisk" = data.table(time = 0, absRisk = 0))
## merge outcome with CI and band ####
pattern <- paste0(name.outcome,"_")
if(ci){
pattern <- c(pattern,"lowerCI_","upperCI_")
colnames(outcome.lower) <- paste0("lowerCI_",time.names)
colnames(outcome.upper) <- paste0("upperCI_",time.names)
first.dt[, c("lowerCI") := .SD[[1]], .SDcols = name.outcome]
first.dt[, c("upperCI") := .SD[[1]], .SDcols = name.outcome]
}
if(band){
pattern <- c(pattern,"lowerBand_","upperBand_")
keep.cols <- c(keep.cols,"lowerBand","upperBand")
colnames(outcome.lowerBand) <- paste0("lowerBand_",time.names)
colnames(outcome.upperBand) <- paste0("upperBand_",time.names)
first.dt[, c("lowerBand") := .SD[[1]], .SDcols = name.outcome]
first.dt[, c("upperBand") := .SD[[1]], .SDcols = name.outcome]
}
outcome <- data.table::as.data.table(
cbind(outcome,
outcome.lower, outcome.upper,
outcome.lowerBand,outcome.upperBand)
)
## merge with convariates ####
outcome[, row := 1:.N]
if(group.by == "covariates"){
cov.names <- names(newdata)
newdata <- newdata[, (cov.names) := lapply(cov.names,function(col){
if (is.numeric(.SD[[col]]))
paste0(col,"=",round(.SD[[col]],digits)) else paste0(col,"=",.SD[[col]])})]
outcome[, ("covariates") := interaction(newdata,sep = " ")]
}else if(group.by == "strata"){
outcome[, strata := strata]
}
## reshape to long format ####
dataL <- melt(outcome, id.vars = union("row",group.by),
measure = patterns(pattern),
variable.name = "time", value.name = gsub("_","",pattern))
dataL[, time := as.numeric(as.character(factor(time, labels = time.names)))]
dataL <- dataL[!is.na(dataL[[name.outcome]])]
dataL <- dataL[, rbind(first.dt,.SD), by = c(union("row",group.by))]
return(dataL)
}
# }}}
# {{{ predict2plot
## * predict2plot
predict2plot <- function(dataL, name.outcome,
ci, band, group.by,
conf.level, alpha, ylab){
# for CRAN tests
original <- lowerCI <- upperCI <- lowerBand <- upperBand <- timeLeft <- NULL
#### duplicate observations to obtain step curves ####
keep.cols <- unique(c("time",name.outcome,"row",group.by,"original"))
if(ci){
keep.cols <- c(keep.cols,"lowerCI","upperCI")
}
if(band){
keep.cols <- c(keep.cols,"lowerBand","upperBand")
}
## set at t- the value of t-1
dtTempo <- copy(dataL)
vec.outcome <- name.outcome
if(ci){
vec.outcome <- c(vec.outcome,"lowerCI","upperCI")
}
if(band){
vec.outcome <- c(vec.outcome,"lowerBand","upperBand")
}
dataL[,c("timeLeft") := .SD[c(1,1:(.N-1))], .SDcols = "time", by = "row"]
## display ####
labelCI <- paste0(conf.level*100,"% confidence \n interval")
labelBand <- paste0(conf.level*100,"% confidence \n band")
gg.base <- ggplot(data = dataL, mapping = aes_string(group = "row", color = group.by))
gg.base <- gg.base + ggplot2::geom_segment(aes_string(x = "timeLeft", y = name.outcome, xend = "time", yend = name.outcome), size = 1.5)
gg.base <- gg.base + ggplot2::geom_point(aes_string(x = "timeLeft", y = name.outcome), size = 2)
if(group.by=="row"){
gg.base <- gg.base + ggplot2::labs(color="observation") + theme(legend.key.height=unit(0.1,"npc"),
legend.key.width=unit(0.08,"npc"))
# display only integer values
uniqueObs <- unique(dataL$row)
if(length(uniqueObs)==1){
gg.base <- gg.base + scale_color_continuous(guide=FALSE)
}else{
gg.base <- gg.base + scale_color_continuous(breaks = uniqueObs[seq(1,length(uniqueObs), length.out = min(10,length(uniqueObs)))],
limits = c(0.5, length(uniqueObs) + 0.5))
}
}
if(ci){
if(!is.na(alpha)){
gg.base <- gg.base + ggplot2::geom_errorbar(aes(x = time, ymin = lowerCI, ymax = upperCI, linetype = labelCI))
gg.base <- gg.base + scale_linetype_manual("",values=setNames(1,labelCI))
}else{
gg.base <- gg.base + ggplot2::geom_segment(aes(x = timeLeft, y = lowerCI, xend = time, yend = lowerCI, linetype = "ci"),
size = 1.2, color = "black")
gg.base <- gg.base + ggplot2::geom_segment(aes(x = timeLeft, y = upperCI, xend = time, yend = upperCI, linetype = "ci"),
size = 1.2, color = "black")
}
}
if(band){
if(!is.na(alpha)){
gg.base <- gg.base + ggplot2::geom_rect(aes(xmin = timeLeft, xmax = time, ymin = lowerBand, ymax = upperBand,
fill = labelBand), linetype = 0, alpha = alpha)
gg.base <- gg.base + scale_fill_manual("", values="grey12")
}else{
gg.base <- gg.base + ggplot2::geom_segment(aes(x = timeLeft, y = lowerBand, xend = time, yend = lowerBand, linetype = "band"),
size = 1.2, color = "black")
gg.base <- gg.base + ggplot2::geom_segment(aes(x = timeLeft, y = upperBand, xend = time, yend = upperBand, linetype = "band"),
size = 1.2, color = "black")
}
}
if(is.na(alpha)[[1]] && (band[[1]] || ci[[1]])){
indexTempo <- which(c(ci,band)==1)
if(band && ci){
value <- c(1,2)
}else{
value <- 1
}
gg.base <- gg.base + scale_linetype_manual("", breaks = c("ci","band")[indexTempo],
labels = c(labelCI,labelBand)[indexTempo],
values = value)
}else if(ci[[1]] && band[[1]]){
gg.base <- gg.base + ggplot2::guides(linetype = ggplot2::guide_legend(order = 1),
fill = ggplot2::guide_legend(order = 2),
group = ggplot2::guide_legend(order = 3)
)
}
gg.base <- gg.base + ggplot2::ylab(ylab)
## export
ls.export <- list(plot = gg.base,
data = dataL)
return(ls.export)
}
# }}}
#----------------------------------------------------------------------
### autoplot.predictCox.R ends here
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.