R/hybridalg.R

Defines functions hybridalg

Documented in hybridalg

#' Segment data into change points using a mixed hierarchical-exact approach
#'
#' For the larger datasets, assume the data is hierarchical, but calculate
#' the exact segments when they're smaller than a threshold
#'
#' This algorithm implements an approach mixing the hierarchical and exact
#' algorithms. It uses the hierarchical algorithms when the size of the segment
#' is bigger than the threshold, and then goes on to use the exact algorithm
#' when the size of the segment is less than or equal to the threshold.
#'
#' @inherit base_segment
#' @param threshold the threshold for which the exact algorithm will be used,
#'   i.e. when the number of columns in the segment is less than or equal to the
#'   threshold.
#'
#' @export
hybridalg <- function(
                      data,
                      cost,
                      likelihood,
                      allow_parallel = TRUE,
                      max_segments = ncol(data),
                      threshold = 50) {
  cost <- get_cost(cost, likelihood)
  recursive_hybrid <- function(
                                 data,
                                 initial_position,
                                 cost,
                                 allow_parallel,
                                 recursive_fn) {
    if (ncol(data) > threshold) {
      recursive_hieralg(
        data = data,
        initial_position = initial_position,
        cost = cost,
        allow_parallel = allow_parallel,
        recursive_fn = recursive_hybrid
      )
    } else {
      exact_segments(
        data = data,
        cost = cost,
        max_segments = max_segments,
        allow_parallel = allow_parallel,
        initial_position = initial_position
      )
    }
  }

  segs <- recursive_hybrid(
    data = data,
    initial_position = 1,
    cost = cost,
    allow_parallel = allow_parallel,
    recursive_fn = recursive_hybrid
  )
  changepoints <- vapply(segs, "[[", FUN.VALUE = numeric(1), "changepoint")

  if (length(changepoints) > 0) {
    changepoints <- sort(changepoints)
  }

  if (length(changepoints) > 0 && length(changepoints) + 1 > max_segments) {
    temp_results <- list(changepoints = changepoints)
    costs <- calculate_segment_costs(temp_results, data, cost = cost)
    changepoints_with_cost <- data.frame(changepoint = changepoints, cost = head(cost, -1))
    changepoints <- with(changepoints_with_cost, changepoint[order(-cost)[1:(max_segments - 1)]])
  }

  results <- list(
    changepoints = changepoints,
    segments = calculate_segments(changepoints, ncol(data))
  )
  class(results) <- "segmentr"
  results
}
thalesmello/segmentr documentation built on March 4, 2020, 1 a.m.