library(toftools)
library(magrittr)
library(purrr)
library(stringr)
library(knitr)
library(dplyr)
library(tidyr)

knitr::opts_chunk$set(message = FALSE, warning = FALSE, echo = FALSE)

X <- params$X
y <- params$y

stopifnot(is.matrix(X))
stopifnot(!is.null(rownames(X)))
stopifnot(is.factor(y))

{.tabset .tabset-fade}

sample_count <- tibble::tibble(class = y) %>%
  count(class, name = "number")

knitr::kable(sample_count, caption = "samples per class")

if (nrow(sample_count) != 2) {
  stop(paste0(file, ": Expected exactly two classes, found ", nrow(sample_count)))
}
models <- c("xgbTree", "glmnet")
model_summaries <- c(
  "extreme gradient boosting XGBoost (https://xgboost.readthedocs.io)",
  "logistic regression with elastic net regularisation using glmnet (https://cran.r-project.org/web/packages/glmnet/index.html)"
)
names(model_summaries) <- models

for (model in models) {
  cat("\n## ", model, "\n")
  cat(model_summaries[model], "\n\n")

  classifier <- purrr::quietly(crossvalidate)(X, y, model = model, n_folds = 10, tune = params$tune)$result

  print(crossvalidation_roc(classifier))
  cat("\n")
  print(knitr::kable(crossvalidation_metrics(classifier)))
  cat("\n")


  reference_mat <- vec_to_mat(X[1,]) # %>% coarsen(900, 900)
  importance <- crossvalidation_feature_importance(classifier, output_dataframe = TRUE) %>%
    filter(importance > 0) %>%
    mutate(
      GC  = feature %>% str_remove(":.*$") %>% as.numeric(),
      IMS = feature %>% str_remove("^.*:") %>% as.numeric()
    )

  print(toftools:::plot_features(importance,
                                 reference_mat,
                                 ims_transformation = "log1p"))

  print(toftools:::plot_features(importance,
                                 reference_mat,
                                 ims_transformation = "log1p",
                                 discretise_features = TRUE))

  cat('\n')

  print(knitr::kable(crossvalidation_predictive_probabilities(classifier)))
  cat('\n')
  if (params$tune) {
    cat("Predictive accuracy on different hyper-parameter settings (best performing selected):")
    cat("\n")
    tryCatch({
      print(ggplot2::ggplot(classifier))
    }, error = function(err) {
      cat(err$message)
    })
    cat('\n')
  }
}


JimSkinner/toftools documentation built on Oct. 30, 2019, 7:55 p.m.