R/tree_utils.R

Defines functions get_splits print.binary_segmentation_tree plot.binary_segmentation_tree get_change_points_from_tree

Documented in get_change_points_from_tree get_splits plot.binary_segmentation_tree print.binary_segmentation_tree

#' Get Change Points from a binary_segmentation_tree
#'
#' Utility function to get the change points with positive value for some variable from a binary_segmentation_tree
#'
#' @param tree An object of class \strong{binary_segmentation_tree}
#' @export
#' @return A vector with the sorted changepoints.
get_change_points_from_tree <- function(tree) {
  if (!is.null(tree$gamma)) {
    alpha <-
      tree$Get(
        'split_point',
        filterFun = function(x) {
          !is.na(x[['max_gain']]) &&
            !is.null(x[['max_gain']]) &&
            x[['max_gain']] > tree$gamma
        }
      )
  } else if (!is.null(tree$cv_improvement)) {
    alpha <-
      tree$Get(
        'split_point',
        filterFun = function(x) {
          !is.na(x[['cv_improvement']]) &&
            !is.null(x[['cv_improvement']]) &&
            x[['cv_improvement']] > 0
        }
      )
  } else {
    alpha <-
      tree$Get(
        'split_point',
        filterFun = function(x) {
          !is.na(x[['max_gain']]) &&
            !is.null(x[['max_gain']]) &&  x[['max_gain']] > 0
        }
      )
  }
  
  unname(sort(alpha))
}

#' S3 Method for plotting an object of class binary_segmentation_tree
#'
#' Simulatenously plots all gain curves of a binary_segmentation_tree
#'
#' @param x An object of class binary_segmentation_tree
#' @param true_change_points an array containing the true underlying change points
#' @param ... Further arguments passed to print generic.
#' @importFrom stats na.omit
#' @export
plot.binary_segmentation_tree <-
  function(x, true_change_points = NULL, ...) {
    best_split <-
      start <- end <- max_gain <- y <- gain <-  segment <- NULL
    if (!requireNamespace("ggplot2", quietly = TRUE)) {
      stop("Please install ggplot2: install.packages('ggplot2')")
    } else if (!requireNamespace("grDevices", quietly = TRUE)) {
      stop("Please install grDevices: install.packages('grDevices')")
    } else {
      data_lines <- data.frame()
      data_dot <- data.frame()
      
      colours <- grDevices::rainbow(nrow(x$segments))
      
      for (i in 1:nrow(x$segments)) {
        data_lines <- rbind(data_lines,
                            data.frame(
                              x = which(!is.na(unlist(x$segments[i, gain]))),
                              y = stats::na.omit(unlist(x$segments[i, gain])),
                              i = i,
                              segment = paste('(', x$segments[i, start], ' ', x$segments[i, end], ']', sep = '')
                            ))
        
        
        data_dot <- rbind(
          data_dot,
          data.frame(
            x = x$segments[i, best_split][[1]],
            best_split = x$segments[i, max_gain][[1]],
            segment = paste('(', x$segments[i, start], ' ', x$segments[i, end], ']', sep = '')
          )
          
        )
      }
      
      p <- ggplot2::ggplot() +
        ggplot2::geom_line(data = data_lines, ggplot2::aes(
          x = x,
          y = y,
          group = i,
          col = segment
        )) +
        ggplot2::geom_point(data = data_dot, ggplot2::aes(x = x, y = best_split, col = segment))
      
      if (!is.null(true_change_points)) {
        p <- p + ggplot2::geom_vline(xintercept = true_change_points)
      }
      ggplot2::ggplot(p)
    }
  }

#' print.binary_segmentation_tree
#'
#' S3 method for printing a binary_segmentation_tree object.
#'
#' Decorate the print method of the data.tree package to see more details at each node.
#'
#' @param x A data.tree node.
#' @param ... Further arguments passed to print generic.
#' @export
print.binary_segmentation_tree <- function(x, ...) {
  if (!is.null(x$pvalue)) {
    NextMethod(generic = NULL,
               object = NULL,
               'split_point',
               'max_gain',
               'pvalue')
  } else if (!is.null(x$cv_improvement)) {
    NextMethod(
      generic = NULL,
      object = NULL,
      'split_point',
      'max_gain',
      'cv_loss',
      'cv_improvement',
      'lambda'
    )
  } else {
    NextMethod(generic = NULL,
               object = NULL,
               'split_point',
               'max_gain')
  }
}

#' obtain all possible segmentation scenarios that can be obtained from the input x by pruning
#'
#' @param x tree as generated by \link{binary_segmentation}
get_splits <- function(x) {
  variable <- 'split_point'
  
  if (is.null(x[[variable]])) {
    list(NULL)
  } else {
    t1 <- lapply(get_splits(x$children[[1]]), c, x[[variable]])
    t2 <- lapply(get_splits(x$children[[2]]), c, x[[variable]])
    
    res <-
      c(apply(expand.grid(t1, t2), 1, unlist), list(x[[variable]]), list(NULL))
    res <- lapply(res, unique)
    unique(res)
  }
}
mlondschien/hdcd documentation built on Jan. 5, 2021, 11:26 p.m.