library(toftools)
library(magrittr)
library(purrr)
library(stringr)
library(knitr)
library(dplyr)
library(tidyr)

knitr::opts_chunk$set(message = FALSE, warning = FALSE, echo = FALSE)

file <- params$tof_file

if (length(file) != 1) {
  stop("Expected parameter 'tof_file' to contain a single file path")
}

if(!file.exists(file)) {
  stop(paste0("Cannot find file: ", file))
}

{.tabset .tabset-fade}

tof <- purrr::quietly(load_tof)(file)$result

if (ncol(tof) == 1) {
  stop(stringr::str_glue("This does not look like a ToF file: {file}"))
}

tof <- tof %>%
  # Add a column called 'label' containing T/H
  mutate(label = extract_labels(sample)) %>%
  # Mode 'filename' & 'label' to first two columns
  select(sample, label, everything())

labels   <- tof$label

sample_count <- tof %>%
  count(label, name = "number")

knitr::kable(sample_count, caption = "samples per class")
data_mat <- as.matrix(tof[,-(1:2)])
rownames(data_mat) <- tof$sample

models <- c(
  # "xgbTree",
  "rf",
  # "ranger",
  "glmnet",
  "gaussprRadial",
  # "lda2"
  "nnet",
  "svmRadial"
)

model_summaries <- c(
  # "extreme gradient boosting XGBoost (https://xgboost.readthedocs.io)",
  "random forest",
  # "random forest using ranger (https://github.com/imbs-hl/ranger)",
  "logistic regression with elastic net regularisation using glmnet (https://cran.r-project.org/web/packages/glmnet/index.html)",
  "Gaussian Process classifier radial basis kernel (using 'kernlab' R package')",
  # "Linear discriminant analysis",
  "Neural network using 'nnet' package",
  "Support Vector Machine (RBF kernel)"
)
names(model_summaries) <- models

for (model in models) {
  cat("\n## ", model, "\n")
  cat(model_summaries[model], "\n\n")

  success <- FALSE
  try({
    classifier <- purrr::quietly(crossvalidate)(data_mat, labels, model = model, n_folds = 10, tune = params$tune)$result
    success <- TRUE
  })

  if (success) {
    print(crossvalidation_roc(classifier))
    cat("\n")
    print(knitr::kable(crossvalidation_metrics(classifier)))
    cat("\n")
    print(crossvalidation_feature_importance(classifier, n_features = 15))
    cat('\n')
    print(knitr::kable(crossvalidation_predictive_probabilities(classifier)))
    cat('\n')
    if (params$tune) {
      cat("Predictive accuracy on different hyper-parameter settings (best performing selected):")
      cat("\n")
      tryCatch({
        print(ggplot2::ggplot(classifier))
      }, error = function(err) {
        cat(err$message)
      })
      cat('\n')
    }
  } else {
    print("Error fitting classifier")
    cat('\n')
  }
}


JimSkinner/toftools documentation built on Oct. 30, 2019, 7:55 p.m.