R/less-regressor.R

#' @title LESSBase
#'
#' @description The base class for LESSRegressor and LESSClassifier
#'
#' @param isFitted Flag to check whether LESS is fitted
#' @param replications List to store the replications
#' @param scobject Scaling object used for normalization (less::StandardScaler)
#'
#' @return R6 class of LESSBase
LESSBase <- R6::R6Class(classname = "LESSBase",
                      inherit = SklearnEstimator,
                      private = list(
                        isFitted = FALSE,
                        replications = NULL,
                        scobject = NULL,
                        set_local_attributes = function() {
                          if(is.null(private$local_estimator)){
                            stop("\tLESS does not work without a local estimator.")
                          }

                          if(is_classifier(private$local_estimator)){
                            LESSWarn$new("\tLESS might work with local classifiers.\n\tHowever, we recommend using regressors as the local estimators.",
                                         private$warnings)
                          }

                          if(getClassName(self) == "LESSRegressor" & is_classifier(private$global_estimator)){
                            LESSWarn$new("\tLESSRegressor might work with a global classifier.\n\tHowever, we recommend using a regressor as the global estimator.",
                                         private$warnings)
                          }

                          if(getClassName(self) == "LESSClassifier" & is_regressor(private$global_estimator)){
                            LESSWarn$new("\tLESSClassifier might work with a global regressor.\n\tHowever, we recommend using a classifier as the global estimator.",
                                         private$warnings)
                          }

                          if(!is.null(private$val_size)) {
                            if(private$val_size <= 0.0 | private$val_size >= 1.0){
                              stop("\tParameter val_size should be in the interval (0, 1).")
                            }
                          }

                          if(!is.null(private$frac)) {
                            if(private$frac <= 0.0 | private$frac > 1.0){
                              stop("\tParameter frac should be in the interval (0, 1).")
                            }
                          }

                          if(private$n_replications < 1){
                            stop("\tThe number of replications should be greater than or equal to one.")
                          }

                          if(length(private$cluster_method) != 0){ #length of NULL is zero. if it is not a null environment(class), length is not zero
                            if(!is.null(private$frac) | !is.null(private$n_neighbors) | !is.null(private$n_subsets)){
                              LESSWarn$new("\tParameter cluster_method overrides parameters frac, n_neighbors and n_subsets.\n\tProceeding with clustering...",
                                           private$warnings)
                              private$frac <- NULL
                              private$n_neighbors <- NULL
                            }

                            # Different numbers of subsets may be generated by the clustering method
                            private$n_subsets <- list()

                            if('n_clusters' %in% private$cluster_method$get_all_fields()){
                              if(private$cluster_method$get_attributes()$n_cluster == 1){
                                LESSWarn$new("\tThere is only one cluster, so the global estimator is set to NULL.",
                                             private$warnings)
                                private$global_estimator <- NULL
                                private$d_normalize <- TRUE
                                # If there is also no validation step, then there is
                                # no randomness. So, no need for replications.
                                if(is.null(private$val_size)){
                                  LESSWarn$new("\tSince validation set is not used, there is no randomness.\n\tThus, the number of replications is set to one.",
                                               private$warnings)
                                  private$n_replications <- 1
                                }
                              }
                            }
                          }else if(is.null(private$frac) &
                                   is.null(private$n_neighbors) &
                                   is.null(private$n_subsets)){
                            private$frac <- 0.05
                          }

                          # When there is no global estimator, the scaling should be set to FALSE
                          if (is.null(private$global_estimator)){
                            private$scaling <- FALSE
                          }
                        },

                        check_input = function(len_X) {
                          if(length(private$cluster_method) == 0){
                            if(!is.null(private$frac)){
                              private$n_neighbors <- as.integer(ceiling(private$frac * len_X))
                              private$n_subsets <- as.integer(len_X/private$n_neighbors)
                            }

                            if(is.null(private$n_subsets)){
                              private$n_subsets <- as.integer(len_X/private$n_neighbors)
                            }

                            if(is.null(private$n_neighbors)){
                              private$n_neighbors <- as.integer(len_X/private$n_subsets)
                            }

                            if(private$n_neighbors > len_X){
                              LESSWarn$new("\tThe number of neighbors is larger than the number of samples. \n\tSetting number of subsets to one.",
                                           private$warnings)
                              private$n_neighbors <- len_X
                              private$n_subsets <- 1
                            }

                            if(private$n_subsets > len_X){
                              LESSWarn$new("\tThe number of subsets is larger than the number of samples. \n\tSetting number of neighbors to one.",
                                           private$warnings)
                              private$n_neighbors <- 1
                              private$n_subsets <- len_X
                            }

                            if(private$n_subsets == 1){
                              LESSWarn$new("\tThere is only one subset, so the global estimator is set to NULL",
                                           private$warnings)
                              private$global_estimator <- NULL
                              private$d_normalize <- TRUE
                              # If there is also no validation step, then there is
                              # no randomness. So, no need for replications.
                              if(is.null(private$val_size)){
                                LESSWarn$new("\tSince validation set is not used, there is no randomness. \n\tThus, the number of replications is set to one.",
                                             private$warnings)
                                private$n_replications <- 1
                              }
                            }
                          }
                        },

                        fitnoval = function(X, y) {
                          # Fit function: All data is used with the global estimator (no validation)
                          # Tree method is used (no clustering)
                          len_X <- length(y)
                          private$check_input(len_X)
                          tree <- private$tree_method(X)
                          private$replications <- list()
                          for (i in 1:private$n_replications) {
                            sample_indices <- private$rng$choice(range = len_X, size = private$n_subsets)
                            nearest_neighbors <- tree$query(X[sample_indices,], private$n_neighbors)
                            neighbor_indices_list <- nearest_neighbors[[1]]

                            local_models <- list()
                            dists <- matrix(0, len_X, private$n_subsets)
                            predicts <- matrix(0, len_X, private$n_subsets)

                            for (i in 1:nrow(neighbor_indices_list)) {
                              Xneighbors <- as.matrix(X[neighbor_indices_list[i, ],])
                              yneighbors <- as.matrix(y[neighbor_indices_list[i, ]])
                              if(nrow(yneighbors) == 1){
                                # if there is only one sample in a group,
                                # prevent Xneighbors being a (n,1) dimensional matrix
                                Xneighbors <- t(Xneighbors)
                              }

                              # Centroid is used as the center of the local sample set
                              local_center <- colMeans(Xneighbors)

                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$local_estimator$get_all_fields())) {
                                # set random state to an integer from rng
                                private$local_estimator$set_random_state(private$rng$integers(32767))
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }else{
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }
                              local_models <- append(local_models, LocalModel$new(estimator = local_model, center = local_center))

                              predicts[,i] <- local_model$predict(X)
                              if(is.null(c(private$distance_function))) {
                                dists[,i] <- rbf(X, local_center, 1.0/(private$n_subsets ^ 2.0))
                              }else {
                                dists[,i] <- private$distance_function(X, local_center)
                              }
                            }

                            # Normalize the distances from samples to the local subsets
                            if(private$d_normalize) {
                              denom <- rowSums(dists)
                              denom[denom < 1e-08] <- 1e-08
                              dists <- dists/denom
                            }

                            Z <- dists * predicts
                            scobject <- StandardScaler$new()
                            if(private$scaling){
                              Z <- scobject$fit_transform(Z)
                            }

                            # if(Reduce('|', is.null(private$global_estimator)))
                            if(length(private$global_estimator) != 0){ #for a null environment, the length is 0
                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$global_estimator$get_all_fields())){
                                private$global_estimator$set_random_state(private$rng$integers(32767))
                                global_model <- private$global_estimator$fit(Z, y)$clone()
                              }else{
                                global_model <- private$global_estimator$fit(Z, y)$clone()
                              }
                            }
                            else{
                              global_model <- NULL
                            }
                            private$replications <- append(private$replications, Replication$new(local_estimators = local_models,
                                                                                                 sc_object = scobject,
                                                                                                 global_estimator = global_model))
                          }

                          invisible(self)
                        },

                        fitval = function(X, y) {
                          # Fit function: (val_size x data) is used for the global estimator (validation)
                          # Tree method is used (no clustering)

                          private$replications <- list()
                          for (i in 1:private$n_replications) {
                            #Split for global estimation
                            split_list <- train_test_split(cbind(X, y), test_size =  private$val_size,
                                                           random_state = private$rng$integers(32767))
                            X_train <- split_list[[1]]
                            X_val <- split_list[[2]]
                            y_train <- split_list[[3]]
                            y_val <- split_list[[4]]

                            len_X_val <- length(y_val)
                            len_X_train <- length(y_train)
                            # Check the validity of the input
                            if(i == 1){
                              private$check_input(len_X_train)
                            }

                            # A nearest neighbor tree is grown for querying
                            tree <- private$tree_method(X_train)

                            # Select n_subsets many samples to construct the local sample sets
                            sample_indices <- private$rng$choice(range = len_X_train, size = private$n_subsets)
                            # Construct the local sample sets
                            nearest_neighbors <- tree$query(X[sample_indices,], private$n_neighbors)
                            neighbor_indices_list <- nearest_neighbors[[1]]

                            local_models <- list()
                            dists <- matrix(0, len_X_val, private$n_subsets)
                            predicts <- matrix(0, len_X_val, private$n_subsets)

                            for (i in 1:nrow(neighbor_indices_list)) {
                              Xneighbors <- as.matrix(X_train[neighbor_indices_list[i, ],])
                              yneighbors <- as.matrix(y_train[neighbor_indices_list[i, ]])
                              if(nrow(yneighbors) == 1){
                                # if there is only one sample in a group,
                                # prevent Xneighbors being a (n,1) dimensional matrix
                                Xneighbors <- t(Xneighbors)
                              }

                              # Centroid is used as the center of the local sample set
                              local_center <- colMeans(Xneighbors)

                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$local_estimator$get_all_fields())) {
                                # set random state to an integer from rng
                                private$local_estimator$set_random_state(private$rng$integers(32767))
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }else{
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }
                              local_models <- append(local_models, LocalModel$new(estimator = local_model, center = local_center))

                              predicts[,i] <- local_model$predict(X_val)
                              if(is.null(c(private$distance_function))) {
                                dists[,i] <- rbf(X_val, local_center, 1.0/(private$n_subsets ^ 2.0))
                              }else {
                                dists[,i] <- private$distance_function(X, local_center)
                              }
                            }

                            # Normalize the distances from samples to the local subsets
                            if(private$d_normalize) {
                              denom <- rowSums(dists)
                              denom[denom < 1e-08] <- 1e-08
                              dists <- dists/denom
                            }

                            Z <- dists * predicts
                            scobject <- StandardScaler$new()
                            if(private$scaling){
                              Z <- scobject$fit_transform(Z)
                            }

                            # if(Reduce('|', is.null(private$global_estimator)))
                            if(length(private$global_estimator) != 0){ #for a null environment, the length is 0
                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$global_estimator$get_all_fields())){
                                private$global_estimator$set_random_state(private$rng$integers(32767))
                                global_model <- private$global_estimator$fit(Z, y_val)$clone()
                              }else{
                                global_model <- private$global_estimator$fit(Z, y_val)$clone()
                              }
                            }
                            else{
                              global_model <- NULL
                            }
                            private$replications <- append(private$replications, Replication$new(local_estimators = local_models,
                                                                                                 sc_object = scobject,
                                                                                                 global_estimator = global_model))
                          }
                          invisible(self)
                        },

                        fitnovalc = function(X, y){
                          # Fit function: All data is used for the global estimator (no validation)
                          # Clustering is used (no tree method)

                          len_X <- length(y)
                          # Check the validity of the input
                          private$check_input(len_X)

                          # if the cluster method does not have parameter named 'random_state'
                          if(!('random_state' %in% (private$cluster_method$get_all_fields()))){
                            LESSWarn$new("\tClustering method is not random,
                            \tso there is no need for replications unless validaton set is used.
                            \tThe number of replications is set to one.", private$warnings)
                            private$n_replications <- 1
                          }

                          if(private$n_replications == 1){
                            cluster_fit <- private$cluster_method$fit(X)
                          }

                          private$replications <- list()
                          for (i in 1:private$n_replications) {

                            if(private$n_replications > 1){
                              cluster_fit <- private$cluster_method$
                                set_random_state(private$rng$integers(32767))$
                                fit(X)
                            }

                            unique_labels <- unique(cluster_fit$get_labels())
                            # Some clustering methods may find less number of clusters than requested 'n_clusters'
                            private$n_subsets <- append(private$n_subsets, length(unique_labels))
                            n_subsets <- private$n_subsets[[i]]

                            local_models <- list()
                            dists <- matrix(0, len_X, n_subsets)
                            predicts <- matrix(0, len_X, n_subsets)

                            if(!is.null(cluster_fit$get_cluster_centers())){
                              use_cluster_centers <- TRUE
                            }else{
                              use_cluster_centers <- FALSE
                            }

                            for (cluster_indx in 1:length(unique_labels)) {
                              neighbor_indices <- cluster_fit$get_labels() == unique_labels[[cluster_indx]]
                              Xneighbors <- as.matrix(X[neighbor_indices, ])
                              yneighbors <- as.matrix(y[neighbor_indices])
                              if(nrow(yneighbors) == 1){
                                # if there is only one sample in a group,
                                # prevent Xneighbors being a (n,1) dimensional matrix
                                Xneighbors <- t(Xneighbors)
                              }

                              # Centroid is used as the center of the local sample set
                              if(use_cluster_centers){
                                local_center <- cluster_fit$get_cluster_centers()[cluster_indx, ]
                              }else{
                                local_center <- colMeans(Xneighbors)
                              }

                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$local_estimator$get_all_fields())) {
                                # set random state to an integer from rng
                                private$local_estimator$set_random_state(private$rng$integers(32767))
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }else{
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }
                              local_models <- append(local_models, LocalModel$new(estimator = local_model, center = local_center))

                              predicts[, cluster_indx] <- local_model$predict(X)
                              if(is.null(c(private$distance_function))) {
                                dists[, cluster_indx] <- rbf(X, local_center, 1.0/(n_subsets ^ 2.0))
                              }else {
                                dists[, cluster_indx] <- private$distance_function(X, local_center)
                              }
                            }

                            # Normalize the distances from samples to the local subsets
                            if(private$d_normalize) {
                              denom <- rowSums(dists)
                              denom[denom < 1e-08] <- 1e-08
                              dists <- dists/denom
                            }

                            Z <- dists * predicts
                            scobject <- StandardScaler$new()
                            if(private$scaling){
                              Z <- scobject$fit_transform(Z)
                            }

                            # if(Reduce('|', is.null(private$global_estimator)))
                            if(length(private$global_estimator) != 0){ #for a null environment, the length is 0
                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$global_estimator$get_all_fields())){
                                private$global_estimator$set_random_state(private$rng$integers(32767))
                                global_model <- private$global_estimator$fit(Z, y)$clone()
                              }else{
                                global_model <- private$global_estimator$fit(Z, y)$clone()
                              }
                            }
                            else{
                              global_model <- NULL
                            }
                            private$replications <- append(private$replications, Replication$new(local_estimators = local_models,
                                                                                                 sc_object = scobject,
                                                                                                 global_estimator = global_model))
                          }

                          invisible(self)
                        },

                        fitvalc = function(X, y){
                          # Fit function: (val_size x data) is used for the global estimator (validation)
                          # Clustering is used (no tree method)

                          private$replications <- list()
                          for (i in 1:private$n_replications){
                            # Split for global estimation
                            split_list <- train_test_split(cbind(X, y), test_size =  private$val_size,
                                                           random_state = private$rng$integers(32767))
                            X_train <- split_list[[1]]
                            X_val <- split_list[[2]]
                            y_train <- split_list[[3]]
                            y_val <- split_list[[4]]

                            len_X_val <- length(y_val)
                            len_X_train <- length(y_train)
                            # Check the validity of the input
                            if(i == 1){
                              private$check_input(len_X_train)
                            }

                            if('random_state' %in% (private$cluster_method$get_all_fields())){
                              cluster_fit <- private$cluster_method$
                                set_random_state(private$rng$integers(32767))$
                                fit(X_train)
                            }else{
                              cluster_fit <- private$cluster_method$fit(X_train)
                            }

                            if(i == 1){
                              if(!is.null(cluster_fit$get_cluster_centers())){
                                use_cluster_centers <- TRUE
                              }else{
                                use_cluster_centers <- FALSE
                              }
                            }

                            unique_labels <- unique(cluster_fit$get_labels())
                            # Some clustering methods may find less number of clusters than requested 'n_clusters'
                            private$n_subsets <- append(private$n_subsets, length(unique_labels))
                            n_subsets <- private$n_subsets[[i]]

                            local_models <- list()
                            dists <- matrix(0, len_X_val, n_subsets)
                            predicts <- matrix(0, len_X_val, n_subsets)

                            for (cluster_indx in 1:length(unique_labels)){
                              neighbor_indices <- cluster_fit$get_labels() == unique_labels[[cluster_indx]]
                              Xneighbors <- as.matrix(X_train[neighbor_indices, ])
                              yneighbors <- as.matrix(y_train[neighbor_indices])
                              if(nrow(yneighbors) == 1){
                                # if there is only one sample in a group,
                                # prevent Xneighbors being a (n,1) dimensional matrix
                                Xneighbors <- t(Xneighbors)
                              }

                              # Centroid is used as the center of the local sample set
                              if(use_cluster_centers){
                                local_center <- cluster_fit$get_cluster_centers()[cluster_indx, ]
                              }else{
                                local_center <- colMeans(Xneighbors)
                              }

                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$local_estimator$get_all_fields())) {
                                # set random state to an integer from rng
                                private$local_estimator$set_random_state(private$rng$integers(32767))
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }else{
                                local_model <- private$local_estimator$fit(Xneighbors, yneighbors)$clone()
                              }
                              local_models <- append(local_models, LocalModel$new(estimator = local_model, center = local_center))

                              predicts[, cluster_indx] <- local_model$predict(X_val)
                              if(is.null(c(private$distance_function))) {
                                dists[, cluster_indx] <- rbf(X_val, local_center, 1.0/(n_subsets ^ 2.0))
                              }else {
                                dists[, cluster_indx] <- private$distance_function(X_val, local_center)
                              }
                            }

                            # Normalize the distances from samples to the local subsets
                            if(private$d_normalize) {
                              denom <- rowSums(dists)
                              denom[denom < 1e-08] <- 1e-08
                              dists <- dists/denom
                            }

                            Z <- dists * predicts
                            scobject <- StandardScaler$new()
                            if(private$scaling){
                              Z <- scobject$fit_transform(Z)
                            }

                            # if(Reduce('|', is.null(private$global_estimator)))
                            if(length(private$global_estimator) != 0){ #for a null environment, the length is 0
                              #if random_state is one of the estimator's parameter
                              if('random_state' %in% (private$global_estimator$get_all_fields())){
                                private$global_estimator$set_random_state(private$rng$integers(32767))
                                global_model <- private$global_estimator$fit(Z, y_val)$clone()
                              }else{
                                global_model <- private$global_estimator$fit(Z, y_val)$clone()
                              }
                            }
                            else{
                              global_model <- NULL
                            }
                            private$replications <- append(private$replications, Replication$new(local_estimators = local_models,
                                                                                                 sc_object = scobject,
                                                                                                 global_estimator = global_model))
                          }

                          invisible(self)
                        }
                      ),
                      public = list(
                        #' @description Creates a new instance of R6 Class of LESSBase
                        initialize = function(replications = NULL, scobject = NULL, isFitted = FALSE) {
                          private$replications = replications
                          private$scobject = scobject
                          private$isFitted = isFitted
                        },

                        #' @description Auxiliary function that sets random state attribute of the self class
                        #'
                        #' @param random_state seed number to be set as random state
                        #' @return self
                        set_random_state = function(random_state) {
                          private$random_state <- random_state
                          private$rng = RandomGenerator$new(random_state = private$random_state)
                          invisible(self)
                        },

                        #' @description  Auxiliary function returning the number of subsets
                        get_n_subsets = function(){
                          return(private$n_subsets)
                        },

                        #' @description Auxiliary function returning the number of neighbors
                        get_n_neighbors = function(){
                          return(private$n_neighbors)
                        },

                        #' @description Auxiliary function returning the percentage of samples used to set the number of neighbors
                        get_frac = function(){
                          return(private$frac)
                        },

                        #' @description Auxiliary function returning the number of replications
                        get_n_replications = function(){
                          return(private$n_replications)
                        },

                        #' @description Auxiliary function returning the flag for normalization
                        get_d_normalize = function(){
                          return(private$d_normalize)
                        },

                        #' @description Auxiliary function returning the flag for scaling
                        get_scaling = function(){
                          return(private$scaling)
                        },

                        #' @description Auxiliary function returning the validation set size
                        get_val_size = function(){
                          return(private$val_size)
                        },

                        #' @description Auxiliary function returning the random seed
                        get_random_state = function(){
                          return(private$random_state)
                        },

                        #' @description Auxiliary function returning the isFitted flag
                        get_isFitted = function(){
                          return(private$isFitted)
                        },

                        #' @description Auxiliary function returning the isFitted flag
                        get_replications = function(){
                          return(private$replications)
                        }
                      )
                    )

#' @title  LESSRegressor
#'
#' @description Regressor for Learning with Subset Stacking (LESS)
#'
#' @param frac fraction of total samples used for the number of neighbors (default is 0.05)
#' @param n_neighbors number of neighbors (default is NULL)
#' @param n_subsets number of subsets (default is NULL)
#' @param n_replications number of replications (default is 20)
#' @param d_normalize distance normalization (default is TRUE)
#' @param val_size percentage of samples used for validation (default is NULL - no validation)
#' @param random_state initialization of the random seed (default is NULL)
#' @param tree_method method used for constructing the nearest neighbor tree, e.g., less::KDTree (default)
#' @param cluster_method method used for clustering the subsets, e.g., less::KMeans (default is NULL)
#' @param local_estimator estimator for the local models (default is less::LinearRegression)
#' @param global_estimator estimator for the global model (default is less::DecisionTreeRegressor)
#' @param distance_function distance function evaluating the distance from a subset to a sample,
#' e.g., df(subset, sample) which returns a vector of distances (default is RBF(subset, sample, 1.0/n_subsets^2))
#' @param scaling flag to normalize the input data (default is TRUE)
#' @param warnings flag to turn on (TRUE) or off (FALSE) the warnings (default is TRUE)
#'
#' @return R6 class of LESSRegressor
#' @seealso [LESSBase]
#' @export
LESSRegressor <- R6::R6Class(classname = "LESSRegressor",
                             inherit = LESSBase,
                             private = list(
                               estimator_type = "regressor",
                               frac = NULL,
                               n_neighbors = NULL,
                               n_subsets = NULL,
                               n_replications = NULL,
                               d_normalize = NULL,
                               val_size = NULL,
                               random_state = NULL,
                               tree_method = NULL,
                               cluster_method = NULL,
                               local_estimator = NULL,
                               global_estimator = NULL,
                               distance_function = NULL,
                               scaling = NULL,
                               warnings = NULL,
                               rng = NULL
                             ),
                             public = list(
                               #' @description Creates a new instance of R6 Class of LESSRegressor
                               #'
                               #' @examples
                               #' lessRegressor <- LESSRegressor$new()
                               #' lessRegressor <- LESSRegressor$new(val_size = 0.3)
                               #' lessRegressor <- LESSRegressor$new(cluster_method = less::KMeans$new())
                               #' lessRegressor <- LESSRegressor$new(val_size = 0.3, cluster_method = less::KMeans$new())
                               initialize = function(frac = NULL, n_neighbors = NULL, n_subsets = NULL, n_replications = 20, d_normalize = TRUE, val_size = NULL,
                                                     random_state = NULL, tree_method = function(X) KDTree$new(X), cluster_method = NULL,
                                                     local_estimator = LinearRegression$new(), global_estimator = DecisionTreeRegressor$new(), distance_function = NULL,
                                                     scaling = TRUE, warnings = TRUE) {
                                 private$frac = frac
                                 private$n_replications = n_replications
                                 private$random_state = random_state
                                 private$n_subsets = n_subsets
                                 private$n_neighbors = n_neighbors
                                 private$local_estimator = local_estimator
                                 private$d_normalize = d_normalize
                                 private$global_estimator = global_estimator
                                 private$scaling = scaling
                                 private$cluster_method = cluster_method
                                 private$distance_function = distance_function
                                 private$rng = RandomGenerator$new(random_state = private$random_state)
                                 private$warnings = warnings
                                 private$val_size = val_size
                                 private$tree_method = tree_method
                               },

                               #' @description
                               #' Dummy fit function that calls the proper method according to validation and clustering parameters
                               #' Options are:
                               #' - Default fitting (no validation set, no clustering)
                               #' - Fitting with validation set (no clustering)
                               #' - Fitting with clustering (no) validation set)
                               #' - Fitting with validation set and clustering
                               #'
                               #' @param X 2D matrix or dataframe that includes predictors
                               #' @param y 1D vector or (n,1) dimensional matrix/dataframe that includes response variables
                               #'
                               #' @return Fitted R6 Class of LESSRegressor
                               #'
                               #' @examples
                               #' data(abalone)
                               #' split_list <- train_test_split(abalone[1:100,], test_size =  0.3)
                               #' X_train <- split_list[[1]]
                               #' X_test <- split_list[[2]]
                               #' y_train <- split_list[[3]]
                               #' y_test <- split_list[[4]]
                               #'
                               #' lessRegressor <- LESSRegressor$new()
                               #' lessRegressor$fit(X_train, y_train)
                               fit = function(X, y){

                                 # Check that X and y have correct shape
                                 X_y_list <- check_X_y(X, y)
                                 X <- X_y_list[[1]]
                                 y <- X_y_list[[2]]
                                 private$set_local_attributes()

                                 if(private$scaling){
                                   private$scobject <- StandardScaler$new()
                                   X <- private$scobject$fit_transform(X)
                                 }

                                 if(!is.null(private$val_size)){
                                   # Validation set is used for global estimation
                                   if(length(private$cluster_method) == 0){
                                     private$fitval(X, y)
                                   }
                                   else{
                                     private$fitvalc(X, y)
                                   }
                                 }
                                 else{
                                   # Validation set is not used for global estimation
                                   if(length(private$cluster_method) == 0){
                                     private$fitnoval(X, y)
                                   }
                                   else{
                                     private$fitnovalc(X, y)
                                   }
                                 }
                                 private$isFitted <- TRUE
                                 invisible(self)
                               },
                               #' @description
                               #' Predictions are evaluated for the test samples in X0
                               #'
                               #' @param X0 2D matrix or dataframe that includes predictors
                               #'
                               #' @return Predicted values of the given predictors
                               #'
                               #' @examples
                               #' preds <- lessRegressor$predict(X_test)
                               #' print(head(matrix(c(y_test, preds), ncol = 2, dimnames = (list(NULL, c("True", "Prediction"))))))
                               predict = function(X0) {

                                 check_is_fitted(self)
                                 # Input validation
                                 check_matrix(X0)

                                 if(private$scaling){
                                   X0 <- private$scobject$transform(X0)
                                 }

                                 len_X0 <- nrow(X0)
                                 yhat <- matrix(0, len_X0, 1)

                                 for (i in 1:private$n_replications) {
                                   # Get the fitted global and local estimators
                                   global_model <- private$replications[[i]]$global_estimator
                                   local_models <- private$replications[[i]]$local_estimators

                                   if(length(private$cluster_method) == 0){
                                     n_subsets <- private$n_subsets
                                   }else{
                                     n_subsets <- private$n_subsets[[i]]
                                   }

                                   dists <- matrix(0, len_X0, n_subsets)
                                   predicts <- matrix(0, len_X0, n_subsets)
                                   for(j in 1:n_subsets){
                                     local_center <- local_models[[j]]$center
                                     local_model <- local_models[[j]]$estimator
                                     predicts[, j] <- local_model$predict(X0)

                                     if(is.null(c(private$distance_function))) {
                                       dists[, j] <- rbf(X0, local_center, 1.0/(n_subsets ^ 2.0))
                                     }else {
                                       dists[, j] <- private$distance_function(X0, local_center)
                                     }
                                   }

                                   # Normalize the distances from samples to the local subsets
                                   if(private$d_normalize) {
                                     denom <- rowSums(dists)
                                     denom[denom < 1e-08] <- 1e-08
                                     dists <- dists/denom
                                   }

                                   Z0 <- dists * predicts
                                   if(private$scaling){
                                     Z0 <- private$replications[[i]]$sc_object$transform(Z0)
                                   }

                                   if(length(global_model) != 0){
                                     yhat <- yhat + global_model$predict(Z0)
                                   }else{
                                     yhat <- yhat + rowSums(Z0)
                                   }
                                 }

                                 yhat <- yhat/private$n_replications

                                 return(yhat)
                               },
                               #' @description Auxiliary function returning the estimator type e.g 'regressor', 'classifier'
                               #'
                               #' @examples
                               #' lessRegressor$get_estimator_type()
                               get_estimator_type = function() {
                                 return(private$estimator_type)
                               }
                             ))

Try the less package in your browser

Any scripts or data that you put into this service are public.

less documentation built on Sept. 27, 2022, 5:05 p.m.