R/wild_binary_segmentation.R

#' wild_binary_segmentation
#' 
#' @param x An n x p data matrix
#' @inheritParams hdcd
#' @import data.table
#' @import data.tree
wild_binary_segmentation <- function(x, get_best_split, delta, lambda, model_selection_function = NULL, cross_validation_function = NULL, control = hdcd_control()){
  
  control$n_obs <- n_obs <- nrow(x)
  minimal_segment_length <- ceiling(delta * n_obs)
  
  # Create tree structure to save binary segmentation output. Each segment corresponds to a node.
  tree <- data.tree::Node$new(paste('(', 0, ' ', n_obs, ']', sep = ''), start = 0, end = n_obs)
  class(tree) <- c("wild_binary_segmentation_tree", class(tree))
  
  # draw segments for which gains and optimal splits will be calculated
  segments <- draw_segments(start = 0, end = n_obs, n_segments = control$wbs_n_segments, minimal_segment_length = 2 * minimal_segment_length, include_full_segment = FALSE)
  
  # add column with list of split candidates
  segments[, split_candidates := mapply(function(i, j){(i + minimal_segment_length) : (j - minimal_segment_length)}, start, end)]
  
  segments <- cbind(segments, t(mapply(get_best_split, start = segments[, start], end = segments[, end], split_candidates =  segments[, split_candidates], MoreArgs = list(lambda = lambda, x = x))))
  
  # CHANGE THIS POSSIBLY??? MAKE get_best_split RETURN comparable GAINS
  segments[, rel_max_gain :=  unlist(max_gain) * 150 / (end - start)]
  
  # Store estimation parameters in root node.
  tree$lambda <- lambda
  tree$delta <- delta
  tree$n_obs <- n_obs

  # Recursive function that creates the binary segmentation tree 
  wild_binary_segmentation_recursive <- function(node){
    
    # stop if segment is not long enough to allow for splitting
    segment_length <- node$end - node$start
    if(segment_length < 2 * minimal_segment_length){ 
      return(NA)
    }
    split_candidates <- (node$start + minimal_segment_length) : (node$end - minimal_segment_length)
    
    temp <- get_best_split(x = x, start = node$start, end = node$end, split_candidates = split_candidates, lambda = lambda)
    
    segments <<- rbind(segments, 
                      c(list(start = node$start, 
                             end = node$end,
                             split_candidates = list(split_candidates)),
                             lapply(temp, function(i){if(length(i) == 1){i}else{list(i)}}),
                             list(rel_max_gain = temp$max_gain * n_obs / (node$end - node$start))
                      )
    )
    
    i <- segments[, .I[which.max(rel_max_gain * (start >= node$start & end <= node$end))]]
    
    node$gain <- unlist(segments[i, gain])
    node$split_point <- unlist(segments[i, best_split])
    node$max_gain <- unlist(segments[i, max_gain])
    
    # Allow get_best_split to also return prediction probabilities and splits used
    if(!is.null(temp$predictions)) node$predictions <- temp$predictions
    if(!is.null(temp$splits)) node$splits <- temp$splits
    
    # stop if no best split can be found (e.g. when the gain is nonpositive for each split)
    if(is.na(node$split_point)){
      return(NA)
    } else if(node$max_gain <= 0){
      return(NA)
    }
    
    # Create left child
    child_left <- node$AddChild(
      paste('(', node$start, ' ', node$split_point, ']', sep = ''),
      start = node$start,
      end = node$split_point,
      lambda = node$lambda
    )
    
    # Create right child
    child_right <- node$AddChild(
      paste('(', node$split_point, ' ', node$end, ']', sep = ''),
      start = node$split_point,
      end = node$end,
      lambda = node$lambda
    )
    
    class(child_left) <- c("wild_binary_segmentation_tree", class(child_left))
    class(child_right) <- c("wild_binary_segmentation_tree", class(child_right))
    
    
    # if a model_selection_function is supplied, test for significance of split
    if(!is.null(model_selection_function)){
      temp <- model_selection_function(x, node$start, node$split_point, node$end)
      
      # Stop with error if model_selection_function is not of required form
      if(is.null(temp$statistic) | is.null(temp$is_significant)){
        stop('model_selection_function is not of the required form. Make sure model_selection_function returns a list with
             attritbutes statistic and is_significant')
      }
      
      # Save output of model_selection_function in node
      node$model_selection_statistic <- temp$statistic
      node$is_significant <- temp$is_significant
      
      # stop if split is not significant
      if(!node$is_significant){
        return(NA)
      }
    }
    
    # If cross_validation_function is supplied, calculate cv_losses of both subsegments and calculate
    # inner cross-validated gain
    if(!is.null(cross_validation_function)){
      temp_left <- cross_validation_function(x, start = node$start, end = node$split_point, lambda = node$lambda, folds = node$root$folds)
      temp_right <- cross_validation_function(x, start = node$split_point, end = node$end, lambda = node$lambda, folds = node$root$folds)
      
      node$cv_improvement <- node$cv_loss - temp_left$cv_loss - temp_right$cv_loss
      node$relative_cv_improvement <- node$cv_improvement / (node$end - node$start)
      child_left$cv_loss <- temp_left$cv_loss
      child_left$lambda <- temp_left$lambda_opt
      child_right$cv_loss <- temp_right$cv_loss
      child_right$lambda <- temp_right$lambda_opt
      rm(temp_left)
      rm(temp_right)
      if(node$cv_improvement < 0){
        return(NA)
      }
    }
    
    wild_binary_segmentation_recursive(child_left)
    wild_binary_segmentation_recursive(child_right)
    
    }
  
  wild_binary_segmentation_recursive(tree)
  
  tree$segments <- segments
  
  tree
  
}
MalteLond/rfcd documentation built on June 19, 2019, 2:52 p.m.