R/edd_plot.r

Defines functions edd_plot_stats_graph edd_plot_stats edd_plot_grouped_eds edd_plot_eds edd_plot_grouped_mus edd_plot_mus edd_plot_grouped_las edd_plot_las edd_plot_grouped_ltt edd_plot_ltt edd_plot_nltt edd_plot_histree edd_plot_tree edd_plot

Documented in edd_plot edd_plot_eds edd_plot_grouped_eds edd_plot_grouped_las edd_plot_grouped_ltt edd_plot_grouped_mus edd_plot_las edd_plot_ltt edd_plot_mus edd_plot_nltt edd_plot_stats edd_plot_tree

#' @name edd_plot
#' @title Generating plots for a replicated edd simulation
#' @description Function to automatically generate several plots from raw data
#' of a replicated edd simulation
#' @param raw_data a list of results generated by edd simulation function
#' @param which Which part of data to be plotted
#' @param save_plot Logical, decides whether to save the plots to files
#' @param strategy Determine if the simulation is sequential or multi-sessioned
#' or multi-cored
#' @param workers Determine how many sessions are participated in the simulation
#' @param verbose Logical, decides whether to print loading details
#' @return a plot pack containing several plots
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot
edd_plot <- function(raw_data = NULL,
                     which = "all",
                     save_plot = FALSE,
                     path = NULL,
                     strategy = "sequential",
                     workers = 1,
                     verbose = TRUE) {
  check_parallel_arguments(strategy, workers, verbose)
  check_raw_data(raw_data$data)
  plots <- list()
  message(paste0("Generating plots for ", which))

  if (save_plot == TRUE) {
    message("Saving plots to files")
  }

  # plot normalized lineages through time
  #if ("all" %in% which | "nltt" %in% which) {
  #  plot_nltt <- lapply(raw_data$data, edd_plot_nltt, save_plot = save_plot, path = path)
  #  plots <- list(plots, plot_nltt)
  #}

  # plot lineages through time
  if ("all" %in% which | "ltt" %in% which) {
    message("Plotting lineages-through-time")
    plot_ltt <- lapply(raw_data$data, edd_plot_ltt, save_plot = save_plot, path = path)
    plots <- list(plots, plot_ltt)
  }

  if ("all" %in% which | "grouped_ltt" %in% which) {
    message("Plotting grouped lineages-through-time")
    plot_grouped_ltt <- edd_plot_grouped_ltt(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_grouped_ltt)
  }

  # plot speciation rates
  if ("all" %in% which | "las" %in% which) {
    message("Plotting speciation rates")
    plot_las <- lapply(raw_data$data, edd_plot_las, save_plot = save_plot, path = path)
    plots <- list(plots, plot_las)
  }

  if ("all" %in% which | "grouped_las" %in% which) {
    message("Plotting grouped speciation rates")
    plot_grouped_las <- edd_plot_grouped_las(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_grouped_las)
  }

  # plot extinction rates
  if ("all" %in% which | "mus" %in% which) {
    message("Plotting extinction rates")
    plot_mus <- lapply(raw_data$data, edd_plot_mus, save_plot = save_plot, path = path)
    plots <- list(plots, plot_mus)
  }

  if ("all" %in% which | "grouped_mus" %in% which) {
    message("Plotting grouped extinction rates")
    plot_grouped_mus <- edd_plot_grouped_mus(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_grouped_mus)
  }

  # plot evolutionary distinctiveness-es
  if ("all" %in% which | "eds" %in% which) {
    message("Plotting evolutionary distinctiveness-es")
    plot_eds <- lapply(raw_data$data, edd_plot_eds, save_plot = save_plot, path = path)
    plots <- list(plots, plot_eds)
  }

  if ("all" %in% which | "grouped_eds" %in% which) {
    message("Plotting grouped evolutionary distinctiveness-es")
    plot_grouped_eds <- edd_plot_grouped_eds(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_grouped_eds)
  }

  if ("all" %in% which | "balance" %in% which) {
    message("Plotting balance metrics")
    plot_balance <- edd_plot_balance(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_balance)
  }

  if ("all" %in% which | "branch" %in% which) {
    message("Plotting branching metrics")
    plot_balance <- edd_plot_branch(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_balance)
  }

  if ("all" %in% which | "temporal" %in% which) {
    message("Plotting temporal dynamics")
    plot_temporal <- lapply(raw_data$data, edd_plot_temporal_dynamics, save_plot = save_plot, path = path)
    plots <- list(plots, plot_temporal)
  }

  if ("all" %in% which | "grouped_temporal" %in% which) {
    message("Plotting grouped temporal dynamics")
    plot_grouped_temporal <- edd_plot_grouped_temporal_dynamics(raw_data, save_plot = save_plot, path = path)
    plots <- list(plots, plot_grouped_temporal)
  }

  if ("all" %in% which | "best_histrees" %in% which) {
    message("Plotting best represented histrees")
    message("Looking for best representatives")
    rep_ids <- find_best_rep_ids(raw_data)
    plot_best_histrees <- edd_plot_best_histrees(raw_data,
                                                 rep_ids = rep_ids,
                                                 which = "las",
                                                 save_plot = save_plot,
                                                 path = path)
    plots <- list(plots, plot_best_histrees)
  }

  if (save_plot != TRUE) {
    message("Plots are not saved, returning ggplot2 objects as a list")
    return(plots)
  }
}


#' @name edd_plot_tree
#' @title Plotting a tree for a selected parameter set
#' @description Function to plot a tree for a selected parameter set
#' @param raw_data a list of results generated by edd simulation function
#' @param rep_id the id of the replicate to be plotted
#' @param drop_extinct Logical, decides whether to drop extinct lineages
#' @param save_plot Logical, decides whether to save the plots to files
#' @return an plot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_tree
edd_plot_tree <- function(raw_data = NULL,
                          rep_id = stop("Please specify the id of the replicate to be plotted"),
                          drop_extinct = TRUE,
                          save_plot = FALSE
) {
  if (drop_extinct == TRUE) {
    tree <- raw_data$tes[[rep_id]]
  } else {
    tree <- raw_data$tas[[rep_id]]
  }

  if (drop_extinct == FALSE) {
    if (!is.null(find_extinct(tree, which = "tip_label"))) {
      tree <- mark_extinct_tips(tree)
      plot_tree <- ggtree::ggtree(tree, aes(linetype = linetype)) +
        ggplot2::ylab("Representative tree") +
        ggtree::theme_tree2() +
        ggplot2::theme(legend.position = "none")
    } else {
      plot_tree <- ggtree::ggtree(tree) +
        ggtree::theme_tree2() +
        ggplot2::ylab("Representative tree")
    }
  } else {
    plot_tree <- ggtree::ggtree(tree) +
      ggtree::theme_tree2() +
      ggplot2::ylab("Representative tree")
  }

  if (save_plot == TRUE) {
    stop("Saving plots to files is not supported for this function")
  } else {
    return(plot_tree)
  }
}


edd_plot_histree <- function(raw_data = NULL,
                             rep_id = stop("Please specify the id of the replicate to be plotted"),
                             which = stop("Please specify which history to be plotted"),
                             drop_extinct = FALSE,
                             save_plot = FALSE
) {
  l_table <- raw_data$l_tables[[rep_id]]
  if (drop_extinct == TRUE) {
    phy <- raw_data$tes[[rep_id]]
  } else {
    phy <- raw_data$tas[[rep_id]]
  }
  params <- raw_data$all_pars
  stopifnot(which %in% c("las", "mus", "eds"))
  if (which == "las") {
    history <- raw_data$las[[rep_id]]
  }
  if (which == "mus") {
    history <- raw_data$mus[[rep_id]]
  }
  if (which == "eds") {
    history <- raw_data$eds[[rep_id]]
  }
  history <- cbind(time = raw_data$ltt[[rep_id]]$time, history)

  end_state <- sample_end_state(l_table, params, which)
  history <- dplyr::bind_rows(history, end_state)

  segments <- stat_histree(phy, history)
  histree <- ggplot2::ggplot(segments) +
    ggplot2::geom_segment(ggplot2::aes(x = x, y = y, xend = xend, yend = yend, color = state)) +
    ggplot2::theme(panel.grid.major = element_blank(),
                   panel.grid.minor = element_blank(),
                   panel.background = element_blank(),
                   axis.line.x = element_line(colour = "black"),
                   axis.text.y = element_blank(),
                   axis.ticks.y = element_blank()) +
    ggplot2::labs(x = NULL, y = NULL) +
    ggplot2::guides(title = "State") +
    viridis::scale_color_viridis(discrete = FALSE, option = "turbo")

  if (save_plot == TRUE) {
    stop("Saving plots to files is not supported for this function")
  } else {
    return(histree)
  }
}


#' @name edd_plot_nltt
#' @title Generating nLTT plot for a replicated edd simulation
#' @description Function to generate normalized lineages through time plot from
#' raw data of a replicated edd simulation
#' @param raw_data a list of results generated by edd simulation function
#' @param drop_extinct Logical, decides whether to drop extinct lineages
#' @param save_plot Logical, decides whether to save the plots to files
#' @return an plot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_nltt
edd_plot_nltt <- function(raw_data = NULL,
                          drop_extinct = TRUE,
                          save_plot = FALSE,
                          path = NULL
) {
  if (drop_extinct == TRUE) {
    message("Drawing nLTT plot with trees of extant species")
    df <- nLTT::get_nltt_values(raw_data$tes, dt = 0.01)
  } else {
    message("Drawing nLTT plot with trees of all species")
    df <- nLTT::get_nltt_values(raw_data$tas, dt = 0.01)
  }

  pars_list <- extract_parameters(raw_data)
  anno <- create_annotation(pars_list, vjust = c(1, 2.1))

  plot_nltt <-
    ggplot2::ggplot() +
      ggplot2::geom_point(data = df,
                          ggplot2::aes(t, nltt, color = id),
                          size = I(0.1)) +
      ggplot2::stat_summary(data = df,
                            ggplot2::aes(t, nltt),
                            fun.data = ggplot2::mean_cl_boot,
                            geom = "smooth") +
      ggplot2::ggtitle("Average nLTT plot of phylogenies") +
      ggplot2::labs(x = "Normalized time", y = "Normalized number of lineages") +
      #scale_colour_ggthemr_d() +
      ggtext::geom_richtext(
        data = anno,
        ggplot2::aes(
          x,
          y,
          label = label,
          angle = angle,
          hjust = hjust,
          vjust = vjust
        ),
        fill = "#E8CB9C"
      ) +
      viridis::scale_colour_viridis(discrete = TRUE, option = "A") +
      ggplot2::xlim(0, 1) +
      ggplot2::ylim(0, 1) +
      ggplot2::theme(legend.position = "none",
                     aspect.ratio = 3 / 4)

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = pars_list,
                         plot = plot_nltt,
                         which = "nltt",
                         path = path,
                         device = "png",
                         width = 5,
                         height = 4,
                         dpi = "retina")
  } else {
    return(plot_nltt)
  }
}


#' @name edd_plot_ltt
#' @title Generating LTT plot for a replicated edd simulation
#' @description Function to generate lineages through time plot from
#' raw data of a replicated edd simulation
#' @param raw_data A list of results generated by edd simulation function
#' @param alpha Resolution of the confidence interval
#' @param save_plot true or false, to decide whether to save the plots to files
#' @param path The path to save the plots
#' @param annotation Logical, decide whether to add annotation to the plot
#' @param ribbon Logical, decide whether to add ribbon of confidence interval to the plot
#' @param trans Specify the transformation of the y axis
#' @return an plot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_ltt
edd_plot_ltt <-
  function(raw_data = NULL,
           alpha = 0.05,
           save_plot = FALSE,
           path = NULL,
           annotation = TRUE,
           ribbon = TRUE,
           trans = "log10"
  ) {
    pars_list <- extract_parameters(raw_data)

    brts <- lapply(raw_data$l_tables, function(x) {
      x[-1, 1]
    })

    ts <- unlist(brts)
    ts <- unique(ts)
    ts <- sort(ts)

    df <- calculate_CI(brts, ts, alpha = alpha)
    colnames(df) <- c("t", "median", "minalpha", "maxalpha", "mean")
    df <- as.data.frame(df)

    plot_ltt <-
      ggplot2::ggplot(df) +
        ggplot2::geom_line(ggplot2::aes(t, mean)) +
        ggplot2::theme(legend.position = "none",
                       aspect.ratio = 3 / 4) +
        ggplot2::xlab("Age") +
        ggplot2::ylab(paste0("Number of lineages ", "(", trans, ")")) +
        ggplot2::scale_y_continuous(trans = trans, labels = function(x) format(x, scientific = FALSE)) +
        ggplot2::xlim(0, 6) +
        ggplot2::ggtitle(toupper(pars_list$metric))

    if (annotation == TRUE) {
      anno <- create_annotation(pars_list, vjust = c(1, 2.1))
      plot_ltt <- plot_ltt +
        ggtext::geom_richtext(
          data = anno,
          ggplot2::aes(
            x = x,
            y = y,
            label = label,
            angle = angle,
            hjust = hjust,
            vjust = vjust
          ),
          fill = "#E8CB9C"
        )
    }

    if (ribbon == TRUE) {
      plot_ltt <- plot_ltt + ggplot2::geom_ribbon(ggplot2::aes(t, mean, ymax =
        maxalpha, ymin = minalpha), alpha = 0.2)
    }

    if (save_plot == TRUE) {
      save_with_parameters(pars_list = pars_list,
                           plot = plot_ltt,
                           which = "ltt",
                           path = path,
                           device = "png",
                           width = 5,
                           height = 4,
                           dpi = "retina")
    } else {
      return(plot_ltt)
    }
  }


#' @name edd_plot_grouped_ltt
#' @title Generating LTT plot for a replicated edd simulation
#' @description Function to generate grouped lineages through time plot from
#' raw data of a replicated edd simulation
#' @param raw_data a list of results generated by edd simulation function
#' @param group specify the group to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_grouped_ltt
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_grouped_ltt <- function(raw_data = NULL, group = "metric", save_plot = FALSE, path = NULL) {
  tally <- tally_by_group(raw_data$params, group)
  indexes <- create_indexes_by_group(tally)
  grouped_ltt <- lapply(indexes, function(x) {
    plots <- lapply(x, function(y) edd_plot_ltt(raw_data$data[[y]], save_plot = FALSE, annotation = FALSE))
    grouped_plot <- patchwork::wrap_plots(plots, nrow = 1) +
      patchwork::plot_annotation(title = pars_to_title(raw_data$data[[x[1]]]$all_pars),
                                 theme = ggplot2::theme(plot.title = element_text(size = 25)))
    if (save_plot == TRUE) {
      pars_list <- extract_parameters(raw_data$data[[x[1]]])
      save_with_parameters(pars_list = pars_list,
                           plot = grouped_plot,
                           which = "grouped_ltt",
                           path = path,
                           device = "png",
                           width = 4 * tally$groups,
                           height = 10,
                           dpi = "retina")
    }
  })
}


#' @name edd_plot_las
#' @title Generating line plot of speciation rates of each lineage
#' @description Function to generate a plot showing the transition of speciation
#' rates of each lineage
#' @param raw_data a list of results generated by edd simulation function
#' @param rep_id specify the id of replication to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @param deviation Logical, whether to plot the deviation of speciation rates
#' @param annotation Logical, decide whether to add annotation to the plot
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_las
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_las <- function(raw_data = NULL,
                         rep_id = 1,
                         save_plot = FALSE,
                         path = NULL,
                         deviation = FALSE,
                         annotation = TRUE) {
  pars_list <- extract_parameters(raw_data)

  las_table <- cbind(Time = raw_data$ltt[[rep_id]]$time, raw_data$las[[rep_id]])
  las_long <-
    las_table %>%
      tidyr::gather("Tip", "Lambda", -Time) %>%
      na.omit()

  plot_las1 <- ggplot2::ggplot(las_long) +
    ggplot2::geom_path(ggplot2::aes(Time, Lambda, group = Tip, color = Lambda)) +
    viridis::scale_color_viridis(option = "D") +
    ggplot2::ylab("Speciation rate") +
    ggplot2::xlab("Age") +
    ggplot2::xlim(0, 6) +
    ggplot2::theme_classic() +
    ggplot2::theme(legend.position = "none",
                   aspect.ratio = 1 / 1)

  if (annotation == TRUE) {
    anno <- create_annotation(pars_list, y = c(-Inf, -Inf), vjust = c(-1, 0))
    plot_las1 <- plot_las1 +
      ggtext::geom_richtext(
        data = anno,
        ggplot2::aes(
          x = x,
          y = y,
          label = label,
          angle = angle,
          hjust = hjust,
          vjust = vjust
        ),
        fill = "#E8CB9C"
      )
  }

  if (deviation == TRUE) {
    las_long <- las_long %>%
      dplyr::group_by(Time) %>%
      dplyr::mutate(Deviation = Lambda - mean(Lambda))
    plot_las2 <- ggplot2::ggplot(las_long) +
      ggplot2::geom_path(ggplot2::aes(Time, Deviation, group = Tip, color = Lambda)) +
      viridis::scale_color_viridis(option = "D") +
      ggplot2::geom_hline(yintercept = 0,
                          linetype = "twodash",
                          color = "grey") +
      ggplot2::xlim(0, 6) +
      ggplot2::theme_classic() +
      ggplot2::theme(legend.position = "right",
                     aspect.ratio = 1 / 1)
    plot_las <- patchwork::wrap_plots(plot_las1 + plot_las2, ncol = 1)
  } else {
    plot_las <- plot_las1
  }

  plot_tree <- edd_plot_histree(raw_data, rep_id = rep_id, which = "las", drop_extinct = FALSE, save_plot = FALSE) +
    ggplot2::ggtitle(toupper(pars_list$metric))

  plot_las <- patchwork::wrap_plots(plot_tree, plot_las, ncol = 1)

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = pars_list,
                         plot = plot_las,
                         which = "las",
                         path = path,
                         device = "png",
                         width = 5,
                         height = 8,
                         dpi = "retina")
  } else {
    return(plot_las)
  }
}

#' @name edd_plot_grouped_las
#' @title Generating line plot of speciation rates of each lineage by group
#' @description Function to generate a plot showing the transition of speciation
#' rates of each lineage by specified group
#' @param raw_data a list of results generated by edd simulation function
#' @param group specify the group to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_grouped_las
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_grouped_las <- function(raw_data = NULL, group = "metric", save_plot = FALSE, path = NULL) {
  tally <- tally_by_group(raw_data$params, group)
  indexes <- create_indexes_by_group(tally)
  grouped_las <- lapply(indexes, function(x) {
    plots <- lapply(x, function(y) edd_plot_las(raw_data$data[[y]], save_plot = FALSE, annotation = FALSE))
    grouped_plot <- patchwork::wrap_plots(plots, nrow = 1) +
      patchwork::plot_annotation(title = pars_to_title(raw_data$data[[x[1]]]$all_pars),
                                 theme = ggplot2::theme(plot.title = element_text(size = 25)))
    if (save_plot == TRUE) {
      pars_list <- extract_parameters(raw_data$data[[x[1]]])
      save_with_parameters(pars_list = pars_list,
                           plot = grouped_plot,
                           which = "grouped_las",
                           path = path,
                           device = "png",
                           width = 4 * tally$groups,
                           height = 10,
                           dpi = "retina")
    }
  })
}


#' @name edd_plot_mus
#' @title Generating line plot of extinction rates of each lineage
#' @description Function to generate a plot showing the transition of extinction
#' rates of each lineage
#' @param raw_data a list of results generated by edd simulation function
#' @param rep_id specify the id of replication to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @param deviation Logical, whether plot the deviation of extinction rates
#' @param annotation Logical, whether plot the annotation of parameters
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_mus
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_mus <- function(raw_data = NULL,
                         rep_id = 1,
                         save_plot = FALSE,
                         path = NULL,
                         deviation = FALSE,
                         annotation = TRUE) {
  pars_list <- extract_parameters(raw_data)

  mus_table <- cbind(Time = raw_data$ltt[[rep_id]]$time, raw_data$mus[[rep_id]])
  mus_long <-
    mus_table %>%
      tidyr::gather("Tip", "Mu", -Time) %>%
      na.omit()

  plot_mus1 <- ggplot2::ggplot(mus_long) +
    ggplot2::geom_path(ggplot2::aes(Time, Mu, group = Tip, color = Mu)) +
    viridis::scale_color_viridis(option = "D") +
    ggplot2::ylab("Extinction rate") +
    ggplot2::xlab("Age") +
    ggplot2::xlim(0, 6) +
    ggplot2::theme_classic() +
    ggplot2::theme(legend.position = "none",
                   aspect.ratio = 1 / 1)

  if (annotation == TRUE) {
    anno <- create_annotation(pars_list, y = c(-Inf, -Inf), vjust = c(-1, 0))
    plot_mus1 <- plot_mus1 +
      ggtext::geom_richtext(
        data = anno,
        ggplot2::aes(
          x = x,
          y = y,
          label = label,
          angle = angle,
          hjust = hjust,
          vjust = vjust
        ),
        fill = "#E8CB9C"
      )
  }

  if (deviation == TRUE) {
    mus_long <- mus_long %>%
      dplyr::group_by(Time) %>%
      dplyr::mutate(Deviation = Mu - mean(Mu))
    plot_mus2 <- ggplot2::ggplot(mus_long) +
      ggplot2::geom_path(ggplot2::aes(Time, Deviation, group = Tip, color = Mu)) +
      viridis::scale_color_viridis(option = "D") +
      ggplot2::geom_hline(yintercept = 0,
                          linetype = "twodash",
                          color = "grey") +
      ggplot2::xlim(0, 6) +
      ggplot2::theme_classic() +
      ggplot2::theme(legend.position = "right",
                     aspect.ratio = 1 / 1)
    plot_mus <- patchwork::wrap_plots(plot_mus1, plot_mus2, ncol = 1)
  } else {
    plot_mus <- plot_mus1
  }

  plot_tree <- edd_plot_histree(raw_data, rep_id = rep_id, which = "mus", drop_extinct = FALSE, save_plot = FALSE) +
    ggplot2::ggtitle(toupper(pars_list$metric))

  plot_mus <- patchwork::wrap_plots(plot_tree, plot_mus, ncol = 1)

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = pars_list,
                         plot = plot_mus,
                         which = "mus",
                         path = path,
                         device = "png",
                         width = 5,
                         height = 8,
                         dpi = "retina")
  } else {
    return(plot_mus)
  }
}


#' @name edd_plot_grouped_mus
#' @title Generating line plot of extinction rates of each lineage by group
#' @description Function to generate a plot showing the transition of extinction
#' rates of each lineage by specified group
#' @param raw_data a list of results generated by edd simulation function
#' @param group specify the group to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_grouped_mus
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_grouped_mus <- function(raw_data = NULL, group = "metric", save_plot = FALSE, path = NULL) {
  tally <- tally_by_group(raw_data$params, group)
  indexes <- create_indexes_by_group(tally)
  grouped_mus <- lapply(indexes, function(x) {
    plots <- lapply(x, function(y) edd_plot_mus(raw_data$data[[y]], save_plot = FALSE, annotation = FALSE))
    grouped_plot <- patchwork::wrap_plots(plots, nrow = 1) +
      patchwork::plot_annotation(title = pars_to_title(raw_data$data[[x[1]]]$all_pars),
                                 theme = ggplot2::theme(plot.title = element_text(size = 25)))
    if (save_plot == TRUE) {
      pars_list <- extract_parameters(raw_data$data[[x[1]]])
      save_with_parameters(pars_list = pars_list,
                           plot = grouped_plot,
                           which = "grouped_mus",
                           path = path,
                           device = "png",
                           width = 4 * tally$groups,
                           height = 10,
                           dpi = "retina")
    }
  })
}


#' @name edd_plot_eds
#' @title Generating line plot of evolutionary distinctiveness of each lineage
#' @description Function to generate a plot showing the transition of evolutionary
#' distinctiveness of each lineage
#' @param raw_data a list of results generated by edd simulation function
#' @param rep_id specify the id of replication to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @param deviation Logical, whether to plot the deviation of evolutionary distinctiveness
#' @param annotation Logical, whether to plot the annotation of parameters
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_eds
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_eds <- function(raw_data = NULL,
                         rep_id = 1,
                         save_plot = FALSE,
                         path = NULL,
                         deviation = FALSE,
                         annotation = TRUE) {
  pars_list <- extract_parameters(raw_data)

  eds_table <- cbind(Time = raw_data$ltt[[rep_id]]$time, raw_data$eds[[rep_id]])
  eds_long <-
    eds_table %>%
      tidyr::gather("Tip", "ED", -Time) %>%
      na.omit()

  plot_eds1 <- ggplot2::ggplot(eds_long) +
    ggplot2::geom_path(ggplot2::aes(Time, ED, group = Tip, color = ED)) +
    viridis::scale_color_viridis(option = "D") +
    ggplot2::ylab("Evolutionary distinctiveness") +
    ggplot2::xlab("Age") +
    ggplot2::xlim(0, 6) +
    ggplot2::theme_classic() +
    ggplot2::theme(legend.position = "none",
                   aspect.ratio = 1 / 1)

  if (annotation == TRUE) {
    anno <- create_annotation(pars_list, y = c(Inf, Inf), vjust = c(1, 2.1))
    plot_eds1 <- plot_eds1 +
      ggtext::geom_richtext(
        data = anno,
        ggplot2::aes(
          x = x,
          y = y,
          label = label,
          angle = angle,
          hjust = hjust,
          vjust = vjust
        ),
        fill = "#E8CB9C"
      )
  }

  if (deviation == TRUE) {
    eds_long <- eds_long %>%
      dplyr::group_by(Time) %>%
      dplyr::mutate(Deviation = ED - mean(ED))
    plot_eds2 <- ggplot2::ggplot(eds_long) +
      ggplot2::geom_path(ggplot2::aes(Time, Deviation, group = Tip, color = ED)) +
      viridis::scale_color_viridis(option = "D") +
      ggplot2::geom_hline(yintercept = 0,
                          linetype = "twodash",
                          color = "grey") +
      ggplot2::xlim(0, 6) +
      ggplot2::theme_classic() +
      ggplot2::theme(legend.position = "right",
                     aspect.ratio = 1 / 1)
    plot_eds <- patchwork::wrap_plots(plot_eds1, plot_eds2, ncol = 1)
  } else {
    plot_eds <- plot_eds1
  }

  plot_tree <- edd_plot_histree(raw_data, rep_id = rep_id, which = "eds", drop_extinct = FALSE, save_plot = FALSE) +
    ggplot2::ggtitle(toupper(pars_list$metric))

  plot_eds <- patchwork::wrap_plots(plot_tree, plot_eds, ncol = 1)

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = pars_list,
                         plot = plot_eds,
                         which = "eds",
                         path = path,
                         device = "png",
                         width = 10,
                         height = 8,
                         dpi = "retina")
  } else {
    return(plot_eds)
  }
}


#' @name edd_plot_grouped_eds
#' @title Generating line plot of evolutionary distinctiveness of each lineage by group
#' @description Function to generate a plot showing the transition of evolutionary distinctiveness
#' of each lineage by specified group
#' @param raw_data a list of results generated by edd simulation function
#' @param group specify the group to plot
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_grouped_eds
#' @importFrom magrittr %>%
#' @import patchwork
edd_plot_grouped_eds <- function(raw_data = NULL, group = "metric", save_plot = FALSE, path = NULL) {
  tally <- tally_by_group(raw_data$params, group)
  indexes <- create_indexes_by_group(tally)
  grouped_eds <- lapply(indexes, function(x) {
    plots <- lapply(x, function(y) edd_plot_eds(raw_data$data[[y]], save_plot = FALSE, annotation = FALSE))
    grouped_plot <- patchwork::wrap_plots(plots, nrow = 1) +
      patchwork::plot_annotation(title = pars_to_title(raw_data$data[[x[1]]]$all_pars),
                                 theme = ggplot2::theme(plot.title = element_text(size = 25)))
    if (save_plot == TRUE) {
      pars_list <- extract_parameters(raw_data$data[[x[1]]])
      save_with_parameters(pars_list = pars_list,
                           plot = grouped_plot,
                           which = "grouped_eds",
                           path = path,
                           device = "png",
                           width = 4 * tally$groups,
                           height = 10,
                           dpi = "retina")
    }
  })
}


#' @name edd_plot_stats
#' @title Generating boxplots of tree statistics
#' @description Function to generate boxplots showing the tree statistics e.g. Colless, J-One, B1, Gamma etc.
#' @param raw_data  a list of results generated by edd simulation function
#' @param method Specify which package to be used to calculate calculate statistics
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_stats
edd_plot_stats <- function(raw_data = NULL, method = "treestats", save_plot = FALSE, path = NULL) {
  stats <- edd_stat_cached(raw_data$data, method = method)
  stats_names <- colnames(stats)
  model <- raw_data$params$model[1]
  if (model == "dsce2") {
    stats_names <- stats_names[-(1:8)]
  } else if (model == "dsde2") {
    stats_names <- stats_names[-(1:10)]
  } else {
    stop("Model not supported")
  }

  stats <- tidyr::pivot_longer(stats, cols = -(lambda:offset), names_to = "stats", values_to = "value")
  stats <- transform_data(stats)

  rates <- levels(stats$lambda)

  for (i in stats_names) {
    plot_object <- lapply(rates, edd_plot_stats_graph,
                                 stats = stats,
                                 params = raw_data$params,
                                 offset = "Simulation time", name = i,
                                 save_plot = save_plot,
                                 path = path)
  }


  if (save_plot != TRUE) {
    return(plot_object)
  }
}


edd_plot_stats_graph <- function(rates, stats, params, offset = NULL, name = NULL, save_plot = FALSE, path = NULL) {
  if (!inherits(name, "character")) {
    stop("Stat name must be of type character")
  }

  lambda_num <- rates[1]

  lambda <- as.character(rates[1])
  offset_char <- as.character(offset)

  plot_data_pd <- dplyr::filter(stats,
                                lambda == lambda_num &
                                  metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stats,
                                lambda == lambda_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stats,
                                 lambda == lambda_num &
                                   metric == "nnd")

  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  plot_data_stats <- dplyr::filter(plot_data, stats == name)

  sts <- boxplot.stats(plot_data_stats$value)$stats

  stats_plot <- ggplot2::ggplot(plot_data_stats) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
    ggplot2::facet_grid(mu ~ beta_n,
                        labeller = labeller(beta_n = as_labeller(~paste0("beta[italic(N)]:", .x), label_parsed),
                                            mu = as_labeller(~paste0("mu[0]:", .x), label_parsed))) +
    ggplot2::xlab(bquote(beta[italic(Phi)])) +
    ggplot2::ylab(name) +
    ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
    ggplot2::coord_cartesian(ylim = c(min(sts) * 1.1, max(sts) * 1.1)) +
    ggplot2::theme(strip.background = ggplot2::element_blank(),
                   panel.background = ggplot2::element_blank(),
                   panel.grid = ggplot2::element_blank())

  final_plot <- stats_plot +
    ggplot2::ggtitle(bquote(.(name) ~ lambda[0] ~ "=" ~ .(lambda))) +
    ggplot2::labs(fill = "Metric")

  if (save_plot == TRUE) {
    save_with_rates_offset(rates = rates,
                           offset = offset,
                           plot = final_plot,
                           which = name,
                           path = path,
                           device = "png",
                           width = 10, height = 8,
                           dpi = "retina")
  } else {
    return(final_plot)
  }
}



#' @name edd_plot_balance
#' @title Generating boxplots of tree balance indices
#' @description Function to generate boxplots showing the tree balance indices (Sackin, Colless and Blum indices)
#' @param raw_data  a list of results generated by edd simulation function
#' @param method Specify which package to be used to calculate calculate statistics
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path The path to save the plots
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_balance
edd_plot_balance <- function(raw_data = NULL, method = "treestats", save_plot = FALSE, path = NULL) {
  stat_balance <- edd_stat(raw_data$data, method = method)
  stat_balance <- tidyr::pivot_longer(stat_balance, cols = -(lambda:offset), names_to = "balance", values_to = "value")
  stat_balance <- transform_data(stat_balance)

  rates <- levels(stat_balance$lambda)

  plot_pd_ed_simtime <- lapply(rates, edd_plot_balance_pd_ed, stat_balance = stat_balance, params = raw_data$params, offset = "Simulation time", save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(plot_pd_ed_simtime)
  }
}


edd_plot_stats_single <- function(raw_data = NULL, rates = NULL, name = "J_One", method = "treestats", save_plot = FALSE, path = NULL) {
  if (!inherits(name, "character")) {
    stop("Stat name must be of type character")
  }

  stats <- edd_stat_cached(raw_data$data, method = method)

  check_stats_names(raw_data, stats, name)

  stats <- tidyr::pivot_longer(stats, cols = -(lambda:offset), names_to = "stats", values_to = "value")
  stats <- transform_data(stats)

  plot_between_metric <- edd_plot_stats_single_between_metric(rates,
                                                      name,
                                                      stats = stats,
                                                      offset = "Simulation time",
                                                      save_plot = save_plot,
                                                      path = path)

  if (save_plot != TRUE) {
    return(plot_between_metric)
  }
}


edd_plot_stats_single_between_metric <- function(rates, name, stats, offset = NULL, save_plot = FALSE, path = NULL) {
  lambda_num <- rates[1]
  mu_num <- rates[2]
  beta_n_num <- rates[3]

  lambda <- as.character(rates[1])
  mu <- as.character(rates[2])
  offset_char <- as.character(offset)

  plot_data_pd <- dplyr::filter(stats,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stats,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stats,
                                 lambda == lambda_num &
                                   mu == mu_num &
                                   metric == "nnd")

  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  plot_data_stats <- plot_data %>% dplyr::filter(stats == name) %>%
    dplyr::filter(beta_n == beta_n_num)

  sts <- plot_data_stats %>% dplyr::group_by(metric) %>%
    dplyr::reframe(bp = boxplot.stats(value)$stats) %>% dplyr::select(-metric)

  if (name == "J_One") {
    coord_lim <- c(min(sts) , 1)
  } else {
    coord_lim <- c(min(sts), max(sts))
  }
  stats_plot <- ggplot2::ggplot(plot_data_stats) +
    ggplot2::facet_wrap(. ~ beta_phi, ncol = 1) +
    ggplot2::geom_boxplot(ggplot2::aes(metric, value, fill = metric), outlier.shape = NA) +
    ggplot2::scale_y_continuous(position = "right") +
    ggplot2::coord_cartesian(ylim = coord_lim) +
    ggplot2::ylab(NULL) +
    ggplot2::ggtitle(name) +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank(),
                   panel.background = ggplot2::element_blank(),
                   panel.grid = ggplot2::element_blank())

  plot_final <- stats_plot + scale_fill_discrete(name = "Scenario", labels = c("PD", "ED", "NND"))

  if (save_plot == TRUE) {
    save_with_rates_offset(rates = rates[1:2],
                           offset = offset,
                           plot = plot_final,
                           which = "stats_single",
                           path = path,
                           device = "png",
                           width = 2, height = 20,
                           dpi = "retina")
  } else {
    return(plot_final)
  }
}


edd_plot_balance_pd_ed <- function(rates, stat_balance, params, offset = NULL, save_plot = FALSE, path = NULL) {
  lambda_num <- rates[1]

  lambda <- as.character(rates[1])
  offset_char <- as.character(offset)

  plot_data_pd <- dplyr::filter(stat_balance,
                                lambda == lambda_num &
                                  metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stat_balance,
                                lambda == lambda_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stat_balance,
                                 lambda == lambda_num &
                                   metric == "nnd")

  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  plot_data_colless <- dplyr::filter(plot_data, balance == "Gamma")

  sts <- boxplot.stats(plot_data_colless$value)$stats

  colless_plot <- ggplot2::ggplot(plot_data_colless) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
    ggplot2::facet_grid(mu ~ beta_n,
                        labeller = labeller(beta_n = as_labeller(~paste0("beta[italic(N)]:", .x), label_parsed),
                                            mu = as_labeller(~paste0("mu[0]:", .x), label_parsed))) +
    ggplot2::xlab(bquote(beta[italic(Phi)])) +
    ggplot2::ylab("Colless (Yule)") +
    ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
    ggplot2::coord_cartesian(ylim = c(min(sts) * 1.1, max(sts) * 1.1)) +
    ggplot2::theme(strip.background = ggplot2::element_blank(),
                   panel.background = ggplot2::element_blank(),
                   panel.grid = ggplot2::element_blank())

  pd_ed_plot <- colless_plot +
    ggplot2::ggtitle(bquote("Colless balance index" ~ lambda[0] ~ "=" ~ .(lambda))) +
    ggplot2::labs(fill = "Metric")

  if (save_plot == TRUE) {
    save_with_rates_offset(rates = rates,
                           offset = offset,
                           plot = pd_ed_plot,
                           which = "balance_pd_ed",
                           path = path,
                           device = "png",
                           width = 10, height = 8,
                           dpi = "retina")
  } else {
    return(pd_ed_plot)
  }
}


edd_plot_balance_significance <- function(params, stat_balance, save_plot = FALSE, path = NULL) {
  plot_data <- stat_balance %>%
    dplyr::filter(!(metric == "pd" & offset != "Simulation time")) %>%
    dplyr::filter(lambda == params$lambda &
                    mu == params$mu &
                    beta_n == params$beta_n &
                    beta_phi == params$beta_phi)
  plot_balance <- ggstatsplot::grouped_ggbetweenstats(x = metric,
                                                      y = value,
                                                      data = plot_data,
                                                      grouping.var = balance,
                                                      pairwise.comparisons = FALSE,
                                                      results.subtitle = FALSE,
                                                      subtitle = NULL,
                                                      bf.message = FALSE,
                                                      caption = pars_to_title2(params),
                                                      xlab = "Metric",
                                                      ylab = "Value")

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = params,
                         plot = plot_balance,
                         which = "balance_signif",
                         path = path,
                         device = "png",
                         width = 12,
                         height = 5,
                         dpi = "retina")
  } else {
    return(plot_balance)
  }
}


#' @name edd_plot_branch
#' @title Generating boxplots of various tree branching measures
#' @description Function to generate boxplots showing the meaning branch lengths of the trees.
#' @param raw_data  a list of results generated by edd simulation function
#' @param method Specify which package to be used to calculate statistics
#' @param save_plot Logical, whether save to file or return a ggplot object
#' @param path Path to save the plot
#' @return an ggplot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_branch
edd_plot_branch <- function(raw_data = NULL, method = "treestats", save_plot = FALSE, path = NULL) {
  stat_branch <- edd_stat(raw_data$data, method = method)
  stat_branch <- transform_data(stat_branch)

  lambdas <- levels(stat_branch$lambda)
  mus <- levels(stat_branch$mu)
  rates <- expand.grid(lambdas, mus)

  plot_significance <- lapply(split(raw_data$params, seq(nrow(raw_data$params))), edd_plot_branch_significance, stat_branch = stat_branch, save_plot = save_plot, path = path)
  plot_pd_offsets <- apply(rates, 1, edd_plot_branch_pd_offsets, stat_branch = stat_branch, params = raw_data$params, save_plot = save_plot, path = path)
  plot_pd_ed_none <- apply(rates, 1, edd_plot_branch_pd_ed, stat_branch = stat_branch, params = raw_data$params, offset = "None", save_plot = save_plot, path = path)
  plot_pd_ed_simtime <- apply(rates, 1, edd_plot_branch_pd_ed, stat_branch = stat_branch, params = raw_data$params, offset = "Simulation time", save_plot = save_plot, path = path)
  plot_pd_ed_spcount <- apply(rates, 1, edd_plot_branch_pd_ed, stat_branch = stat_branch, params = raw_data$params, offset = "Species count", save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(list(significance = plot_significance,
                pd_pffsets = plot_pd_offsets,
                pd_none_ed = plot_pd_ed_none,
                pd_simetime_ed = plot_pd_ed_simtime,
                pd_spcount_ed = plot_pd_ed_spcount))
  }
}


edd_plot_branch_pd_offsets <- function(rates, stat_branch, params, save_plot = FALSE, path = NULL) {
  lambda_num <- rates[1]
  mu_num <- rates[2]

  lambda <- as.character(rates[1])
  mu <- as.character(rates[2])

  plot_data <- dplyr::filter(stat_branch,
                             lambda == lambda_num &
                               mu == mu_num &
                               metric == "pd")

  mbl_plot <- ggplot2::ggplot(data = plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, MBL, fill = beta_n)) +
    ggplot2::facet_wrap(. ~ offset, nrow = 1) +
    ggplot2::scale_y_continuous() +
    ggplot2::ylab("Mean branch length") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank())

  pd_plot <- ggplot2::ggplot(data = plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, PD, fill = beta_n)) +
    ggplot2::facet_wrap(. ~ offset, nrow = 1) +
    ggplot2::scale_y_continuous(trans = "log2") +
    ggplot2::ylab("Phylogenetic diversity") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank())

  mntd_plot <- ggplot2::ggplot(data = plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, MNTD, fill = beta_n)) +
    ggplot2::facet_wrap(. ~ offset, nrow = 1) +
    ggplot2::scale_y_continuous() +
    ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
    ggplot2::ylab("Mean nearest taxon index") +
    ggplot2::xlab(expression(beta[italic(Phi)])) +
    ggplot2::theme(strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank())

  # Tree Balance indices by Beta_Phi, . ~ Offset, grouped by Beta_N
  pd_offsets_plot <- mbl_plot +
    ggplot2::ggtitle(bquote("Branching metrices of PD, comparisions between offset methods, " ~ lambda ~ "=" ~ .(lambda) ~ mu ~ "=" ~ .(mu))) +
    pd_plot +
    mntd_plot +
    patchwork::plot_layout(ncol = 1, guides = "collect") &
    ggplot2::labs(fill = expression(beta[italic(N)]))

  if (save_plot == TRUE) {
    save_with_rates(rates = rates,
                    plot = pd_offsets_plot,
                    which = "branch_pd_offsets",
                    path = path,
                    device = "png",
                    width = 10, height = 8,
                    dpi = "retina")
  } else {
    return(pd_offsets_plot)
  }
}


edd_plot_branch_pd_ed <- function(rates, stat_branch, params, offset = NULL, save_plot = FALSE, path = NULL) {
  lambda_num <- rates[1]
  mu_num <- rates[2]

  lambda <- as.character(rates[1])
  mu <- as.character(rates[2])
  offset_char <- as.character(offset)

  plot_data_pd <- dplyr::filter(stat_branch,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stat_branch,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stat_branch,
                                 lambda == lambda_num &
                                   mu == mu_num &
                                   metric == "nnd")

  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  mbl_plot <- ggplot2::ggplot(plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, MBL, fill = metric)) +
    ggplot2::facet_wrap(. ~ beta_n, labeller = as_labeller(
      ~paste0("beta[italic(N)]:", .x), label_parsed
    )) +
    ggplot2::scale_y_continuous() +
    ggplot2::ylab("Mean branch length") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank())

  pd_plot <- ggplot2::ggplot(plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, PD, fill = metric)) +
    ggplot2::facet_wrap(. ~ beta_n) +
    ggplot2::scale_y_continuous(trans = "log2") +
    ggplot2::ylab("Phylogenetic diversity") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank(),
                   strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank())

  mntd_plot <- ggplot2::ggplot(plot_data) +
    ggplot2::geom_boxplot(ggplot2::aes(beta_phi, MNTD, fill = metric)) +
    ggplot2::facet_wrap(. ~ beta_n) +
    #ggplot2::scale_y_continuous(trans = "sqrt") +
    ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
    ggplot2::ylab("Mean nearest taxon index") +
    ggplot2::xlab(expression(beta[italic(Phi)])) +
    ggplot2::theme(strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank())

  pd_ed_plot <- mbl_plot +
    ggplot2::ggtitle(bquote("Branching metrices, comparisons between PD (" ~ .(offset) ~ "), ED and NND, " ~ lambda ~ "=" ~ .(lambda) ~ mu ~ "=" ~ .(mu))) +
    pd_plot +
    mntd_plot +
    patchwork::plot_layout(nrow = 3, guides = "collect") &
    ggplot2::labs(fill = "Metric")

  if (save_plot == TRUE) {
    save_with_rates_offset(rates = rates,
                           offset = offset,
                           plot = pd_ed_plot,
                           which = "branch_pd_ed",
                           path = path,
                           device = "png",
                           width = 10, height = 8,
                           dpi = "retina")
  } else {
    return(pd_ed_plot)
  }
}


edd_plot_branch_significance <- function(params, stat_branch, save_plot = FALSE, path = NULL) {
  plot_data <- stat_branch %>%
    dplyr::filter(!(metric == "pd" & offset != "Simulation time")) %>%
    dplyr::filter(lambda == params$lambda &
                    mu == params$mu &
                    beta_n == params$beta_n &
                    beta_phi == params$beta_phi) %>%
    tidyr::gather(key = "measure", value = "value", MBL, PD, MNTD)
  plot_branch <- ggstatsplot::grouped_ggbetweenstats(x = metric,
                                                     y = value,
                                                     data = plot_data,
                                                     grouping.var = measure,
                                                     pairwise.comparisons = FALSE,
                                                     results.subtitle = FALSE,
                                                     subtitle = NULL,
                                                     bf.message = FALSE,
                                                     caption = pars_to_title2(params),
                                                     xlab = "Metric",
                                                     ylab = "Value")

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = params,
                         plot = plot_branch,
                         which = "branch_signif",
                         path = path,
                         device = "png",
                         width = 12,
                         height = 5,
                         dpi = "retina")
  } else {
    return(plot_branch)
  }
}


#' @name edd_plot_temporal_dynamics
#' @title Generating plots of temporal dynamics of a single simulation
#' @description Function to generate plots of temporal dynamics of a single simulation, metrics including tree balance
#' indexes and tree branch related measures
#' @param raw_data A list of results generated by edd simulation function
#' @param rep_id specify the id of replication to plot
#' @param save_plot true or false, to decide whether to save the plots to files
#' @param path The path to save the plots
#' @return an plot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_temporal_dynamics
edd_plot_temporal_dynamics <- function(raw_data = NULL,
                                       rep_id = 1,
                                       save_plot = FALSE,
                                       path = NULL) {
  pars_list <- extract_parameters(raw_data)
  temporal_data <- reconstruct_temporal_dynamics(raw_data$l_tables[[rep_id]], pars_list$age)
  temporal_data <- temporal_data %>%
    tidyr::gather("Metric", "Value", -Age, -Event)
  temporal_data$Metric <- factor(temporal_data$Metric, levels = c("Blum", "Colless", "Sackin", "PD", "MNTD", "MBL"))
  temporal_data$Event <- factor(temporal_data$Event, levels = c("speciation", "extinction", "present"))
  plot_temporal <- ggplot2::ggplot(temporal_data) +
    ggplot2::geom_line(ggplot2::aes(Age, Value)) +
    ggplot2::geom_point(ggplot2::aes(Age, Value, color = Event, shape = Event)) +
    ggplot2::facet_wrap(. ~ Metric, scales = "free_y") +
    ggplot2::scale_color_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::scale_shape_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::theme(legend.title = element_blank(),
                   aspect.ratio = 1 / 1) +
    ggplot2::ggtitle(pars_to_title2(pars_list))

  if (save_plot == TRUE) {
    save_with_parameters(pars_list = pars_list,
                         plot = plot_temporal,
                         which = "temporal",
                         path = path,
                         device = "png",
                         width = 10,
                         height = 6,
                         dpi = "retina")
  } else {
    return(plot_temporal)
  }
}


#' @name edd_plot_grouped_temporal_dynamics
#' @title Generating plots of grouped temporal dynamics of a single simulation
#' @description Function to generate plots of groupedtemporal dynamics of a single simulation, metrics including tree
#' balance indexes and tree branch related measures
#' @param raw_data A list of results generated by edd simulation function
#' @param rep_id specify the id of replication to plot
#' @param save_plot true or false, to decide whether to save the plots to files
#' @param path The path to save the plots
#' @return an plot object
#' @author Tianjian Qin
#' @keywords phylogenetics
#' @export edd_plot_grouped_temporal_dynamics
edd_plot_grouped_temporal_dynamics <- function(raw_data = NULL,
                                               rep_id = 1,
                                               save_plot = FALSE,
                                               path = NULL) {
  stat_temporal <- edd_summarize_temporal_dynamics(raw_data$data, rep_id = rep_id)
  stat_temporal <- stat_temporal %>%
    tidyr::gather("Metric", "Value", Blum, Colless, Sackin, PD, MNTD, MBL) %>%
    transform_data()
  stat_temporal$Metric <- factor(stat_temporal$Metric, levels = c("Blum", "Colless", "Sackin", "PD", "MNTD", "MBL"))
  stat_temporal$Event <- factor(stat_temporal$Event, levels = c("speciation", "extinction", "present"))
  rates <- expand.grid(unique(raw_data$params$lambda),
                       unique(raw_data$params$mu),
                       unique(raw_data$params$beta_n),
                       unique(raw_data$params$beta_phi))

  plot_pd_ed_simtime <- apply(rates, 1, edd_plot_temporal_pd_ed, stat_temporal = stat_temporal, offset = "Simulation time", save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(plot_pd_ed_simtime)
  }
}


edd_plot_temporal_pd_ed <- function(rates, stat_temporal, offset, save_plot = FALSE, path = NULL) {
  lambda_num <- rates[1]
  mu_num <- rates[2]
  beta_n_num <- rates[3]
  beta_phi_num <- rates[4]

  plot_data_pd <- dplyr::filter(stat_temporal,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  beta_n == beta_n_num &
                                  beta_phi == beta_phi_num &
                                  metric == "pd" &
                                  offset == "Simulation time")

  plot_data_ed <- dplyr::filter(stat_temporal,
                                lambda == lambda_num &
                                  mu == mu_num &
                                  beta_n == beta_n_num &
                                  beta_phi == beta_phi_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stat_temporal,
                                 lambda == lambda_num &
                                   mu == mu_num &
                                   beta_n == beta_n_num &
                                   beta_phi == beta_phi_num &
                                   metric == "nnd")


  plot_pd <- ggplot2::ggplot(plot_data_pd) +
    ggplot2::geom_line(ggplot2::aes(Age, Value)) +
    ggplot2::geom_point(ggplot2::aes(Age, Value, color = Event, shape = Event)) +
    ggplot2::facet_wrap(. ~ Metric, scales = "free_y", nrow = 1) +
    ggplot2::scale_color_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::scale_shape_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::ylab("PD") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank(),
                   legend.title = element_blank(),
                   aspect.ratio = 1 / 1)

  plot_ed <- ggplot2::ggplot(plot_data_ed) +
    ggplot2::geom_line(ggplot2::aes(Age, Value)) +
    ggplot2::geom_point(ggplot2::aes(Age, Value, color = Event, shape = Event)) +
    ggplot2::facet_wrap(. ~ Metric, scales = "free_y", nrow = 1) +
    ggplot2::scale_color_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::scale_shape_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::ylab("ED") +
    ggplot2::theme(axis.title.x = ggplot2::element_blank(),
                   axis.text.x = ggplot2::element_blank(),
                   axis.ticks.x = ggplot2::element_blank(),
                   axis.line.x = ggplot2::element_blank(),
                   strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank(),
                   legend.title = element_blank(),
                   aspect.ratio = 1 / 1)

  plot_nnd <- ggplot2::ggplot(plot_data_nnd) +
    ggplot2::geom_line(ggplot2::aes(Age, Value)) +
    ggplot2::geom_point(ggplot2::aes(Age, Value, color = Event, shape = Event)) +
    ggplot2::facet_wrap(. ~ Metric, scales = "free_y", nrow = 1) +
    ggplot2::scale_color_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::scale_shape_discrete(labels = c("Speciation", "Extinction", "Present")) +
    ggplot2::ylab("NND") +
    ggplot2::theme(strip.background = ggplot2::element_blank(),
                   strip.text.x = ggplot2::element_blank(),
                   legend.title = element_blank(),
                   aspect.ratio = 1 / 1)

  # plot_data_all <- dplyr::filter(stat_temporal,
  #                               lambda == lambda_num &
  #                                 mu == mu_num &
  #                                 beta_n == beta_n_num &
  #                                 beta_phi == beta_phi_num)
  # plot_data_all <- dplyr::filter(plot_data_all,
  #                               !(metric == "pd" & offset == "None"))
  # plot_all <- ggplot2::ggplot(plot_data_all) +
  #   ggplot2::geom_line(ggplot2::aes(Age, Value)) +
  #   ggplot2::geom_point(ggplot2::aes(Age, Value, color = Event, shape = Event)) +
  #   ggplot2::facet_wrap(. ~ metric + Metric, scales = "free_y", nrow = 3) +
  #   ggplot2::scale_color_discrete(labels = c("Speciation","Extinction","Present")) +
  #   ggplot2::scale_shape_discrete(labels = c("Speciation","Extinction","Present")) +
  #   ggplot2::ylab("NND") +
  #   ggplot2::theme(strip.background = ggplot2::element_blank(),
  #                  legend.title = element_blank(),
  #                  aspect.ratio = 1 / 1)

  plot_grouped_temporal <- plot_pd +
    ggplot2::ggtitle(bquote("Temporal dynamics of tree topology, comparisons between PD (" ~ .(offset) ~ "), ED and NND, " ~ lambda[0] ~ "=" ~ .(lambda_num) ~ mu[0] ~ "=" ~ .(mu_num) ~ beta[italic(N)] ~ "=" ~ .(beta_n_num) ~ beta[italic(Phi)] ~ "=" ~ .(beta_phi_num))) +
    plot_ed +
    plot_nnd +
    patchwork::plot_layout(nrow = 3, guides = "collect") &
    ggplot2::labs(fill = "Metric")

  if (save_plot == TRUE) {
    save_with_rates2(rates = rates,
                     plot = plot_grouped_temporal,
                     which = "grouped_temporal",
                     path = path,
                     device = "png",
                     width = 20,
                     height = 10,
                     dpi = "retina")
  } else {
    return(plot_grouped_temporala)
  }
}


edd_plot_best_histrees <- function(raw_data = NULL,
                                   rep_ids = stop("Best replication IDs not provided"),
                                   which = stop("Please specify which history to be plotted"),
                                   drop_extinct = FALSE,
                                   save_plot = FALSE,
                                   path = NULL) {
  if (length(raw_data$data) != length(rep_ids$row_id)) {
    stop("The lengths of the data and the replication IDs do not match")
  }

  for (i in seq_along(raw_data$data)) {
    if (save_plot == TRUE) {
      plot_temp <- edd_plot_histree(raw_data$data[[i]],
                                    rep_id = rep_ids$row_id[i],
                                    which = which,
                                    drop_extinct = drop_extinct,
                                    save_plot = FALSE)
      save_with_parameters(pars_list = extract_parameters(raw_data$data[[i]]),
                           plot = plot_temp,
                           which = "best_histrees",
                           path = path,
                           device = "png",
                           width = 5,
                           height = 8,
                           dpi = "retina")
    } else {
      edd_plot_histree(raw_data$data[[i]],
                       rep_id = rep_ids$row_id[i],
                       which = which,
                       drop_extinct = drop_extinct,
                       save_plot = save_plot)
    }
  }
}


edd_plot_grouped_histrees <- function(raw_data = NULL,
                                      sample_rep = NULL,
                                      which = "las",
                                      name = "J_One",
                                      drop_extinct = FALSE,
                                      save_plot = FALSE,
                                      path = NULL) {
  duplicates <- duplicated(name)
  if (any(duplicates)) {
    stop("Duplicated statistic names: ", paste(name[duplicates], collapse = ", "))
  }

  rates <- expand.grid(unique(raw_data$params$lambda),
                       unique(raw_data$params$mu),
                       unique(raw_data$params$beta_n))

  plots <- apply(rates, 1, edd_plot_grouped_histrees_core,
                 raw_data = raw_data,
                 sample_rep = sample_rep,
                 name = name,
                 which = which,
                 drop_extinct = drop_extinct,
                 save_plot = save_plot, path = path)

  if (save_plot == FALSE) {
    return(plots)
  }
}


edd_plot_grouped_histrees_core <- function(rates, raw_data, sample_rep, name, which, drop_extinct, save_plot, path) {
  lambda_num <- rates[1]
  mu_num <- rates[2]
  beta_n_num <- rates[3]

  sample_rep <- sample_rep %>%
    filter(lambda == lambda_num, mu == mu_num, beta_n == beta_n_num) %>%
    filter(!(metric == "pd" & offset == "none"))

  tally <- tally_by_group(sample_rep, "metric")
  indexes <- create_indexes_by_group(tally, unlist = TRUE)
  plots <- list()
  j <- 1
  for (i in indexes) {
    plots[[j]] <- edd_plot_histree(raw_data$data[[sample_rep$pars_id[i]]],
                                   rep_id = sample_rep$rep_id[i],
                                   which = which,
                                   drop_extinct = drop_extinct,
                                   save_plot = FALSE) +
      ggplot2::theme(legend.position = "none")
    if (j <= tally$groups) {
      plots[[j]] <- plots[[j]] +
        ggplot2::ggtitle(bquote(.(toupper(sample_rep$metric[i]))))
    }
    if ((j + tally$groups - 1) %% tally$groups == 0) {
      plots[[j]] <- plots[[j]] +
        ggplot2::ylab(bquote(italic(β)[italic(Φ)] ~ "=" ~ .(sample_rep$beta_phi[i])))
    }
    plots[[j]] <- plots[[j]] + ggplot2::theme(axis.text.x = ggplot2::element_blank(),
                                              axis.line.x = ggplot2::element_blank(),
                                              axis.ticks.x = ggplot2::element_blank(),
                                              axis.line.y = ggplot2::element_blank()) +
      edd_plot_ltt(raw_data$data[[sample_rep$pars_id[i]]],
                   save_plot = FALSE,
                   annotation = FALSE,
                   ribbon = TRUE) +
      ggplot2::ylab(NULL) +
      ggplot2::xlab(NULL) +
      ggplot2::ggtitle(NULL) +
      ggplot2::theme(aspect.ratio = NULL,
                     panel.background = ggplot2::element_blank(),
                     panel.grid = ggplot2::element_blank()) +
      patchwork::plot_layout(nrow = 2)


    j <- j + 1
  }

  plot_grouped_histrees <- patchwork::wrap_plots(plots, nrow = tally$rows)

  plots <- list()
  plots[[1]] <- plot_grouped_histrees

  j <- 2
  for (i in name) {
    plots[[j]] <- edd_plot_stats_single(raw_data, rates = rates, name = i, save_plot = FALSE)
    j <- j + 1
  }

  title_str <- "Phylogenetic patterns ("

  if (mu_num == 0) {
    title_str <- paste0(title_str, "pure birth, ")
  }

  if (beta_n_num == 0) {
    title_str <- paste0(title_str, "no species richness effect)")
  } else {
    title_str <- paste0(title_str, "with species richness effect)")
  }

  plot_final <- patchwork::wrap_plots(plots, nrow = 1, widths = c(8, rep(1.1, length(plots) - 1)), guides = "collect") +
    patchwork::plot_annotation(title = title_str,
                               subtitle = bquote(italic(λ)[0] ~ "=" ~ .(sample_rep$lambda[1]) ~ italic(μ)[0] ~ "=" ~ .(sample_rep$mu[1]) ~ italic(β)[italic(N)] ~ "=" ~ .(sample_rep$beta_n[1])))

  # best_histrees <- lapply(indexes, function(x) {
  #   plots <- lapply(x, function(y) {
  #     edd_plot_histree(raw_data$data[[sample_rep$pars_id[y]]],
  #                      rep_id = sample_rep$rep_id[y],
  #                      which = which,
  #                      drop_extinct = FALSE,
  #                      save_plot = FALSE)
  #   })
  #   grouped_plot <- patchwork::wrap_plots(plots, nrow = 1) + patchwork::plot_layout(guides = "collect")
  #   return(grouped_plot)
  # })
  #
  # patchwork::wrap_plots(best_histrees, nrow = tally$rows)

  if (save_plot == TRUE) {
    save_with_rates_and_index_name(rates = rates,
                         plot = plot_final,
                         which = paste0("grouped_histrees_", paste(name, collapse = "_")),
                         path = path,
                         device = "png",
                         width = 9 + length(name) * 1.1,
                         height = 15,
                         dpi = "retina")
  } else {
    return(plot_final)
  }
}


edd_plot_stats_grouped_single <- function(raw_data = NULL, method = "treestats", by = "lambda", name = NULL, save_plot = FALSE, path = NULL) {
  if (length(name) > 1) {
    stop("Only one statistic can be specified")
  }

  if (!is.character(name)) {
    stop("Statistic name must be a character")
  }

  if (!(by %in% c("lambda", "mu"))) {
    stop("by must be one of 'lambda' or 'mu'")
  }

  stats <- edd_stat_cached(raw_data$data, method = method)

  check_stats_names(raw_data, stats, name)

  stats <- tidyr::pivot_longer(stats, cols = -(lambda:offset), names_to = "stats", values_to = "value")
  stats <- transform_data(stats)

  if (by == "lambda") {
    rates <- levels(stats$lambda)
  } else {
    rates <- levels(stats$mu)
  }

  plot_pd_ed_simtime <- lapply(rates, edd_plot_stats_grouped_single_core, stats = stats, by = by, name = name, params = raw_data$params, offset = "Simulation time", save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(plot_pd_ed_simtime)
  }
}

edd_plot_stats_grouped_single_core <- function(rates, stats, by, name, params, offset = NULL, save_plot = FALSE, path = NULL) {
  rate_num <- rates[1]
  rate_char <- as.character(rate_num)

  if (by == "lambda") {
    greek <- "λ"
    other <- "mu"
    other_greek <- "μ"
  } else {
    greek <- "μ"
    other <- "lambda"
    other_greek <- "λ"
  }

  offset_char <- as.character(offset)

  plot_data_pd <- dplyr::filter(stats,
                                !!as.symbol(by) == rate_num &
                                  metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stats,
                                !!as.symbol(by) == rate_num &
                                  metric == "ed")

  plot_data_nnd <- dplyr::filter(stats,
                                 !!as.symbol(by) == rate_num &
                                   metric == "nnd")

  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  plot_data_stats <- dplyr::filter(plot_data, stats == name)

  sts <- plot_data_stats %>% dplyr::group_by(metric) %>%
    dplyr::reframe(bp = boxplot.stats(value)$stats) %>% dplyr::select(-metric)

  if (name == "J_One") {
    coord_lim <- c(min(sts) , 1)
  } else {
    coord_lim <- c(min(sts), max(sts))
  }

  if (by == "lambda") {
    stats_plot <- ggplot2::ggplot(plot_data_stats) +
      ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
      ggplot2::facet_grid(reformulate(other, "beta_n"),
                          labeller = labeller(beta_n = as_labeller(~paste0("italic(β)[italic(N)]:", .x), label_parsed),
                                              mu = as_labeller(~paste0("italic(",other_greek,")","[0]:", .x), label_parsed))) +
      ggplot2::xlab(bquote(β[italic(Φ)])) +
      ggplot2::ylab(name) +
      ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
      ggplot2::coord_cartesian(ylim = coord_lim) +
      ggplot2::theme(strip.background = ggplot2::element_blank(),
                     panel.background = ggplot2::element_blank(),
                     panel.grid = ggplot2::element_blank())
  } else {
    stats_plot <- ggplot2::ggplot(plot_data_stats) +
      ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
      ggplot2::facet_grid(reformulate(other, "beta_n"),
                          labeller = labeller(beta_n = as_labeller(~paste0("italic(β)[italic(N)]:", .x), label_parsed),
                                              lambda = as_labeller(~paste0("italic(",other_greek,")","[0]:", .x), label_parsed))) +
      ggplot2::xlab(bquote(β[italic(Φ)])) +
      ggplot2::ylab(name) +
      ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
      ggplot2::coord_cartesian(ylim = coord_lim) +
      ggplot2::theme(strip.background = ggplot2::element_blank(),
                     panel.background = ggplot2::element_blank(),
                     panel.grid = ggplot2::element_blank())

  }

  final_plot <- stats_plot +
    ggplot2::ggtitle(bquote(.(name) ~ ~ italic(.(greek))[0] ~ "=" ~ .(rate_char))) +
    ggplot2::labs(fill = "Metric")

  if (save_plot == TRUE) {
    save_with_rates_offset(rates = rates,
                           offset = offset,
                           plot = final_plot,
                           which = paste("stat_single", by, name, sep = "_"),
                           path = path,
                           device = "png",
                           width = 10, height = 8,
                           dpi = "retina")
  } else {
    return(final_plot)
  }
}


edd_animate_grouped_single_core <- function(pairs, stats = NULL, params = NULL, save_plot = FALSE, path = NULL) {
  by <- pairs[1]
  name <- pairs[2]
  offset_char <- "Simulation time"

  message(paste0("Now plotting ", by, " against ", name))

  name_title <- index_name_to_title(name)

  plot_data_pd <- dplyr::filter(stats,
                                metric == "pd" &
                                  offset == offset_char)

  plot_data_ed <- dplyr::filter(stats,
                                metric == "ed")

  plot_data_nnd <- dplyr::filter(stats,
                                 metric == "nnd")
  plot_data <- rbind(plot_data_pd, plot_data_ed, plot_data_nnd)

  plot_data_stats <- dplyr::filter(plot_data, stats == name)

  sts <- plot_data_stats %>% dplyr::group_by(metric) %>%
    dplyr::reframe(bp = boxplot.stats(value)$stats) %>% dplyr::select(-metric)

  if (name == "J_One") {
    coord_lim <- c(min(sts) , 1)
  } else {
    coord_lim <- c(min(sts), max(sts))
  }

  if (by == "beta_phi") {
    stats_plot <- ggplot2::ggplot(plot_data_stats) +
      ggplot2::geom_boxplot(ggplot2::aes(beta_n, value, fill = metric), outlier.shape = NA) +
      ggplot2::facet_grid(reformulate("mu", "lambda"),
                          labeller = labeller(lambda = as_labeller(~paste0("italic(λ)[0]:", .x), label_parsed),
                                              mu = as_labeller(~paste0("italic(μ)[0]:", .x), label_parsed))) +
      ggplot2::xlab(bquote(italic(β)[italic(N)])) +
      ggplot2::ylab(NULL) +
      ggplot2::scale_x_discrete(labels = format(unique(params$beta_n), scientific = FALSE)) +
      ggplot2::coord_cartesian(ylim = coord_lim) +
      ggplot2::theme(strip.background = ggplot2::element_blank(),
                     panel.background = ggplot2::element_blank(),
                     panel.grid = ggplot2::element_blank(),
                     plot.title = element_markdown()) +
      gganimate::transition_states(states = beta_phi, transition_length = 1, state_length = 1) +
      labs(title = paste0(name_title, " *β*<sub>*Φ*</sub> = ", "{closest_state}"))
  } else if (by == "beta_n") {
    stats_plot <- ggplot2::ggplot(plot_data_stats) +
      ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
      ggplot2::facet_grid(reformulate("mu", "lambda"),
                          labeller = labeller(lambda = as_labeller(~paste0("italic(λ)[0]:", .x), label_parsed),
                                              mu = as_labeller(~paste0("italic(μ)[0]:", .x), label_parsed))) +
      ggplot2::xlab(bquote(italic(β)[italic(Φ)])) +
      ggplot2::ylab(NULL) +
      ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
      ggplot2::coord_cartesian(ylim = coord_lim) +
      ggplot2::theme(strip.background = ggplot2::element_blank(),
                     panel.background = ggplot2::element_blank(),
                     panel.grid = ggplot2::element_blank(),
                     plot.title = element_markdown()) +
      gganimate::transition_states(states = beta_n, transition_length = 1, state_length = 1) +
      labs(title = paste0(name_title, " *β*<sub>*N*</sub> = ", "{closest_state}"))
  } else {
    if (by == "lambda") {
      greek <- "λ"
      other <- "mu"
      other_greek <- "μ"
    } else {
      greek <- "μ"
      other <- "lambda"
      other_greek <- "λ"
    }

    if (by == "lambda") {
      stats_plot <- ggplot2::ggplot(plot_data_stats) +
        ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
        ggplot2::facet_grid(reformulate(other, "beta_n"),
                            labeller = labeller(beta_n = as_labeller(~paste0("italic(β)[italic(N)]:", .x), label_parsed),
                                                mu = as_labeller(~paste0("italic(",other_greek,")","[0]:", .x), label_parsed))) +
        ggplot2::xlab(bquote(italic(β)[italic(Φ)])) +
        ggplot2::ylab(NULL) +
        ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
        ggplot2::coord_cartesian(ylim = coord_lim) +
        ggplot2::theme(strip.background = ggplot2::element_blank(),
                       panel.background = ggplot2::element_blank(),
                       panel.grid = ggplot2::element_blank(),
                       plot.title = element_markdown()) +
        gganimate::transition_states(states = !!as.symbol(by), transition_length = 1, state_length = 1) +
        labs(title = paste0(name_title, " *", greek, "*<sub>0</sub> = ", "{closest_state}"))
    } else if (by == "mu") {
      stats_plot <- ggplot2::ggplot(plot_data_stats) +
        ggplot2::geom_boxplot(ggplot2::aes(beta_phi, value, fill = metric), outlier.shape = NA) +
        ggplot2::facet_grid(reformulate(other, "beta_n"),
                            labeller = labeller(beta_n = as_labeller(~paste0("italic(β)[italic(N)]:", .x), label_parsed),
                                                lambda = as_labeller(~paste0("italic(",other_greek,")","[0]:", .x), label_parsed))) +
        ggplot2::xlab(bquote(italic(β)[italic(Φ)])) +
        ggplot2::ylab(NULL) +
        ggplot2::scale_x_discrete(labels = format(unique(params$beta_phi), scientific = FALSE)) +
        ggplot2::coord_cartesian(ylim = coord_lim) +
        ggplot2::theme(strip.background = ggplot2::element_blank(),
                       panel.background = ggplot2::element_blank(),
                       panel.grid = ggplot2::element_blank(),
                       plot.title = element_markdown()) +
        gganimate::transition_states(states = !!as.symbol(by), transition_length = 1, state_length = 1) +
        labs(title = paste0(name_title, " *", greek, "*<sub>0</sub> = ", "{closest_state}"))
    }
  }

  stats_plot <- stats_plot + scale_fill_discrete(name = "Scenario", labels = c("PD", "ED", "NND"))

  if (save_plot == TRUE) {
    save_path <- file.path(path, "plot", "animate")
    check_path(save_path)
    gganimate::anim_save(filename = paste0(by, "_", name, ".gif"),
                         animation = stats_plot,
                         path = save_path,
                         width = 10, height = 8, units = "in", res = 300)
  } else {
    return(stats_plot)
  }
}


edd_animate_grouped_single <- function(raw_data = NULL, name = NULL, method = "treestats", save_plot = FALSE, path = NULL) {
  stats <- edd_stat_cached(raw_data$data, method = method)

  if (is.null(name)) {
    name <- c("J_One", "Gamma", "PD", "MBL", "MNTD")
  }

  check_stats_names(raw_data, stats, name)

  stats <- tidyr::pivot_longer(stats, cols = -(lambda:offset), names_to = "stats", values_to = "value")
  stats <- transform_data(stats)

  by <- c("lambda", "mu", "beta_n", "beta_phi")

  pairs <- expand.grid(by = by, name = name)

  plots <- apply(pairs, 1, edd_animate_grouped_single_core,
                  stats = stats, params = raw_data$params, save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(plots)
  }
}


edd_animate_phylogenetic_evenness <- function(raw_data = NULL, save_plot = FALSE, path = NULL, strategy = "sequential",
                                      workers = 1, verbose = TRUE) {
  stats <- edd_phylogenetic_evenness(raw_data$data, strategy = strategy,
                             workers = workers, verbose = verbose)

  stats <- tidyr::pivot_longer(stats, cols = -(lambda:offset), names_to = "stats", values_to = "value")
  stats <- transform_data(stats)

  name <- "ERE"

  by <- c("lambda", "mu", "beta_n", "beta_phi")

  pairs <- expand.grid(by = by, name = name)

  plots <- apply(pairs, 1, edd_animate_grouped_single_core,
                 stats = stats, params = raw_data$params, save_plot = save_plot, path = path)

  if (save_plot != TRUE) {
    return(plots)
  }
}
EvoLandEco/eve documentation built on Sept. 14, 2024, 12:04 a.m.