R/distribution.R

setGeneric("distribution", function(result, ...)
           {standardGeneric("distribution")})

setMethod("distribution", "ClassifyResult", 
          function(result, dataType = c("features", "samples"),
                   plotType = c("density", "histogram"), summaryType = c("percentage", "count"),
                   plot = TRUE, xMax = NULL, xLabel = "Percentage of Cross-validations",
                   yLabel = "Density", title = "Distribution of Feature Selections",
                   fontSizes = c(24, 16, 12), ...)
{
  if(plot == TRUE && !requireNamespace("ggplot2", quietly = TRUE))
    stop("The package 'ggplot2' could not be found. Please install it.")
            
  if(plot == TRUE)
    ggplot2::theme_set(ggplot2::theme_classic() + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA)))
            
  dataType <- match.arg(dataType)
  plotType <- match.arg(plotType)
  summaryType <- match.arg(summaryType)

  if(dataType == "samples")
  {
    allPredictions <- do.call(rbind, predictions(result))
    errors <- by(allPredictions, allPredictions[, "sample"], function(samplePredicitons)
             {
                sampleClass <- rep(actualClasses(result)[samplePredicitons[1, 1]], nrow(samplePredicitons))
                confusion <- table(samplePredicitons[, 2], sampleClass)
                (confusion[upper.tri(confusion)] + confusion[lower.tri(confusion)]) /
                (sum(diag(confusion)) + confusion[upper.tri(confusion)] + confusion[lower.tri(confusion)])
             }) # Sample error rate.
    scores <- rep(NA, length(sampleNames(result)))
    scores[as.numeric(names(errors))] <- errors
    names(scores) <- sampleNames(result)
  } else { # features
    chosenFeatures <- features(result)
    if(is.character(chosenFeatures[[1]])) # No longer numeric row indicies, but character feature IDs.
      allFeatures <- unlist(chosenFeatures)
    else if(is.data.frame(chosenFeatures[[1]])) # Data set and feature ID columns.
      allFeatures <- do.call(rbind, chosenFeatures)
    else if("Pairs" %in% class(chosenFeatures[[1]]))
      allFeatures <- as.data.frame(do.call(c, unname(chosenFeatures)))
    else if(is.character(chosenFeatures[[1]][[1]])) # Two-level list, such as generated by permuting and folding.
      allFeatures <- unlist(chosenFeatures)
    else if(is.data.frame(chosenFeatures[[1]][[1]])) # Data set and feature ID columns.
      allFeatures <- do.call(rbind, lapply(chosenFeatures, function(iteration) do.call(rbind, iteration)))
    else if("Pairs" %in% class(chosenFeatures[[1]]))
      allFeatures <- as.data.frame(do.call(c, unname(lapply(chosenFeatures, function(iteration) do.call(c, unname(iteration))))))
    if(is.data.frame(allFeatures))
    {
      if(all(colnames(allFeatures) == c("feature", "dataset")))
        allFeatures <- paste(allFeatures[, "feature"], paste('(', allFeatures[, "dataset"], ')', sep = ''))
      else
        allFeatures <- paste(allFeatures[, "first"], allFeatures[, "second"], sep = ', ')
    }
    scores <- table(allFeatures)
  }

  if(dataType == "features" && summaryType == "percentage")
  {
    crossValidations <- length(features(result))
    if(result@validation[[1]] == "permuteFold")
      crossValidations <- crossValidations * length(features(result)[[1]])
    scores <- scores / crossValidations * 100
  }
  
  if(is.null(xMax))
  {
    if(dataType == "features")
      xMax <- max(scores)
    else # Samples
      xMax <- 1 # Error rates.
  }
  
  plotData <- data.frame(scores = as.numeric(scores))
  if(plot == TRUE)
  {
    if(!missing(...))
      extras <- list(...)
    else
      extras <- list()
    if(plotType == "density")
    {
      plottedGeom <- do.call(ggplot2::stat_density, c(geom = "path", position = "identity", extras))
    } else { # Histogram plot.
      plottedGeom <- do.call(ggplot2::geom_histogram, extras)
    }
    
    print(ggplot2::ggplot(plotData, ggplot2::aes(x = scores)) + plottedGeom + ggplot2::xlim(0, xMax) +
          ggplot2::xlab(xLabel) + ggplot2::labs(x = xLabel, y = yLabel) + ggplot2::ggtitle(title) +
          ggplot2::theme(axis.title = ggplot2::element_text(size = fontSizes[2]),
                         axis.text = ggplot2::element_text(colour = "black", size = fontSizes[3]),
                         plot.title = ggplot2::element_text(size = fontSizes[1], hjust = 0.5)))
  }
  
  scores
})

Try the ClassifyR package in your browser

Any scripts or data that you put into this service are public.

ClassifyR documentation built on Nov. 8, 2020, 6:53 p.m.