R/RSLSMOTE.R

Defines functions RSLSMOTE

Documented in RSLSMOTE

#' @title  Relocating safe-level SMOTE with minority outcast handling
#'
#' @description The Relocating Safe-Level SMOTE (RSLS) algorithm improves the
#' quality of synthetic samples generated by Safe-Level SMOTE (SLS) by
#' relocating specific synthetic data points that are too close to the majority
#' class distribution towards the original minority class distribution in the
#' feature space.
#'
#' @param x feature matrix or data.frame.
#' @param y a factor class variable with two classes.
#' @param k1 number of neighbors to link. Default is 5.
#' @param k2 number of neighbors to determine safe levels. Default is 5.
#'
#' @details
#' In Safe-level SMOTE (SLS), a safe-level threshold is used to control the number of synthetic
#' samples generated from each minority instance. This threshold is calculated
#' based on the number of minority and majority instances in the local
#' neighborhood of each minority instance. SLS generates synthetic samples that
#' are located closer to the original minority class distribution in the feature
#' space.
#'
#' In Relocating safe-level SMOTE (RSLS), after generating synthetic samples
#' using the SLS algorithm, the algorithm relocates specific synthetic data
#' points that are deemed to be too close to the majority class distribution in
#' the feature space. The relocation process moves these synthetic data points
#' towards the original minority class distribution in the feature space.
#'
#' This relocation process is performed by first identifying the synthetic data
#' points that are too close to the majority class distribution. Then, for each
#' identified synthetic data point, the algorithm calculates a relocation vector
#' based on the distance between the synthetic data point and its k nearest
#' minority class instances. This relocation vector is used to move the
#' synthetic data point towards the minority class distribution in the feature
#' space.
#'
#' Note: Much faster than \code{smotefamily::RSLS()}.
#'
#' @return a list with resampled dataset.
#'  \item{x_new}{Resampled feature matrix.}
#'  \item{y_new}{Resampled target variable.}
#'  \item{x_syn}{Generated synthetic data.}
#'  \item{C}{Number of synthetic samples for each positive class samples.}
#'
#' @author Fatih Saglam, saglamf89@gmail.com
#'
#' @importFrom  Rfast Dist
#' @importFrom  FNN knnx.index
#' @importFrom  stats rnorm
#' @importFrom  stats sd
#'
#' @references
#' Siriseriwan, W., & Sinapiromsaran, K. (2016). The effective redistribution
#' for imbalance dataset: Relocating safe-level SMOTE with minority outcast
#' handling. Chiang Mai J. Sci, 43(1), 234-246.
#'
#' @examples
#'
#' set.seed(1)
#' x <- rbind(matrix(rnorm(2000, 3, 1), ncol = 2, nrow = 1000),
#'            matrix(rnorm(100, 5, 1), ncol = 2, nrow = 50))
#' y <- as.factor(c(rep("negative", 1000), rep("positive", 50)))
#'
#' plot(x, col = y)
#'
#' # resampling
#' m <- RSLSMOTE(x = x, y = y, k1 = 5, k2 = 5)
#'
#' plot(m$x_new, col = m$y_new)
#'
#' @rdname RSLSMOTE
#' @export

RSLSMOTE <- function(x, y, k1 = 5, k2 = 5) {

  if (!is.data.frame(x) & !is.matrix(x)) {
    stop("x must be a matrix or dataframe")
  }

  if (is.data.frame(x)) {
    x <- as.matrix(x)
  }

  if (!is.factor(y)) {
    stop("y must be a factor")
  }

  if (!is.numeric(k1)) {
    stop("k1 must be numeric")
  }

  if (k1 < 1) {
    stop("k1 must be positive")
  }

  if (!is.numeric(k2)) {
    stop("k2 must be numeric")
  }

  if (k2 < 1) {
    stop("k2 must be positive")
  }

  var_names <- colnames(x)
  x <- as.matrix(x)
  p <- ncol(x)

  class_names <- as.character(unique(y))
  class_pos <- names(which.min(table(y)))
  class_neg <- class_names[class_names != class_pos]

  x_pos <- x[y == class_pos,,drop = FALSE]
  x_neg <- x[y == class_neg,,drop = FALSE]

  n_pos <- nrow(x_pos)
  n_neg <- nrow(x_neg)

  x <- rbind(x_pos, x_neg)

  nn_pos2all <- FNN::knnx.index(data = x, query = x_pos, k = k2 + 1)[,-1]
  nn_pos2pos <- FNN::knnx.index(data = x_pos, query = x_pos, k = k1 + 1)[,-1]
  nn_pos2all_classcounts <- cbind(
    rowSums(nn_pos2all <= n_pos),
    rowSums(nn_pos2all > n_pos)
  )

  safe_levels <- nn_pos2all_classcounts[,1]
  i_safe <- which(safe_levels > 0)
  x_pos_safe <- x_pos[i_safe,,drop = FALSE]
  n_safe <- nrow(x_pos_safe)
  n_syn <- (n_neg - n_pos)
  C <- rep(0, n_pos)
  C[i_safe] <- rep(ceiling(n_syn/n_safe) - 1, n_safe)

  n_diff <- (n_syn - sum(C))
  ii <- sample(1:n_safe, size = abs(n_diff))
  C[i_safe][ii] <- C[i_safe][ii]  + n_diff/abs(n_diff)

  x_syn <- matrix(nrow = 0, ncol = p)

  for (i in 1:n_pos) {
    if (safe_levels[i] > 0 & C[i] > 0) {
      i_k <- sample(1:k1, C[i], replace = TRUE)
      i_nn_pos2pos <- nn_pos2pos[i, i_k]
      i_nn <- nn_pos2all[c(1, i_nn_pos2pos),]
      i_nn_neg <- unique(i_nn[i_nn > n_pos])
      k_safe_levels <- safe_levels[i_nn_pos2pos]
      r <- rep(0, C[i])
      for (j in 1:C[i]) {
        if (k_safe_levels[j] == 0) {
          r[j] <- 0
        } else if (k_safe_levels[j] == safe_levels[i]) {
          r[j] <- runif(1, 0, 1)
        } else if (k_safe_levels[j] < safe_levels[i]) {
          r[j] <- runif(1, 0, k_safe_levels[j]/safe_levels[i])
        } else {
          r[j] <- runif(1, 1 - safe_levels[i]/k_safe_levels[j], 1)
        }
      }
      x_pos_step <- x_pos[rep(i, C[i]),,drop = FALSE]
      x_pos_k <- x_pos[i_nn_pos2pos,,drop = FALSE]
      x_syn_step <- x_pos_step + (x_pos_k - x_pos_step)*r

      if (length(i_nn_neg) > 0) {
        for (j in 1:C[i]) {

          x_syn_vs_pos_vs_nn <- rbind(
            x_syn_step[j,],
            x_pos[i,],
            x_pos[i_nn_pos2pos[j],]
          )
          dist_syn_vs_pos_vs_nn <- Rfast::Dist(x_syn_vs_pos_vs_nn)
          min_dist <- min(dist_syn_vs_pos_vs_nn[row(dist_syn_vs_pos_vs_nn) != col(dist_syn_vs_pos_vs_nn)])
          x_syn_vs_nn_neg <- rbind(
            x_syn_step[j,],
            x[i_nn_neg,]
          )

          dist_syn_vs_nn_neg <- Rfast::Dist(x_syn_vs_nn_neg)[1,-1]
          while (any(dist_syn_vs_nn_neg < min_dist)) {

            if (safe_levels[i] >= k_safe_levels[j]) {
              x_start <- x_pos[i,]
              x_end <- x_syn_step[j,]
            } else {
              x_start <- x_syn_step[j,]
              x_end <- x_pos[i_nn_pos2pos[j],]
            }
            r <- runif(1)
            x_syn_step[j,] <- x_start + r*(x_end - x_start)
            x_syn_vs_pos_vs_nn <- rbind(
              x_syn_step[j,],
              x_pos[i,],
              x_pos[i_nn_pos2pos[j],]
            )

            dist_syn_vs_pos_vs_nn <- Rfast::Dist(x_syn_vs_pos_vs_nn)
            min_dist <- min(dist_syn_vs_pos_vs_nn[row(dist_syn_vs_pos_vs_nn) != col(dist_syn_vs_pos_vs_nn)])

            x_syn_vs_nn_neg <- rbind(
              x_syn_step[j,],
              x[i_nn_neg,]
            )

            dist_syn_vs_nn_neg <- Rfast::Dist(x_syn_vs_nn_neg)[1,-1]
          }
        }
      }
      x_syn <- rbind(x_syn, x_syn_step)
    }
  }

  x_new <- rbind(
    x_syn,
    x_pos,
    x_neg
  )
  y_new <- c(
    rep(class_pos, n_syn + n_pos),
    rep(class_neg, n_neg)
  )
  y_new <- factor(y_new, levels = levels(y), labels = levels(y))
  colnames(x_new) <- var_names

  return(list(
    x_new = x_new,
    y_new = y_new,
    x_syn = x_new[1:n_syn,,drop = FALSE],
    C = C
  ))
}

Try the SMOTEWB package in your browser

Any scripts or data that you put into this service are public.

SMOTEWB documentation built on May 29, 2024, 11:15 a.m.