
#' Initialise validatr
#' Initialises a validatr object.
#' The type of data being tested influences how the validatr and it's methods
#' respond. The following types are supported:
#' * __regression__ regression data (default).
#' * __ts__ time-series data (`ts` argument specified).
#' * __classification__ classification data (`y` variable character, factor or
#' logical).
#' The type of cross-validation and accuracy measures to be calculated are
#' influenced by this parameter. For regression, k-fold cross-validation is
#' carried out and requires the number of folds `k` to be specified. Leave one
#' out cross-validation can easily be accomplished by setting `k` equal to the
#' number of observations.
#' For time-series, time-series cross-validation takes place. This requires the
#' `start`, `horizon`, `shift` and `ts` arguments to be specified:
#' * `start` is the start of the first fold.
#' * `horizon` is the length of the fold.
#' * `shift` is the length of time to move forward.
#' * `ts` is the name of the variable containing time-series data.
#' If `start` is numeric, then `horizon` and `shift` are also numeric. If
#' `start` is date or POSIX, then `horizon` and `shift` follow the same
#' convention as for `seq.Date` and `seq.POSIXt`. Hence, they are a character
#' string, containing one of "sec", "min", "hour", "day", "DSTday", "week",
#' "month", "quarter" or "year".
#' Finally, classification carries out k-fold cross validation as well, but its
#' accruacy measures will be different to regression.
#' @param data data frame containing variables for modelling.
#' @param y dependent variable name. Non-standard evaluation.
#' @param k integer. Number of folds.
#' @param start numeric, date or POSIX object specifying the start date for
#'   time-series validation folds.
#' @param horizon forecast horizon to evaluate.
#' @param shift length of time to move forward for each new fold.
#' @param ts time-series variable name. Non-standard evaluation.
#' @return
#' `validatr` returns an initial validatr object. This object contains cross
#' validation folds and validation parameters.
#' @export
#' @examples
#' validatr_obj <- validatr(iris, y = Sepal.Length, k = 5)
#' head(validatr_obj$folds[[5]]$train)
#' head(validatr_obj$folds[[5]]$validation)
validatr <- function(data, y, k = 10, ts = NULL, start = NULL,
                     horizon = NULL, shift = NULL) {

#' @export
validatr.data.frame <- function(data, y, k = 10, ts = NULL, start = NULL,
                                horizon = NULL, shift = NULL) {

  y <- deparse(substitute(y))
  ts <- deparse(substitute(ts))

  if (ts != "NULL") {
    data_type = "ts"
  } else if (any(is.character(data[1,y]),
                 is.logical(data[1,y]))) {
    data_type = "classification"
  } else {
    data_type = "regression"

  validatr <- list(params = as.list(environment()),
                   folds = list())

  if (data_type %in% c("regression", "classification")) {
    folds <- cut(sample(nrow(data)), breaks = k, labels = FALSE)
    for(iF in 1:k){
      idx <- which(folds==iF, arr.ind=TRUE)
      validatr$folds[[as.character(iF)]] <- list(
        "train" = c(1:nrow(data))[-idx],
        "validation" = idx
  } else if (data_type == "ts") {
    if (length(c(start, horizon, shift, ts)) != 4) {
      stop("a time-series cross-validation parameter has not been entered.")
    if (!identical(class(start), class(data[1, ts]))) {
      print(class(data[1, ts]))
      stop("start is not same class as ts variable.")

    end <- max(data[[ts]])
    if (end <= start) stop("Start of fold is later then final ts value.")
    fold_names <- seq(start, end, shift)
    validatr$folds <- lapply(fold_names, function(x) {
      train_idx <- which(data[[ts]] < x, arr.ind = TRUE)
      val_end <- seq(x, length=2, by=horizon)[2]
      val_idx <- which(data[[ts]] >= x & data[[ts]] < val_end,
                              arr.ind = TRUE)
      list(train = train_idx,
           validation = val_idx)
    len_complete <- max(sapply(validatr$folds,
                                  function(x) length(x$validation)))
    idx_complete <- sapply(validatr$folds,
                           function(x) length(x$validation) == len_complete)
    validatr$folds <- validatr$folds[idx_complete]
    names(validatr$folds) <- fold_names[idx_complete]

  class(validatr) <- "validatr"

#' @export
validatr.grouped_df <- function(data, y, k = 10, ts = NULL, start = NULL,
                                horizon = NULL, shift = NULL) {

  y <- deparse(substitute(y))
  ts <- deparse(substitute(ts))
  group_vars <- attr(data, "vars")
  group_labels <- apply(attr(data, "labels"), 1, paste0, collapse="_")
  group_indices <- lapply(attr(data, "indices"), function(x) x + 1) # 0 base index fix
  names(group_indices) <- group_labels
  data <- lapply(group_indices, function(x) {
    # remove group columns so user can do lm(y ~ ., train)
    dplyr::ungroup(data[x,])[,-which(names(data) %in% group_vars)]

  if (ts != "NULL") {
    data_type = "ts"
  } else if (any(is.character(data[1,y]),
                 is.logical(data[1,y]))) {
    data_type = "classification"
  } else {
    data_type = "regression"

  validatr <- list(params = as.list(environment()),
                   folds = list())

  if (data_type %in% c("regression", "classification")) {
    for (iG in group_labels) {
      validatr$folds[[iG]] <- list()
      folds <- cut(sample(nrow(data[[iG]])), breaks = k, labels = FALSE)
      for(iF in 1:k){
        idx <- which(folds==iF, arr.ind=TRUE)
        validatr$folds[[iG]][[as.character(iF)]] <- list(
          "train" = c(1:nrow(data[[iG]]))[-idx],
          "validation" = idx
  } else if (data_type == "ts") {
    if (length(c(start, horizon, shift, ts)) != 4) {
      stop("a time-series cross-validation parameter has not been entered.")
    if (!identical(class(start), class(data[[1]][[1, ts]]))) {
      print(class(data[[1]][1, ts]))
      stop("start is not same class as ts variable.")

    for (iG in group_labels) {
      validatr$folds[[iG]] <- list()
      end <- max(data[[iG]][[ts]])
      if (end <= start) stop("Start of fold is later then final ts value.")
      fold_names <- seq(start, end, shift)
      validatr$folds[[iG]] <- lapply(fold_names, function(x) {
        train_idx <- which(data[[iG]][[ts]] < x, arr.ind = TRUE)
        val_end <- seq(x, length=2, by=horizon)[2]
        val_idx <- which(data[[iG]][[ts]] >= x & data[[iG]][[ts]] < val_end,
                         arr.ind = TRUE)
        list(train = train_idx,
             validation = val_idx)
      len_complete <- max(sapply(validatr$folds[[iG]],
                                 function(x) length(x$validation)))
      idx_complete <- sapply(validatr$folds[[iG]],
                             function(x) length(x$validation) == len_complete)
      validatr$folds[[iG]] <- validatr$folds[[iG]][idx_complete]
      names(validatr$folds[[iG]]) <- fold_names[idx_complete]

  class(validatr) <- c("grouped_validatr", "validatr")

#' @export
print.validatr <- function(x, ...) {
  cat("You are working with a validatr object. Good job!\n\n",
      "Number of folds: ", length(x$folds), "\n",
      "Data type: ", x$params$data_type, "\n",
      "Response variable: ", x$params$y, "\n",
      sep = "")

  if (is.null(x$models)) {
    cat("Models not fitted.\n")
  } else {
    cat("Number of models fitted:", length(x$models[[1]]), "\n")

  if (is.null(x$params$models_predicted)) {
    cat("Predictions not calculated.\n")
  } else {
    cat("Predictions have been calculated.\n")

  if (is.null(x$accuracy$average_accuracy)) {
    cat("Accuracy measures not calculated.\n")
  } else {
    cat("\nAverage accuracy:\n\n")


#' @export
print.grouped_validatr <- function(x, ...) {
  cat("You are working with a grouped validatr object. Good job!\n\n",
      "Number of groups: ", length(x$folds), "\n",
      "Data type: ", x$params$data_type, "\n",
      "Response variable: ", x$params$y, "\n",
      sep = "")

  if (is.null(x$models)) {
    cat("Models not fitted.\n")
  } else {
    cat("Number of models fitted:", length(x$models[[1]]), "\n")

  if (is.null(x$params$models_predicted)) {
    cat("Predictions not calculated.\n")
  } else {
    cat("Predictions have been calculated.\n")

  if (is.null(x$accuracy$average_accuracy)) {
    cat("Accuracy measures not calculated.\n")
  } else {
    cat("\nAverage accuracy:\n\n")


#' Validatr object
#' `is.validatr` tests if its argument is a validatr object.
#' @param x an R object.
#' @export
is.validatr <- function(x) {
  any(class(x) == "validatr")
camroach87/validatr documentation built on May 14, 2019, 2:41 p.m.