R/plotSplines.R

Defines functions plotSplines

Documented in plotSplines

#' Plot splines used by the CytoNorm model
#'
#' @param model Model as generated by CytoNorm.train
#' @param batches Batches to include. One plot per batch is generated.
#'                Default = all batches used in the model.
#' @param channels Channels to include. Default = first three channels used.
#' @param clusters Clusters to include. Default = all clusters.
#' @param groupClusters Logical, if TRUE al clusters are shown on one subplot,
#'                      if FALSE, there will be a separate row per cluster.
#'
#' @returns List with one plot per batch. The figure shows a grid with the
#'          specified clusters in rows and the specified markers in columns.
#'          In every subfigure, black dots indicate the quantiles used by the
#'          model and a red line shows the spline.
#'
#' @examples
#'
#' dir <- system.file("extdata", package = "CytoNorm")
#' files <- list.files(dir, pattern = "fcs$")
#' data <- data.frame(File = files,
#'                    Path = file.path(dir, files),
#'                    Type = stringr::str_match(files, "_([12]).fcs")[,2],
#'                    Batch = stringr::str_match(files, "PTLG[0-9]*")[,1],
#'                    stringsAsFactors = FALSE)
#' data$Type <- c("1" = "Train", "2" = "Validation")[data$Type]
#' train_data <- dplyr::filter(data, Type == "Train")
#' validation_data <- dplyr::filter(data, Type == "Validation")
#'
#' ff <- flowCore::read.FCS(data$Path[1])
#' channels <- grep("Di$", flowCore::colnames(ff), value = TRUE)
#' transformList <- flowCore::transformList(channels,
#'                                          cytofTransform)
#' transformList.reverse <- flowCore::transformList(channels,
#'                                                  cytofTransform.reverse)
#'
#' model <- CytoNorm.train(files = train_data$Path,
#'                         labels = train_data$Batch,
#'                         channels = channels,
#'                         transformList = transformList,
#'                         FlowSOM.params = list(nCells = 10000, #1000000
#'                                               xdim = 10,
#'                                               ydim = 10,
#'                                               nClus = 5,
#'                                               scale = FALSE),
#'                         normParams = list(nQ = 99,
#'                                           limit = c(0,7)),
#'                         seed = 1)
#'
#' plotlist <- plotSplines(model)
#' plotlist <- plotSplines(model, groupClusters = TRUE)
#' plotlist[[1]]
#'
#' @export
#' @importFrom ggplot2 ggplot geom_point facet_grid geom_abline geom_line .data
#'             theme_minimal xlim ylim xlab ylab ggtitle geom_text theme_void
plotSplines <- function(model,
                        batches = names(model$clusterRes[[1]]$splines),
                        channels = model$clusterRes[[1]]$channels[1:3],
                        clusters = names(model$clusterRes),
                        groupClusters = FALSE){

    minValue <- suppressWarnings(min(sapply(model$clusterRes,
                                            function(x) min(sapply(x$quantiles,
                                                                   function(batchQuantiles){
                                                                     min(batchQuantiles[,channels], na.rm=TRUE )
                                                                   })))))
    maxValue <- suppressWarnings(max(sapply(model$clusterRes,
                                            function(x) max(sapply(x$quantiles,
                                                                   function(batchQuantiles){
                                                                     max(batchQuantiles[,channels], na.rm=TRUE )
                                                                   })))))
    range <- maxValue - minValue
    minValue <- minValue - 0.1*range
    maxValue <- maxValue + 0.1*range

    plotlist <- list()
    for(batch in batches){
        ref_points <- lapply(clusters, function(cluster){
            l <- lapply(channels, function(channel){
                data.frame("cluster" = cluster,
                           "channel" = channel,
                           "batch_quantiles" = model$clusterRes[[cluster]]$quantiles[[batch]][,channel],
                           "goal_quantiles" = model$clusterRes[[cluster]]$refQuantiles[,channel])})
            do.call(rbind, l)
        })
        ref_points <- do.call(rbind, ref_points)
        ref_points$cluster <- factor(ref_points$cluster,
                                     levels = clusters)
        if(all(sapply(ref_points$channel, function(x)x %in% names(model$fsom$prettyColnames)))){
            ref_points$channel <- factor(model$fsom$prettyColnames[ref_points$channel],
                                     levels = model$fsom$prettyColnames)
        }

        spline_x  <- seq(minValue, maxValue, (maxValue-minValue)/100)
        spline_points <- lapply(clusters, function(cluster){
            splines <- model$clusterRes[[cluster]]$splines
            l <- lapply(channels, function(channel){
                data.frame("cluster" = cluster,
                           "channel" = channel,
                           "x" = spline_x,
                           "y" = splines[[batch]][[channel]](spline_x))})
            do.call(rbind, l)
        })
        spline_points <- do.call(rbind, spline_points)
        spline_points$cluster <- factor(spline_points$cluster,
                                        levels = clusters)
        if(all(sapply(spline_points$channel, function(x)x %in% names(model$fsom$prettyColnames)))){
            spline_points$channel <- factor(model$fsom$prettyColnames[spline_points$channel],
                                            levels = model$fsom$prettyColnames)
        }



        if(any(!is.na(ref_points$batch_quantiles))){
            p <- ggplot2::ggplot(ref_points) +
                ggplot2::geom_abline(slope = 1, intercept = 0, col = "#999999") +
                ggplot2::theme_minimal() +
                ggplot2::xlim(minValue, maxValue) +
                ggplot2::ylim(minValue, maxValue) +
                ggplot2::xlab("Original distribution") +
                ggplot2::ylab("Goal distribution") +
                ggplot2::ggtitle(paste("Batch", batch))

            if(!groupClusters){
                p <- p +
                    ggplot2::facet_grid(.data$cluster ~ .data$channel) +
                    ggplot2::geom_line(aes(x = .data$x, y = .data$y),
                                       data = spline_points, col = "#b30000")+
                    ggplot2::geom_point(aes(x = .data$batch_quantiles,
                                            y = .data$goal_quantiles))
            } else {
                p <- p +
                    ggplot2::facet_wrap(~ .data$channel) +
                    ggplot2::geom_line(aes(x = .data$x, y = .data$y,
                                           color = .data$cluster),
                                       data = spline_points)+
                    ggplot2::geom_point(aes(x = .data$batch_quantiles,
                                            y = .data$goal_quantiles,
                                            color = .data$cluster))
            }


        } else {
            p <- ggplot2::ggplot() +
                ggplot2::geom_text(aes(x = .data$x, y = .data$y,
                                       label = .data$label),
                          data.frame(x = 0, y = 0,
                                     label = paste("Batch ", batch, ": NA"))) +
                ggplot2::theme_void()
        }
        plotlist[[batch]] <- p
    }
    return(plotlist)
}
saeyslab/CytoNorm documentation built on Nov. 2, 2024, 12:39 p.m.