#' wild_binary_segmentation
#'
#' @param x An n x p data matrix
#' @inheritParams hdcd
#' @import data.table
#' @import data.tree
wild_binary_segmentation <- function(x, get_best_split, delta, lambda, model_selection_function = NULL, cross_validation_function = NULL, control = hdcd_control()){
control$n_obs <- n_obs <- nrow(x)
minimal_segment_length <- ceiling(delta * n_obs)
# Create tree structure to save binary segmentation output. Each segment corresponds to a node.
tree <- data.tree::Node$new(paste('(', 0, ' ', n_obs, ']', sep = ''), start = 0, end = n_obs)
class(tree) <- c("wild_binary_segmentation_tree", class(tree))
# draw segments for which gains and optimal splits will be calculated
segments <- draw_segments(start = 0, end = n_obs, n_segments = control$wbs_n_segments, minimal_segment_length = 2 * minimal_segment_length, include_full_segment = FALSE)
# add column with list of split candidates
segments[, split_candidates := mapply(function(i, j){(i + minimal_segment_length) : (j - minimal_segment_length)}, start, end)]
segments <- cbind(segments, t(mapply(get_best_split, start = segments[, start], end = segments[, end], split_candidates = segments[, split_candidates], MoreArgs = list(lambda = lambda, x = x))))
# CHANGE THIS POSSIBLY??? MAKE get_best_split RETURN comparable GAINS
segments[, rel_max_gain := unlist(max_gain) * 150 / (end - start)]
# Store estimation parameters in root node.
tree$lambda <- lambda
tree$delta <- delta
tree$n_obs <- n_obs
# Recursive function that creates the binary segmentation tree
wild_binary_segmentation_recursive <- function(node){
# stop if segment is not long enough to allow for splitting
segment_length <- node$end - node$start
if(segment_length < 2 * minimal_segment_length){
return(NA)
}
split_candidates <- (node$start + minimal_segment_length) : (node$end - minimal_segment_length)
temp <- get_best_split(x = x, start = node$start, end = node$end, split_candidates = split_candidates, lambda = lambda)
segments <<- rbind(segments,
c(list(start = node$start,
end = node$end,
split_candidates = list(split_candidates)),
lapply(temp, function(i){if(length(i) == 1){i}else{list(i)}}),
list(rel_max_gain = temp$max_gain * n_obs / (node$end - node$start))
)
)
i <- segments[, .I[which.max(rel_max_gain * (start >= node$start & end <= node$end))]]
node$gain <- unlist(segments[i, gain])
node$split_point <- unlist(segments[i, best_split])
node$max_gain <- unlist(segments[i, max_gain])
# Allow get_best_split to also return prediction probabilities and splits used
if(!is.null(temp$predictions)) node$predictions <- temp$predictions
if(!is.null(temp$splits)) node$splits <- temp$splits
# stop if no best split can be found (e.g. when the gain is nonpositive for each split)
if(is.na(node$split_point)){
return(NA)
} else if(node$max_gain <= 0){
return(NA)
}
# Create left child
child_left <- node$AddChild(
paste('(', node$start, ' ', node$split_point, ']', sep = ''),
start = node$start,
end = node$split_point,
lambda = node$lambda
)
# Create right child
child_right <- node$AddChild(
paste('(', node$split_point, ' ', node$end, ']', sep = ''),
start = node$split_point,
end = node$end,
lambda = node$lambda
)
class(child_left) <- c("wild_binary_segmentation_tree", class(child_left))
class(child_right) <- c("wild_binary_segmentation_tree", class(child_right))
# if a model_selection_function is supplied, test for significance of split
if(!is.null(model_selection_function)){
temp <- model_selection_function(x, node$start, node$split_point, node$end)
# Stop with error if model_selection_function is not of required form
if(is.null(temp$statistic) | is.null(temp$is_significant)){
stop('model_selection_function is not of the required form. Make sure model_selection_function returns a list with
attritbutes statistic and is_significant')
}
# Save output of model_selection_function in node
node$model_selection_statistic <- temp$statistic
node$is_significant <- temp$is_significant
# stop if split is not significant
if(!node$is_significant){
return(NA)
}
}
# If cross_validation_function is supplied, calculate cv_losses of both subsegments and calculate
# inner cross-validated gain
if(!is.null(cross_validation_function)){
temp_left <- cross_validation_function(x, start = node$start, end = node$split_point, lambda = node$lambda, folds = node$root$folds)
temp_right <- cross_validation_function(x, start = node$split_point, end = node$end, lambda = node$lambda, folds = node$root$folds)
node$cv_improvement <- node$cv_loss - temp_left$cv_loss - temp_right$cv_loss
node$relative_cv_improvement <- node$cv_improvement / (node$end - node$start)
child_left$cv_loss <- temp_left$cv_loss
child_left$lambda <- temp_left$lambda_opt
child_right$cv_loss <- temp_right$cv_loss
child_right$lambda <- temp_right$lambda_opt
rm(temp_left)
rm(temp_right)
if(node$cv_improvement < 0){
return(NA)
}
}
wild_binary_segmentation_recursive(child_left)
wild_binary_segmentation_recursive(child_right)
}
wild_binary_segmentation_recursive(tree)
tree$segments <- segments
tree
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.