R/search_n_samples.R

Defines functions search_n_samples

Documented in search_n_samples

#' Executes a binary search to find the best # of samples.
#'
#' @description Executes a binary search to find the best # of samples.
#'
#' @param rpart.tree rpart.tree. A Decision tree generated by rpart package.
#' @param lower int. The # of samples to be used in lower_delta calculation.
#' @param upper int. The # of samples to be used in upper_delta calculation.
#' @param epsilon float. The epsilon to be used in the delta calculation.
#'
#' @usage search_n_samples(rpart.tree, lower, upper, epsilon)
#'
#' @return the number of samples needed to ensure learning.
#'
#' @export search_n_samples
search_n_samples <- function(rpart.tree, lower, upper, epsilon){
    mid_samples = (lower+upper)%/%2

    lower_delta = compute_delta(rpart.tree, lower, epsilon)
    upper_delta = compute_delta(rpart.tree, upper, epsilon)
    mid_delta = compute_delta(rpart.tree, mid_samples, epsilon)

    if((lower_delta > epsilon) & (mid_delta <= epsilon) & (lower == mid_samples-1)) {
        return (mid_samples)
    } else if((lower_delta > epsilon) & (mid_delta > epsilon) & (lower == mid_samples-1)){
        return (upper)
    } else if((upper_delta <= epsilon) & (mid_delta <= epsilon)){
        return (search_n_samples(rpart.tree, lower, mid_samples, epsilon))
    } else if((upper_delta <= epsilon) & (mid_delta > epsilon)){
        return (search_n_samples(rpart.tree, mid_samples, upper, epsilon))
    }
}

Try the shatteringdt package in your browser

Any scripts or data that you put into this service are public.

shatteringdt documentation built on March 3, 2021, 9:05 a.m.