Nothing
#' Estimate subsetted topic-word distribution
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param keyATM_docs an object generated by [keyATM_read()].
#' @param by a vector whose length is the number of documents.
#'
#' @return strata_topicword object (a list).
#' @import magrittr
#' @export
by_strata_TopicWord <- function(x, keyATM_docs, by)
{
# Check inputs
if (!is.vector(by)) {
cli::cli_abort("`by` should be a vector.")
}
if (!"Z" %in% names(x$kept_values)) {
cli::cli_abort("`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`.")
}
if (!"S" %in% names(x$kept_values)) {
cli::cli_abort("`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`.")
}
if (length(keyATM_docs$W_raw) != length(by)) {
cli::cli_abort("The length of `by` should be the same as the length of documents.")
}
# Get unique values of `by`
unique_val <- unique(by)
tnames <- rownames(x$phi)
# Get phi for each
obj <- lapply(unique_val,
function(val) {
doc_index <- which(by == val)
all_words <- unlist(keyATM_docs$W_raw[doc_index], use.names = FALSE)
all_topics <- as.integer(unlist(x$kept_values$Z[doc_index]), use.names = FALSE)
all_s <- as.integer(unlist(x$kept_values$S[doc_index]), use.names = FALSE)
pi_estimated <- keyATM_output_pi(x$kept_values$Z[doc_index],
x$kept_values$S[doc_index],
x$priors$gamma)
vocab <- sort(unique(all_words))
phi_obj <- keyATM_output_phi_calc_key(all_words, all_topics, all_s, pi_estimated,
x$keywords_raw,
vocab, x$priors, tnames, model = x)
}
)
names(obj) <- unique_val
res <- list(phi = obj, theta = x$theta, keywords_raw = x$keywords_raw)
class(res) <- c("strata_topicword", class(res))
return(res)
}
#' Show covariates information
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @export
covariates_info <- function(x) {
if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
cli::cli_abort("This is not an output of the covariate model")
}
cat(paste0("Colnames: ", paste(colnames(x$kept_values$model_settings$covariates_data_use), collapse = ", "),
"\nStandardization: ", as.character(x$kept_values$model_settings$standardize),
"\nFormula: ", paste(as.character(x$kept_values$model_settings$covariates_formula), collapse = " "), "\n\nPreview:\n"))
print(utils::head(x$kept_values$model_settings$covariates_data_use))
}
#' Return covariates used in the iteration
#'
#' @param x the output from the covariate keyATM model (see [keyATM()])
#' @export
covariates_get <- function(x) {
if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
cli::cli_abort("This is not an output of the covariate model")
}
return(x$kept_values$model_settings$covariates_data_use)
}
#' Estimate document-topic distribution by strata (for covariate models)
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @param by_var character. The name of the variable to use.
#' @param labels character. The labels for the values specified in `by_var` (ascending order).
#' @param by_values numeric. Specific values for `by_var`, ordered from small to large. If it is not specified, all values in `by_var` will be used.
#' @param ... other arguments passed on to the [predict.keyATM_output()] function.
#' @return strata_topicword object (a list).
#' @import magrittr
#' @importFrom stats predict
#' @export
by_strata_DocTopic <- function(x, by_var, labels, by_values = NULL, ...)
{
# Check inputs
variables <- colnames(x$kept_values$model_settings$covariates_data_use)
if (length(by_var) != 1)
cli::cli_abort("`by_var` should be a single variable.")
if (!by_var %in% variables)
cli::cli_abort(paste0(by_var, " is not in the set of covariates in keyATM model. Check with `covariates_info()`.",
"Covariates provided are: ",
paste(colnames(x$kept_values$model_settings$covariates_data_use), collapse = " , ")))
# Info
if (is.null(by_values)) {
by_values <- sort(unique(x$kept_values$model_settings$covariates_data_use[, by_var]))
}
if (length(by_values) != length(labels)) {
cli::cli_abort("Length mismatches. Please check `labels`.")
}
apply_predict <- function(i, ...) {
value <- by_values[i]
new_data <- x$kept_values$model_settings$covariates_data_use
new_data[, by_var] <- value
obj <- predict(object = x, newdata = new_data, raw_values = TRUE, ...)
return(obj)
}
set.seed(x$options$seed)
res <- lapply(1:length(by_values), apply_predict, ...)
names(res) <- by_values
res <- tibble::as_tibble(res)
# Making CI
tables <- lapply(1:length(by_values),
function(index) {
theta <- res[[index]]
return(strata_doctopic_CI(theta[, 1:(ncol(theta)-1)],
label = labels[index], ...))
})
names(tables) <- labels
# Making a return object
obj <- list(theta = res, tables = tables, by_values = by_values, by_var = by_var, labels = labels)
class(obj) <- c("strata_doctopic", class(obj))
return(obj)
}
#' @noRd
#' @export
print.strata_doctopic <- function(x, ...)
{
cat(paste0("strata_doctopic object for the ", x$by_var, "\n"))
}
#' @noRd
#' @export
summary.strata_doctopic <- function(object, ...)
{
return(object$tables)
}
#' @noRd
#' @import magrittr
#' @keywords internal
strata_doctopic_CI <- function(theta, ci = 0.9, method = c("hdi", "eti"), point = c("mean", "median"), label = NULL, ...)
{
method <- rlang::arg_match(method)
point <- rlang::arg_match(point)
q <- as.data.frame(apply(theta, 2, calc_ci, ci, method, point))
q$CI <- c("Lower", "Point", "Upper")
q %>%
tidyr::gather(key = "Topic", value = "Value", -"CI") %>%
tidyr::spread(key = "CI", value = "Value") %>%
dplyr::mutate(TopicID = 1:(dplyr::n())) %>% tibble::as_tibble() -> res
if (!is.null(label)) {
res %>% dplyr::mutate(label = label) -> res
}
return(res)
}
#' @noRd
#' @keywords internal
calc_ci <- function(vec, ci, method, point)
{
# Check bayestestR package and Kruschke's book
if (point == "mean") {
point <- mean(vec)
}
if (point == "median") {
point <- stats::median(vec)
}
if (method == "hdi") {
sorted_points <- sort.int(vec, method = "quick")
window_size <- ceiling(ci * length(sorted_points))
nCIs <- length(sorted_points) - window_size
if (window_size < 2 | nCIs < 1) {
cli::cli_warn("`ci` is too small or interations are not enough, using `eti` option instead.")
return(calc_ci(vec, ci, method = "eti", point))
}
ci_width <- sapply(1:nCIs, function(x) {sorted_points[x + window_size] - sorted_points[x]})
slice_min <- which.min(ci_width)
res <- c(sorted_points[slice_min], point, sorted_points[slice_min + window_size])
}
if (method == "eti") {
qua <- stats::quantile(vec, probs = c((1 - ci) / 2, (1 + ci) / 2))
res <- c(qua[1], point, qua[2])
}
names(res) <- c("Lower", "Point", "Upper")
return(res)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.