R/val.R

Defines functions val batch_validate

Documented in batch_validate val

#' Batch validate algorithms over a list of dfms
#'
#' @param x A dfm or list of dfms to evaluate.
#' @param y A variable in docvars of the dfms with the class (must be logical
#'   for now).
#' @param set A variable in docvars indicating membership to training or test
#'   set.
#' @param alg A named list of functions containing the algorithms to be
#'   evaluated (see example)
#' @param pred A list of prediction functions (most packages use predict(),
#'   which is the default). Either length of 1 or same length as alg.
#' @param as_matrix A list of logical values indicating if the respective alg
#'   needs the dfm converted to a matrix.
#' @param positive value for the positive class.
#' @param write_out Should intermediate results be written to the working
#'   directory.
#'
#' @import purrr
#' @import tibble
#' @importFrom utils head
#' @export
#'
#' @examples
#' \dontrun{
#'
#' library(quanteda)
#' corp <- corpus(c(d1 = "Chinese Beijing Chinese",
#'                  d2 = "Chinese Chinese Shanghai",
#'                  d3 = "Chinese Macao",
#'                  d4 = "Tokyo Japan Chinese",
#'                  d5 = "Chinese Chinese Chinese Tokyo Japan"))
#'
#' docvars(corp, "class") <- c(TRUE, TRUE, TRUE, FALSE, FALSE)
#' docvars(corp, "training") <- c(TRUE, TRUE, TRUE, TRUE, FALSE)
#'
#' test_prep <- batch_prep(corp)
#'
#' batch_validate(x = test_prep[1:3],
#'                y = "class",
#'                set = "training",
#'                alg = list(textmodel_nb = quanteda.textmodels::textmodel_nb,
#'                           textmodel_svm = quanteda.textmodels::textmodel_svm)
#' }
batch_validate <- function(x,
                           y,
                           set = docvars(x, "training"),
                           alg = NULL,
                           pred = predict,
                           as_matrix = FALSE,
                           positive = FALSE,
                           write_out = TRUE) {

  if (!is.list(x)) x <- list(x)
  if (!is.list(alg)) alg <- list(alg)
  if (!is.list(pred)) pred <- list(pred)
  if (!is.list(as_matrix)) as_matrix <- list(as_matrix)

  # force same length
  if (length(alg) > length(pred))
    pred <- head(rep(pred, length(alg)), length(alg))
  if (length(alg) > length(as_matrix))
    as_matrix <- head(rep(as_matrix, length(alg)), length(alg))

  out <- purrr::map(seq_along(alg), function(a) {

    if (interactive()) {
      pb <- progress::progress_bar$new(
        total = length(x),
        format = ":what [:bar] :current/:total (:percent) :eta"
      )
    } else {
      pb <- NULL
    }

    purrr::map(x, .f = val, y = y, set = set, alg = alg[[a]],
               pred = pred[[a]], as_matrix = as_matrix[[a]],
               pb = pb, what = names(alg)[a], positive = positive, write_out = write_out)
  })
  names(out) <- names(alg)
  if (is.null(names(out))) {
    names(out) <- seq_along(names(out))
  }

  results <- purrr::map(out, ~ purrr::map(.x, ~ .x[["res"]]))
  results2 <- tibble(
    algorithm = rep(names(out), purrr::map_int(out, length)),
    prep = unlist(purrr::map(out, names)),
    x = unlist(results, recursive = FALSE)
  ) %>%
    unnest(x)

  attr(results2, "prediction") <- purrr::map(out, ~ purrr::map(.x, ~ .x[["prediction"]]))

  return(results2)
}


#' Evaluate  an algorithm with one dfm
#'
#' @param x A dfm to evaluate.
#' @param y A variable in docvars of the dfms with the class (must be logical
#'   for now).
#' @param set A variable in docvars indicating membership to training or test
#'   set.
#' @param alg A functions containing the algorithm to be evaluated (see example)
#' @param pred A function used for prediction (most packages use predict(),
#'   which is the default).
#' @param as_matrix Logical. Indicating if alg needs the dfm converted to a
#'   matrix.
#' @param positive value for the positive class.
#' @param pb,what Used to display status bar when used in batch mode.
#' @param write_out Should intermediate results be written to the working
#'   directory.
#'
#' @return A confusion matrix object.
#' @importFrom  stats predict
#' @export
val <- function(x,
                y,
                set = "training",
                alg = NULL,
                pred = stats::predict,
                as_matrix = FALSE,
                positive = TRUE,
                pb = NULL,
                what = NULL,
                write_out = FALSE) {

  if (!is.null(pb)) {
    pb$tick(tokens = list(what = what))
  }

  training_dfm <- quanteda::dfm_subset(x, docvars(x, set))
  test_dfm <- quanteda::dfm_subset(x, !docvars(x, set))

  train_v <- quanteda::docvars(training_dfm, y)
  true <- quanteda::docvars(test_dfm, y)
  docn <- docnames(test_dfm)
  if (as_matrix) {
    training_dfm <- as.matrix(training_dfm)
    test_dfm <- as.matrix(test_dfm)
  }

  model <- alg(training_dfm, train_v)

  # quanteda warns if test_dfm has features not in the model
  pred <- suppressWarnings(as.logical(pred(model, test_dfm)))
  names(pred) <- docn
  out <- list(
    #model = model,
    prediction = pred,
    res = confu_mat(pred, true, positive = positive)
  )
  if (write_out) {
   saveRDS(out, paste0(what, Sys.time(), ".RDS"))
  }
  return(out)
}
JBGruber/smlhelper documentation built on Oct. 7, 2022, 3:43 p.m.