R/plots.R

Defines functions plot_timetrend plot.strata_doctopic plot_topicprop plot_pi plot_modelfit plot_alpha print.keyATM_fig save_fig.keyATM_fig values_fig.keyATM_fig values_fig save_fig

Documented in plot_alpha plot_modelfit plot_pi plot.strata_doctopic plot_timetrend plot_topicprop save_fig values_fig

#' Save a figure
#'
#' @param x the keyATM_fig object.
#' @param filename file name to create on disk.
#' @param ... other arguments passed on to the [ggplot2::ggsave()][ggplot2::ggsave] function.
#' @seealso [visualize_keywords()], [plot_alpha()], [plot_modelfit()], [plot_pi()], [plot_timetrend()], [plot_topicprop()], [by_strata_DocTopic()], [values_fig()]
#' @export
save_fig <- function(x, filename, ...) {
  UseMethod("save_fig")
}


#' Get values used to create a figure
#'
#' @param x the keyATM_fig object.
#' @seealso [save_fig()], [visualize_keywords()], [plot_alpha()], [plot_modelfit()], [plot_pi()], [plot_timetrend()], [plot_topicprop()], [by_strata_DocTopic()]
#' @export
values_fig <- function(x) {
  UseMethod("values_fig")
}

#' @noRd
#' @export
values_fig.keyATM_fig <- function(x) {
  return(x$values)
}


#' @noRd
#' @export
save_fig.keyATM_fig <- function(x, filename, ...) {
  ggplot2::ggsave(filename = filename, plot = x$figure, ...)
}


#' @noRd
#' @export
print.keyATM_fig <- function(x, ...) {
  print(x$figure)
}


#' Show a diagnosis plot of alpha
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param start integer. The start of slice iteration. Default is \code{0}.
#' @param show_topic a vector to specify topic indexes to show. Default is \code{NULL}.
#' @param scales character. Control the scale of y-axis (the parameter in [ggplot2::facet_wrap()][ggplot2::facet_wrap]): \code{free} adjusts y-axis for parameters. Default is \code{fixed}.
#' @return keyATM_fig object
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_alpha <- function(x, start = 0, show_topic = NULL, scales = "fixed") {
  check_arg_type(x, "keyATM_output")
  modelname <- extract_full_model_name(x)

  if (modelname %in% c("lda", "ldacov", "ldahmm")) {
    cli::cli_abort(paste0("This is not a model with keywords.")) # only plot keywords later
  }
  if (!"alpha_iter" %in% names(x$values_iter)) {
    cli::cli_abort(
      "`alpha` is not stored. Please check the options.\nNote that the covariate model does not have `alpha`.\nPlease check our paper for details."
    )
  }
  if (is.null(show_topic)) {
    show_topic <- 1:x$keyword_k
  } else {
    if (!all(show_topic %in% 1:x$keyword_k)) {
      cli::cli_abort(
        "Topics specified in `show_topic` are not the keyword topics."
      )
    }
  }
  if (!is.numeric(start) | length(start) != 1) {
    cli::cli_abort("`start` argument is invalid.")
  }

  tnames <- as.character(show_topic)
  names(tnames) <- c(names(x$keywords_raw))[show_topic]
  temp <- x$values_iter$alpha_iter %>%
    dplyr::filter(.data$Iteration >= start) %>%
    dplyr::filter(.data$Topic %in% (!!show_topic)) %>%
    tidyr::pivot_wider(names_from = "Topic", values_from = "alpha")

  if (modelname %in% c("base", "lda")) {
    res_alpha <- temp %>%
      dplyr::rename(tidyselect::all_of(tnames)) %>%
      tidyr::pivot_longer(-"Iteration", names_to = "Topic", values_to = "alpha")

    p <- ggplot(
      res_alpha,
      aes(x = .data$Iteration, y = .data$alpha, group = .data$Topic)
    ) +
      geom_line() +
      geom_point(size = 0.3) +
      facet_wrap(~ .data$Topic, ncol = 2, scales = scales) +
      xlab("Iteration") +
      ylab("Value") +
      ggtitle("Estimated alpha") +
      theme_bw() +
      theme(plot.title = element_text(hjust = 0.5))
  } else if (modelname %in% c("hmm", "ldahmm")) {
    res_alpha <- temp %>%
      dplyr::rename(tidyselect::all_of(tnames)) %>%
      tidyr::pivot_longer(
        -c("Iteration", "State"),
        names_to = "Topic",
        values_to = "alpha"
      )
    res_alpha$State <- factor(res_alpha$State, levels = 1:max(res_alpha$State))

    p <- ggplot(
      res_alpha,
      aes(
        x = .data$Iteration,
        y = .data$alpha,
        group = .data$State,
        colour = .data$State
      )
    ) +
      geom_line() +
      geom_point(size = 0.3) +
      facet_wrap(~ .data$Topic, ncol = 2, scales = scales) +
      xlab("Iteration") +
      ylab("Value") +
      ggtitle("Estimated alpha") +
      theme_bw() +
      theme(plot.title = element_text(hjust = 0.5))
  }

  p <- list(figure = p, values = res_alpha)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}


#' Show a diagnosis plot of log-likelihood and perplexity
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param start integer. The starting value of iteration to use in plot. Default is \code{1}.
#' @return keyATM_fig object.
#' @import ggplot2
#' @importFrom stats as.formula
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_modelfit <- function(x, start = 1) {
  check_arg_type(x, "keyATM_output")
  modelfit <- x$model_fit

  if (!is.numeric(start) | length(start) != 1) {
    cli::cli_abort("`start` argument is invalid.")
  }

  if (!is.null(start)) {
    modelfit <- modelfit[modelfit$Iteration >= start, ]
  }

  modelfit <- tidyr::gather(
    modelfit,
    key = "Measures",
    value = "value",
    -"Iteration"
  )

  p <- ggplot(
    data = modelfit,
    aes(
      x = .data$Iteration,
      y = .data$value,
      group = .data$Measures,
      color = .data$Measures
    )
  ) +
    geom_line(show.legend = FALSE) +
    geom_point(size = 0.3, show.legend = FALSE) +
    facet_wrap(~ .data$Measures, ncol = 2, scales = "free") +
    xlab("Iteration") +
    ylab("Value")
  p <- p +
    ggtitle("Model Fit") +
    theme_bw() +
    theme(plot.title = element_text(hjust = 0.5))

  p <- list(figure = p, values = modelfit)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}


#' Show a diagnosis plot of pi
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param start integer. The starting value of iteration to use in the plot. Default is \code{0}.
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%). This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}). This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}. This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_pi <- function(
  x,
  show_topic = NULL,
  start = 0,
  ci = 0.9,
  method = c("hdi", "eti"),
  point = c("mean", "median")
) {
  method <- rlang::arg_match(method)
  point <- rlang::arg_match(point)
  check_arg_type(x, "keyATM_output")
  modelname <- extract_full_model_name(x)

  if (modelname %in% c("lda", "ldacov", "ldahmm")) {
    cli::cli_abort(paste0("This is not a model with keywords."))
  }

  if (is.null(show_topic)) {
    show_topic <- 1:x$keyword_k
  } else if (sum(!show_topic %in% 1:x$keyword_k) != 0) {
    cli::cli_abort("`plot_pi` only visualize keyword topics.")
  }

  if (!is.numeric(start) | length(start) != 1) {
    cli::cli_abort("`start` argument is invalid.")
  }

  tnames <- c(names(x$keywords_raw))[show_topic]

  if (!is.null(x$values_iter$pi_iter)) {
    pi_mat <- t(sapply(x$values_iter$pi_iter, unlist, use.names = FALSE))[,
      show_topic,
      drop = FALSE
    ]
    pi_mat %>%
      tibble::as_tibble(.name_repair = ~tnames) %>%
      dplyr::mutate(Iteration = x$values_iter$used_iter) %>%
      dplyr::filter(.data$Iteration >= start) %>%
      dplyr::select(-tidyselect::all_of("Iteration")) -> pi_mat

    if (nrow(pi_mat) == 0) {
      cli::cli_abort("Nothing left to plot. Please check arguments.")
    }

    pi_mat %>%
      tidyr::pivot_longer(cols = dplyr::everything(), names_to = "Topic") %>%
      dplyr::group_by(.data$Topic) %>%
      dplyr::summarise(
        x = list(tibble::enframe(
          calc_ci(.data$value, ci, method, point),
          "q",
          "value"
        )),
        .groups = "drop_last"
      ) %>%
      tidyr::unnest(x) %>%
      tidyr::pivot_wider(
        names_from = tidyselect::all_of("q"),
        values_from = tidyselect::all_of("value")
      ) -> temp

    p <- ggplot(temp, aes(y = .data$Point, x = .data$Topic)) +
      theme_bw() +
      geom_point() +
      geom_errorbar(
        aes(ymin = .data$Lower, ymax = .data$Upper),
        data = temp,
        width = 0.01,
        linewidth = 1
      ) +
      xlab("Topic") +
      ylab("Probability") +
      ggtitle(
        "Probability of words drawn from keyword topic-word distribution"
      ) +
      theme(plot.title = element_text(hjust = 0.5))
  } else {
    cli::cli_alert_info(
      "Plotting pi from the final MCMC draw. Please set `store_pi` to `TRUE` if you want to plot pi over iterations."
    )
    x$pi %>%
      dplyr::mutate(Probability = .data$Proportion / 100) %>%
      dplyr::filter(.data$Topic %in% (!!show_topic)) %>%
      dplyr::mutate(Topic = tnames) -> temp

    p <- ggplot(temp, aes(x = .data$Topic, y = .data$Probability)) +
      geom_bar(stat = "identity") +
      theme_bw() +
      xlab("Topic") +
      ylab("Probability") +
      ggtitle(
        "Probability of words drawn from keyword topic-word distribution"
      ) +
      theme(plot.title = element_text(hjust = 0.5))
  }
  p <- list(figure = p, values = temp)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}


#' Show the expected proportion of the corpus belonging to each topic
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param n The number of top words to show. Default is \code{3}.
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param show_topwords logical. Show topwords. The default is \code{TRUE}.
#' @param order The order of topics.
#' @param label_topic a character vector. The name of the topics in the plot.
#' @param xmax a numeric. Indicate the max value on the x axis
#' @return keyATM_fig object
#' @import magrittr
#' @import ggplot2
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_topicprop <- function(
  x,
  n = 3,
  show_topic = NULL,
  show_topwords = TRUE,
  label_topic = NULL,
  order = c("proportion", "topicid"),
  xmax = NULL
) {
  check_arg_type(x, "keyATM_output")
  order <- rlang::arg_match(order)

  total_k <- x$keyword_k + x$no_keyword_topics
  if (is.null(show_topic)) {
    show_topic <- 1:total_k
  } else {
    if (max(show_topic) > total_k | min(show_topic) < 1) {
      cli::cli_abort("Invalid topic ID in `show_topic`.")
    }
  }

  topwords <- top_words(x, n = n)[, show_topic]

  if (!is.null(label_topic)) {
    if (length(label_topic) != ncol(topwords)) {
      cli::cli_abort("The length of `label_topic` is incorrect.")
    }
    colnames(topwords) <- label_topic
  }

  topwords %>%
    dplyr::summarise(dplyr::across(
      dplyr::everything(),
      ~ stringr::str_c(.x, collapse = ", ")
    )) %>%
    tidyr::pivot_longer(
      dplyr::everything(),
      values_to = "Topwords",
      names_to = "Topic"
    ) -> topwords_commas

  theta_use <- x$theta[, show_topic]
  if (!is.null(label_topic)) {
    colnames(theta_use) <- label_topic
  }

  theta_use %>%
    tibble::as_tibble() %>%
    dplyr::summarise(dplyr::across(dplyr::everything(), ~ mean(.x))) %>%
    tidyr::pivot_longer(
      dplyr::everything(),
      values_to = "Topicprop",
      names_to = "Topic"
    ) %>%
    dplyr::left_join(topwords_commas, by = "Topic") -> theta_use_tbl

  if (order == "proportion") {
    theta_use_tbl %>%
      dplyr::mutate(
        Topic = stringr::str_remove(.data$Topic, "^\\d+_")
      ) -> theta_use_tbl
    theta_use_tbl %>%
      dplyr::arrange(dplyr::desc(.data$Topicprop)) %>%
      dplyr::pull(.data$Topic) -> use_order
  } else if (order == "topicid") {
    theta_use_tbl %>%
      dplyr::pull(.data$Topic) -> use_order
  }

  theta_use_tbl %>%
    dplyr::mutate(
      Topic = factor(.data$Topic, levels = rev(use_order)),
      xpos = max(.data$Topicprop) + 0.01
    ) %>%
    dplyr::arrange(dplyr::desc(.data$Topic)) -> plot_obj

  if (is.null(xmax)) {
    if (show_topwords) {
      xmax <- min(max(plot_obj$Topicprop) * 2.5, 1)
    } else {
      xmax <- max(plot_obj$Topicprop) + 0.02
    }
  }

  label_percent <- function(x) {
    paste0(round(x * 100, 2), "%")
  }

  p <- ggplot(plot_obj, aes(x = .data$Topicprop, y = .data$Topic)) +
    geom_col() +
    {
      if (show_topwords) {
        geom_text(
          aes(x = .data$xpos, y = .data$Topic, label = .data$Topwords),
          hjust = 0,
          size = max(10 / n + 1, 2.5)
        )
      }
    } +
    scale_x_continuous(
      "Expected topic proportions",
      limits = c(0, xmax),
      labels = label_percent
    ) +
    theme_bw() +
    theme(
      panel.grid.major.x = element_blank(),
      panel.grid.minor.x = element_blank(),
      panel.grid.major.y = element_blank(),
      panel.grid.minor.y = element_blank()
    )

  p <- list(figure = p, values = plot_obj)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}


#' Plot document-topic distribution by strata (for covariate models)
#'
#' @param x a strata_doctopic object (see [by_strata_DocTopic()]).
#' @param show_topic a vector or an integer. Indicate topics to visualize.
#' @param var_name the name of the variable in the plot.
#' @param by `topic` or `covariate`. Default is by `topic`.
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}.
#' @param width numeric. Width of the error bars.
#' @param show_point logical. Show point estimates. The default is \code{TRUE}.
#' @param ... additional arguments not used.
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()], [by_strata_DocTopic()]
#' @export
plot.strata_doctopic <- function(
  x,
  show_topic = NULL,
  var_name = NULL,
  by = c("topic", "covariate"),
  ci = 0.9,
  method = c("hdi", "eti"),
  point = c("mean", "median"),
  width = 0.1,
  show_point = TRUE,
  ...
) {
  by <- rlang::arg_match(by)
  method <- rlang::arg_match(method)
  point <- rlang::arg_match(point)

  tables <- summary.strata_doctopic(x, ci, method, point)
  by_var <- x$by_var
  by_values <- x$by_values
  if (!is.null(var_name)) {
    by_var <- var_name
  }

  if (is.null(show_topic)) {
    show_topic <- 1:nrow(tables[[1]])
  }

  tables <- dplyr::bind_rows(tables)
  tnames <- unique(tables$Topic)
  num_keytopic <- sum(!grepl("Other_[0-9]+", tnames))
  topic_parse <- function(s) {
    if (grepl("Other_[0-9]+", s)) {
      return(as.numeric(strsplit(s, "_")[[1]][2]) + num_keytopic)
    } else {
      return(as.numeric(strsplit(s, "_")[[1]][1]))
    }
  }
  tables$TopicID <- purrr::map_dbl(tables$Topic, topic_parse)
  tables$Topic <- factor(
    tables$Topic,
    levels = tnames[order(purrr::map_dbl(tnames, topic_parse))]
  )
  tables <- tables %>% dplyr::filter(.data$TopicID %in% show_topic)

  variables <- unique(tables$label)

  if (point == "mean") {
    ylabel <- expression(paste("Mean of ", theta))
  } else {
    ylabel <- expression(paste("Median of ", theta))
  }

  p <- ggplot(tables) +
    coord_flip() +
    scale_x_discrete(limits = rev(variables)) +
    xlab(paste0(by_var)) +
    ylab(ylabel) +
    guides(color = guide_legend(title = "Topic")) +
    theme_bw()

  if (by == "topic") {
    p <- p +
      geom_errorbar(
        width = width,
        aes(
          x = .data$label,
          ymin = .data$Lower,
          ymax = .data$Upper,
          group = .data$Topic
        ),
        position = position_dodge(width = -1 / 2)
      ) +
      facet_wrap(~Topic)
    if (show_point) {
      p <- p + geom_point(aes(x = .data$label, y = .data$Point))
    }
  } else {
    p <- p +
      geom_errorbar(
        width = width,
        aes(
          x = .data$label,
          ymin = .data$Lower,
          ymax = .data$Upper,
          group = .data$Topic,
          colour = .data$Topic
        ),
        position = position_dodge(width = -1 / 2)
      )
    if (show_point) {
      p <- p +
        geom_point(
          aes(x = .data$label, y = .data$Point, colour = .data$Topic),
          position = position_dodge(width = -1 / 2)
        )
    }
  }

  p <- list(figure = p, values = tables)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}


#' Plot time trend
#'
#' @param x the output from the dynamic keyATM model (see [keyATM()]).
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param time_index_label a vector. The label for time index. The length should be equal to the number of documents (time index provided to [keyATM()]).
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%). This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}). This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}. This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param xlab a character.
#' @param scales character. Control the scale of y-axis (the parameter in [ggplot2::facet_wrap()][ggplot2::facet_wrap]): \code{free} adjusts y-axis for parameters. Default is \code{fixed}.
#' @param show_point logical. The default is \code{TRUE}. This is an option when calculating credible intervals.
#' @param ... additional arguments not used.
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_timetrend <- function(
  x,
  show_topic = NULL,
  time_index_label = NULL,
  ci = 0.9,
  method = c("hdi", "eti"),
  point = c("mean", "median"),
  xlab = "Time",
  scales = "fixed",
  show_point = TRUE,
  ...
) {
  method <- rlang::arg_match(method)
  point <- rlang::arg_match(point)
  check_arg_type(x, "keyATM_output")
  modelname <- extract_full_model_name(x)
  if (!modelname %in% c("hmm", "ldahmm")) {
    cli::cli_abort(paste0("This is not a model with time trends."))
  }

  if (!is.null(time_index_label)) {
    if (length(x$values_iter$time_index) != length(time_index_label)) {
      cli::cli_abort(
        "The length of `time_index_label` does not match with the number of documents."
      )
    }
    time_index <- time_index_label
  } else {
    time_index <- x$values_iter$time_index
  }
  time_index_tbl <- tibble::tibble(
    time_index = time_index,
    time_index_raw = x$values_iter$time_index
  ) %>%
    dplyr::distinct()

  if (is.null(show_topic)) {
    show_topic <- 1:x$keyword_k
  }

  format_theta <- function(theta, time_index, tnames) {
    theta[, show_topic, drop = FALSE] %>%
      tibble::as_tibble(.name_repair = ~tnames) %>%
      dplyr::mutate(time_index = time_index) %>%
      tidyr::pivot_longer(
        -tidyselect::all_of("time_index"),
        names_to = "Topic",
        values_to = "Proportion"
      ) %>%
      dplyr::group_by(.data$time_index, .data$Topic) %>%
      dplyr::summarize(
        Proportion = base::mean(.data$Proportion),
        .groups = "drop_last"
      ) -> res
    return(res)
  }

  tnames <- colnames(x$theta)

  if (is.null(x$values_iter$theta_iter)) {
    dat <- format_theta(x$theta, time_index, tnames[show_topic])
    p <- ggplot(
      dat,
      aes(x = .data$time_index, y = .data$Proportion, group = .data$Topic)
    ) +
      geom_line(linewidth = 0.8, color = "blue") +
      geom_point(size = 0.9)
  } else {
    dat <- dplyr::bind_rows(lapply(
      x$values_iter$theta_iter,
      format_theta,
      time_index,
      tnames[show_topic]
    )) %>%
      dplyr::group_by(.data$time_index, .data$Topic) %>%
      dplyr::summarise(
        x = list(tibble::enframe(
          calc_ci(.data$Proportion, ci, method, point),
          "q",
          "value"
        ))
      ) %>%
      tidyr::unnest(tidyselect::all_of("x")) %>%
      dplyr::ungroup() %>%
      tidyr::pivot_wider(
        names_from = tidyselect::all_of("q"),
        values_from = tidyselect::all_of("value")
      ) %>%
      stats::setNames(c("time_index", "Topic", "Lower", "Point", "Upper"))
    p <- ggplot(
      dat,
      aes(x = .data$time_index, y = .data$Point, group = .data$Topic)
    ) +
      geom_ribbon(
        aes(ymin = .data$Lower, ymax = .data$Upper),
        fill = "gray75"
      ) +
      geom_line(linewidth = 0.8, color = "blue")

    if (show_point) {
      p <- p + geom_point(size = 0.9)
    }
  }
  p <- p +
    xlab(xlab) +
    ylab(expression(paste("Mean of ", theta))) +
    facet_wrap(~ .data$Topic, scales = scales) +
    theme_bw() +
    theme(panel.grid.minor = element_blank())
  dat <- dplyr::left_join(dat, time_index_tbl, by = "time_index") %>%
    dplyr::mutate(state_id = x$values_iter$R_iter_last[.data$time_index_raw])
  p <- list(figure = p, values = dat)
  class(p) <- c("keyATM_fig", class(p))
  return(p)
}

Try the keyATM package in your browser

Any scripts or data that you put into this service are public.

keyATM documentation built on Aug. 8, 2025, 6:14 p.m.