# ======================================================================
# Description: R implementation for Rotation Forest
# Author: Haedong Kim
# 2018-08-19: created
# 2018-08-26: prototype
# ======================================================================
# 1. Helper functions --------------------------------------------------
#' Split the data into k subsets
#'
#' @inheritParams rotation_forest
#' @return list of k subsets
get_rdm_subset <- function(x, y, k) {
# Calculate the number of features per each subset
q <- ncol(x) %/% k
r <- ncol(x) %% k
num_fts <- rep(q, k)
if (r != 0) {
for (i in 1:r) {
#### The length of num_fts (k) is always larger than r,
# so this loop ends before the subscript go out of bound
num_fts[i] <- num_fts[i] + 1
r <- r - 1
if (r == 0) {break()}
}
}
# Assign variable names for indentifing the original sequence of variables
colnames(x) <- paste0("X", 1:ncol(x))
# Randomly shuffle variables
x <- x[, sample(1:ncol(x))]
# Pick out subsets
subsets <- list()
begin_idx <- 0
end_idx <- 0
for (i in seq_along(num_fts)) {
begin_idx <- end_idx + 1
end_idx <- end_idx + num_fts[i]
subsets[[i]] <- data.frame(x[ , begin_idx:end_idx], y = y)
}
return(subsets)
}
#' Bootstrapping
#'
#' boostrap function performs bootstrapping over subsets generated by get_rdm_subsets
#' @param subsets list of randomly generated subsets
#' @param boot_rate bootstrapping rate
#' @return bootstrapped subset list
bootstrap <- function(subsets, boot_rate) {
boot_ind <- list()
boot_subsets <- list()
for (i in seq_along(subsets)) {
# Bootstrap sampling
num_record <- nrow(subsets[[i]])
boot_ind[[i]] <- sample(1:num_record, round(boot_rate * num_record), replace = TRUE)
# Save a bootstrapped subset in 'boot_subsets'
boot_subsets[[i]] <- subsets[[i]][boot_ind[[i]], ]
names(boot_subsets)[i] <- paste0("bootstraped", i)
}
return(boot_subsets)
}
#' Generate a rotation matrix
#'
#' @param boot_subsets bootstrapped subset list
#' @param x training x of the original data
#' @param y training y of the original data
#' @return matrix constructed by projecting original data X onto a rotation matrix
rot_mtrx <- function(boot_subsets) {
# Drop the dependent variable
boot_subsets_rdc <- lapply(boot_subsets, function(x) {x[, !(names(x) == "y")]})
# Apply PCA
PCA_comps <- lapply(boot_subsets_rdc, function(x) {prcomp(x, center = TRUE)$rotation})
PCA_compsT <- lapply(PCA_comps, t)
PCA_col_names <- as.vector(unlist(lapply(PCA_compsT, colnames)))
# Generate a block-shaped sparse diagonal matrix
rot_mtrx <- as.matrix(do.call(Matrix::bdiag, lapply(PCA_compsT, as.matrix)))
colnames(rot_mtrx) <- PCA_col_names
rot_mtrx <- rot_mtrx[, sort(PCA_col_names)]
return(rot_mtrx)
}
# 2. Main part of Rotation Forest --------------------------------------
#' Classification algorithm: Rotation Forest
#'
#' rotation_forest actually performs Rotation Forest algorithm
#'
#' @param x dataframe of independent variables
#' @param y vector of a dependent variable (has to be a factor)
#' @param k positive integer of number of variables in each subset
#' @param boot_rate bootstrapping rate
#' @param ntree
#' @return
rotation_forest <- function(x, y, k, ntree = 15, boot_rate = 0.75) {
# Check x; numeric
if (any(sapply(x, is.numeric) == FALSE)) {stop("x has non-numeric column")}
# Check y; factor
if (is.factor(y) == FALSE) {stop("y is not a factor")}
# Check k; a natural number 1 <= k <= ncol(x)
if (is.integer(k) == FALSE) {stop("k is not an integer")}
if (k < 1 || k > ncol(x)) {stop("k has to be greater than or equal to 0
and less than or equal to the number of columns of x")}
trees <- list()
rot_mtrcs <- list()
rotated_data <- list()
for (i in 1:ntree) {
# Generate a rotation matrix
subsets <- get_rdm_subset(x, y, k)
boot_subsets <- bootstrap(subsets, boot_rate)
rot_mtrcs[[i]] <- rot_mtrx(boot_subsets)
# Fit a tree
rotated_data[[i]] <- data.frame(as.matrix(x) %*% rot_mtrcs[[i]], y = y)
trees[[i]] <- rpart::rpart(y ~., method = "class", data = rotated_data[[i]])
}
rlt <- list(trees = trees, rot_mtrcs = rot_mtrcs, rotated_data = rotated_data)
class(rlt) <- "rotation_forest"
return(rlt)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.