R/rotation_forest.R

Defines functions get_rdm_subset bootstrap rot_mtrx rotation_forest

# ======================================================================
# Description: R implementation for Rotation Forest
# Author: Haedong Kim
# 2018-08-19: created
# 2018-08-26: prototype
# ======================================================================


# 1. Helper functions --------------------------------------------------
#' Split the data into k subsets
#'
#' @inheritParams rotation_forest
#' @return list of k subsets
get_rdm_subset <- function(x, y, k) {

  # Calculate the number of features per each subset
  q <- ncol(x) %/% k
  r <- ncol(x) %% k
  num_fts <- rep(q, k)
  if (r != 0) {
    for (i in 1:r) {

      #### The length of num_fts (k) is always larger than r,
      # so this loop ends before the subscript go out of bound
      num_fts[i] <- num_fts[i] + 1
      r <- r - 1

      if (r == 0) {break()}
    }
  }

  # Assign variable names for indentifing the original sequence of variables
  colnames(x) <- paste0("X", 1:ncol(x))

  # Randomly shuffle variables
  x <- x[, sample(1:ncol(x))]

  # Pick out subsets
  subsets <- list()
  begin_idx <- 0
  end_idx <- 0
  for (i in seq_along(num_fts)) {
    begin_idx <- end_idx + 1
    end_idx <- end_idx + num_fts[i]

    subsets[[i]] <- data.frame(x[ , begin_idx:end_idx], y = y)
  }

  return(subsets)
}

#' Bootstrapping
#'
#' boostrap function performs bootstrapping over subsets generated by get_rdm_subsets
#' @param subsets list of randomly generated subsets
#' @param boot_rate bootstrapping rate
#' @return bootstrapped subset list
bootstrap <- function(subsets, boot_rate) {

  boot_ind <- list()
  boot_subsets <- list()
  for (i in seq_along(subsets)) {

    # Bootstrap sampling
    num_record <- nrow(subsets[[i]])
    boot_ind[[i]] <- sample(1:num_record, round(boot_rate * num_record), replace = TRUE)

    # Save a bootstrapped subset in 'boot_subsets'
    boot_subsets[[i]] <- subsets[[i]][boot_ind[[i]], ]
    names(boot_subsets)[i] <- paste0("bootstraped", i)
  }

  return(boot_subsets)
}

#' Generate a rotation matrix
#'
#' @param boot_subsets bootstrapped subset list
#' @param x training x of the original data
#' @param y training y of the original data
#' @return matrix constructed by projecting original data X onto a rotation matrix
rot_mtrx <- function(boot_subsets) {

  # Drop the dependent variable
  boot_subsets_rdc <- lapply(boot_subsets, function(x) {x[, !(names(x) == "y")]})

  # Apply PCA
  PCA_comps <- lapply(boot_subsets_rdc, function(x) {prcomp(x, center = TRUE)$rotation})
  PCA_compsT <- lapply(PCA_comps, t)
  PCA_col_names <- as.vector(unlist(lapply(PCA_compsT, colnames)))

  # Generate a block-shaped sparse diagonal matrix
  rot_mtrx <- as.matrix(do.call(Matrix::bdiag, lapply(PCA_compsT, as.matrix)))
  colnames(rot_mtrx) <- PCA_col_names
  rot_mtrx <- rot_mtrx[, sort(PCA_col_names)]

  return(rot_mtrx)
}


# 2. Main part of Rotation Forest --------------------------------------
#' Classification algorithm: Rotation Forest
#'
#' rotation_forest actually performs Rotation Forest algorithm
#'
#' @param x dataframe of independent variables
#' @param y vector of a dependent variable (has to be a factor)
#' @param k positive integer of number of variables in each subset
#' @param boot_rate bootstrapping rate
#' @param ntree
#' @return
rotation_forest <- function(x, y, k, ntree = 15, boot_rate = 0.75) {

  # Check x; numeric
  if (any(sapply(x, is.numeric) == FALSE)) {stop("x has non-numeric column")}

  # Check y; factor
  if (is.factor(y) == FALSE) {stop("y is not a factor")}

  # Check k; a natural number 1 <= k <= ncol(x)
  if (is.integer(k) == FALSE) {stop("k is not an integer")}
  if (k < 1 || k > ncol(x)) {stop("k has to be greater than or equal to 0
                                  and less than or equal to the number of columns of x")}

  trees <- list()
  rot_mtrcs <- list()
  rotated_data <- list()
  for (i in 1:ntree) {

    # Generate a rotation matrix
    subsets <- get_rdm_subset(x, y, k)
    boot_subsets <-  bootstrap(subsets, boot_rate)
    rot_mtrcs[[i]] <- rot_mtrx(boot_subsets)

    # Fit a tree
    rotated_data[[i]] <- data.frame(as.matrix(x) %*% rot_mtrcs[[i]], y = y)
    trees[[i]] <- rpart::rpart(y ~., method = "class", data = rotated_data[[i]])
  }

  rlt <- list(trees = trees, rot_mtrcs = rot_mtrcs, rotated_data = rotated_data)
  class(rlt) <- "rotation_forest"

  return(rlt)
}
haedong31/rotation_forest documentation built on Nov. 4, 2019, 1:26 p.m.