R/ggrisk.R

#' @title Risk Score Plot for Cox Regression
#' @param fit cox regression results of coxph() from 'survival' package or cph() from 'rms' package
#' @param heatmap.genes (optional) numeric variables. Name for genes
#' @param data new data for validation
#' @param code.0 string. Code for event 0. Default is 'Alive'
#' @param code.1 string. Code for event 1. Default is 'Dead'
#' @param code.highrisk string. Code for highrisk in risk score. Default is 'High'
#' @param code.lowrisk string. Code for lowrisk in risk score. Default is 'Low'
#' @param cutoff.show logical, whether to show text for cutoff in figure A. Default is TRUE
#' @param cutoff.value string, which can be 'median', 'roc' or 'cutoff'. Even you can define it by yourself
#' @param cutoff.x numeric (optional), ordination x for cutoff text
#' @param cutoff.y numeric (optional), ordination y for cutoff text
#' @param cutoff.label (should be) string. Define cutoff label by yourself
#' @param title.A.ylab string, y-lab title for figure A. Default is 'Risk Score'
#' @param title.B.ylab string, y-lab title for figure B. Default is 'Survival Time'
#' @param title.A.legend string, legend title for figure A. Default is 'Risk Group'
#' @param title.B.legend string, legend title for figure B. Default is 'Status'
#' @param title.C.legend string, legend title for figure C. Default is 'Expression'
#' @param size.ABC numeric, size for ABC. Default is 1.5
#' @param size.ylab.title numeric, size for y-axis label title. Default is 14
#' @param size.Atext numeric, size for y-axis text in figure A. Default is 11
#' @param size.Btext numeric, size for y-axis text in figure B. Default is 11
#' @param size.Ctext numeric, size for y-axis text in figure C. Default is 11
#' @param size.yticks numeric, size for y-axis ticks. Default is 0.5
#' @param size.yline numeric, size for y-axis line. Default is 0.5
#' @param size.points numeric, size for scatter points. Default is 2
#' @param size.dashline numeric, size for dashline. Default is 1
#' @param size.cutoff numeric, size for cutoff text. Default is 5
#' @param size.legendtitle numeric, size for legend title. Default is 13
#' @param size.legendtext numeric, size for legend text. Default is 12
#' @param color.A color for figure A. Default is low = 'blue', high = 'red'
#' @param color.B color for figure B. Default is code.0 = 'blue', code.1 = 'red'
#' @param color.C color for figure C. Default is low = 'blue', median = 'white', high = 'red'
#' @param vjust.A.ylab numeric, vertical just for y-label in figure A. Default is 1
#' @param vjust.B.ylab numeric, vertical just for y-label in figure B. Default is 2
#' @param family family, default is sans
#' @param expand.x  numeric, expand for x-axis
#' @param relative_heights numeric, relative heights for figure A, B, colored side bar and heatmap. Default is 0.1 0.1 0.01 and 0.15
#' @importFrom ggplot2 aes aes_string geom_point geom_vline theme element_blank element_text scale_colour_hue coord_trans
#' @importFrom ggplot2 ylab geom_tile unit scale_fill_gradient2 scale_x_continuous geom_raster theme_classic annotate
#' @importFrom ggplot2 scale_color_manual element_line scale_fill_manual ggplot scale_fill_manual
#' @importFrom stats as.formula median sd cor
#' @importFrom stats predict update
#' @importFrom rms cph
#' @return A risk score picture
#' @export
ggrisk <- function (fit, heatmap.genes = NULL, data = NULL, code.0 = "Alive",
          code.1 = "Dead", code.highrisk = "High", code.lowrisk = "Low",
          cutoff.show = TRUE, cutoff.value = "median", cutoff.x = NULL,
          cutoff.y = NULL, cutoff.label = NULL, title.A.ylab = "Risk Score",
          title.B.ylab = "Survival Time", title.A.legend = "Risk Group",
          title.B.legend = "Status", title.C.legend = "Expression",
          size.ABC = 1.5, size.ylab.title = 14, size.Atext = 11, size.Btext = 11,
          size.Ctext = 11, size.yticks = 0.5, size.yline = 0.5, size.points = 2,
          size.dashline = 1, size.cutoff = 5, size.legendtitle = 13,
          size.legendtext = 12, color.A = c(low = "blue", high = "red"),
          color.B = c(code.0 = "blue", code.1 = "red"), color.C = c(low = "blue",
                                                                    median = "white", high = "red"), vjust.A.ylab = 1, vjust.B.ylab = 2,
          family = "sans", expand.x = 3, relative_heights = c(0.1,
                                                              0.1, 0.01, 0.15))
{
    fit = to.cph(fit)
    data = data[, all.vars(fit$terms)]
    event = ggrisk:::model.y(fit)[2]
    time = ggrisk:::model.y(fit)[1]
    riskscore = predict(fit, type = "lp", newdata = data)
    data2 = cbind(data, riskscore)
    data3 = data2[order(data2$riskscore), ]
    if (cutoff.value == "roc") {
        cutoff.point = cutoff::roc(score = data3$riskscore, class = data3[,
                                                                          event])$cutoff
    }else if (cutoff.value == "cutoff") {
        rs = cutoff::cox(data = data3, time = time, y = event,
                         x = "riskscore", cut.numb = 1, n.per = 0.1, y.per = 0.1,
                         round = 20)
        fastStat::to.numeric(rs$p.adjust) = 1
        cutoff.point = (rs$cut1[rs$p.adjust == min(rs$p.adjust)])
        if (length(cutoff.point) > 1)
            cutoff.point = cutoff.point[1]
    }else if (cutoff.value == "median") {
        cutoff.point = median(x = data3$riskscore, na.rm = TRUE)
    }else {
        cutoff.point = cutoff.value
    }
    if (cutoff.point < min(riskscore) || cutoff.point > max(riskscore)) {
        stop("cutoff must between ", min(riskscore), " and ",
             max(riskscore))
    }
    times = median(data3[, time],na.rm = TRUE)
    prob = (rms::Survival(fit))(times = times, lp = riskscore)
    correlaiton = cor(prob, riskscore, method = "spearman")
    if (correlaiton < 0) {
        `Risk Group` = ifelse(data3$riskscore > cutoff.point,
                              code.highrisk, code.lowrisk)
    }else {
        `Risk Group` = ifelse(data3$riskscore < cutoff.point,
                              code.highrisk, code.lowrisk)
    }
    data4 = cbind(data3, `Risk Group`)
    cut.position = (1:nrow(data4))[data4$riskscore == cutoff.point]
    if (length(cut.position) == 0) {
        cut.position = which.min(abs(data4$riskscore - cutoff.point))
    }else if (length(cut.position) > 1) {
        cut.position = cut.position[length(cut.position)]
    }
    data4$riskscore = round(data4$riskscore, 1)
    data4[, time] = round(data4[, time], 1)
    color.A = c(color.A["low"], color.A["high"])
    names(color.A) = c(code.lowrisk, code.highrisk)
    fA = ggplot(data = data4, aes_string(x = 1:nrow(data4), y = data4$riskscore,
                                         color = factor(`Risk Group`))) + geom_point(size = size.points) +
        scale_color_manual(name = title.A.legend, values = color.A) +
        geom_vline(xintercept = cut.position, linetype = "dotted",
                   size = size.dashline) + theme(panel.grid = element_blank(),
                                                 panel.background = element_blank()) + theme(axis.ticks.x = element_blank(),
                                                                                             axis.line.x = element_blank(), axis.text.x = element_blank(),
                                                                                             axis.title.x = element_blank()) + theme(axis.title.y = element_text(size = size.ylab.title,
                                                                                                                                                                 vjust = vjust.A.ylab, angle = 90, family = family), axis.text.y = element_text(size = size.Atext,
                                                                                                                                                                                                                                                family = family), axis.line.y = element_line(size = size.yline,
                                                                                                                                                                                                                                                                                             colour = "black"), axis.ticks.y = element_line(size = size.yticks,
                                                                                                                                                                                                                                                                                                                                            colour = "black")) + theme(legend.title = element_text(size = size.legendtitle,
                                                                                                                                                                                                                                                                                                                                                                                                   family = family), legend.text = element_text(size = size.legendtext,
                                                                                                                                                                                                                                                                                                                                                                                                                                                family = family)) + coord_trans() + ylab(title.A.ylab) +
        scale_x_continuous(expand = c(0, expand.x))
    fA_data <<- data4[,'riskscore',drop=FALSE]
    if (cutoff.show) {
        if (is.null(cutoff.label))
            cutoff.label = paste0("cutoff: ", round(cutoff.point,
                                                    2))
        if (is.null(cutoff.x))
            cutoff.x = cut.position + 3
        if (is.null(cutoff.y))
            cutoff.y = cutoff.point
        fA = fA + annotate("text", x = cutoff.x, y = cutoff.y,
                           label = cutoff.label, family = family, size = size.cutoff,
                           fontface = "plain", colour = "black")
    }
    fA
    color.B = c(color.B["code.0"], color.B["code.1"])
    names(color.B) = c(code.0, code.1)
    fB = ggplot(data = data4, aes_string(x = 1:nrow(data4), y = data4[,
                                                                      time], color = factor(ifelse(data4[, event] == 1, code.1,
                                                                                                   code.0)))) + geom_point(size = size.points) + scale_color_manual(name = title.B.legend,
                                                                                                                                                                    values = color.B) + geom_vline(xintercept = cut.position,
                                                                                                                                                                                                   linetype = "dotted", size = size.dashline) + theme(panel.grid = element_blank(),
                                                                                                                                                                                                                                                      panel.background = element_blank()) + theme(axis.ticks.x = element_blank(),
                                                                                                                                                                                                                                                                                                  axis.line.x = element_blank(), axis.text.x = element_blank(),
                                                                                                                                                                                                                                                                                                  axis.title.x = element_blank()) + theme(axis.title.y = element_text(size = size.ylab.title,
                                                                                                                                                                                                                                                                                                                                                                      vjust = vjust.B.ylab, angle = 90, family = family), axis.text.y = element_text(size = size.Btext,
                                                                                                                                                                                                                                                                                                                                                                                                                                                     family = family), axis.ticks.y = element_line(size = size.yticks),
                                                                                                                                                                                                                                                                                                                                          axis.line.y = element_line(size = size.yline, colour = "black")) +
        theme(legend.title = element_text(size = size.legendtitle,
                                          family = family), legend.text = element_text(size = size.legendtext,
                                                                                       family = family)) + ylab(title.B.ylab) + coord_trans() +
        scale_x_continuous(expand = c(0, expand.x))
    fB
    fB_data <<-data4[,c(time,event)]
    middle = ggplot(data4, aes(x = 1:nrow(data4), y = 1)) + geom_tile(aes(fill = `Risk Group`)) +
        scale_fill_manual(name = title.A.legend, values = color.A) +
        theme(panel.grid = element_blank(), panel.background = element_blank(),
              axis.line = element_blank(), axis.ticks = element_blank(),
              axis.text = element_blank(), axis.title = element_blank(),
              plot.margin = unit(c(0.15, 0, -0.3, 0), "cm")) +
        theme(legend.title = element_text(size = size.legendtitle,
                                          family = family), legend.text = element_text(size = size.legendtext,
                                                                                       family = family)) + scale_x_continuous(expand = c(0,
                                                                                                                                         expand.x))
    middle
    if (is.null(heatmap.genes))
        heatmap.genes = set::not(colnames(data4), c(time, event,
                                                    "Risk Group", "riskscore"))
    data5 = data4[, heatmap.genes]
    if (length(heatmap.genes) == 1) {
        data5 = data.frame(data5)
        colnames(data5) = heatmap.genes
    }
    data6 = cbind(id = 1:nrow(data5), data5)
    data7 <<- do::reshape_toLong(data = data6, var.names = colnames(data5))
    fC = ggplot(data7, aes_string(x = "id", y = "variable", fill = "value")) +
        geom_raster() + theme(panel.grid = element_blank(), panel.background = element_blank(),
                              axis.line = element_blank(), axis.ticks = element_blank(),
                              axis.text.x = element_blank(), axis.title = element_blank(),
                              plot.background = element_blank()) + scale_fill_gradient2(name = title.C.legend,
                                                                                        low = color.C[1], mid = color.C[2], high = color.C[3]) +
        theme(axis.text = element_text(size = size.Ctext, family = family)) +
        theme(legend.title = element_text(size = size.legendtitle,
                                          family = family), legend.text = element_text(size = size.legendtext,
                                                                                       family = family)) + scale_x_continuous(expand = c(0,
                                                                                                                                         expand.x))
    fC
    egg::ggarrange(fA, fB, middle, fC, ncol = 1, labels = c("A",
                                                            "B", "C", ""), label.args = list(gp = grid::gpar(font = 2,
                                                                                                             cex = size.ABC, family = family)), heights = relative_heights)
}
yikeshu0611/TCGAimmunelncRNA documentation built on Dec. 23, 2021, 7:20 p.m.