R/k_means.R

Defines functions aml_k_means .run_k_means_until_convergence .test_k_means_input .find_initial_assignments .calculate_centroids .calculate_distances_from_centroids .calculate_new_labels .create_k_means_output_object plot.aml_k_means .make_ggplot_object .make_ggplot_object.aml_k_means .convert_to_ggname

Documented in aml_k_means plot.aml_k_means

#' AML k-Means
#'
#' Calculates k means predictions for given data. It is assumed that all
#' given columns will be used in calculation.
#'
#' @param data Input n x p sized data.frame of numeric data
#' @param k Number of clusters to fit
#' @param maximum_iterations Maximum number of iterations to allow while trying
#' to converge
#' @return Results object of class aml_k_means
#' @export
aml_k_means = function(data, k, maximum_iterations = 1e6){
    .test_k_means_input(data, k)

    results = .run_k_means_until_convergence(data, k, maximum_iterations)

    .create_k_means_output_object(results)
}


.run_k_means_until_convergence = function(data, k, maximum_iterations){
    new_labels = .find_initial_assignments(nrow(data), k)
    labels = rep(0, nrow(data))
    roids = .calculate_centroids(data, labels)
    iter = 0

    while(!all(labels == new_labels) & iter < maximum_iterations){
        labels = new_labels
        roids = .calculate_centroids(data, new_labels)
        distances = .calculate_distances_from_centroids(data, roids)
        new_labels = .calculate_new_labels(distances)
        iter = iter + 1
    }

    if(iter == maximum_iterations){
        warning(paste("Convergence was not met, stopped at iteration", 
                      maximum_iterations))
    }

    list(new_labels = new_labels, 
         data = data, 
         iter = iter, 
         roids = roids, 
         k = k)
}

.test_k_means_input = function(data, k){
    if(k <= 0){
        stop("Number of groups k must be positive!")
    }
    if(as.integer(k) != k){
        stop("Number of groups k must be an integer!")
    }
    if(!is.data.frame(data)){
        stop("Data must be a data.frame!")
    }
    if(nrow(data) <= 0){
        stop("Data must have a positive number of rows!")
    }
    if(!all(sapply(data, is.numeric))){
        stop("All columns of data must be numeric!")
    }
}

.find_initial_assignments = function(n, k, seed = 1337){
    set.seed(seed)
    sample(1:k, n, replace = TRUE)
}

.calculate_centroids = function(data, labels){
    centroids = sapply(data, function(column){
        tapply(column, as.factor(labels), mean)
    })
    data.frame(centroids)
}

.calculate_distances_from_centroids = function(data, centroids){
    distances = list()
    n = nrow(data)
    for(roid in 1:nrow(centroids)){
        distances_for_roid = numeric(n)
        for(i in 1:n){
            distances_for_roid[i] = dist(rbind(centroids[roid,], data[i,]))
        }
        distances[[roid]] = distances_for_roid
    }
    distances
}

.calculate_new_labels = function(distances){
    matrix_of_distances = do.call(cbind, distances)
    apply(matrix_of_distances, 1, which.min)
}

.create_k_means_output_object = function(results){
    output = list(labels = results$new_labels, 
                  iterations = results$iter,
                  centroids = results$roids, 
                  k = results$k,
                  data = results$data)
    output = .prepend_class(output, "aml_k_means")
    output
}

#' AML k-means default plot function
#'
#' Plots a basic k-means scatterplot graph based on given results. Will plot a 
#' scatterplot of the first two principal components if data has more than 2 
#' columns.
#'
#' @param results_object Results object returned by aml_k_means function 
#' @param plot_centroids Logical to plot centroids along with points
#' @param tsne_reduction Logical to use t-SNE for dimensionality reduction
#' @return NULL
#' @export
plot.aml_k_means = function(results_object, 
                            plot_centroids = FALSE,
                            tsne_reduction = FALSE){
    ggplot_object = .make_ggplot_object(results_object, 
                                        plot_centroids, 
                                        tsne_reduction)
    plot(ggplot_object)
}

.make_ggplot_object = function(object, plot_centroids, tsne_reduction){
    UseMethod(".make_ggplot_object", object)
}

.make_ggplot_object.aml_k_means = function(results_object, 
                                           plot_centroids = FALSE,
                                           tsne_reduction = FALSE){
    if(ncol(results_object$data) == 2){
        column_names = names(results_object$data)
        plotting_data = data.frame(results_object$data, 
                                   Labels = factor(results_object$labels))
    }else{
        if(tsne_reduction){
            column_names = c("TSNE I",
                             "TSNE II")
            tsne_object = Rtsne(results_object$data, check_duplicates = FALSE)
            plotting_data = data.frame(tsne_object$Y, 
                                       factor(results_object$labels))
            names(plotting_data) = c(column_names, "Labels")
        }else{
            column_names = c("Principal Component I",
                             "Principal Component II")
            prcomp_object = prcomp(results_object$data, 
                                   center = TRUE, 
                                   scale = TRUE)
            plotting_data = data.frame(prcomp_object$x[, 1:2], 
                                       factor(results_object$labels))
            names(plotting_data) = c(column_names, "Labels")
        }
    }
    
    ggplot_object = ggplot(plotting_data, 
                        aes_string(x = .convert_to_ggname(column_names[1]), 
                                   y = .convert_to_ggname(column_names[2]),
                                   col = "Labels",
                                   shape = "Labels")) + 
                        geom_point(aes(shape = Labels), 
                                   size = 2.1, 
                                   col = "grey") +
                        geom_point() +
                        scale_color_viridis(discrete = TRUE) 

    if(plot_centroids){
        centroids = aggregate(as.matrix(plotting_data[, column_names]) ~ Labels, 
                              plotting_data, 
                              mean)
        ggplot_object = ggplot_object + 
                            geom_point(data = centroids, 
                                       size = 6, 
                                       col = "grey") + 
                            geom_point(data = centroids, size = 5)

    }

    ggplot_object
}

.convert_to_ggname = function(x){
    if(class(x) != "character") {
        return(x)
    }
    y = sapply(x, function(s){
        if (!grepl("^`", s)){
            s = paste("`", s, sep="", collapse="")
        }
        if (!grepl("`$", s)){
            s = paste(s, "`", sep="", collapse="")
        }
    })
    y 
}
jmwerner/ArtisanalMachineLearning documentation built on Jan. 7, 2020, 1:50 a.m.