#' Generate distinct colours
#' @param ncolours number of colours we want
#' @return vector of hex colours
#' @export
generate_colours <- function(ncolours) {
colours <- vector("list", ncolours)
for (i in 1:ncolours) {
h <- i/ncolours
s <- 0.8 + 0.1 * runif(1)
v <- 0.6 + 0.1 * runif(1)
colours[i] <- hsv(h,s,v)
}
return(unlist(colours))
}
#' Generate a set of blues
#' @param n number of colours to generate
#' @import RColorBrewer
generate_blues<- function(n){
if(n >= 7){
colours <- colorRampPalette(brewer.pal(8, 'Blues')[2:8])(n)
} else{
colours <- brewer.pal(n+1, 'Blues')[2:(n+1)]
}
return(colours)
}
#' Calculate average AUC
#' @param rocs a list of ROC_curve objects
#' @return average AUC
average_auc <- function(rocs){
lapply(rocs, function(x) x$calc_auc())
auc <- mean(sapply(rocs, function(x) x$auc))
return(auc)
}
#' Plot Multiple ROC curves
#' @param rocs a list of ROC_curve objects
#' @param title str title to give plot, if null a default title is generated
#' @param info additional info to add to the default title
#' @param scale controls plot cex arguments to size text
#' @export
plot_rocs <- function(rocs, title=NULL, info="", scale=0.8){
n <- length(rocs)
auc <- average_auc(rocs)
colours <- generate_blues(n)
if(is.null(title)){
title <- sprintf("%s ROC curves with mean AUC of %.3f", info, auc)
}
plot(0:1, 0:1, type="l", lty=2, xlab="False Positive Rate",
ylab="True Positive Rate", main=title,
cex.main=scale, cex.lab=scale, cex.axis=scale)
for(i in 1:n){
lines(rocs[[i]]$FP, rocs[[i]]$TP, col=colours[i])
}
legend("bottomright", legend=c("Chance", paste0("Models 1-", n)),
col=c("black", colours[n]), lty=2:1, cex=scale)
}
#' Calculate optimal threshold
#' @param amss a list of AMS_data objects
#' @return optimal decision threshold considering all curves#
ams_threshold <- function(amss){
# calculate the minimum bounding curve of the curves in amss
ams_vec <- sapply(amss, function(x) x$ams)
min_ams_obj <- amss[[1]]
min_ams <- pmax(apply(ams_vec, 1, min), 1e-1)
min_ams_obj$ams <- min_ams
min_max_thresh <- min_ams_obj$get_max_thresh()
return(min_max_thresh)
}
#' define a function to plot multiple ams objects on the same axes
#' @param amss list of ams objects
#' @param title str title to give plot, if null a default title is generated
#' @param info additional info to add to the default title
#' @param scale controls plot cex arguments to size text
#' @export
plot_amss <- function(amss, title=NULL, info="", min.max=TRUE, scale=0.8){
lapply(amss, function(x) x$calc_ams())
y_max <- max(sapply(amss, function(x) max(x$ams)))
n <- length(amss)
colours <- generate_blues(n)
ams <- amss[[1]]
if(min.max){
min_max_thresh <- ams_threshold(amss)
if(is.null(title)){
title <- sprintf("%s AMS plots with threshold at %.3f", info, min_max_thresh)
}
} else if(!min.max & is.null(title)){
title <- "AMS data"
}
plot(ams$thresholds, ams$ams, type="l", col=colours[1], main=title,
xlab="Decision Threshold", ylab="AMS", ylim=c(0, y_max),
cex.main=scale, cex.lab=scale, cex.axis=scale)
for(i in 2:n){
ams <- amss[[i]]
lines(ams$thresholds, ams$ams, type="l", col=colours[i])
}
if(min.max){
abline(v=min_max_thresh, lty=2)
legend("bottomleft", legend=c("Optimal threshold", paste0("Models 1-", n)),
col=c("black", colours[n]), lty=2:1, bg="white", cex=scale)
}else{
legend("bottomleft", legend=1:n, col=colours, lty=1, bg="white", cex=scale)
}
}
#' Wrap figure saving
#' @param plot_func partially called function of no arguments to generate plot
#' @param filepath string file path to save to
#' @param filetype type to save as (function name)
#' @export
save_fig <- function(plot_func, filepath, filetype=pdf) {
filetype(file=filepath)
plot_func()
dev.off() -> .
}
#' Density plots of variables
#'
#' Given a matrix plot the density of the listed variables using ggplot2 facet_wrap.
#'
#' If the matrix X is more than 10,000 samples, a random 10,000 samples will be selected
#' to keep the amount of data plotted reasonable.
#'
#' @param X An nxd matrix with samples as rows and features as columns.
#' @param variables Optional vector of column names to be plotted
#' @param lables Optional vector of class labels to view distributions by class
#'
#' @import ggplot2
#' @importFrom tidyr pivot_longer
#' @export
plot_distributions <- function(X, variables=NULL, labels=NULL){
#if labels not provided just use 1s
if(is.null(labels)){
labels <- rep(1, nrow(X))
}
#if X is large, take a sample
if(nrow(X) > 10000){
idx <- sample(1:nrow(X), 10000)
X <- X[idx,]
labels <- labels[idx]
}
#if columns specified, select just these
if(!is.null(variables)){
X <- X[, variables]
}
#unpivot (make a long format)
plot_data <- cbind(rownames(X), labels, as.data.frame(X)) %>%
pivot_longer(cols=c(-1, -2), values_to = "Value", names_to="Variable")
#create plot
ggplot(plot_data, aes(Value, colour=labels)) +
geom_density() +
theme_minimal() +
theme(legend.position = "none") +
facet_wrap(vars(Variable), scales="free")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.