#' Basic methods for processing outbreaker results
#' Several methods are defined for instances of the class
#' \code{outbreaker_chains}, returned by \code{\link{outbreaker}}, including:
#' \code{print}, \code{plot}, \code{summary}
#' @rdname outbreaker_chains
#' @aliases outbreaker_chains print.outbreaker_chains plot.outbreaker_chains
#'   summary.outbreaker_chains
#' @author Thibaut Jombart (\email{thibautjombart@@gmail.com}).
#' @param x an \code{outbreaker_chains} object as returned by \code{outbreaker}.
#' @param n_row the number of rows to display in head and tail; defaults to 3.
#' @param n_col the number of columns to display; defaults to 8.
#' @param ... further arguments to be passed to other methods.
#' @export
#' @importFrom utils head tail
print.outbreaker_chains <- function(x, n_row = 3, n_col = 8, ...) {
  cat("\n\n ///// outbreaker results ///\n")
  cat("\nclass: ", class(x))
  cat("\ndimensions", nrow(x), "rows, ", ncol(x), "columns")

  ## process names of variables not shown
  if (ncol(x) > n_col) {
    ori_names <- names(x)
    x <- x[, seq_len(min(n_col, ncol(x)))]

    not_shown <- setdiff(ori_names, names(x))

    alpha_txt <- paste(not_shown[range(grep("alpha", not_shown))], collapse=" - ")
    t_inf_txt <- paste(not_shown[range(grep("t_inf", not_shown))], collapse=" - ")
    kappa_txt <- paste(not_shown[range(grep("kappa", not_shown))], collapse=" - ")

    cat("\nancestries not shown:", alpha_txt)
    cat("\ninfection dates not shown:", t_inf_txt)
    cat("\nintermediate generations not shown:", kappa_txt)

  ## heads and tails
  cat("\n\n/// head //\n")
  print(head(as.data.frame(x), n_row))
  cat("\n/// tail //\n")
  print(tail(as.data.frame(x), n_row))

#' @rdname outbreaker_chains
#' @param y a character string indicating which element of an
#'   \code{outbreaker_chains} object to plot.
#' @param type a character string indicating the kind of plot to be used (see details).
#' @param group a vector of character strings indicating the parameters to display, 
#' or "all" to display all global parameters (non node-specific parameters).
#' @param burnin the number of iterations to be discarded as burnin.
#' @param min_support a number between 0 and 1 indicating the minimum support of
#' ancestries to be plotted; only used if 'type' is 'network'.
#' @param labels a vector of length N indicating the case labels (must be
#'   provided in the same order used for dates of symptom onset).
## #' @param dens_all a logical indicating if the overal density computed over
## all runs should be displayed; defaults to TRUE #' @param col the colors to be
## used for different runs.
#' @export
#' @seealso See \href{http://www.repidemicsconsortium.org/outbreaker2/articles/introduction.html#graphics}{introduction vignette} for detailed examples on how to visualise \code{outbreaker_chains} objects.
#' @details \code{type} indicates the type of graphic to plot:
#' \itemize{
#' \item \code{trace} to visualise MCMC traces for parameters or augmented data (plots the
#' log-likelihood by default)
#' \item \code{hist} to plot histograms of quantitative values
#' \item \code{density} to plot kernel density estimations of quantitative values
#' \item \code{alpha} to visualise the posterior frequency of ancestries
#' \item \code{network} to visualise the transmission tree; note that
#'  this opens up an interactive plot and requires a web browser with
#'  Javascript enabled; the argument `min_support` is useful to select only the
#'  most supported ancestries and avoid displaying too many links
#' \item \code{kappa} to visualise the distributions generations between cases and their
#' ancestor/infector
#' }
#' @importFrom ggplot2 ggplot geom_line geom_point geom_histogram geom_density
#'   geom_violin aes aes_string coord_flip labs guides scale_size_area
#'   scale_x_discrete scale_y_discrete scale_color_manual scale_fill_manual
#'   scale_x_continuous scale_y_continuous theme_bw facet_wrap
#' @importFrom stats reshape
#' @importFrom grDevices xyTable
#' @importFrom graphics plot
plot.outbreaker_chains <- function(x, y = "post",
                                   type = c("trace", "hist", "density",
                                            "alpha", "t_inf", "kappa", "network"),
                                   group = NULL, 
                                   burnin = 0, min_support = 0.1, labels = NULL, ...) {

  ## CHECKS ##
  type <- match.arg(type)
  if (!y %in% names(x)) {
    stop(paste(y,"is not a column of x"))


  ## hopefully cran will avoid spurious warnings along the lines of "no
  ## visible binding for global variable" when using ggplot2::aes(...)

  frequency <- NULL

  if (burnin > max(x$step)) {
    stop("burnin exceeds the number of steps in x")
  x <- x[x$step>burnin,,drop = FALSE]

  ## check group
  if(!is.null(group)) {
    if(length(group) == 1L && group == "all") {
      ##remove _[digit] vars
      y_vars = names(x)[!grepl("(_[[:digit:]]+$)", names(x))]
    } else if(all(group %in% names(x))) {
      y_vars = c("step", group)
    } else {
      stop("grouping variables not found in outbreaker object")
    ## get only relevant data
    x_sub = as.data.frame(x)[,y_vars]
    ## switch it to long format to use in ggplot
    x_long = reshape(x_sub, 
                     idvar = "step", 
                     ids = x_sub$step,
                     direction = "long", 
                     new.row.names = NULL,
                     timevar = "Parameters", 
                     v.names = "y",
                     varying = list(names(x_sub)[2:ncol(x_sub)]), 
                     times = names(x_sub)[2:ncol(x_sub)])

  ## MAKE PLOT ##
  if (type == "trace") {
    if (!is.null(group)) {
      out <- ggplot(x_long) +
        geom_line(aes_string(x = "step", y = "y")) +
        scale_x_continuous(name = "Iteration") + 
        scale_y_continuous(name = NULL) + 
        facet_wrap(~ Parameters, scales = "free") 
    } else {
      out <- ggplot(x) +
        geom_line(aes_string(x = "step", y = y)) +
        labs(x = "Iteration",
             y = y,
             title = paste("trace:",y))

  if (type == "hist") {
    if (!is.null(group)) {
      out <- ggplot(x_long) +
        geom_histogram(aes_string(x = "y")) +
        geom_point(aes_string(x = "y", y = 0),
                   alpha = 0.5,
                   size = 3) +
        scale_x_continuous(name = NULL) + 
        scale_y_continuous(name = NULL) + 
        facet_wrap(~Parameters, scales = "free") 
    } else {
      out <- ggplot(x) +
        geom_histogram(aes_string(x = y)) +
        geom_point(aes_string(x = y, y = 0),
                   alpha = 0.5,
                   size = 3) +
        labs(x = y,
             title = paste("histogram:",y))

  if (type == "density") {
    if (!is.null(group)) {
      out <- ggplot(x_long) +
        geom_density(aes_string(x = "y")) +
        geom_point(aes_string(x = "y", y = 0),
                   alpha = 0.5,
                   size = 3) +
        scale_x_continuous(name = NULL) + 
        scale_y_continuous(name = NULL) + 
        facet_wrap(~Parameters, scales = "free") 
    } else {
      out <- ggplot(x) +
        geom_density(aes_string(x = y)) +
        geom_point(aes_string(x = y, y = 0),
                   alpha = 0.5,
                   size = 3) +
        labs(x = y,
             title = paste("density:",y))

  if (type == "alpha") {
    alpha <- as.matrix(x[,grep("alpha", names(x))])
    colnames(alpha) <- seq_len(ncol(alpha))
    from <- as.vector(alpha)
    to <- as.vector(col(alpha))
    from[is.na(from)] <- 0
    out_dat <- data.frame(xyTable(from,to))
    names(out_dat) <- c("from", "to", "frequency")
    ## Calculate proportion among ancestries
    get_prop <- function(i) {
        ind <- which(out_dat$to == out_dat$to[i])
    ## Return labels, if provided
    get_alpha_lab <- function(axis, labels = NULL) {
      if(is.null(labels)) labels <- seq_len(ncol(alpha))
      if(axis == 'x') return(labels) else
      if(axis == 'y') return(c("Import", labels))
    ## Return custom colors if provided
    get_alpha_color <- function(color = NULL) {
      if(is.null(color)) return(NULL)
      else return(scale_color_manual(values = color))
    ## This joining function is needed so that the '...' argument can be passed
    ## to two functions with different arguments
    get_lab_color <- function(labels = NULL, color = NULL) {
      list(alpha_lab_x = get_alpha_lab('x', labels),
           alpha_lab_y = get_alpha_lab('y', labels),
           alpha_color = get_alpha_color(color))

    tmp <- get_lab_color(labels, ...)
    out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
    out_dat$from <- factor(out_dat$from, levels = c(0, sort(unique(out_dat$to))))
    out_dat$to <- factor(out_dat$to, levels = sort(unique(out_dat$to)))
    out <- ggplot(out_dat) +
      geom_point(aes(x = to, y = from, size = frequency, color = to)) +
      scale_x_discrete(drop = FALSE, labels = tmp$alpha_lab_x) +
      scale_y_discrete(drop = FALSE, labels = tmp$alpha_lab_y) +
      labs(x = 'To', y = 'From', size = 'Posterior\nfrequency') +
      tmp$alpha_color +
      scale_size_area() +
      guides(colour = "none")

  if (type == "t_inf") {
    get_t_inf_lab <- function(labels = NULL) {
      N <- ncol(t_inf)
      if(is.null(labels)) labels <- 1:N
    ## Return custom colors if provided
    get_t_inf_color <- function(color = NULL) {
      if(is.null(color)) return(NULL)
      else return(scale_fill_manual(values = color))
    ## This joining function is needed so that the '...' argument can be passed
    ## to two functions with different arguments
    get_lab_color <- function(labels = NULL, color = NULL) {
      list(t_inf_lab_x = get_t_inf_lab(labels),
           t_inf_color = get_t_inf_color(color))

    t_inf <- as.matrix(x[,grep("t_inf", names(x))])
    tmp <- get_lab_color(labels, ...)
    dates <- as.vector(t_inf)
    cases <- as.vector(col(t_inf))
    out_dat <- data.frame(cases = factor(cases), dates = dates)
    out <- ggplot(out_dat) +
      geom_violin(aes(x = cases, y = dates, fill = cases)) +
      coord_flip() + 
      guides(fill = "none") +
      labs(y = 'Infection time', x = NULL) +
      tmp$t_inf_color +
      scale_x_discrete(labels = tmp$t_inf_lab)

  if (type == "kappa") {
    get_kappa_lab <- function(labels = NULL) {
      N <- ncol(kappa)
      if(is.null(labels)) labels <- 1:N
    kappa <- as.matrix(x[,grep("kappa", names(x))])
    generations <- as.vector(kappa)
    cases <- as.vector(col(kappa))
    to_keep <- !is.na(generations)
    generations <- generations[to_keep]
    cases <- cases[to_keep]
    out_dat <- data.frame(xyTable(generations, cases))
    get_prop <- function(i) {
        ind <- which(out_dat$y == out_dat$y[i])
    out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
    names(out_dat) <- c("generations", "cases", "frequency")
    out <- ggplot(out_dat) +
      geom_point(aes(x = generations, y = as.factor(cases), size = frequency, color = factor(cases))) +
      scale_size_area() +
      scale_y_discrete(labels = get_kappa_lab(labels)) +
      guides(colour = "none") +
      labs(title = "number of generations between cases",
           x = "number of generations to ancestor",
           y = NULL)

  if (type == "network") {
    ## extract edge info: ancestries
    alpha <- x[, grep("alpha",names(x)), drop = FALSE]
    from <- unlist(alpha)
    to <- as.vector(col(alpha))
    N <- ncol(alpha)
    edges <- stats::na.omit(data.frame(xyTable(from, to)))
    edges[3] <- edges$number/nrow(alpha)
    names(edges) <- c("from", "to", "value")
    edges <- edges[edges$value > min_support,,drop = FALSE]
    edges$arrows <- "to"
    case_cols <- cases_pal(N)
    edges$color <- case_cols[edges$from]

    ## ## extract edge info: timing
    ## t_inf <- x[, grep("t_inf",names(x)), drop = FALSE]
    ## mean_time <- apply(t_inf, 2, mean)
    ## mean_delay <- mean_time[edges$to] - mean_time[edges$from]
    ## mean_delay[mean_delay<1] <- 1
    ## edges$label <- paste(round(mean_delay), "days")

    ## node info
    find_nodes_size <- function(i) {
      sum(from==i, na.rm = TRUE) / nrow(alpha)
    get_node_lab <- function(labels = NULL) {
      if(is.null(labels)) labels <- 1:N
    nodes <- data.frame(id = seq_len(ncol(alpha)),
                        label = seq_len(ncol(alpha)))
    nodes$value <- vapply(nodes$id,
    nodes$color <- case_cols
    nodes$shape <- rep("dot", N)
    nodes$label <- get_node_lab(labels)

    smry <- summary(x, burnin = burnin)
    is_imported <- is.na(smry$tree$from)
    nodes$shaped[is_imported] <- "star"

    ## generate graph
    out <- visNetwork::visNetwork(nodes = nodes, edges = edges, ...)
    out <- visNetwork::visNodes(out, shadow = list(enabled = TRUE, size = 10),
                                color = list(highlight = "red"))
    out <- visNetwork::visEdges(out, arrows = list(
      to = list(enabled = TRUE, scaleFactor = 0.2)),
      color = list(highlight = "red"))



#' @rdname outbreaker_chains
#' @param object an \code{outbreaker_chains} object as returned by \code{outbreaker}.
#' @param method the method used to determine consensus ancestries. 'mpa'
#'   (maximum posterior ancestry) simply returns the posterior ancestry with the
#'   highest posterior support for each case, even if this includes
#'   cycles. 'decycle' will return the maximum posterior ancestry, except when
#'   cycles are detected, in which case the link in the cycle with the lowest
#'   support is pruned and the tree recalculated.
#' @export
#' @importFrom stats median
summary.outbreaker_chains <- function(object, burnin = 0, method = c("mpa", "decycle"), ...) {
  ## check burnin ##
  x <- object
  if (burnin > max(x$step)) {
    stop("burnin exceeds the number of steps in object")
  x <- x[x$step>burnin,,drop = FALSE]

  ## make output ##
  out <- list()

  ## summary for $step ##
  interv <- ifelse(nrow(x)>2, diff(tail(x$step, 2)), NA)
  out$step <- c(first = min(x$step),
                last = max(x$step),
                interval = interv,
                n_steps = length(x$step)

  ## summary of post, like, prior ##
  out$post <- summary(x$post)
  out$like <- summary(x$like)
  out$prior <- summary(x$prior)

  ## summary for mu ##
  out$mu <- summary(x$mu)

  ## summary for pi ##
  out$pi <- summary(x$pi)

  ## summary tree ##
  out$tree <- list()

  ## summary of alpha ##
  alpha <- as.matrix(x[,grep("alpha", names(x))])

  method <- match.arg(method)

  if(method == 'mpa') {

    ## function to get most frequent item
    f1 <- function(x) {
      as.integer(names(sort(table(x, exclude = NULL), decreasing = TRUE)[1]))
    out$tree$from <- apply(alpha, 2, f1)
    out$tree$to <- seq_len(ncol(alpha))

    ## function to get frequency of most frequent item
    f2 <- function(x) {
      (sort(table(x), decreasing = TRUE)/length(x))[1]
    support <- apply(alpha, 2, f2)

  } else if(method == 'decycle') {

    cons <- .decycle_tree(x)
    out$tree$from <- cons$from
    out$tree$to <- cons$to
    support <- cons$support


  ## summary of t_inf ##
  t_inf <- as.matrix(x[,grep("t_inf", names(x))])
  out$tree$time <- apply(t_inf, 2, median)

  out$tree$support <- support

  ## summary of kappa ##
  kappa <- as.matrix(x[,grep("kappa", names(x))])
  out$tree$generations <- apply(kappa, 2, median, na.rm = TRUE)
  out$tree$generations[is.na(out$tree$from)] <- NA

  ## shape tree as a data.frame
  out$tree <- as.data.frame(out$tree)
  rownames(out$tree) <- NULL

