R/class_balance.R

Defines functions class_balance.nestcv.train class_balance.default class_balance

Documented in class_balance class_balance.default class_balance.nestcv.train

#' Check class balance in training folds
#' 
#' @param object Object of class `nestedcv.glmnet`, `nestcv.train` or `outercv`
#' @return Invisibly a table of the response classes in the training folds
#' @export
#' 
class_balance <- function(object) {
  UseMethod("class_balance")
}


#' @rdname class_balance
#' @export
#' 
class_balance.default <- function(object) {
  ytrain <- unlist(lapply(object$outer_result, '[[', 'ytrain'))
  if (is.numeric(ytrain)) stop("Not classification", call. = FALSE)
  tab <- table(ytrain)
  cat("Training folds:\n")
  print(c(tab))
  yfinal <- object$yfinal
  if (!is.null(yfinal)) {
    cat("Final fit:\n")
    print(c(table(yfinal)))
  }
  invisible(tab)
}


#' @rdname class_balance
#' @export
#' 
class_balance.nestcv.train <- function(object) {
  ytrain <- unlist(lapply(object$outer_result, function(i) i$fit$pred$obs))
  if (is.numeric(ytrain)) stop("Not classification", call. = FALSE)
  tab <- table(ytrain)
  cat("Training folds:\n")
  print(c(tab))
  yfinal <- object$yfinal
  if (!is.null(yfinal)) {
    cat("Final fit:\n")
    print(c(table(yfinal)))
  }
  invisible(tab)
}

Try the nestedcv package in your browser

Any scripts or data that you put into this service are public.

nestedcv documentation built on Oct. 26, 2023, 5:08 p.m.