Nothing
#' Estimate subsetted topic-word distribution
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param keyATM_docs an object generated by [keyATM_read()].
#' @param by a vector whose length is the number of documents.
#'
#' @return strata_topicword object (a list).
#' @import magrittr
#' @export
by_strata_TopicWord <- function(x, keyATM_docs, by) {
# Check inputs
if (!is.vector(by)) {
cli::cli_abort("`by` should be a vector.")
}
if (!"Z" %in% names(x$kept_values)) {
cli::cli_abort(
"`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`."
)
}
if (!"S" %in% names(x$kept_values)) {
cli::cli_abort(
"`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`."
)
}
if (length(keyATM_docs$W_raw) != length(by)) {
cli::cli_abort(
"The length of `by` should be the same as the length of documents."
)
}
# Get unique values of `by`
unique_val <- unique(by)
tnames <- rownames(x$phi)
# Get phi for each
obj <- lapply(unique_val, function(val) {
doc_index <- which(by == val)
all_words <- unlist(keyATM_docs$W_raw[doc_index], use.names = FALSE)
all_topics <- as.integer(
unlist(x$kept_values$Z[doc_index]),
use.names = FALSE
)
all_s <- as.integer(unlist(x$kept_values$S[doc_index]), use.names = FALSE)
pi_estimated <- keyATM_output_pi(
x$kept_values$Z[doc_index],
x$kept_values$S[doc_index],
x$priors$gamma
)
vocab <- sort(unique(all_words))
phi_obj <- keyATM_output_phi_calc_key(
all_words,
all_topics,
all_s,
pi_estimated,
x$keywords_raw,
vocab,
x$priors,
tnames,
model = x
)
})
names(obj) <- unique_val
res <- list(phi = obj, theta = x$theta, keywords_raw = x$keywords_raw)
class(res) <- c("strata_topicword", class(res))
return(res)
}
#' Show covariates information
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @export
covariates_info <- function(x) {
if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
cli::cli_abort("This is not an output of the covariate model")
}
cat(paste0(
"Colnames: ",
paste(
colnames(x$kept_values$model_settings$covariates_data_use),
collapse = ", "
),
"\nStandardization: ",
as.character(x$kept_values$model_settings$standardize),
"\nFormula: ",
paste(
as.character(x$kept_values$model_settings$covariates_formula),
collapse = " "
),
"\n\nPreview:\n"
))
print(utils::head(x$kept_values$model_settings$covariates_data_use))
}
#' Return covariates used in the iteration
#'
#' @param x the output from the covariate keyATM model (see [keyATM()])
#' @export
covariates_get <- function(x) {
if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
cli::cli_abort("This is not an output of the covariate model")
}
return(x$kept_values$model_settings$covariates_data_use)
}
#' Estimate document-topic distribution by strata (for covariate models)
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @param by_var character. The name of the variable to use.
#' @param labels character. The labels for the values specified in `by_var` (ascending order).
#' @param by_values numeric. Specific values for `by_var`, ordered from small to large. If it is not specified, all values in `by_var` will be used.
#' @param ... other arguments passed on to the [predict.keyATM_output()] function.
#' @return strata_topicword object (a list).
#' @import magrittr
#' @importFrom stats predict
#' @export
by_strata_DocTopic <- function(x, by_var, labels, by_values = NULL, ...) {
# Check inputs
variables <- colnames(x$kept_values$model_settings$covariates_data_use)
if (length(by_var) != 1) {
cli::cli_abort("`by_var` should be a single variable.")
}
if (!by_var %in% variables) {
cli::cli_abort(paste0(
by_var,
" is not in the set of covariates in keyATM model. Check with `covariates_info()`.",
"Covariates provided are: ",
paste(
colnames(x$kept_values$model_settings$covariates_data_use),
collapse = " , "
)
))
}
# Info
if (is.null(by_values)) {
by_values <- sort(unique(x$kept_values$model_settings$covariates_data_use[,
by_var
]))
}
if (length(by_values) != length(labels)) {
cli::cli_abort("Length mismatches. Please check `labels`.")
}
apply_predict <- function(i, ...) {
value <- by_values[i]
new_data <- x$kept_values$model_settings$covariates_data_use
new_data[, by_var] <- value
obj <- predict(object = x, newdata = new_data, raw_values = TRUE, ...)
return(obj)
}
set.seed(x$options$seed)
res <- lapply(1:length(by_values), apply_predict, ...)
names(res) <- by_values
res <- tibble::as_tibble(res)
# Making CI
tables <- lapply(1:length(by_values), function(index) {
theta <- res[[index]]
return(strata_doctopic_CI(
theta[, 1:(ncol(theta) - 1)],
label = labels[index],
...
))
})
names(tables) <- labels
# Making a return object
obj <- list(
theta = res,
tables = tables,
by_values = by_values,
by_var = by_var,
labels = labels
)
class(obj) <- c("strata_doctopic", class(obj))
return(obj)
}
#' @noRd
#' @export
print.strata_doctopic <- function(x, ...) {
cat(paste0("strata_doctopic object for the ", x$by_var, "\n"))
}
#' @noRd
#' @export
summary.strata_doctopic <- function(object, ...) {
return(object$tables)
}
#' @noRd
#' @import magrittr
#' @keywords internal
strata_doctopic_CI <- function(
theta,
ci = 0.9,
method = c("hdi", "eti"),
point = c("mean", "median"),
label = NULL,
...
) {
method <- rlang::arg_match(method)
point <- rlang::arg_match(point)
q <- as.data.frame(apply(theta, 2, calc_ci, ci, method, point))
q$CI <- c("Lower", "Point", "Upper")
q %>%
tidyr::gather(key = "Topic", value = "Value", -"CI") %>%
tidyr::spread(key = "CI", value = "Value") %>%
dplyr::mutate(TopicID = 1:(dplyr::n())) %>%
tibble::as_tibble() -> res
if (!is.null(label)) {
res %>% dplyr::mutate(label = label) -> res
}
return(res)
}
#' @noRd
#' @keywords internal
calc_ci <- function(vec, ci, method, point) {
# Check bayestestR package and Kruschke's book
if (point == "mean") {
point <- mean(vec)
}
if (point == "median") {
point <- stats::median(vec)
}
if (method == "hdi") {
sorted_points <- sort.int(vec, method = "quick")
window_size <- ceiling(ci * length(sorted_points))
nCIs <- length(sorted_points) - window_size
if (window_size < 2 | nCIs < 1) {
cli::cli_warn(
"`ci` is too small or interations are not enough, using `eti` option instead."
)
return(calc_ci(vec, ci, method = "eti", point))
}
ci_width <- sapply(1:nCIs, function(x) {
sorted_points[x + window_size] - sorted_points[x]
})
slice_min <- which.min(ci_width)
res <- c(
sorted_points[slice_min],
point,
sorted_points[slice_min + window_size]
)
}
if (method == "eti") {
qua <- stats::quantile(vec, probs = c((1 - ci) / 2, (1 + ci) / 2))
res <- c(qua[1], point, qua[2])
}
names(res) <- c("Lower", "Point", "Upper")
return(res)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.