Nothing
#' Save a figure
#'
#' @param x the keyATM_fig object.
#' @param filename file name to create on disk.
#' @param ... other arguments passed on to the [ggplot2::ggsave()][ggplot2::ggsave] function.
#' @seealso [visualize_keywords()], [plot_alpha()], [plot_modelfit()], [plot_pi()], [plot_timetrend()], [plot_topicprop()], [by_strata_DocTopic()], [values_fig()]
#' @export
save_fig <- function(x, filename, ...) {
UseMethod("save_fig")
}
#' Get values used to create a figure
#'
#' @param x the keyATM_fig object.
#' @seealso [save_fig()], [visualize_keywords()], [plot_alpha()], [plot_modelfit()], [plot_pi()], [plot_timetrend()], [plot_topicprop()], [by_strata_DocTopic()]
#' @export
values_fig <- function(x) {
UseMethod("values_fig")
}
#' @noRd
#' @export
values_fig.keyATM_fig <- function(x)
{
return(x$values)
}
#' @noRd
#' @export
save_fig.keyATM_fig <- function(x, filename, ...)
{
ggplot2::ggsave(filename = filename, plot = x$figure, ...)
}
#' @noRd
#' @export
print.keyATM_fig <- function(x, ...)
{
print(x$figure)
}
#' Show a diagnosis plot of alpha
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param start integer. The start of slice iteration. Default is \code{0}.
#' @param show_topic a vector to specify topic indexes to show. Default is \code{NULL}.
#' @param scales character. Control the scale of y-axis (the parameter in [ggplot2::facet_wrap()][ggplot2::facet_wrap]): \code{free} adjusts y-axis for parameters. Default is \code{fixed}.
#' @return keyATM_fig object
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_alpha <- function(x, start = 0, show_topic = NULL, scales = "fixed")
{
check_arg_type(x, "keyATM_output")
modelname <- extract_full_model_name(x)
if (modelname %in% c("lda", "ldacov", "ldahmm")) {
cli::cli_abort(paste0("This is not a model with keywords.")) # only plot keywords later
}
if (!"alpha_iter" %in% names(x$values_iter)) {
cli::cli_abort("`alpha` is not stored. Please check the options.\nNote that the covariate model does not have `alpha`.\nPlease check our paper for details.")
}
if (is.null(show_topic)) {
show_topic <- 1:x$keyword_k
} else {
if (!all(show_topic %in% 1:x$keyword_k))
cli::cli_abort("Topics specified in `show_topic` are not the keyword topics.")
}
if (!is.numeric(start) | length(start) != 1) {
cli::cli_abort("`start` argument is invalid.")
}
tnames <- as.character(show_topic)
names(tnames) <- c(names(x$keywords_raw))[show_topic]
temp <- x$values_iter$alpha_iter %>%
dplyr::filter(.data$Iteration >= start) %>%
dplyr::filter(.data$Topic %in% (!!show_topic)) %>%
tidyr::pivot_wider(names_from = "Topic", values_from = "alpha")
if (modelname %in% c("base", "lda")) {
res_alpha <- temp %>% dplyr::rename(tidyselect::all_of(tnames)) %>%
tidyr::pivot_longer(-"Iteration", names_to = "Topic", values_to = "alpha")
p <- ggplot(res_alpha, aes(x = .data$Iteration, y = .data$alpha, group = .data$Topic)) +
geom_line() +
geom_point(size = 0.3) +
facet_wrap(~ .data$Topic, ncol = 2, scales = scales) +
xlab("Iteration") + ylab("Value") +
ggtitle("Estimated alpha") + theme_bw() +
theme(plot.title = element_text(hjust = 0.5))
} else if (modelname %in% c("hmm", "ldahmm")) {
res_alpha <- temp %>% dplyr::rename(tidyselect::all_of(tnames)) %>%
tidyr::pivot_longer(-c("Iteration", "State"), names_to = "Topic", values_to = "alpha")
res_alpha$State <- factor(res_alpha$State, levels = 1:max(res_alpha$State))
p <- ggplot(res_alpha, aes(x = .data$Iteration, y = .data$alpha, group = .data$State, colour = .data$State)) +
geom_line() +
geom_point(size = 0.3) +
facet_wrap(~ .data$Topic, ncol = 2, scales = scales) +
xlab("Iteration") + ylab("Value") +
ggtitle("Estimated alpha") + theme_bw() +
theme(plot.title = element_text(hjust = 0.5))
}
p <- list(figure = p, values = res_alpha)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
#' Show a diagnosis plot of log-likelihood and perplexity
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param start integer. The starting value of iteration to use in plot. Default is \code{1}.
#' @return keyATM_fig object.
#' @import ggplot2
#' @importFrom stats as.formula
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_modelfit <- function(x, start = 1)
{
check_arg_type(x, "keyATM_output")
modelfit <- x$model_fit
if (!is.numeric(start) | length(start) != 1) {
cli::cli_abort("`start` argument is invalid.")
}
if (!is.null(start)) {
modelfit <- modelfit[ modelfit$Iteration >= start, ]
}
modelfit <- tidyr::gather(modelfit, key = "Measures", value = "value", -"Iteration")
p <- ggplot(data = modelfit, aes(x = .data$Iteration, y = .data$value, group = .data$Measures, color = .data$Measures)) +
geom_line(show.legend = FALSE) +
geom_point(size = 0.3, show.legend = FALSE) +
facet_wrap(~ .data$Measures, ncol = 2, scales = "free") +
xlab("Iteration") + ylab("Value")
p <- p + ggtitle("Model Fit") + theme_bw() + theme(plot.title = element_text(hjust = 0.5))
p <- list(figure = p, values = modelfit)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
#' Show a diagnosis plot of pi
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param start integer. The starting value of iteration to use in the plot. Default is \code{0}.
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%). This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}). This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}. This is an option when calculating credible intervals (you need to set \code{store_pi = TRUE} in [keyATM()]).
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_pi <- function(x, show_topic = NULL, start = 0, ci = 0.9, method = c("hdi", "eti"), point = c("mean", "median"))
{
method <- rlang::arg_match(method)
point <- rlang::arg_match(point)
check_arg_type(x, "keyATM_output")
modelname <- extract_full_model_name(x)
if (modelname %in% c("lda", "ldacov", "ldahmm")) {
cli::cli_abort(paste0("This is not a model with keywords."))
}
if (is.null(show_topic)) {
show_topic <- 1:x$keyword_k
} else if (sum(!show_topic %in% 1:x$keyword_k) != 0) {
cli::cli_abort("`plot_pi` only visualize keyword topics.")
}
if (!is.numeric(start) | length(start) != 1) {
cli::cli_abort("`start` argument is invalid.")
}
tnames <- c(names(x$keywords_raw))[show_topic]
if (!is.null(x$values_iter$pi_iter)) {
pi_mat <- t(sapply(x$values_iter$pi_iter, unlist, use.names = FALSE))[, show_topic, drop = FALSE]
pi_mat %>%
tibble::as_tibble(.name_repair = ~ tnames) %>%
dplyr::mutate(Iteration = x$values_iter$used_iter) %>%
dplyr::filter(.data$Iteration >= start) %>%
dplyr::select(-tidyselect::all_of("Iteration")) -> pi_mat
if (nrow(pi_mat) == 0) {
cli::cli_abort("Nothing left to plot. Please check arguments.")
}
pi_mat %>%
tidyr::pivot_longer(cols = dplyr::everything(), names_to = "Topic") %>%
dplyr::group_by(.data$Topic) %>%
dplyr::summarise(x = list(tibble::enframe(calc_ci(.data$value, ci, method, point), "q", "value")), .groups = "drop_last") %>%
tidyr::unnest(x) %>% tidyr::pivot_wider(names_from = tidyselect::all_of("q"), values_from = tidyselect::all_of("value")) -> temp
p <- ggplot(temp, aes(y = .data$Point, x = .data$Topic)) +
theme_bw() + geom_point() +
geom_errorbar(aes(ymin = .data$Lower, ymax = .data$Upper), data = temp, width = 0.01, linewidth = 1) +
xlab("Topic") + ylab("Probability") +
ggtitle("Probability of words drawn from keyword topic-word distribution") +
theme(plot.title = element_text(hjust = 0.5))
} else {
cli::cli_alert_info("Plotting pi from the final MCMC draw. Please set `store_pi` to `TRUE` if you want to plot pi over iterations.")
x$pi %>%
dplyr::mutate(Probability = .data$Proportion / 100) %>%
dplyr::filter(.data$Topic %in% (!!show_topic)) %>%
dplyr::mutate(Topic = tnames) -> temp
p <- ggplot(temp, aes(x = .data$Topic, y = .data$Probability)) +
geom_bar(stat = "identity") +
theme_bw() +
xlab("Topic") + ylab("Probability") +
ggtitle("Probability of words drawn from keyword topic-word distribution") +
theme(plot.title = element_text(hjust = 0.5))
}
p <- list(figure = p, values = temp)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
#' Show the expected proportion of the corpus belonging to each topic
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param n The number of top words to show. Default is \code{3}.
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param show_topwords logical. Show topwords. The default is \code{TRUE}.
#' @param order The order of topics.
#' @param label_topic a character vector. The name of the topics in the plot.
#' @param xmax a numeric. Indicate the max value on the x axis
#' @return keyATM_fig object
#' @import magrittr
#' @import ggplot2
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_topicprop <- function(x, n = 3, show_topic = NULL, show_topwords = TRUE, label_topic = NULL, order = c("proportion", "topicid"), xmax = NULL)
{
check_arg_type(x, "keyATM_output")
order <- rlang::arg_match(order)
total_k <- x$keyword_k + x$no_keyword_topics
if (is.null(show_topic)) {
show_topic <- 1:total_k
} else {
if (max(show_topic) > total_k | min(show_topic) < 1) {
cli::cli_abort("Invalid topic ID in `show_topic`.")
}
}
topwords <- top_words(x, n = n)[, show_topic]
if (!is.null(label_topic)) {
if (length(label_topic) != ncol(topwords)) {
cli::cli_abort("The length of `label_topic` is incorrect.")
}
colnames(topwords) <- label_topic
}
topwords %>%
dplyr::summarise(dplyr::across(dplyr::everything(), ~stringr::str_c(.x, collapse = ", "))) %>%
tidyr::pivot_longer(dplyr::everything(),
values_to = "Topwords",
names_to = "Topic") -> topwords_commas
theta_use <- x$theta[, show_topic]
if (!is.null(label_topic)) {
colnames(theta_use) <- label_topic
}
theta_use %>%
tibble::as_tibble() %>%
dplyr::summarise(dplyr::across(dplyr::everything(), ~mean(.x))) %>%
tidyr::pivot_longer(dplyr::everything(),
values_to = "Topicprop",
names_to = "Topic") %>%
dplyr::left_join(topwords_commas, by = "Topic") -> theta_use_tbl
if (order == "proportion") {
theta_use_tbl %>%
dplyr::mutate(Topic = stringr::str_remove(.data$Topic, "^\\d+_")) -> theta_use_tbl
theta_use_tbl %>%
dplyr::arrange(dplyr::desc(.data$Topicprop)) %>%
dplyr::pull(.data$Topic) -> use_order
} else if (order == "topicid") {
theta_use_tbl %>%
dplyr::pull(.data$Topic) -> use_order
}
theta_use_tbl %>%
dplyr::mutate(
Topic = factor(.data$Topic, levels = rev(use_order)),
xpos = max(.data$Topicprop) + 0.01
) %>%
dplyr::arrange(dplyr::desc(.data$Topic)) -> plot_obj
if (is.null(xmax)) {
if (show_topwords) {
xmax <- min(max(plot_obj$Topicprop) * 2.5, 1)
} else {
xmax <- max(plot_obj$Topicprop) + 0.02
}
}
label_percent <- function(x) {
paste0(round(x * 100, 2), "%")
}
p <- ggplot(plot_obj, aes(x = .data$Topicprop, y = .data$Topic)) +
geom_col() +
{if (show_topwords)
geom_text(
aes(x = .data$xpos, y = .data$Topic, label = .data$Topwords),
hjust = 0, size = max(10 / n + 1, 2.5))
} +
scale_x_continuous(
"Expected topic proportions", limits = c(0, xmax), labels = label_percent
) +
theme_bw() +
theme(panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank(),
panel.grid.major.y = element_blank(),
panel.grid.minor.y = element_blank())
p <- list(figure = p, values = plot_obj)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
#' Plot document-topic distribution by strata (for covariate models)
#'
#' @param x a strata_doctopic object (see [by_strata_DocTopic()]).
#' @param show_topic a vector or an integer. Indicate topics to visualize.
#' @param var_name the name of the variable in the plot.
#' @param by `topic` or `covariate`. Default is by `topic`.
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}.
#' @param width numeric. Width of the error bars.
#' @param show_point logical. Show point estimates. The default is \code{TRUE}.
#' @param ... additional arguments not used.
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()], [by_strata_DocTopic()]
#' @export
plot.strata_doctopic <- function(x, show_topic = NULL, var_name = NULL, by = c("topic", "covariate"),
ci = 0.9, method = c("hdi", "eti"), point = c("mean", "median"),
width = 0.1, show_point = TRUE, ...)
{
by <- rlang::arg_match(by)
method <- rlang::arg_match(method)
point <- rlang::arg_match(point)
tables <- summary.strata_doctopic(x, ci, method, point)
by_var <- x$by_var
by_values <- x$by_values
if (!is.null(var_name)) {
by_var <- var_name
}
if (is.null(show_topic)) {
show_topic <- 1:nrow(tables[[1]])
}
tables <- dplyr::bind_rows(tables)
tnames <- unique(tables$Topic)
num_keytopic <- sum(!grepl("Other_[0-9]+", tnames))
topic_parse <- function(s) {
if (grepl("Other_[0-9]+", s)) {
return(as.numeric(strsplit(s, "_")[[1]][2]) + num_keytopic)
} else {
return(as.numeric(strsplit(s, "_")[[1]][1]))
}
}
tables$TopicID <- purrr::map_dbl(tables$Topic, topic_parse)
tables$Topic <- factor(tables$Topic, levels = tnames[order(purrr::map_dbl(tnames, topic_parse))])
tables <- tables %>% dplyr::filter(.data$TopicID %in% show_topic)
variables <- unique(tables$label)
if (point == "mean") {
ylabel <- expression(paste("Mean of ", theta))
} else {
ylabel <- expression(paste("Median of ", theta))
}
p <- ggplot(tables) +
coord_flip() +
scale_x_discrete(limits = rev(variables)) +
xlab(paste0(by_var)) +
ylab(ylabel) +
guides(color = guide_legend(title = "Topic")) +
theme_bw()
if (by == "topic") {
p <- p + geom_errorbar(width = width, aes(x = .data$label, ymin = .data$Lower, ymax = .data$Upper,
group = .data$Topic), position = position_dodge(width = -1/2)) + facet_wrap(~Topic)
if (show_point)
p <- p + geom_point(aes(x = .data$label, y = .data$Point))
} else {
p <- p + geom_errorbar(width = width, aes(x = .data$label, ymin = .data$Lower, ymax = .data$Upper,
group = .data$Topic, colour = .data$Topic), position = position_dodge(width = -1/2))
if (show_point)
p <- p + geom_point(aes(x = .data$label, y = .data$Point, colour = .data$Topic), position = position_dodge(width = -1/2))
}
p <- list(figure = p, values = tables)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
#' Plot time trend
#'
#' @param x the output from the dynamic keyATM model (see [keyATM()]).
#' @param show_topic an integer or a vector. Indicate topics to visualize. Default is \code{NULL}.
#' @param time_index_label a vector. The label for time index. The length should be equal to the number of documents (time index provided to [keyATM()]).
#' @param ci value of the credible interval (between 0 and 1) to be estimated. Default is \code{0.9} (90%). This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param method method for computing the credible interval. The Highest Density Interval (\code{hdi}, default) or Equal-tailed Interval (\code{eti}). This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param point method for computing the point estimate. \code{mean} (default) or \code{median}. This is an option when calculating credible intervals (you need to set \code{store_theta = TRUE} in [keyATM()]).
#' @param xlab a character.
#' @param scales character. Control the scale of y-axis (the parameter in [ggplot2::facet_wrap()][ggplot2::facet_wrap]): \code{free} adjusts y-axis for parameters. Default is \code{fixed}.
#' @param show_point logical. The default is \code{TRUE}. This is an option when calculating credible intervals.
#' @param ... additional arguments not used.
#' @return keyATM_fig object.
#' @import ggplot2
#' @import magrittr
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
plot_timetrend <- function(x, show_topic = NULL, time_index_label = NULL,
ci = 0.9, method = c("hdi", "eti"), point = c("mean", "median"),
xlab = "Time", scales = "fixed", show_point = TRUE, ...)
{
method <- rlang::arg_match(method)
point <- rlang::arg_match(point)
check_arg_type(x, "keyATM_output")
modelname <- extract_full_model_name(x)
if (!modelname %in% c("hmm", "ldahmm")) {
cli::cli_abort(paste0("This is not a model with time trends."))
}
if (!is.null(time_index_label)) {
if (length(x$values_iter$time_index) != length(time_index_label)) {
cli::cli_abort("The length of `time_index_label` does not match with the number of documents.")
}
time_index <- time_index_label
} else {
time_index <- x$values_iter$time_index
}
time_index_tbl <- tibble::tibble(time_index = time_index, time_index_raw = x$values_iter$time_index) %>% dplyr::distinct()
if (is.null(show_topic)) {
show_topic <- 1:x$keyword_k
}
format_theta <- function(theta, time_index, tnames) {
theta[, show_topic, drop = FALSE] %>%
tibble::as_tibble(.name_repair = ~tnames) %>%
dplyr::mutate(time_index = time_index) %>%
tidyr::pivot_longer(-tidyselect::all_of("time_index"), names_to = "Topic", values_to = "Proportion") %>%
dplyr::group_by(.data$time_index, .data$Topic) %>%
dplyr::summarize(Proportion = base::mean(.data$Proportion), .groups = "drop_last") -> res
return(res)
}
tnames <- colnames(x$theta)
if (is.null(x$values_iter$theta_iter)) {
dat <- format_theta(x$theta, time_index, tnames[show_topic])
p <- ggplot(dat, aes(x = .data$time_index, y = .data$Proportion, group = .data$Topic)) +
geom_line(linewidth = 0.8, color = "blue") + geom_point(size = 0.9)
} else {
dat <- dplyr::bind_rows(lapply(x$values_iter$theta_iter, format_theta, time_index, tnames[show_topic])) %>%
dplyr::group_by(.data$time_index, .data$Topic) %>%
dplyr::summarise(x = list(tibble::enframe(calc_ci(.data$Proportion, ci, method, point), "q", "value"))) %>%
tidyr::unnest(tidyselect::all_of("x")) %>% dplyr::ungroup() %>%
tidyr::pivot_wider(names_from = tidyselect::all_of("q"), values_from = tidyselect::all_of("value")) %>%
stats::setNames(c("time_index", "Topic", "Lower", "Point", "Upper"))
p <- ggplot(dat, aes(x = .data$time_index, y = .data$Point, group = .data$Topic)) +
geom_ribbon(aes(ymin = .data$Lower, ymax = .data$Upper), fill = "gray75") +
geom_line(linewidth = 0.8, color = "blue")
if (show_point)
p <- p + geom_point(size = 0.9)
}
p <- p + xlab(xlab) + ylab(expression(paste("Mean of ", theta))) +
facet_wrap(~.data$Topic, scales = scales) + theme_bw() + theme(panel.grid.minor = element_blank())
dat <- dplyr::left_join(dat, time_index_tbl, by = "time_index") %>%
dplyr::mutate(state_id = x$values_iter$R_iter_last[.data$time_index_raw])
p <- list(figure = p, values = dat)
class(p) <- c("keyATM_fig", class(p))
return(p)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.