R/analysis_covariates.R

Defines functions calc_ci strata_doctopic_CI summary.strata_doctopic print.strata_doctopic by_strata_DocTopic covariates_get covariates_info by_strata_TopicWord

Documented in by_strata_DocTopic by_strata_TopicWord covariates_get covariates_info

#' Estimate subsetted topic-word distribution
#'
#' @param x the output from a keyATM model (see [keyATM()]).
#' @param keyATM_docs an object generated by [keyATM_read()].
#' @param by a vector whose length is the number of documents.
#'
#' @return strata_topicword object (a list).
#' @import magrittr
#' @export
by_strata_TopicWord <- function(x, keyATM_docs, by)
{
  # Check inputs
  if (!is.vector(by)) {
    cli::cli_abort("`by` should be a vector.")
  }
  if (!"Z" %in% names(x$kept_values)) {
    cli::cli_abort("`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`.")
  }
  if (!"S" %in% names(x$kept_values)) {
    cli::cli_abort("`Z` and `S` should be in the output. Please check `keep` option in `keyATM()`.")
  }
  if (length(keyATM_docs$W_raw) != length(by)) {
    cli::cli_abort("The length of `by` should be the same as the length of documents.")
  }


  # Get unique values of `by`
  unique_val <- unique(by)
  tnames <- rownames(x$phi)

  # Get phi for each
  obj <- lapply(unique_val,
                function(val) {
                  doc_index <- which(by == val)
                  all_words <- unlist(keyATM_docs$W_raw[doc_index], use.names = FALSE)
                  all_topics <- as.integer(unlist(x$kept_values$Z[doc_index]), use.names = FALSE)
                  all_s <- as.integer(unlist(x$kept_values$S[doc_index]), use.names = FALSE)
                  pi_estimated <- keyATM_output_pi(x$kept_values$Z[doc_index],
                                                   x$kept_values$S[doc_index],
                                                   x$priors$gamma)
                  vocab <- sort(unique(all_words))
                  phi_obj <- keyATM_output_phi_calc_key(all_words, all_topics, all_s, pi_estimated,
                                                        x$keywords_raw,
                                                        vocab, x$priors, tnames, model = x)
                }
               )
  names(obj) <- unique_val

  res <- list(phi = obj, theta = x$theta, keywords_raw = x$keywords_raw)

  class(res) <- c("strata_topicword", class(res))
  return(res)
}


#' Show covariates information
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @export
covariates_info <- function(x) {
  if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
    cli::cli_abort("This is not an output of the covariate model")
  }
  cat(paste0("Colnames: ", paste(colnames(x$kept_values$model_settings$covariates_data_use), collapse = ", "),
             "\nStandardization: ", as.character(x$kept_values$model_settings$standardize),
             "\nFormula: ", paste(as.character(x$kept_values$model_settings$covariates_formula), collapse = " "), "\n\nPreview:\n"))
  print(utils::head(x$kept_values$model_settings$covariates_data_use))
}


#' Return covariates used in the iteration
#'
#' @param x the output from the covariate keyATM model (see [keyATM()])
#' @export
covariates_get <- function(x) {
  if (x$model != "covariates" | !("keyATM_output" %in% class(x))) {
    cli::cli_abort("This is not an output of the covariate model")
  }
  return(x$kept_values$model_settings$covariates_data_use)
}


#' Estimate document-topic distribution by strata (for covariate models)
#'
#' @param x the output from the covariate keyATM model (see [keyATM()]).
#' @param by_var character. The name of the variable to use.
#' @param labels character. The labels for the values specified in `by_var` (ascending order).
#' @param by_values numeric. Specific values for `by_var`, ordered from small to large. If it is not specified, all values in `by_var` will be used.
#' @param ... other arguments passed on to the [predict.keyATM_output()] function.
#' @return strata_topicword object (a list).
#' @import magrittr
#' @importFrom stats predict
#' @export
by_strata_DocTopic <- function(x, by_var, labels, by_values = NULL, ...)
{
  # Check inputs
  variables <- colnames(x$kept_values$model_settings$covariates_data_use)
  if (length(by_var) != 1)
    cli::cli_abort("`by_var` should be a single variable.")
  if (!by_var %in% variables)
    cli::cli_abort(paste0(by_var, " is not in the set of covariates in keyATM model. Check with `covariates_info()`.",
                "Covariates provided are: ",
                paste(colnames(x$kept_values$model_settings$covariates_data_use), collapse = " , ")))

  # Info
  if (is.null(by_values)) {
    by_values <- sort(unique(x$kept_values$model_settings$covariates_data_use[, by_var]))
  }
  if (length(by_values) != length(labels)) {
    cli::cli_abort("Length mismatches. Please check `labels`.")
  }

  apply_predict <- function(i, ...) {
    value <- by_values[i]
    new_data <- x$kept_values$model_settings$covariates_data_use
    new_data[, by_var] <- value
    obj <- predict(object = x, newdata = new_data, raw_values = TRUE, ...)
    return(obj)
  }

  set.seed(x$options$seed)
  res <- lapply(1:length(by_values), apply_predict, ...)
  names(res) <- by_values
  res <- tibble::as_tibble(res)

  # Making CI
  tables <- lapply(1:length(by_values),
                  function(index) {
                     theta <- res[[index]]
                     return(strata_doctopic_CI(theta[, 1:(ncol(theta)-1)],
                                               label = labels[index], ...))
                  })
  names(tables) <- labels


  # Making a return object
  obj <- list(theta = res, tables = tables, by_values = by_values, by_var = by_var, labels = labels)
  class(obj) <- c("strata_doctopic", class(obj))
  return(obj)
}


#' @noRd
#' @export
print.strata_doctopic <- function(x, ...)
{
  cat(paste0("strata_doctopic object for the ", x$by_var, "\n"))
}


#' @noRd
#' @export
summary.strata_doctopic <- function(object, ...)
{
  return(object$tables)
}


#' @noRd
#' @import magrittr
#' @keywords internal
strata_doctopic_CI <- function(theta, ci = 0.9, method = c("hdi", "eti"), point = c("mean", "median"), label = NULL, ...)
{
  method <- rlang::arg_match(method)
  point <- rlang::arg_match(point)

  q <- as.data.frame(apply(theta, 2, calc_ci, ci, method, point))
  q$CI <- c("Lower", "Point", "Upper")
  q %>%
    tidyr::gather(key = "Topic", value = "Value", -"CI") %>%
    tidyr::spread(key = "CI", value = "Value") %>%
    dplyr::mutate(TopicID = 1:(dplyr::n())) %>% tibble::as_tibble() -> res

  if (!is.null(label)) {
    res %>% dplyr::mutate(label = label) -> res
  }
  return(res)
}


#' @noRd
#' @keywords internal
calc_ci <- function(vec, ci, method, point)
{
  # Check bayestestR package and Kruschke's book
  if (point == "mean") {
    point <- mean(vec)
  }
  if (point == "median") {
    point <- stats::median(vec)
  }

  if (method == "hdi") {
    sorted_points <- sort.int(vec, method = "quick")
    window_size <- ceiling(ci * length(sorted_points))
    nCIs <-  length(sorted_points) - window_size

    if (window_size < 2 | nCIs < 1) {
      cli::cli_warn("`ci` is too small or interations are not enough, using `eti` option instead.")
      return(calc_ci(vec, ci, method = "eti", point))
    }

    ci_width <- sapply(1:nCIs, function(x) {sorted_points[x + window_size] - sorted_points[x]})
    slice_min <- which.min(ci_width)

    res <- c(sorted_points[slice_min], point, sorted_points[slice_min + window_size])
  }

  if (method == "eti") {
    qua <- stats::quantile(vec, probs = c((1 - ci) / 2, (1 + ci) / 2))
    res <- c(qua[1], point, qua[2])
  }

  names(res) <- c("Lower", "Point", "Upper")
  return(res)
}

Try the keyATM package in your browser

Any scripts or data that you put into this service are public.

keyATM documentation built on May 31, 2023, 6:27 p.m.