R/model.R

Defines functions visualize_keywords summary.keyATM_docs print.keyATM_docs keyATM_read

Documented in keyATM_read visualize_keywords

#' Read texts
#'
#' Read texts and create a \code{keyATM_docs} object, which is a list of texts.
#'
#'
#' @param texts input. keyATM takes a quanteda dfm (dgCMatrix), data.frame, \pkg{tibble} tbl_df, or a vector of file paths.
#' @param encoding character. Only used when \code{texts} is a vector of file paths. Default is \code{UTF-8}.
#' @param check logical. If \code{TRUE}, check whether there is anything wrong with the structure of texts. Default is \code{TRUE}.
#' @param keep_docnames logical. If \code{TRUE}, it keeps the document names in a quanteda dfm. Default is \code{FALSE}.
#' @param split numeric. This option works only with a quanteda dfm. It creates a two subset of the dfm by randomly splitting each document (i.e., the total number of documents is the same between two subsets). This option specifies the split proportion. Default is \code{0}.
#'
#' @return a keyATM_docs object. The first element is a list whose elements are split texts. The length of the list equals to the number of documents.
#'
#' @examples
#' \dontrun{
#'  # Use quanteda dfm
#'  keyATM_docs <- keyATM_read(texts = quanteda_dfm)
#'
#'  # Use data.frame or tibble (texts should be stored in a column named `text`)
#'  keyATM_docs <- keyATM_read(texts = data_frame_object)
#'  keyATM_docs <- keyATM_read(texts = tibble_object)
#'
#'  # Use a vector that stores full paths to the text files
#'  files <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
#'  keyATM_docs <- keyATM_read(texts = files)
#'
#' }
#' @import magrittr
#' @importFrom rlang .data
#' @export
keyATM_read <- function(texts, encoding = "UTF-8", check = TRUE, keep_docnames = FALSE, split = 0)
{
  # Detect input
  if ("tbl" %in% class(texts)) {
    cli::cli_alert_info("Using tibble.")
    text_dfm <- NULL
    files <- NULL
    text_df <- texts
  } else if ("data.frame" %in% class(texts)) {
    cli::cli_alert_info("Using data.frame.")
    text_dfm <- NULL
    files <- NULL
    text_df <- tibble::as_tibble(texts)
  } else if ("dfm" %in% class(texts) | "dgCMatrix" %in% class(texts)) {
    cli::cli_alert_info("Using quanteda dfm.")
    text_dfm <- texts
    files <- NULL
    text_df <- NULL
  } else if ("character" %in% class(texts)) {
    cli::cli_alert_info("Reading from files. Please make sure files are preprocessed. Encoding: {encoding}.")
    text_dfm <- NULL
    files <- texts
    text_df <- NULL
  } else {
    cli::cli_abort(c("x" = "Check `texts` argument.",
         "i" = "It can take quanteda dfm, data.frame, tibble, and a vector of characters."))
  }

  if (split != 0) {
    if (split < 0 | split >= 1)
      cli::cli_abort("Invalid option. `split` should be a proportion.")
  }

  # Read texts

  # If you have quanteda object
  docnames <- NULL
  W_read <- list(W_raw = list(), W_split = list())
  if (!is.null(text_dfm)) {
    vocabulary <- colnames(text_dfm)
    W_read <- read_dfm_cpp(text_dfm, W_read, vocabulary, split)
    if (keep_docnames) {
      docnames <- quanteda::docnames(text_dfm)
    }
  } else {
    # Preprocess each text
    # Use files <- list.files(doc_folder, pattern = "txt", full.names = TRUE) when you pass
    if (is.null(text_df)) {
      text_df <- tibble::tibble(text = unlist(lapply(files,
                                                 function(x)
                                                 {
                                                     paste0(readLines(x, encoding = encoding),
                                                           collapse = "\n")
                                                 })))
    }
    text_df <- text_df %>% dplyr::mutate(text_split = stringr::str_split(.data$text, pattern = " "))

    # Extract split text and create a list
    W_read$W_raw <- text_df %>% dplyr::pull(tidyselect::all_of("text_split"))
  }
  W_raw <- W_read$W_raw

  # Check whether there is nothing wrong with the structure of texts
  if (check) {
    wd_names <- unique(unlist(W_raw, use.names = FALSE, recursive = FALSE))
    check_vocabulary(wd_names)
    doc_index <- get_doc_index(W_raw, check = TRUE)
  } else {
    doc_index <- NULL
    wd_names <- NULL
  }

  W_split <- list(W_raw = W_read$W_split)
  class(W_split) <- c("keyATM_docs", class(W_split))

  W_data <- list(W_raw = W_raw, W_split = W_split, doc_index = doc_index,
                 wd_names = wd_names, docnames = docnames)
  class(W_data) <- c("keyATM_docs", class(W_data))
  return(W_data)
}


#' @noRd
#' @export
print.keyATM_docs <- function(x, ...)
{
  cli::cli_inform("keyATM_docs object of {length(x$W_raw)} documents.")
}


#' @noRd
#' @export
summary.keyATM_docs <- function(object, ...)
{
  doc_len <- sapply(object$W_raw, length)
  cli::cli_inform("keyATM_docs object of {length(object$W_raw)} documents.")
  ul <- cli::cli_ul()
  cli::cli_li("Average (min/max) document length: {round(mean(doc_len), 2)} ({min(doc_len)}/{max(doc_len)}) words")
  cli::cli_li("Number of unique words: {length(object$wd_names)}")
  cli::cli_end(ul)
}


#' Visualize keywords
#'
#' Visualize the proportion of keywords in the documents.
#'
#' @param docs a keyATM_docs object, generated by \code{keyATM_read()} function
#' @param keywords a list of keywords
#' @param prune logical. If \code{TRUE}, prune keywords that do not appear in `docs`. Default is \code{TRUE}.
#' @param label_size the size of keyword labels in the output plot. Default is \code{3.2}.
#' @return keyATM_fig object
#' @examples
#' \dontrun{
#'  # Prepare a keyATM_docs object
#'  keyATM_docs <- keyATM_read(input)
#'
#'  # Keywords are in a list
#'  keywords <- list(Education = c("education", "child", "student"),
#'                   Health    = c("public", "health", "program"))
#'
#'  # Visualize keywords
#'  keyATM_viz <- visualize_keywords(keyATM_docs, keywords)
#'
#'  # View a figure
#'  keyATM_viz
#'
#'  # Save a figure
#'  save_fig(keyATM_viz, filename)
#' }
#' @import magrittr
#' @import ggplot2
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
visualize_keywords <- function(docs, keywords, prune = TRUE, label_size = 3.2)
{
  # Check type
  check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
  check_arg_type(keywords, "list")
  c <- lapply(keywords, function(x){check_arg_type(x, "character")})

  unlisted <- unlist(docs$W_raw, recursive = FALSE, use.names = FALSE)

  # Check keywords
  keywords <- check_keywords(unique(unlisted), keywords, prune)

  # Organize data
  unnested_data <- tibble::tibble(text_split = unlisted)
  totalwords <- nrow(unnested_data)

  data <- unnested_data %>%
    dplyr::rename(Word = "text_split") %>%
    dplyr::group_by(.data$Word) %>%
    dplyr::summarize(WordCount = dplyr::n()) %>%
    dplyr::mutate(`Proportion(%)` = round(.data$WordCount / totalwords * 100, 3)) %>%
    dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
    dplyr::mutate(Ranking = 1:(dplyr::n()))

  keywords <- lapply(keywords, function(x) {unlist(strsplit(x, " "))})
  ext_k <- length(keywords)
  max_num_words <- max(unlist(lapply(keywords, function(x) {length(x)}), use.names = FALSE))

  # Make keywords_df
  keywords_df <- data.frame(Topic = 1, Word = 1)
  tnames <- names(keywords)
  for (k in 1:ext_k) {
    words <- keywords[[k]]
    numwords <- length(words)
    if (is.null(tnames)) {
      topicname <- paste0("Topic", k)
    } else {
      topicname <- paste0(k, "_", tnames[k])
    }

    for (w in 1:numwords) {
      keywords_df <- rbind(keywords_df, data.frame(Topic = topicname, Word = words[w]))
    }
  }
  keywords_df <- keywords_df[2:nrow(keywords_df), ]
  keywords_df$Topic <- factor(keywords_df$Topic, levels = unique(keywords_df$Topic))

  temp <- dplyr::right_join(data, keywords_df, by = "Word") %>%
    dplyr::group_by(.data$Topic) %>%
    dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
    dplyr::mutate(Ranking  =  1:(dplyr::n())) %>%
    dplyr::arrange(.data$Topic, .data$Ranking)

  # Visualize
  visualize_keywords <-
    ggplot(temp, aes(x = .data$Ranking, y = .data$`Proportion(%)`, colour = .data$Topic)) +
      geom_line() +
      geom_point() +
      ggrepel::geom_label_repel(aes(label = .data$Word), size = label_size,
                       box.padding = 0.20, label.padding = 0.12,
                       arrow = arrow(angle = 10, length = unit(0.10, "inches"),
                                   ends = "last", type = "closed"),
                       show.legend = FALSE) +
      scale_x_continuous(breaks = 1:max_num_words) +
      xlab("Ranking") + ylab("Proportion (%)") +
      theme_bw()

  keyATM_viz <- list(figure = visualize_keywords, values = temp, keywords = keywords)
  class(keyATM_viz) <- c("keyATM_fig", "keyATM_viz", class(keyATM_viz))

  return(keyATM_viz)
}


check_keywords <- function(unique_words, keywords, prune)
{
  # Prune keywords that do not appear in the corpus
  keywords_flat <- unlist(keywords, use.names = FALSE, recursive = FALSE)
  non_existent <- keywords_flat[!keywords_flat %in% unique_words]

  if (prune) {
    # Prune keywords
    if (length(non_existent) != 0) {
      cli::cli_warn(
        "{?A keyword/Keywords} {?is/are} pruned because {?it does/they do} not appear in the documents: {non_existent}",
        immediate. = TRUE
      )
    }

    keywords <- lapply(keywords, function(x) {x[!x %in% non_existent]})

  } else {

    # Raise error
    if (length(non_existent) != 0) {
      cli::cli_abort("{?A keyword/Keywords} {?is/are} not found in the documents: {non_existent}", immediate. = TRUE)
    }

  }

  # Check there is at least one keywords in each topic
  num_keywords <- unlist(lapply(keywords, length))
  check_zero <- which(as.vector(num_keywords) == 0)

  if (length(check_zero) != 0) {
    zero_names <- names(keywords)[check_zero]
    cli::cli_abort(paste0("All keywords are pruned. Please check: ", paste(zero_names, collapse = ", ")))
  }

  return(keywords)
}


get_doc_index <- function(W_raw, check = FALSE)
{
  len <- lapply(W_raw, length) %>% unlist(use.names = FALSE)
  index <- 1:length(W_raw)
  nonzero_index <- index[index[len != 0]]
  zero_index <- as.character(index[index[len == 0]])  # as.character to use `cli::cli_warn`
  if (length(zero_index) != 0) {
    if (check) {
      cli::cli_warn(c(
        "There {?is a/are} document{?s} with length 0 length. {?Index/Indexes} to check: {zero_index}",
        "i" = paste0(
          "This may cause invalid covariates or time indexes. Please review the preprocessing steps."
        )), immediate. = TRUE)
    } else {
      cli::cli_warn("There {?is a/are}  document{?s} dropped because of 0 length. {?Index/Indexes} to check: {zero_index}", immediate. = TRUE)
    }
  }
  return(nonzero_index)
}


#' Initialize a keyATM model
#'
#' keyATM_initialize is wrapped by keyATM() and weightedLDA()
#' @keywords internal
keyATM_initialize <- function(docs, model, no_keyword_topics,
                       keywords = list(), model_settings = list(),
                       priors = list(), options = list())
{
  ##
  ## Check
  ##

  # Check type
  check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
  if (!is.integer(no_keyword_topics) & !is.numeric(no_keyword_topics))
    cli::cli_abort("`no_keyword_topics` is neigher numeric nor integer.")

  no_keyword_topics <- as.integer(no_keyword_topics)

  if (!model %in% c("base", "cov", "hmm", "lda", "ldacov", "ldahmm")) {
    cli::cli_abort("Please select a correct model.")
  }

  info <- list(
                models_keyATM = c("base", "cov", "hmm"),
                models_lda = c("lda", "ldacov", "ldahmm")
              )
  keywords <- check_arg(keywords, "keywords", model, info)

  # Get Info
  if (is.null(docs$doc_index)) {
    info$use_doc_index <- get_doc_index(docs$W_raw)
  } else {
    info$use_doc_index <- docs$doc_index
    if (length(docs$doc_index) != length(docs$W_raw))
      cli::cli_warn("Some documents have 0 length. Please review the preprocessing steps.")
  }
  docs$W_raw <- docs$W_raw[info$use_doc_index]
  info$num_doc <- length(docs$W_raw)
  info$keyword_k <- length(keywords)
  info$total_k <- length(keywords) + no_keyword_topics

  # Set default values
  model_settings <- check_arg(model_settings, "model_settings", model, info)
  priors <- check_arg(priors, "priors", model, info)
  options <- check_arg(options, "options", model, info)
  info$parallel_init <- options$parallel_init

  if (info$parallel_init) {
    cli::cli_alert_info("Parallel initialization is enabled. Please make sure to specify the {.fn future::plan} before calling the {.fn keyATM}.")
  }

  ##
  ## Initialization
  ##
  cli::cli_progress_step("Initializing the model", spinner = TRUE)
  set.seed(options$seed)

  # W
  if (is.null(docs$wd_names)) {
    info$wd_names <- unique(unlist(docs$W_raw, use.names = FALSE, recursive = FALSE))
    check_vocabulary(info$wd_names)
  } else {
    info$wd_names <- docs$wd_names
  }

  info$wd_map <- myhashmap(info$wd_names, 1:length(info$wd_names) - 1L)

  if (info$parallel_init) {
    W <- future.apply::future_lapply(docs$W_raw, function(x) { myhashmap_getvec(info$wd_map, x) }, future.seed = TRUE)
  } else {
    W <- lapply(docs$W_raw, function(x) { myhashmap_getvec(info$wd_map, x) })
  }

  # Check keywords
  keywords <- check_keywords(info$wd_names, keywords, options$prune)

  keywords_raw <- keywords  # keep raw keywords (not word_id)
  keywords_id <- lapply(keywords, function(x) { myhashmap_getvec(info$wd_map, x) })
  info$keywords_id <- unlist(keywords_id, use.names = FALSE, recursive = FALSE)

  # Assign S and Z
  if (model %in% info$models_keyATM) {
    res <- make_sz_key(W, keywords, info)
    S <- res$S
    Z <- res$Z
  } else {
    # LDA based models
    res <- make_sz_lda(W, info)
    S <- res$S
    Z <- res$Z
  }
  rm(res)

  # Organize
  stored_values <- list(vocab_weights = rep(-1, length(info$wd_names)),
                        doc_index = info$use_doc_index, keyATMdoc_meta = docs[-c(1, 3)])

  if (model %in% c("base", "lda")) {
    if (options$estimate_alpha)
      stored_values$alpha_iter <- list()
  }

  if (model %in% c("hmm", "ldahmm")) {
    options$estimate_alpha <- 1
    stored_values$alpha_iter <- list()
  }

  if (model %in% c("cov", "ldacov")) {
    stored_values$Lambda_iter <- list()
  }

  if (model %in% c("hmm", "ldahmm")) {
    stored_values$R_iter <- list()

    if (options$store_transition_matrix) {
      stored_values$P_iter <- list()
    } else {
      stored_values$P_last <- list()  # for resume
    }
  }

  if (model %in% info$models_keyATM) {
    if (options$store_pi)
      stored_values$pi_vectors <- list()
  }

  if (options$store_theta) {
    stored_values$Z_tables <- list()
    stored_values$theta_PG <- list()
  }

  key_model <- list(
    W = W, Z = Z, S = S,
    model = abb_model_name(model),
    keywords = keywords_id, keywords_raw = keywords_raw,
    no_keyword_topics = no_keyword_topics,
    keyword_k = length(keywords_raw),
    vocab = info$wd_names,
    model_settings = model_settings, priors = priors, options = options,
    stored_values = stored_values, model_fit = list(), call = match.call()
  )

  rm(info)
  class(key_model) <- c("keyATM_model", model, class(key_model))

  keyATM_initialized <- list(model = key_model, model_name = model)
  class(keyATM_initialized) <- c("keyATM_initialized", class(keyATM_initialized))
  return(keyATM_initialized)
}

keyATM_fit <- function(keyATM_initialized, resume = FALSE)
{
  if (! ("keyATM_initialized" %in% class(keyATM_initialized) | "keyATM_resume" %in% class(keyATM_initialized)))
    cli::cli_abort("The input is not an initialized object.")

  key_model <- keyATM_initialized$model
  model_name <- keyATM_initialized$model_name
  iterations <- key_model$options$iter_new

  if ("keyATM_resume" %in% class(keyATM_initialized)) {
    .GlobalEnv$.Random.seed <- keyATM_initialized$rand_state  # to resume
    cli::cli_progress_step("Fitting the model: adding {iterations} iteration{?s} to the existing {key_model$options$iterations - iterations} iteration{?s} for a total of {key_model$options$iterations} iteration{?s}", spinner = TRUE)
  } else {
    set.seed(key_model$options$seed)
    cli::cli_progress_step("Fitting the model: {iterations} iteration{?s}", spinner = TRUE)
  }

  if (model_name == "base") {
    key_model <- keyATM_fit_base(key_model, resume = resume)
  } else if (model_name == "hmm") {
    key_model <- keyATM_fit_HMM(key_model, resume = resume)
  } else if (model_name == "lda") {
    key_model <- keyATM_fit_LDA(key_model, resume = resume)
  } else if (model_name == "ldacov") {
    key_model <- keyATM_fit_LDAcov(key_model, resume = resume)
  } else if (model_name == "ldahmm") {
    key_model <- keyATM_fit_LDAHMM(key_model, resume = resume)
  } else if (model_name == "cov" & key_model$model_settings$covariates_model == "PG") {
    key_model <- keyATM_fit_covPG(key_model,  resume = resume)
  } else if (model_name == "cov" & key_model$model_settings$covariates_model == "DirMulti") {
    key_model <- keyATM_fit_cov(key_model, resume = resume)
  } else {
    cli::cli_abort("Please check {.var model}.")
  }
  if (! "keyATM_fitted" %in% class(key_model))
    class(key_model) <- c("keyATM_fitted", class(key_model))

  return(key_model)
}


#' @noRd
#' @export
print.keyATM_model <- function(x, ...)
{
  cli::cli_inform("keyATM_model object for the {x$model} model.")
}


check_arg <- function(obj, name, model, info = list())
{
  if (name == "keywords") {
    return(check_arg_keywords(obj, model, info))
  }
  if (name == "model_settings") {
    return(check_arg_model_settings(obj, model, info))
  }
  if (name == "priors") {
    return(check_arg_priors(obj, model, info))
  }
  if (name == "options") {
    return(check_arg_options(obj, model, info))
  }
  if (name == "vb_options") {
    return(check_arg_vboptions(obj, model, info))
  }
}


check_arg_keywords <- function(keywords, model, info)
{
  check_arg_type(keywords, "list")

  if (length(keywords) == 0 & model %in% info$models_keyATM) {
    cli::cli_abort("Please provide keywords.")
  }

  if (length(keywords) != 0 & model %in% info$models_lda) {
    cli::cli_abort("This model does not take keywords.")
  }

  # Name of keywords topic
  if (model %in% info$models_keyATM) {
    c <- lapply(keywords, function(x) {check_arg_type(x, "character")})
    if (is.null(names(keywords))) {
      names(keywords)  <- paste0(1:length(keywords))
    } else {
      names(keywords)  <- paste0(1:length(keywords), "_", names(keywords))
    }
  }

  return(keywords)
}


show_unused_arguments <- function(obj, name, allowed_arguments)
{
  unused_input <- names(obj)[! names(obj) %in% allowed_arguments]
  if (length(unused_input) != 0)
    cli::cli_abort(paste0(
                "keyATM doesn't recognize some of the arguments ",
                "in ", name, ": ",
                paste(unused_input, collapse = ", ")
               )
        )
}


check_arg_model_settings <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  allowed_arguments <- c()

  if (model %in% c("base", "lda", "hmm", "ldahmm")) {
    # Slice Sampling Settings
    if (is.null(obj$slice_min)) {
      obj$slice_min <- 1e-9
    } else {
      if (!is.numeric(obj$slice_min)) {
        cli::cli_abort("`model_settings$slice_min` should be a numeric value.")
      }
      if (obj$slice_min <= 0) {
        cli::cli_abort("`model_settings$slice_min` should be a positive value.")
      }
    }

    if (is.null(obj$slice_max)) {
      obj$slice_max <- 100
    } else {
      if (!is.numeric(obj$slice_max)) {
        cli::cli_abort("`model_settings$slice_max` should be a numeric value.")
      }
      if (obj$slice_max <= 0) {
        cli::cli_abort("`model_settings$slice_max` should be a positive value.")
      }
    }
    allowed_arguments <- c(allowed_arguments, "slice_min", "slice_max")
  }

  # check model settings for covariate model
  if (model %in% c("cov", "ldacov")) {
     if (is.null(obj$covariates_data)) {
      cli::cli_abort("Please provide `obj$covariates_data`.")
    }

    obj$covariates_data <- as.data.frame(obj$covariates_data)[info$use_doc_index, , drop = FALSE]

    if (nrow(obj$covariates_data) != info$num_doc) {
      cli::cli_abort("The row of `model_settings$covariates_data` should be the same as the number of documents.")
    }

    if (sum(is.na(obj$covariates_data)) != 0) {
      cli::cli_abort("Covariate data should not contain missing values.")
    }

    if (is.null(obj$covariates_formula)) {
      obj$covariates_formula <- NULL  # do not need to change the matrix
    }

    if (is.null(obj$standardize))
      obj$standardize <- "non-factor"
    if (!obj$standardize %in% c("all", "none", "non-factor"))
      cli::cli_abort('Unknown option in `standardize`. It should be one of "all", "none", or "non-factor".')

    # Standardize
    obj$covariates_data_use <- covariates_standardize(obj$covariates_data, obj$standardize, obj$covariates_formula)

    # Check if it works as a valid regression
    temp <- as.data.frame(obj$covariates_data_use)
    temp$y <- stats::rnorm(nrow(obj$covariates_data_use))

    if ("(Intercept)" %in% colnames(obj$covariates_data_use)) {
      fit <- stats::lm(y ~ 0 + ., data = temp)  # data.frame already includes the intercept
      if (NA %in% fit$coefficients) {
        cli::cli_abort("Covariates are invalid.")
      }
    } else {
      fit <- stats::lm(y ~ ., data = temp)  # data.frame does not have an itercept
      if (NA %in% fit$coefficients) {
        cli::cli_abort("Covariates are invalid.")
      }
    }

    # Slice Sampling Settings
    if (is.null(obj$slice_min)) {
      obj$slice_min <- -5.0
    } else {
      if (!is.numeric(obj$slice_min)) {
        cli::cli_abort("`model_settings$slice_min` should be a numeric value.")
      }
    }

    if (is.null(obj$slice_max)) {
      obj$slice_max <- 5.0
    } else {
      if (!is.numeric(obj$slice_max)) {
        cli::cli_abort("`model_settings$slice_max` should be a numeric value.")
      }
    }

    # MH option
    if (is.null(obj$mh_use)) {
      obj$mh_use <- 0
    } else {
      obj$mh_use <- as.integer(obj$mh_use)
      if (!obj$mh_use %in% c(0, 1)) {
        cli::cli_abort("`model_settings$mh_use` should be TRUE/FALSE (0/1)")
      }
    }

    # Model
    if (is.null(obj$covariates_model)) {
      if (model == "cov")
        obj$covariates_model <- "DirMulti"
      if (model == "ldacov")
        obj$covariates_model <- "DirMulti"
    }

    if (obj$covariates_model != "DirMulti" & model == "ldacov")
      cli::cli_abort("Use Diricule-Multinomial model for LDA covariates.")

    if (!obj$covariates_model %in% c("PG", "DirMulti"))
      cli::cli_abort("Undefined model. `covariates_model` option in `model_settings` take `PG` or `DirMulti`.")

    if (obj$covariates_model == "PG") {
      K <- info$total_k
      X <- obj$covariates_data_use
      M <- ncol(X)  # Number of covariates
      D <- nrow(X)  # Number of douments

      Sigma_Lambda <- diag(rep(1, K - 1))
      if (M == 1) {
        obj$PG_params$PG_Lambda <- t(MASS::mvrnorm(n = M, mu = rep(0, K-1), Sigma = Sigma_Lambda))
      } else {
        obj$PG_params$PG_Lambda <- MASS::mvrnorm(n = M, mu = rep(0, K-1), Sigma = Sigma_Lambda)
      }
      obj$PG_params$PG_SigmaPhi <- diag(rep(1, K-1))
      Mu <- X %*% obj$PG_params$PG_Lambda
      Sigma <- diag(rep(1, K-1))
      Phi <- sapply(1:D,
             function(d) {
                  Phi_d <- rmvn1(mu = Mu[d, ], Sigma = Sigma)
             })
      Phi <- t(Phi)
      colnames(Phi) <- NULL
      row.names(Phi) <- NULL
      obj$PG_params$PG_Phi <- Phi  # D \times (K - 1)
      obj$PG_params$theta_tilda <- exp(Phi) / (1 + exp(Phi))
      obj$PG_params$theta_last <- matrix(rep(0, D*K), nrow = D, ncol = K)
      obj$PG_params$Lambda_list <- list()
      obj$PG_params$Sigma_list <- list()
    }

    allowed_arguments <- c(allowed_arguments, "covariates_data", "covariates_data_use",
                           "slice_min", "slice_max", "mh_use", "covariates_model", "PG_params",
                           "covariates_formula", "standardize", "info")
  }  # cov model end

  # check model settings for dynamic model
  if (model %in% c("hmm", "ldahmm")) {
    if (is.null(obj$num_states)) {
      cli::cli_abort("`model_settings$num_states` is not provided.")
    }

    if (is.null(obj$time_index)) {
      cli::cli_abort("`model_settings$time_index` is not provided.")
    }

    obj$time_index <- obj$time_index[info$use_doc_index]

    if (length(obj$time_index) != info$num_doc) {
      cli::cli_abort("The length of the `model_settings$time_index` does not match with the number of documents.")
    }

    if (min(obj$time_index) != 1 | max(obj$time_index) > info$num_doc) {
      cli::cli_abort("`model_settings$time_index` should start from 1 and not exceed the number of documents.")
    }

    if (max(obj$time_index) < obj$num_states)
      cli::cli_abort("`model_settings$num_states` should not exceed the maximum of `model_settings$time_index`.")

    check <- unique(obj$time_index[2:length(obj$time_index)] - obj$time_index[1:(length(obj$time_index)-1)])
    if (sum(!unique(check) %in% c(0,1)) != 0)
      cli::cli_abort("`model_settings$time_index` does not increment by 1.")

    obj$time_index <- as.integer(obj$time_index)

    allowed_arguments <- c(allowed_arguments, "num_states", "time_index")
  }

  show_unused_arguments(obj, "`model_settings`", allowed_arguments)
  return(obj)
}


check_arg_priors <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  # Base arguments
  allowed_arguments <- c("beta", "eta_1", "eta_2", "eta_1_regular", "eta_2_regular")

  # prior of pi
  if (model %in% info$models_keyATM) {
    if (is.null(obj$gamma)) {
      obj$gamma <- matrix(1.0, nrow = info$total_k, ncol = 2)
    }

    if (!is.null(obj$gamma)) {
      if (dim(obj$gamma)[1] != info$total_k)
        cli::cli_abort("Check the dimension of `priors$gamma`")
      if (dim(obj$gamma)[2] != 2)
        cli::cli_abort("Check the dimension of `priors$gamma`")
    }

    if (info$keyword_k < info$total_k) {
      # Regular topics are used in keyATM models
      # Priors for regular topics should be 0
      if (sum(obj$gamma[(info$keyword_k+1):info$total_k, ]) != 0) {
        obj$gamma[(info$keyword_k+1):info$total_k, ] <- 0
      }
    }
    allowed_arguments <- c(allowed_arguments, "gamma")
  }

  # beta
  if (!"beta" %in% names(obj)) {
    obj$beta <- 0.01
  }

  if (model %in% info$models_keyATM) {  # models with keywords
    if (!"beta_s" %in% names(obj)) {
      obj$beta_s <- 0.1
    }
    allowed_arguments <- c(allowed_arguments, "beta_s")
  }

  # alpha
  if (model %in% c("base", "lda")) {
    if (is.null(obj$alpha)) {
      obj$alpha <- rep(1/info$total_k, info$total_k)
    }
    if (length(obj$alpha) != info$total_k) {
      cli::cli_abort("Starting alpha must be a vector of length ", info$total_k)
    }
    allowed_arguments <- c(allowed_arguments, "alpha")
  }

  # eta
  if (!"eta_1" %in% names(obj)) {
    obj$eta_1 <- 1.0
  }
  if (!"eta_2" %in% names(obj)) {
    obj$eta_2 <- 1.0
  }
  if (!"eta_1_regular" %in% names(obj)) {
    obj$eta_1_regular <- 2.0
  }
  if (!"eta_2_regular" %in% names(obj)) {
    obj$eta_2_regular <- 1.0
  }

  show_unused_arguments(obj, "`priors`", allowed_arguments)
  return(obj)
}


check_arg_options <- function(obj, model, info)
{
  check_arg_type(obj, "list")
  allowed_arguments <- c("seed", "llk_per", "thinning",
                         "iterations", "iter_new", "verbose",
                         "use_weights", "weights_type",
                         "prune", "store_theta", "slice_shape",
                         "parallel_init", "resume")

  # llk_per
  if (is.null(obj$llk_per))
    obj$llk_per <- 10L

  if (!is.numeric(obj$llk_per) | obj$llk_per < 0 | obj$llk_per%%1 != 0) {
      cli::cli_abort("An invalid value in `options$llk_per`")
  }

  # verbose
  if (is.null(obj$verbose)) {
    obj$verbose <- 0L
  } else {
    obj$verbose <- as.integer(obj$verbose)
    if (!obj$verbose %in% c(0, 1)) {
      cli::cli_abort("An invalid value in `options$verbose`")
    }
  }

  # thinning
  if (is.null(obj$thinning))
    obj$thinning <- 5L

  if (!is.numeric(obj$thinning) | obj$thinning < 0| obj$thinning%%1 != 0) {
      cli::cli_abort("An invalid value in `options$thinning`")
  }

  # iterations
  if (is.null(obj$iterations)) {
    obj$iterations <- 1500L
  }
  obj$iter_new <- obj$iterations  # this is initialization so new iteration = total number of iterations
  if (!is.numeric(obj$iterations) | obj$iterations < 0 | obj$iterations %% 1 != 0) {
      cli::cli_abort("An invalid value in `options$iterations`")
  }

  # Store theta
  if (is.null(obj$store_theta)) {
    obj$store_theta <- 0L
  } else {
    obj$store_theta <- as.integer(obj$store_theta)
    if (!obj$store_theta %in% c(0, 1)) {
      cli::cli_abort("An invalid value in `options$store_theta`")
    }
  }

  # Store pi
  if (model %in% info$models_keyATM) {
    if (is.null(obj$store_pi)) {
      obj$store_pi <- 0L
    } else {
      obj$store_pi <- as.integer(obj$store_pi)
      if (!obj$store_pi %in% c(0, 1)) {
        cli::cli_abort("An invalid value in `options$store_theta`")
      }
    }
    allowed_arguments <- c(allowed_arguments, "store_pi")
  }

  # Estimate alpha
  if (model %in% c("base", "lda")) {
    if (is.null(obj$estimate_alpha)) {
      obj$estimate_alpha <- 1L
    } else {
      obj$estimate_alpha <- as.integer(obj$estimate_alpha)
      if (!obj$estimate_alpha %in% c(0, 1)) {
        cli::cli_abort("An invalid value in `options$estimate_alpha`")
      }
    }
    allowed_arguments <- c(allowed_arguments, "estimate_alpha")
  }

  # Slice shape
  if (is.null(obj$slice_shape)) {
    # parameter for slice sampling
    obj$slice_shape <- 1.2
  }
  if (!is.numeric(obj$slice_shape) | obj$slice_shape < 0) {
      cli::cli_abort("An invalid value in `options$slice_shape`")
  }

  # Use weights
  if (is.null(obj$use_weights)) {
    obj$use_weights <- 1L
  } else {
    obj$use_weights <- as.integer(obj$use_weights)
    if (!obj$use_weights %in% c(0, 1)) {
      cli::cli_abort("An invalid value in `options$use_weights`")
    }
  }

  # Type of the weights
  if (is.null(obj$weights_type)) {
    obj$weights_type <- "information-theory"
  } else {
    if (!obj$weights_type %in% c("information-theory", "information-theory-normalized",
                                 "inv-freq", "inv-freq-normalized"))
    {
      cli::cli_abort("An invalid value in `options$weights_type`")
    }
  }

  # Prune keywords
  if (is.null(obj$prune)) {
    obj$prune <- 1L
  } else {
    obj$prune <- as.integer(obj$prune)
    if (!obj$prune %in% c(0, 1)) {
      cli::cli_abort("An invalid value in `options$prune`")
    }
  }

  # Store transition matrix in Dynamic models
  if (model %in% c("hmm", "ldahmm")) {
    if (is.null(obj$store_transition_matrix)) {
      obj$store_transition_matrix <- 0L
    }
    if (!obj$store_transition_matrix %in% c(0, 1)) {
      cli::cli_abort("An invalid value in `options$store_transition_matrix`")
    }
    allowed_arguments <- c(allowed_arguments, "store_transition_matrix")
  }

  # Use parallel function in initialization
  if (!"parallel_init" %in% names(obj)) {
    obj$parallel_init <- FALSE
  } else {
    if (!obj$parallel_init %in% c(0, 1, FALSE, TRUE)) {
      cli::cli_abort("`obj$parallel_init` should be TRUE/FALSE")
    }
  }

  # Use parallel function in initialization
  if (!"resume" %in% names(obj)) {
    obj$resume <- ""
  }

  # Check unused arguments
  show_unused_arguments(obj, "`options`", allowed_arguments)
  return(obj)
}


check_vocabulary <- function(vocab)
{
  if (" " %in% vocab) {
    cli::cli_abort("A space is recognized as a vocabulary. Please remove an empty document or consider using quanteda::dfm.")
  }

  if ("" %in% vocab) {
    cli::cli_abort('A blank `""` is recognized as a vocabulary. Please review preprocessing steps.')
  }

  if (sum(stringr::str_detect(vocab, "^[:upper:]+$")) != 0) {
    cli::cli_warn("Upper case letters are used. Please review preprocessing steps.", immediate. = TRUE)
  }

  if (sum(stringr::str_detect(vocab, "\t")) != 0) {
    cli::cli_warn("Tab is detected in the vocabulary. Please review preprocessing steps.", immediate. = TRUE)
  }

  if (sum(stringr::str_detect(vocab, "\n")) != 0) {
    cli::cli_warn("A line break is detected in the vocabulary. Please review preprocessing steps.", immediate. = TRUE)
  }
}


make_sz_key <- function(W, keywords, info)
{
  # zs_assigner maps keywords to category ids
  key_wdids <- info$keywords_id
  cat_ids <- rep(1:(info$keyword_k) - 1L, unlist(lapply(keywords, length)))

  if (length(key_wdids) == length(unique(key_wdids))) {
    #
    # No keyword appears more than once
    #
    zs_assigner <- myhashmap_keyint(as.integer(key_wdids), as.integer(cat_ids))

    # if the word is a keyword, assign the appropriate (0 start) Z, else a random Z
    topicvec <- 1:(info$total_k) - 1L
    make_z <- function(s, topicvec) {
      zz <- myhashmap_getvec_keyint(zs_assigner, s)
      zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
      return(zz)
    }

  } else {
    #
    # Some keywords appear multiple times
    #
    keys_df <- data.frame(wid = key_wdids, cat = cat_ids)
    keys_char <- sapply(unique(key_wdids),
                        function(x) {
                          paste(as.character(keys_df[keys_df$wid == x, "cat"]), collapse=",")
                        })
    zs_hashtable <- myhashmap_keyint(as.integer(unique(key_wdids)), keys_char)

    zs_assigner <- function(s) {
      topic <- myhashmap_getvec_keyint(zs_hashtable, s)
      topic <- strsplit(as.character(topic), split=",")  # this function should take a character vector
      topic <- lapply(topic, sample, 1)
      topic <- as.integer(unlist(topic, use.names = FALSE))
      return(topic)
    }

    # if the word is a seed, assign the appropriate (0 start) Z, else a random Z
    topicvec <- 1:(info$total_k) - 1L
    make_z <- function(s, topicvec) {
      zz <- zs_assigner(s) # if it is a seed word, we already know the topic
      zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
      return(zz)
    }
  }

  ## ss indicates whether the word comes from a seed topic-word distribution or not
  make_s <- function(s) {
    key <- as.numeric(s %in% key_wdids) # 1 if they're a seed
    # Use s structure
    s[key == 0] <- 0L # non-keyword words have s = 0
    s[key == 1] <- sample(0:1, length(s[key == 1]), prob = c(0.3, 0.7), replace = TRUE)
      # keywords have x = 1 probabilistically
    return(s)
  }

  if (info$parallel_init) {
    S <- future.apply::future_lapply(W, make_s, future.seed = TRUE)
    Z <- future.apply::future_lapply(W, make_z, topicvec, future.seed = TRUE)
  } else {
    S <- lapply(W, make_s)
    Z <- lapply(W, make_z, topicvec)
  }

  return(list(S = S, Z = Z))
}


make_sz_lda <- function(W, info)
{
  topicvec <- 1:(info$total_k) - 1L
  make_z <- function(x, topicvec) {
    zz <- sample(topicvec,
                 length(x),
                 replace = TRUE)
    return(as.integer(zz))
  }

  if (info$parallel_init) {
    Z <- future.apply::future_lapply(W, make_z, topicvec, future.seed = TRUE)
  } else {
    Z <- lapply(W, make_z, topicvec)
  }

  return(list(S = list(), Z = Z))
}
keyATM/keyATM documentation built on June 15, 2025, 7:16 a.m.