R/bsmote_impl.R

Defines functions danger bsmote_impl bsmote

Documented in bsmote

#' borderline-SMOTE Algorithm
#'
#' BSMOTE generates generate new examples of the minority class using nearest
#'  neighbors of these cases in the border region between classes.
#'
#' @inheritParams step_smote
#' @param df data.frame or tibble. Must have 1 factor variable and remaining
#'  numeric variables.
#' @param var Character, name of variable containing factor variable.
#' @param k An integer. Number of nearest neighbor that are used
#'  to generate the new examples of the minority class.
#' @param all_neighbors Type of two borderline-SMOTE method. Defaults to FALSE.
#'  See details.
#'
#' @return A data.frame or tibble, depending on type of `df`.
#' @export
#'
#' @details
#' This methods works the same way as [smote()], expect that instead of
#' generating points around every point of of the minority class each point is
#' first being classified into the boxes "danger" and "not". For each point the
#' k nearest neighbors is calculated. If all the neighbors comes from a
#' different class it is labeled noise and put in to the "not" box. If more then
#' half of the neighbors comes from a different class it is labeled "danger.
#  Points will be generated around points labeled "danger".
#'
#' If `all_neighbors = FALSE` then points will be generated between nearest
#' neighbors in its own class. If `all_neighbors = TRUE` then points will be
#' generated between any nearest neighbors. See examples for visualization.
#'
#' The parameter `neighbors` controls the way the new examples are created.
#' For each currently existing minority class example X new examples will be
#' created (this is controlled by the parameter `over_ratio` as mentioned
#' above). These examples will be generated by using the information from the
#' `neighbors` nearest neighbor of each example of the minority class.
#' The parameter `neighbors` controls how many of these neighbor are used.
#'
#' All columns used in this step must be numeric with no missing data.
#'
#' @references Hui Han, Wen-Yuan Wang, and Bing-Huan Mao. Borderline-smote:
#' a new over-sampling method in imbalanced data sets learning. In
#' International Conference on Intelligent Computing, pages 878–887. Springer,
#' 2005.
#'
#' @seealso [step_bsmote()] for step function of this method
#' @family Direct Implementations
#'
#' @examples
#' circle_numeric <- circle_example[, c("x", "y", "class")]
#'
#' res <- bsmote(circle_numeric, var = "class")
#'
#' res <- bsmote(circle_numeric, var = "class", k = 10)
#'
#' res <- bsmote(circle_numeric, var = "class", over_ratio = 0.8)
#'
#' res <- bsmote(circle_numeric, var = "class", all_neighbors = TRUE)
bsmote <- function(df, var, k = 5, over_ratio = 1, all_neighbors = FALSE) {
  if (length(var) != 1) {
    rlang::abort("Please select a single factor variable for `var`.")
  }

  var <- rlang::arg_match(var, colnames(df))

  if (!(is.factor(df[[var]]) | is.character(df[[var]]))) {
    rlang::abort(glue("{var} should be a factor or character variable."))
  }

  if (length(k) != 1) {
    rlang::abort("`k` must be length 1.")
  }

  if (k < 1) {
    rlang::abort("`k` must be non-negative.")
  }

  predictors <- setdiff(colnames(df), var)

  check_numeric(df[, predictors])
  check_na(select(df, -all_of(var)))

  bsmote_impl(df, var, k, over_ratio)
}

bsmote_impl <- function(df, var, k = 5, over_ratio = 1, all_neighbors = FALSE) {
  majority_count <- max(table(df[[var]]))
  ratio_target <- majority_count * over_ratio
  which_upsample <- which(table(df[[var]]) < ratio_target)
  samples_needed <- ratio_target - table(df[[var]])[which_upsample]
  min_names <- names(samples_needed)
  out_dfs <- list()
  for (i in seq_along(min_names)) {
    data_mat <- as.matrix(df[names(df) != var])
    ids <- RANN::nn2(data_mat, k = k + 1, searchtype = "priority")$nn.idx
    min_class_in <- df[[var]] == min_names[i]

    danger_ids <- danger(
      x = rowSums(matrix((min_class_in)[ids], ncol = ncol(ids))) - 1,
      k = k
    )

    if (sum(danger_ids) <= k) {
      rlang::abort(glue(
        "Not enough danger observations of '{min_names[i]}' to perform BSMOTE."
      ))
    }

    if (all_neighbors == FALSE) {
      tmp_df <- as.data.frame(
        smote_data(
          data = data_mat[min_class_in, ],
          k = k,
          n_samples = samples_needed[i],
          smote_ids = which(danger_ids[min_class_in])
        )
      )
    }
    if (all_neighbors == TRUE) {
      tmp_df <- as.data.frame(
        smote_data(data_mat, k, samples_needed[i], which(danger_ids))
      )
    }

    colnames(tmp_df) <- colnames(data_mat)
    tmp_df[[var]] <- min_names[i]
    out_dfs[[i]] <- tmp_df
  }

  final <- rbind(df, do.call(rbind, out_dfs))
  final[[var]] <- factor(final[[var]], levels = levels(df[[var]]))
  rownames(final) <- NULL
  final
}

danger <- function(x, k) {
  (x != k) & (k / 2 <= x)
}

Try the themis package in your browser

Any scripts or data that you put into this service are public.

themis documentation built on Aug. 15, 2023, 1:05 a.m.