R/plot_PCA.R

Defines functions plot_PCA

Documented in plot_PCA

#' Plot a PCA analysis obtained with do_PCA
#' @param pc result of do_PCA function
#' @param level1 first treatment factor
#' @param level2 second treatment factor
#' @param interactive Generate an interactive plot with plotly?
#' @param n number of axes to plot. Default NULL (all)
#' @param parallel parallel coordinates plot?
#' @param line_size size of line for parallel coordinate plot
#' @export

plot_PCA <- function(pc, level1, level2 = NULL, level1_name = "Treatment",
                     level2_name = "Condition", interactive = FALSE,
                     type = c("none", "polygons", "centroids"),
                     draw_lines = FALSE, arrange_plots = TRUE,
                     n = NULL, parallel = FALSE, line_size = 0.75) {


  if(parallel) {
    data <- pc$pc
    #rescale to -1/1
    rescale_fun <- function(x) 2*((x-min(x)) / (max(x) - min(x))) -1
    data <- apply(data, 2, rescale_fun)
    total_var <- round(pc$cumvar, 1)
    if(is.null(n)) {
      n <- ncol(data)
    }

    data <- data[, seq_len(n)]
    total_var <- total_var[n]

    plot_title <- paste0("PCA ( ", n, " PCs, ", total_var ,"% explained variance)")

      data <- data.frame(var = seq_len(nrow(data)), group=treatment, data)
      data <- reshape2::melt(data, id.vars = c("var", "group"))
      data[, 3] <- as.numeric(as.factor(data[, 3]))

      out <- ggplot()+
        ggplot2::geom_line(data=data, aes(x = variable,
                                          y = value,  group = var,
                                          col  = group), size = line_size) +
        ggplot2::xlab("PC") + ggplot2::ylab("Value") +
        ggplot2::ggtitle(plot_title)

    } else {

    type <- match.arg(type)


    outplot <- function(input, labs_in, axis1, axis2, index) {

      out <- ggplot2::ggplot()
      if(is.null(level2)) {
        out <- out + ggplot2::geom_point(data = input,
                                         ggplot2::aes(x = input[, 1], y = input[, 2],
                                                      color = level1, shape = level2),
                                         show.legend = TRUE, cex=2.3) +
          ggplot2::labs(color = level1_name, shape = level2_name)
      } else {
        out <- out + ggplot2::geom_point(data = input,
                                         ggplot2::aes(x = input[, 1], y = input[, 2],
                                                      color = level1, shape = level2),
                                         show.legend = TRUE, cex=2.3) +
          ggplot2::labs(color = level1_name)
      }
      out <- out + ggplot2::xlab(paste0("PC", axis1, " (", labs_in[1], ")")) +
        ggplot2::ylab(paste0("PC", axis2, " (", labs_in[2], ")"))


      if(type == "polygons") {
        c_h <- plyr::ddply(input, plyr::.(factor_convex),
                           function(df) df[chull(df[,1], df[,2]), ])

        out <- out +
          geom_polygon(data=c_h, aes(x=c_h[, 1], y=c_h[,2],
                                     group=factor_convex, fill=factor_convex),
                       alpha=0.2, show.legend = FALSE)+
          guides(fill=FALSE)

      } else if(type == "ellipses") {
        out <- out + stat_ellipse(data = input, aes(x=input[, 1],
                                                    y= input[, 2], color= factor_convex,
                                                    group = factor_convex))

        #geom_path(data=ellipses, aes(x=x, y=y,colour=factor(z)), size=1, linetype=2)
      }

      if(draw_lines) {

        # } else if(draw_centroids) {
        centroids <- aggregate(input[,1:2], list(factor_convex), mean)
        centroids2 <- centroids[match(input$factor_convex, centroids$Group.1), ]
        input <- cbind(input, centroids2[, 2:3])
        colnames(input)[4:5] <- c("x.mean", "y.mean")

        # ellipses <- lapply(split(input, factor_convex),
        #                    function(x) dataEllipse(x[, 1], x[,2], levels=0.95, draw = FALSE))
        # ellipses <- lapply(seq_along(ellipses), function(i) cbind(ellipses[[i]], rep(i, nrow(ellipses[[i]]))))
        # ellipses <- data.frame(do.call("rbind", ellipses))
        #colnames(ellipses) <- c("x", "y", "z")

        out <- out +
          geom_point(data=centroids, aes(centroids[, 2],centroids[, 3]),
                     size = 1.5, show.legend = FALSE) +
          geom_segment(data = input,  aes(x=x.mean, y=y.mean, xend= input[, 1],
                                          yend= input[, 2], col = factor(factor_convex)),
                       show.legend = FALSE)
      }
      #
      #     out + scale_color_brewer(palette='Set1') +
      #       scale_fill_brewer(palette='Set1')

      #out + scale_colour_manual() +scale_shape_manual()
      invisible(out)
    }

    #
    #
    #
    #   gg <- merge(df,aggregate(cbind(mean.x=x,mean.y=y)~class,df,mean),by="class")
    #   ggplot(gg, aes(x,y,color=factor(class)))+geom_point(size=3)+
    #     geom_point(aes(x=mean.x,y=mean.y),size=5)+
    #     geom_segment(aes(x=mean.x, y=mean.y, xend=x, yend=y))
    # }


    data_labs <- paste0(round(pc$var, 1),  "%")
    data <- data.frame(pc$pc)

    total_var <- round(pc$cumvar, 1)
    total_var <- total_var[n]

    plot_title <- paste0("PCA ( ", n, "PCs, ", total_var ,"% explained variance)")

    if(is.null(n)) {
      n <- ncol(data)
    }

    if(n > 1){
    comb <- combn(seq_len(n), 2)
    } else {
    comb <- matrix(1)
    }

    this_plot <- list()
    this_seq <- seq_len(ncol(comb))

    nmax <- ncol(comb)

    if(!is.null(level2)) {
    factor_convex <- paste0(level1, "_", level2)
    } else {
    factor_convex <- level1
    }
    for(i in this_seq) {
      which_cols <- comb[, i]
      data_i <- data.frame(data[, which_cols], factor_convex)
      this_plot[[i]] <- outplot(data_i,
                                labs_in = data_labs[which_cols], which_cols[1],
                                which_cols[2], index = i)
    }

    lseq <- ncol(comb)
    if(lseq == 1) {
      nrow <- ncol <- 1
    } else {
      ifelse(lseq <=3, ncol <-  2, ncol <- 3)
      nrow <- ceiling(lseq / ncol)
    }

    if(!interactive) {
      if(arrange_plots){
        out <- multiplot(plotlist=this_plot, cols = ncol, top = plot_title)
      } else {
        this_plot[[1]] <- this_plot[[1]] + ggplot2::ggtitle(plot_title)
        for(this in this_plot) plot(this)
        return(invisible(this_plot))
      }
    } else {
      if(arrange_plots) {
        my_exp <- paste0("plotly::subplot(", paste("this_plot[[", seq_along(this_plot), "]]",
                                                   collapse = ", "), ", nrows = ", nrow, ")")
        out <-  eval(parse(text=my_exp))
      } else {
        this_plot[[1]] <- this_plot[[1]] + ggplot2::ggtitle(plot_title)
        for(this in this_plot) plot(plotly::ggplotly(this))
        return(invisible(this_plot))
      }
    }

  }

  out
}
leandroroser/RNASeqFunctions documentation built on May 17, 2019, 7:31 p.m.