R/bagging.R

#' Bagging
#'
#' An implmentation of the bagging ensemble method.
#'
#' @examples
#' clf <- Bagging$new(10)
#' clf$train(100, monks1_train[, -c(1,8)], monks1_train[,1])
#' preds <- clf$predict(monks1_test[, -c(1,8)])
#' accuracy <- sum(preds == monks1_test[,1])/nrow(monks1_test)
#'
#' @export

Bagging <- R6::R6Class("Bagging",

  public = list(

    models = list(),

    initialize = function(x = 10, ml_class = ID3) {
      for(i in 1:x) {
        self$models <- c(ml_class$new(), self$models)
      }
    },

    train = function(m, features, labels) {
      stopifnot(m <= nrow(features))
      for(model in self$models) {
        sampled_rows <- sample(c(1:nrow(features)), size=m, replace=F)
        model$train(features[sampled_rows,], labels[sampled_rows])
      }
    },

    predict = function(features) {
      if(is.vector(features)) {
        preds <- c()
        for(model in self$models) {
          preds <- c(model$predict(features[r,]), preds)
        }
        names(sort(table(preds), decreasing=TRUE))[1]
      } else {
        all_preds <- c()
        for(r in 1:nrow(features)) {
          preds <- c()
          for(model in self$models) {
            preds <- c(model$predict(features[r,]), preds)
          }
          all_preds <- c(all_preds, names(sort(table(preds), decreasing=TRUE))[1])
        }
        all_preds
      }
    }

  )

)
msats5/capstoneProject documentation built on May 18, 2019, 12:27 p.m.