branchmodel
source code#' Return an empty branchmodel object. #' #' @slot raw_data: One row is one datum. Stores coords in some space. #' @slot dist_df: One row is one datum. Stores distances from each branch of the latent 'Y' shape (three columns). #' @slot assignments: length is number of data points. Entries give rows of `tips` corresponding to the closest branch. #' @slot tips: Matrix with one row per branch (three total). One column per variable (same dimension as `raw_data`). #' @slot models: List with one element per branch (three total). Currently supports only `principal.curve` objects from the `princurve` package. #' @slot center: atomic of the same dimension as `raw_data[1, ]`. new_branchmodel_helper = setClass( "branchmodel", slots = c( raw_data = "data.frame", dist_df = "data.frame", assignments = "integer", tips = "matrix", tip_indices = "integer", models = "list", center = "numeric" ) ) #' Check whether a branchmodel object is valid. #' #' @return Returns empty string if everything is fine. Otherwise, some other length-1 character. setGeneric("get_issues", function( object, ... ) standardGeneric("get_issues")) setMethod("get_issues", valueClass = "character", signature = signature(object = "branchmodel"), function(object) { issues = "" n_points = nrow( object@raw_data ) dimension = ncol( object@raw_data ) if( length( object@center ) != dimension ){ issues = paste(issues, "dim_center" ) } if( ncol( object@tips ) != dimension ){ issues = paste(issues, "dim_tips" ) } if( nrow( object@dist_df ) != n_points ){ issues = paste(issues, "nrow_dist_df" ) } if( length( object@assignments ) != n_points ){ issues = paste(issues, "len_assignments" ) } if( nrow( object@tips ) != 3 ){ issues = paste(issues, "num_tips" ) } if( length( object@tip_indices ) != 3 ){ issues = paste(issues, "tip_indices" ) } if( length( object@models ) != 3 ){ issues = paste(issues, "num_models" ) } if( any( sapply( object@models, class ) != rep( "principal.curve", 3 ) ) ){ issues = paste(issues, "model_class" ) } if( ncol( object@dist_df ) != 3 ){ issues = paste(issues, "dim_dist_df" ) } return( issues ) }) #' Set up the initial guess for the branchmodel. Inspired by kmeans++. #' #' @details Fix a tip randomly. Take the farthest point from that, and the farthest point from #' those two, and then reset the random one to the farthest from the other two. #' Distance to the pair is the minimum over the individual distances. initialize_tips = function( branchmodel ){ min_dist_sq_y_z = function( x, y, z = NULL ) { d1 = distance_sq( x, y ) if(is.null(z)){ return(d1) } d2 = distance_sq( x, z ) return( min( d1, d2 ) ) } tip3_idx = sample(size = 1, 1:nrow(branchmodel@raw_data)) tip3_embedding = branchmodel@raw_data[tip3_idx, ] tip1_idx = which.max( apply( X = branchmodel@raw_data, MARGIN = 1, FUN = min_dist_sq_y_z, y = tip3_embedding ) ) tip1_embedding = branchmodel@raw_data[tip1_idx, ] tip2_idx = which.max( apply( X = branchmodel@raw_data, MARGIN = 1, FUN = min_dist_sq_y_z, y = tip3_embedding, z = tip1_embedding ) ) tip2_embedding = branchmodel@raw_data[tip2_idx, ] tip3_idx = which.max( apply( X = branchmodel@raw_data, MARGIN = 1, FUN = min_dist_sq_y_z, y = tip2_embedding, z = tip1_embedding ) ) tip3_embedding = branchmodel@raw_data[tip3_idx, ] branchmodel@tips = as.matrix( Reduce( x = list( tip1_embedding, tip2_embedding, tip3_embedding ), f = rbind ) ) branchmodel@tip_indices = c( tip1_idx, tip2_idx, tip3_idx) return( branchmodel ) } #' For each point, fills in the distance to each line segment. #' setGeneric( "get_simple_sq_distances", function( branchmodel, ... ) standardGeneric( "get_simple_sq_distances" ) ) setMethod("get_simple_sq_distances", valueClass = "branchmodel", signature = signature( branchmodel = "branchmodel" ), function ( branchmodel ) { for(i in 1:nrow( branchmodel@tips ) ){ dtip = function(x) distance_to_ray(tip1 = branchmodel@tips[i, ] , tip2 = branchmodel@center, point = x) branchmodel@dist_df[ , i ] = apply( X = branchmodel@raw_data[ , 1:2 ], FUN = dtip, MARGIN = 1 ) } return( branchmodel ) })
#' Fit a "Y" shape to data #' #' @param raw_data Dataframe with numeric columns. #' @param max_iter Default 20. #' @param tol Stops when less than 100*tol percent of points are reclassified. #' @return S4 object of class "branchmodel". #' @details This function imposes a "Y" shape on data in a (preferably 2D) space. #' It represents the shape as three principal curves (from `princurve`), which each point #' hard-assigned to one curve. The internal methods iterate #' between reassigning the data to the nearest branch and adjusting the branches, #' with a heuristic to make the curves roughly meet in the center. #' @export fit_branchmodel = function( raw_data, max_iter = 100, tol = 0.01 ) { branchmodel = new_branchmodel_helper() branchmodel@raw_data = raw_data # # Initialize center to medioid and tips via kmeans++ ish branchmodel@center = unlist( apply( X = branchmodel@raw_data, FUN = median, MARGIN = 2 ) ) branchmodel = initialize_tips( branchmodel ) # # Distances to branches (using dist to rays from center through tips) branchmodel@dist_df = data.frame( matrix( NA, ncol = 3, nrow = nrow( branchmodel@raw_data ) ) ) branchmodel = get_simple_sq_distances( branchmodel ) # # Assign to nearest branch branchmodel = reassign_points( branchmodel ) # # Iterate branchmodel = fit_branchmodel_internal( branchmodel, max_iter = max_iter, tol = tol ) # # Check for issues and return issues = get_issues( branchmodel ) assertthat::assert_that("" == issues ) return( branchmodel ) }
#' Internal branchmodel function that performs fitting. #' setGeneric( "fit_branchmodel_internal", function( branchmodel, max_iter, tol ) standardGeneric( "fit_branchmodel_internal" ) ) setMethod("fit_branchmodel_internal", valueClass = "branchmodel", signature = signature(branchmodel = "branchmodel", max_iter = "numeric", tol = "numeric"), function (branchmodel, max_iter, tol ) { for( i in 1:max_iter ){ old_assignments = branchmodel@assignments branchmodel = fit_branches( branchmodel ) branchmodel = reassign_points( branchmodel ) branchmodel = relocate_center( branchmodel, iter = i ) # print( plot_branchmodel(branchmodel) ) prop_just_reassigned = mean(old_assignments != branchmodel@assignments) if( prop_just_reassigned < tol ) { print( paste0( "converged after ", i, " iterations" ) ) break } } if( prop_just_reassigned >= tol ) { warning( paste0( "Did not converge: prop_just_reassigned = ", prop_just_reassigned, ", max_iter = ", max_iter, "." ) ) } return( branchmodel ) })
Helpers
#' Update the branch-point. #' #' @details This function updates `@center`. setGeneric( "relocate_center", function( branchmodel, iter, previous_max_ambiguity ) standardGeneric( "relocate_center" ) ) setMethod( "relocate_center", valueClass = "branchmodel", signature = signature( branchmodel = "branchmodel", iter = "numeric" ), function ( branchmodel, iter ){ if(any(branchmodel@assignments == 0)){ branchmodel@center = colMeans( branchmodel@raw_data[branchmodel@assignments == 0, ]) } return( branchmodel ) }) #' Optimize individual branch models given branch assignments. #' #' This function updates `@dist_df` and `@models`. setGeneric( "fit_branches", function( branchmodel, ... ) standardGeneric( "fit_branches" ) ) setMethod( "fit_branches", valueClass = "branchmodel", signature = signature( branchmodel = "branchmodel" ), function ( branchmodel ){ # For each branch, fit a principal curve. # To make it touch the center, I augment the data with copies of the current center. # I remove them once the fitting has finished. for(i in 1:3){ # augment with dummy data at center and fit to that plus unambiguous data this_cluster = which( branchmodel@assignments == i ) n_aug = ceiling( 0.15*length( this_cluster ) ) center_copies = t( matrix( branchmodel@center, ncol = n_aug, nrow = length( branchmodel@center ) ) ) colnames( center_copies ) = colnames( branchmodel@raw_data ) pc_input = rbind( branchmodel@raw_data[this_cluster, ], center_copies ) branchmodel@models[[i]] = princurve::principal.curve( x = as.matrix(pc_input) ) #remove dummy data branchmodel@models[[i]] = princurve_truncate( branchmodel@models[[i]], n_remove = n_aug ) #get distances to curve projected_all = princurve::get.lam(x = as.matrix( branchmodel@raw_data ), s = branchmodel@models[[i]]$s, tag = branchmodel@models[[i]]$tag, stretch = 2 ) branchmodel@dist_df[, i] = apply( branchmodel@raw_data - projected_all$s, 1, norm2 ) } return(branchmodel) }) #' Reassign each point to the nearest branch. Assign ambiguous points the label 0. #' #' This function updates `@assignments`. setGeneric( "reassign_points", function( branchmodel, ... ) standardGeneric( "reassign_points" ) ) setMethod( "reassign_points", valueClass = "branchmodel", signature = signature( branchmodel = "branchmodel" ), function ( branchmodel ){ # # hard assignment branchmodel@assignments = apply( X = branchmodel@dist_df, MARGIN = 1, FUN = which.min ) # # Find ambiguous points by comparing closest to second closest branch. # # 0 means completely ambiguous (closest two branches equal) # # -Inf means well determined: closest branch infinitely closer min2next = function(x) { x = sort(x) x[1] - x[2] } ambiguity = apply( X = branchmodel@dist_df, MARGIN = 1, FUN = min2next ) sd = mean( apply( X = branchmodel@dist_df, MARGIN = 1, FUN = min ) ) shared = ambiguity > -2*sd branchmodel@assignments[shared] = as.integer(0) # Don't ever change the assignment of the tips. Restart any empty branch at its tip. # If a cluster is empty, there must be at least two tips in the same # cluster (by the pigeonhole principle), and one of these will be reassigned to the empty cluster. # Tip is accompanied by its 10 nearest neighbors, so that the princurve fit will have enough data. branchmodel@assignments[branchmodel@tip_indices[1:3]] = 1:3 for( cluster in 1:3 ){ if( sum(cluster == branchmodel@assignments) < 2 ){ warning("Restarting empty branch! This is a bad sign for convergence. Check your results visually.\n") tip_assignments = branchmodel@assignments [ branchmodel@tip_indices ] cluster_hogging_tips = which.max( table( tip_assignments ) ) neighbors = c( FNN::knnx.index( query = branchmodel@tips[ cluster ], data = branchmodel@raw_data, k = 11 ) ) branchmodel@assignments[neighbors] = cluster } } # Make sure assigned cells are in one contiguous block. Assign disconnected segments as ambiguous (0). for( cluster in 1:3 ){ this_cluster_idx = which( branchmodel@assignments == cluster ) conn_comp = find_contiguous_region( all_points = branchmodel@raw_data, good_idx = this_cluster_idx, root_idx = branchmodel@tip_indices[cluster] ) indices_to_discard = setdiff( this_cluster_idx, conn_comp ) branchmodel@assignments[ indices_to_discard ] = as.integer(0) } return( branchmodel ) })
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.