#' Read texts
#'
#' Read texts and create a \code{keyATM_docs} object, which is a list of texts.
#'
#'
#' @param texts input. keyATM takes a quanteda dfm (dgCMatrix), data.frame, \pkg{tibble} tbl_df, or a vector of file paths.
#' @param encoding character. Only used when \code{texts} is a vector of file paths. Default is \code{UTF-8}.
#' @param check logical. If \code{TRUE}, check whether there is anything wrong with the structure of texts. Default is \code{TRUE}.
#' @param keep_docnames logical. If \code{TRUE}, it keeps the document names in a quanteda dfm. Default is \code{FALSE}.
#' @param split numeric. This option works only with a quanteda dfm. It creates a two subset of the dfm by randomly splitting each document (i.e., the total number of documents is the same between two subsets). This option specifies the split proportion. Default is \code{0}.
#'
#' @return a keyATM_docs object. The first element is a list whose elements are split texts. The length of the list equals to the number of documents.
#'
#' @examples
#' \dontrun{
#' # Use quanteda dfm
#' keyATM_docs <- keyATM_read(texts = quanteda_dfm)
#'
#' # Use data.frame or tibble (texts should be stored in a column named `text`)
#' keyATM_docs <- keyATM_read(texts = data_frame_object)
#' keyATM_docs <- keyATM_read(texts = tibble_object)
#'
#' # Use a vector that stores full paths to the text files
#' files <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
#' keyATM_docs <- keyATM_read(texts = files)
#'
#' }
#' @import magrittr
#' @importFrom rlang .data
#' @export
keyATM_read <- function(texts, encoding = "UTF-8", check = TRUE, keep_docnames = FALSE, split = 0)
{
# Detect input
if ("tbl" %in% class(texts)) {
cli::cli_alert_info("Using tibble.")
text_dfm <- NULL
files <- NULL
text_df <- texts
} else if ("data.frame" %in% class(texts)) {
cli::cli_alert_info("Using data.frame.")
text_dfm <- NULL
files <- NULL
text_df <- tibble::as_tibble(texts)
} else if ("dfm" %in% class(texts) | "dgCMatrix" %in% class(texts)) {
cli::cli_alert_info("Using quanteda dfm.")
text_dfm <- texts
files <- NULL
text_df <- NULL
} else if ("character" %in% class(texts)) {
cli::cli_alert_info("Reading from files. Please make sure files are preprocessed. Encoding: {encoding}.")
text_dfm <- NULL
files <- texts
text_df <- NULL
} else {
cli::cli_abort(c("x" = "Check `texts` argument.",
"i" = "It can take quanteda dfm, data.frame, tibble, and a vector of characters."))
}
if (split != 0) {
if (split < 0 | split >= 1)
cli::cli_abort("Invalid option. `split` should be a proportion.")
}
# Read texts
# If you have quanteda object
docnames <- NULL
W_read <- list(W_raw = list(), W_split = list())
if (!is.null(text_dfm)) {
vocabulary <- colnames(text_dfm)
W_read <- read_dfm_cpp(text_dfm, W_read, vocabulary, split)
if (keep_docnames) {
docnames <- quanteda::docnames(text_dfm)
}
} else {
# Preprocess each text
# Use files <- list.files(doc_folder, pattern = "txt", full.names = TRUE) when you pass
if (is.null(text_df)) {
text_df <- tibble::tibble(text = unlist(lapply(files,
function(x)
{
paste0(readLines(x, encoding = encoding),
collapse = "\n")
})))
}
text_df <- text_df %>% dplyr::mutate(text_split = stringr::str_split(.data$text, pattern = " "))
# Extract split text and create a list
W_read$W_raw <- text_df %>% dplyr::pull(tidyselect::all_of("text_split"))
}
W_raw <- W_read$W_raw
# Check whether there is nothing wrong with the structure of texts
if (check) {
wd_names <- unique(unlist(W_raw, use.names = FALSE, recursive = FALSE))
check_vocabulary(wd_names)
doc_index <- get_doc_index(W_raw, check = TRUE)
} else {
doc_index <- NULL
wd_names <- NULL
}
W_split <- list(W_raw = W_read$W_split)
class(W_split) <- c("keyATM_docs", class(W_split))
W_data <- list(W_raw = W_raw, W_split = W_split, doc_index = doc_index,
wd_names = wd_names, docnames = docnames)
class(W_data) <- c("keyATM_docs", class(W_data))
return(W_data)
}
#' @noRd
#' @export
print.keyATM_docs <- function(x, ...)
{
cli::cli_inform("keyATM_docs object of {length(x$W_raw)} documents.")
}
#' @noRd
#' @export
summary.keyATM_docs <- function(object, ...)
{
doc_len <- sapply(object$W_raw, length)
cli::cli_inform("keyATM_docs object of {length(object$W_raw)} documents.")
ul <- cli::cli_ul()
cli::cli_li("Average (min/max) document length: {round(mean(doc_len), 2)} ({min(doc_len)}/{max(doc_len)}) words")
cli::cli_li("Number of unique words: {length(object$wd_names)}")
cli::cli_end(ul)
}
#' Visualize keywords
#'
#' Visualize the proportion of keywords in the documents.
#'
#' @param docs a keyATM_docs object, generated by \code{keyATM_read()} function
#' @param keywords a list of keywords
#' @param prune logical. If \code{TRUE}, prune keywords that do not appear in `docs`. Default is \code{TRUE}.
#' @param label_size the size of keyword labels in the output plot. Default is \code{3.2}.
#' @return keyATM_fig object
#' @examples
#' \dontrun{
#' # Prepare a keyATM_docs object
#' keyATM_docs <- keyATM_read(input)
#'
#' # Keywords are in a list
#' keywords <- list(Education = c("education", "child", "student"),
#' Health = c("public", "health", "program"))
#'
#' # Visualize keywords
#' keyATM_viz <- visualize_keywords(keyATM_docs, keywords)
#'
#' # View a figure
#' keyATM_viz
#'
#' # Save a figure
#' save_fig(keyATM_viz, filename)
#' }
#' @import magrittr
#' @import ggplot2
#' @importFrom rlang .data
#' @seealso [save_fig()]
#' @export
visualize_keywords <- function(docs, keywords, prune = TRUE, label_size = 3.2)
{
# Check type
check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
check_arg_type(keywords, "list")
c <- lapply(keywords, function(x){check_arg_type(x, "character")})
unlisted <- unlist(docs$W_raw, recursive = FALSE, use.names = FALSE)
# Check keywords
keywords <- check_keywords(unique(unlisted), keywords, prune)
# Organize data
unnested_data <- tibble::tibble(text_split = unlisted)
totalwords <- nrow(unnested_data)
data <- unnested_data %>%
dplyr::rename(Word = "text_split") %>%
dplyr::group_by(.data$Word) %>%
dplyr::summarize(WordCount = dplyr::n()) %>%
dplyr::mutate(`Proportion(%)` = round(.data$WordCount / totalwords * 100, 3)) %>%
dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
dplyr::mutate(Ranking = 1:(dplyr::n()))
keywords <- lapply(keywords, function(x) {unlist(strsplit(x, " "))})
ext_k <- length(keywords)
max_num_words <- max(unlist(lapply(keywords, function(x) {length(x)}), use.names = FALSE))
# Make keywords_df
keywords_df <- data.frame(Topic = 1, Word = 1)
tnames <- names(keywords)
for (k in 1:ext_k) {
words <- keywords[[k]]
numwords <- length(words)
if (is.null(tnames)) {
topicname <- paste0("Topic", k)
} else {
topicname <- paste0(k, "_", tnames[k])
}
for (w in 1:numwords) {
keywords_df <- rbind(keywords_df, data.frame(Topic = topicname, Word = words[w]))
}
}
keywords_df <- keywords_df[2:nrow(keywords_df), ]
keywords_df$Topic <- factor(keywords_df$Topic, levels = unique(keywords_df$Topic))
temp <- dplyr::right_join(data, keywords_df, by = "Word") %>%
dplyr::group_by(.data$Topic) %>%
dplyr::arrange(dplyr::desc(.data$WordCount)) %>%
dplyr::mutate(Ranking = 1:(dplyr::n())) %>%
dplyr::arrange(.data$Topic, .data$Ranking)
# Visualize
visualize_keywords <-
ggplot(temp, aes(x = .data$Ranking, y = .data$`Proportion(%)`, colour = .data$Topic)) +
geom_line() +
geom_point() +
ggrepel::geom_label_repel(aes(label = .data$Word), size = label_size,
box.padding = 0.20, label.padding = 0.12,
arrow = arrow(angle = 10, length = unit(0.10, "inches"),
ends = "last", type = "closed"),
show.legend = FALSE) +
scale_x_continuous(breaks = 1:max_num_words) +
xlab("Ranking") + ylab("Proportion (%)") +
theme_bw()
keyATM_viz <- list(figure = visualize_keywords, values = temp, keywords = keywords)
class(keyATM_viz) <- c("keyATM_fig", "keyATM_viz", class(keyATM_viz))
return(keyATM_viz)
}
check_keywords <- function(unique_words, keywords, prune)
{
# Prune keywords that do not appear in the corpus
keywords_flat <- unlist(keywords, use.names = FALSE, recursive = FALSE)
non_existent <- keywords_flat[!keywords_flat %in% unique_words]
if (prune) {
# Prune keywords
if (length(non_existent) != 0) {
cli::cli_warn(
"{?A keyword/Keywords} {?is/are} pruned because {?it does/they do} not appear in the documents: {non_existent}",
immediate. = TRUE
)
}
keywords <- lapply(keywords, function(x) {x[!x %in% non_existent]})
} else {
# Raise error
if (length(non_existent) != 0) {
cli::cli_abort("{?A keyword/Keywords} {?is/are} not found in the documents: {non_existent}", immediate. = TRUE)
}
}
# Check there is at least one keywords in each topic
num_keywords <- unlist(lapply(keywords, length))
check_zero <- which(as.vector(num_keywords) == 0)
if (length(check_zero) != 0) {
zero_names <- names(keywords)[check_zero]
cli::cli_abort(paste0("All keywords are pruned. Please check: ", paste(zero_names, collapse = ", ")))
}
return(keywords)
}
get_doc_index <- function(W_raw, check = FALSE)
{
len <- lapply(W_raw, length) %>% unlist(use.names = FALSE)
index <- 1:length(W_raw)
nonzero_index <- index[index[len != 0]]
zero_index <- as.character(index[index[len == 0]]) # as.character to use `cli::cli_warn`
if (length(zero_index) != 0) {
if (check) {
cli::cli_warn(c(
"There {?is a/are} document{?s} with length 0 length. {?Index/Indexes} to check: {zero_index}",
"i" = paste0(
"This may cause invalid covariates or time indexes. Please review the preprocessing steps."
)), immediate. = TRUE)
} else {
cli::cli_warn("There {?is a/are} document{?s} dropped because of 0 length. {?Index/Indexes} to check: {zero_index}", immediate. = TRUE)
}
}
return(nonzero_index)
}
#' Initialize a keyATM model
#'
#' keyATM_initialize is wrapped by keyATM() and weightedLDA()
#' @keywords internal
keyATM_initialize <- function(docs, model, no_keyword_topics,
keywords = list(), model_settings = list(),
priors = list(), options = list())
{
##
## Check
##
# Check type
check_arg_type(docs, "keyATM_docs", "Please use `keyATM_read()` to read texts.")
if (!is.integer(no_keyword_topics) & !is.numeric(no_keyword_topics))
cli::cli_abort("`no_keyword_topics` is neigher numeric nor integer.")
no_keyword_topics <- as.integer(no_keyword_topics)
if (!model %in% c("base", "cov", "hmm", "lda", "ldacov", "ldahmm")) {
cli::cli_abort("Please select a correct model.")
}
info <- list(
models_keyATM = c("base", "cov", "hmm"),
models_lda = c("lda", "ldacov", "ldahmm")
)
keywords <- check_arg(keywords, "keywords", model, info)
# Get Info
if (is.null(docs$doc_index)) {
info$use_doc_index <- get_doc_index(docs$W_raw)
} else {
info$use_doc_index <- docs$doc_index
if (length(docs$doc_index) != length(docs$W_raw))
cli::cli_warn("Some documents have 0 length. Please review the preprocessing steps.")
}
docs$W_raw <- docs$W_raw[info$use_doc_index]
info$num_doc <- length(docs$W_raw)
info$keyword_k <- length(keywords)
info$total_k <- length(keywords) + no_keyword_topics
# Set default values
model_settings <- check_arg(model_settings, "model_settings", model, info)
priors <- check_arg(priors, "priors", model, info)
options <- check_arg(options, "options", model, info)
info$parallel_init <- options$parallel_init
if (info$parallel_init) {
cli::cli_alert_info("Parallel initialization is enabled. Please make sure to specify the {.fn future::plan} before calling the {.fn keyATM}.")
}
##
## Initialization
##
cli::cli_progress_step("Initializing the model", spinner = TRUE)
set.seed(options$seed)
# W
if (is.null(docs$wd_names)) {
info$wd_names <- unique(unlist(docs$W_raw, use.names = FALSE, recursive = FALSE))
check_vocabulary(info$wd_names)
} else {
info$wd_names <- docs$wd_names
}
info$wd_map <- myhashmap(info$wd_names, 1:length(info$wd_names) - 1L)
if (info$parallel_init) {
W <- future.apply::future_lapply(docs$W_raw, function(x) { myhashmap_getvec(info$wd_map, x) }, future.seed = TRUE)
} else {
W <- lapply(docs$W_raw, function(x) { myhashmap_getvec(info$wd_map, x) })
}
# Check keywords
keywords <- check_keywords(info$wd_names, keywords, options$prune)
keywords_raw <- keywords # keep raw keywords (not word_id)
keywords_id <- lapply(keywords, function(x) { myhashmap_getvec(info$wd_map, x) })
info$keywords_id <- unlist(keywords_id, use.names = FALSE, recursive = FALSE)
# Assign S and Z
if (model %in% info$models_keyATM) {
res <- make_sz_key(W, keywords, info)
S <- res$S
Z <- res$Z
} else {
# LDA based models
res <- make_sz_lda(W, info)
S <- res$S
Z <- res$Z
}
rm(res)
# Organize
stored_values <- list(vocab_weights = rep(-1, length(info$wd_names)),
doc_index = info$use_doc_index, keyATMdoc_meta = docs[-c(1, 3)])
if (model %in% c("base", "lda")) {
if (options$estimate_alpha)
stored_values$alpha_iter <- list()
}
if (model %in% c("hmm", "ldahmm")) {
options$estimate_alpha <- 1
stored_values$alpha_iter <- list()
}
if (model %in% c("cov", "ldacov")) {
stored_values$Lambda_iter <- list()
}
if (model %in% c("hmm", "ldahmm")) {
stored_values$R_iter <- list()
if (options$store_transition_matrix) {
stored_values$P_iter <- list()
} else {
stored_values$P_last <- list() # for resume
}
}
if (model %in% info$models_keyATM) {
if (options$store_pi)
stored_values$pi_vectors <- list()
}
if (options$store_theta) {
stored_values$Z_tables <- list()
stored_values$theta_PG <- list()
}
key_model <- list(
W = W, Z = Z, S = S,
model = abb_model_name(model),
keywords = keywords_id, keywords_raw = keywords_raw,
no_keyword_topics = no_keyword_topics,
keyword_k = length(keywords_raw),
vocab = info$wd_names,
model_settings = model_settings, priors = priors, options = options,
stored_values = stored_values, model_fit = list(), call = match.call()
)
rm(info)
class(key_model) <- c("keyATM_model", model, class(key_model))
keyATM_initialized <- list(model = key_model, model_name = model)
class(keyATM_initialized) <- c("keyATM_initialized", class(keyATM_initialized))
return(keyATM_initialized)
}
keyATM_fit <- function(keyATM_initialized, resume = FALSE)
{
if (! ("keyATM_initialized" %in% class(keyATM_initialized) | "keyATM_resume" %in% class(keyATM_initialized)))
cli::cli_abort("The input is not an initialized object.")
key_model <- keyATM_initialized$model
model_name <- keyATM_initialized$model_name
iterations <- key_model$options$iter_new
if ("keyATM_resume" %in% class(keyATM_initialized)) {
.GlobalEnv$.Random.seed <- keyATM_initialized$rand_state # to resume
cli::cli_progress_step("Fitting the model: adding {iterations} iteration{?s} to the existing {key_model$options$iterations - iterations} iteration{?s} for a total of {key_model$options$iterations} iteration{?s}", spinner = TRUE)
} else {
set.seed(key_model$options$seed)
cli::cli_progress_step("Fitting the model: {iterations} iteration{?s}", spinner = TRUE)
}
if (model_name == "base") {
key_model <- keyATM_fit_base(key_model, resume = resume)
} else if (model_name == "hmm") {
key_model <- keyATM_fit_HMM(key_model, resume = resume)
} else if (model_name == "lda") {
key_model <- keyATM_fit_LDA(key_model, resume = resume)
} else if (model_name == "ldacov") {
key_model <- keyATM_fit_LDAcov(key_model, resume = resume)
} else if (model_name == "ldahmm") {
key_model <- keyATM_fit_LDAHMM(key_model, resume = resume)
} else if (model_name == "cov" & key_model$model_settings$covariates_model == "PG") {
key_model <- keyATM_fit_covPG(key_model, resume = resume)
} else if (model_name == "cov" & key_model$model_settings$covariates_model == "DirMulti") {
key_model <- keyATM_fit_cov(key_model, resume = resume)
} else {
cli::cli_abort("Please check {.var model}.")
}
if (! "keyATM_fitted" %in% class(key_model))
class(key_model) <- c("keyATM_fitted", class(key_model))
return(key_model)
}
#' @noRd
#' @export
print.keyATM_model <- function(x, ...)
{
cli::cli_inform("keyATM_model object for the {x$model} model.")
}
check_arg <- function(obj, name, model, info = list())
{
if (name == "keywords") {
return(check_arg_keywords(obj, model, info))
}
if (name == "model_settings") {
return(check_arg_model_settings(obj, model, info))
}
if (name == "priors") {
return(check_arg_priors(obj, model, info))
}
if (name == "options") {
return(check_arg_options(obj, model, info))
}
if (name == "vb_options") {
return(check_arg_vboptions(obj, model, info))
}
}
check_arg_keywords <- function(keywords, model, info)
{
check_arg_type(keywords, "list")
if (length(keywords) == 0 & model %in% info$models_keyATM) {
cli::cli_abort("Please provide keywords.")
}
if (length(keywords) != 0 & model %in% info$models_lda) {
cli::cli_abort("This model does not take keywords.")
}
# Name of keywords topic
if (model %in% info$models_keyATM) {
c <- lapply(keywords, function(x) {check_arg_type(x, "character")})
if (is.null(names(keywords))) {
names(keywords) <- paste0(1:length(keywords))
} else {
names(keywords) <- paste0(1:length(keywords), "_", names(keywords))
}
}
return(keywords)
}
show_unused_arguments <- function(obj, name, allowed_arguments)
{
unused_input <- names(obj)[! names(obj) %in% allowed_arguments]
if (length(unused_input) != 0)
cli::cli_abort(paste0(
"keyATM doesn't recognize some of the arguments ",
"in ", name, ": ",
paste(unused_input, collapse = ", ")
)
)
}
check_arg_model_settings <- function(obj, model, info)
{
check_arg_type(obj, "list")
allowed_arguments <- c()
if (model %in% c("base", "lda", "hmm", "ldahmm")) {
# Slice Sampling Settings
if (is.null(obj$slice_min)) {
obj$slice_min <- 1e-9
} else {
if (!is.numeric(obj$slice_min)) {
cli::cli_abort("`model_settings$slice_min` should be a numeric value.")
}
if (obj$slice_min <= 0) {
cli::cli_abort("`model_settings$slice_min` should be a positive value.")
}
}
if (is.null(obj$slice_max)) {
obj$slice_max <- 100
} else {
if (!is.numeric(obj$slice_max)) {
cli::cli_abort("`model_settings$slice_max` should be a numeric value.")
}
if (obj$slice_max <= 0) {
cli::cli_abort("`model_settings$slice_max` should be a positive value.")
}
}
allowed_arguments <- c(allowed_arguments, "slice_min", "slice_max")
}
# check model settings for covariate model
if (model %in% c("cov", "ldacov")) {
if (is.null(obj$covariates_data)) {
cli::cli_abort("Please provide `obj$covariates_data`.")
}
obj$covariates_data <- as.data.frame(obj$covariates_data)[info$use_doc_index, , drop = FALSE]
if (nrow(obj$covariates_data) != info$num_doc) {
cli::cli_abort("The row of `model_settings$covariates_data` should be the same as the number of documents.")
}
if (sum(is.na(obj$covariates_data)) != 0) {
cli::cli_abort("Covariate data should not contain missing values.")
}
if (is.null(obj$covariates_formula)) {
obj$covariates_formula <- NULL # do not need to change the matrix
}
if (is.null(obj$standardize))
obj$standardize <- "non-factor"
if (!obj$standardize %in% c("all", "none", "non-factor"))
cli::cli_abort('Unknown option in `standardize`. It should be one of "all", "none", or "non-factor".')
# Standardize
obj$covariates_data_use <- covariates_standardize(obj$covariates_data, obj$standardize, obj$covariates_formula)
# Check if it works as a valid regression
temp <- as.data.frame(obj$covariates_data_use)
temp$y <- stats::rnorm(nrow(obj$covariates_data_use))
if ("(Intercept)" %in% colnames(obj$covariates_data_use)) {
fit <- stats::lm(y ~ 0 + ., data = temp) # data.frame already includes the intercept
if (NA %in% fit$coefficients) {
cli::cli_abort("Covariates are invalid.")
}
} else {
fit <- stats::lm(y ~ ., data = temp) # data.frame does not have an itercept
if (NA %in% fit$coefficients) {
cli::cli_abort("Covariates are invalid.")
}
}
# Slice Sampling Settings
if (is.null(obj$slice_min)) {
obj$slice_min <- -5.0
} else {
if (!is.numeric(obj$slice_min)) {
cli::cli_abort("`model_settings$slice_min` should be a numeric value.")
}
}
if (is.null(obj$slice_max)) {
obj$slice_max <- 5.0
} else {
if (!is.numeric(obj$slice_max)) {
cli::cli_abort("`model_settings$slice_max` should be a numeric value.")
}
}
# MH option
if (is.null(obj$mh_use)) {
obj$mh_use <- 0
} else {
obj$mh_use <- as.integer(obj$mh_use)
if (!obj$mh_use %in% c(0, 1)) {
cli::cli_abort("`model_settings$mh_use` should be TRUE/FALSE (0/1)")
}
}
# Model
if (is.null(obj$covariates_model)) {
if (model == "cov")
obj$covariates_model <- "DirMulti"
if (model == "ldacov")
obj$covariates_model <- "DirMulti"
}
if (obj$covariates_model != "DirMulti" & model == "ldacov")
cli::cli_abort("Use Diricule-Multinomial model for LDA covariates.")
if (!obj$covariates_model %in% c("PG", "DirMulti"))
cli::cli_abort("Undefined model. `covariates_model` option in `model_settings` take `PG` or `DirMulti`.")
if (obj$covariates_model == "PG") {
K <- info$total_k
X <- obj$covariates_data_use
M <- ncol(X) # Number of covariates
D <- nrow(X) # Number of douments
Sigma_Lambda <- diag(rep(1, K - 1))
if (M == 1) {
obj$PG_params$PG_Lambda <- t(MASS::mvrnorm(n = M, mu = rep(0, K-1), Sigma = Sigma_Lambda))
} else {
obj$PG_params$PG_Lambda <- MASS::mvrnorm(n = M, mu = rep(0, K-1), Sigma = Sigma_Lambda)
}
obj$PG_params$PG_SigmaPhi <- diag(rep(1, K-1))
Mu <- X %*% obj$PG_params$PG_Lambda
Sigma <- diag(rep(1, K-1))
Phi <- sapply(1:D,
function(d) {
Phi_d <- rmvn1(mu = Mu[d, ], Sigma = Sigma)
})
Phi <- t(Phi)
colnames(Phi) <- NULL
row.names(Phi) <- NULL
obj$PG_params$PG_Phi <- Phi # D \times (K - 1)
obj$PG_params$theta_tilda <- exp(Phi) / (1 + exp(Phi))
obj$PG_params$theta_last <- matrix(rep(0, D*K), nrow = D, ncol = K)
obj$PG_params$Lambda_list <- list()
obj$PG_params$Sigma_list <- list()
}
allowed_arguments <- c(allowed_arguments, "covariates_data", "covariates_data_use",
"slice_min", "slice_max", "mh_use", "covariates_model", "PG_params",
"covariates_formula", "standardize", "info")
} # cov model end
# check model settings for dynamic model
if (model %in% c("hmm", "ldahmm")) {
if (is.null(obj$num_states)) {
cli::cli_abort("`model_settings$num_states` is not provided.")
}
if (is.null(obj$time_index)) {
cli::cli_abort("`model_settings$time_index` is not provided.")
}
obj$time_index <- obj$time_index[info$use_doc_index]
if (length(obj$time_index) != info$num_doc) {
cli::cli_abort("The length of the `model_settings$time_index` does not match with the number of documents.")
}
if (min(obj$time_index) != 1 | max(obj$time_index) > info$num_doc) {
cli::cli_abort("`model_settings$time_index` should start from 1 and not exceed the number of documents.")
}
if (max(obj$time_index) < obj$num_states)
cli::cli_abort("`model_settings$num_states` should not exceed the maximum of `model_settings$time_index`.")
check <- unique(obj$time_index[2:length(obj$time_index)] - obj$time_index[1:(length(obj$time_index)-1)])
if (sum(!unique(check) %in% c(0,1)) != 0)
cli::cli_abort("`model_settings$time_index` does not increment by 1.")
obj$time_index <- as.integer(obj$time_index)
allowed_arguments <- c(allowed_arguments, "num_states", "time_index")
}
show_unused_arguments(obj, "`model_settings`", allowed_arguments)
return(obj)
}
check_arg_priors <- function(obj, model, info)
{
check_arg_type(obj, "list")
# Base arguments
allowed_arguments <- c("beta", "eta_1", "eta_2", "eta_1_regular", "eta_2_regular")
# prior of pi
if (model %in% info$models_keyATM) {
if (is.null(obj$gamma)) {
obj$gamma <- matrix(1.0, nrow = info$total_k, ncol = 2)
}
if (!is.null(obj$gamma)) {
if (dim(obj$gamma)[1] != info$total_k)
cli::cli_abort("Check the dimension of `priors$gamma`")
if (dim(obj$gamma)[2] != 2)
cli::cli_abort("Check the dimension of `priors$gamma`")
}
if (info$keyword_k < info$total_k) {
# Regular topics are used in keyATM models
# Priors for regular topics should be 0
if (sum(obj$gamma[(info$keyword_k+1):info$total_k, ]) != 0) {
obj$gamma[(info$keyword_k+1):info$total_k, ] <- 0
}
}
allowed_arguments <- c(allowed_arguments, "gamma")
}
# beta
if (!"beta" %in% names(obj)) {
obj$beta <- 0.01
}
if (model %in% info$models_keyATM) { # models with keywords
if (!"beta_s" %in% names(obj)) {
obj$beta_s <- 0.1
}
allowed_arguments <- c(allowed_arguments, "beta_s")
}
# alpha
if (model %in% c("base", "lda")) {
if (is.null(obj$alpha)) {
obj$alpha <- rep(1/info$total_k, info$total_k)
}
if (length(obj$alpha) != info$total_k) {
cli::cli_abort("Starting alpha must be a vector of length ", info$total_k)
}
allowed_arguments <- c(allowed_arguments, "alpha")
}
# eta
if (!"eta_1" %in% names(obj)) {
obj$eta_1 <- 1.0
}
if (!"eta_2" %in% names(obj)) {
obj$eta_2 <- 1.0
}
if (!"eta_1_regular" %in% names(obj)) {
obj$eta_1_regular <- 2.0
}
if (!"eta_2_regular" %in% names(obj)) {
obj$eta_2_regular <- 1.0
}
show_unused_arguments(obj, "`priors`", allowed_arguments)
return(obj)
}
check_arg_options <- function(obj, model, info)
{
check_arg_type(obj, "list")
allowed_arguments <- c("seed", "llk_per", "thinning",
"iterations", "iter_new", "verbose",
"use_weights", "weights_type",
"prune", "store_theta", "slice_shape",
"parallel_init", "resume")
# llk_per
if (is.null(obj$llk_per))
obj$llk_per <- 10L
if (!is.numeric(obj$llk_per) | obj$llk_per < 0 | obj$llk_per%%1 != 0) {
cli::cli_abort("An invalid value in `options$llk_per`")
}
# verbose
if (is.null(obj$verbose)) {
obj$verbose <- 0L
} else {
obj$verbose <- as.integer(obj$verbose)
if (!obj$verbose %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$verbose`")
}
}
# thinning
if (is.null(obj$thinning))
obj$thinning <- 5L
if (!is.numeric(obj$thinning) | obj$thinning < 0| obj$thinning%%1 != 0) {
cli::cli_abort("An invalid value in `options$thinning`")
}
# iterations
if (is.null(obj$iterations)) {
obj$iterations <- 1500L
}
obj$iter_new <- obj$iterations # this is initialization so new iteration = total number of iterations
if (!is.numeric(obj$iterations) | obj$iterations < 0 | obj$iterations %% 1 != 0) {
cli::cli_abort("An invalid value in `options$iterations`")
}
# Store theta
if (is.null(obj$store_theta)) {
obj$store_theta <- 0L
} else {
obj$store_theta <- as.integer(obj$store_theta)
if (!obj$store_theta %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$store_theta`")
}
}
# Store pi
if (model %in% info$models_keyATM) {
if (is.null(obj$store_pi)) {
obj$store_pi <- 0L
} else {
obj$store_pi <- as.integer(obj$store_pi)
if (!obj$store_pi %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$store_theta`")
}
}
allowed_arguments <- c(allowed_arguments, "store_pi")
}
# Estimate alpha
if (model %in% c("base", "lda")) {
if (is.null(obj$estimate_alpha)) {
obj$estimate_alpha <- 1L
} else {
obj$estimate_alpha <- as.integer(obj$estimate_alpha)
if (!obj$estimate_alpha %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$estimate_alpha`")
}
}
allowed_arguments <- c(allowed_arguments, "estimate_alpha")
}
# Slice shape
if (is.null(obj$slice_shape)) {
# parameter for slice sampling
obj$slice_shape <- 1.2
}
if (!is.numeric(obj$slice_shape) | obj$slice_shape < 0) {
cli::cli_abort("An invalid value in `options$slice_shape`")
}
# Use weights
if (is.null(obj$use_weights)) {
obj$use_weights <- 1L
} else {
obj$use_weights <- as.integer(obj$use_weights)
if (!obj$use_weights %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$use_weights`")
}
}
# Type of the weights
if (is.null(obj$weights_type)) {
obj$weights_type <- "information-theory"
} else {
if (!obj$weights_type %in% c("information-theory", "information-theory-normalized",
"inv-freq", "inv-freq-normalized"))
{
cli::cli_abort("An invalid value in `options$weights_type`")
}
}
# Prune keywords
if (is.null(obj$prune)) {
obj$prune <- 1L
} else {
obj$prune <- as.integer(obj$prune)
if (!obj$prune %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$prune`")
}
}
# Store transition matrix in Dynamic models
if (model %in% c("hmm", "ldahmm")) {
if (is.null(obj$store_transition_matrix)) {
obj$store_transition_matrix <- 0L
}
if (!obj$store_transition_matrix %in% c(0, 1)) {
cli::cli_abort("An invalid value in `options$store_transition_matrix`")
}
allowed_arguments <- c(allowed_arguments, "store_transition_matrix")
}
# Use parallel function in initialization
if (!"parallel_init" %in% names(obj)) {
obj$parallel_init <- FALSE
} else {
if (!obj$parallel_init %in% c(0, 1, FALSE, TRUE)) {
cli::cli_abort("`obj$parallel_init` should be TRUE/FALSE")
}
}
# Use parallel function in initialization
if (!"resume" %in% names(obj)) {
obj$resume <- ""
}
# Check unused arguments
show_unused_arguments(obj, "`options`", allowed_arguments)
return(obj)
}
check_vocabulary <- function(vocab)
{
if (" " %in% vocab) {
cli::cli_abort("A space is recognized as a vocabulary. Please remove an empty document or consider using quanteda::dfm.")
}
if ("" %in% vocab) {
cli::cli_abort('A blank `""` is recognized as a vocabulary. Please review preprocessing steps.')
}
if (sum(stringr::str_detect(vocab, "^[:upper:]+$")) != 0) {
cli::cli_warn("Upper case letters are used. Please review preprocessing steps.", immediate. = TRUE)
}
if (sum(stringr::str_detect(vocab, "\t")) != 0) {
cli::cli_warn("Tab is detected in the vocabulary. Please review preprocessing steps.", immediate. = TRUE)
}
if (sum(stringr::str_detect(vocab, "\n")) != 0) {
cli::cli_warn("A line break is detected in the vocabulary. Please review preprocessing steps.", immediate. = TRUE)
}
}
make_sz_key <- function(W, keywords, info)
{
# zs_assigner maps keywords to category ids
key_wdids <- info$keywords_id
cat_ids <- rep(1:(info$keyword_k) - 1L, unlist(lapply(keywords, length)))
if (length(key_wdids) == length(unique(key_wdids))) {
#
# No keyword appears more than once
#
zs_assigner <- myhashmap_keyint(as.integer(key_wdids), as.integer(cat_ids))
# if the word is a keyword, assign the appropriate (0 start) Z, else a random Z
topicvec <- 1:(info$total_k) - 1L
make_z <- function(s, topicvec) {
zz <- myhashmap_getvec_keyint(zs_assigner, s)
zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
return(zz)
}
} else {
#
# Some keywords appear multiple times
#
keys_df <- data.frame(wid = key_wdids, cat = cat_ids)
keys_char <- sapply(unique(key_wdids),
function(x) {
paste(as.character(keys_df[keys_df$wid == x, "cat"]), collapse=",")
})
zs_hashtable <- myhashmap_keyint(as.integer(unique(key_wdids)), keys_char)
zs_assigner <- function(s) {
topic <- myhashmap_getvec_keyint(zs_hashtable, s)
topic <- strsplit(as.character(topic), split=",") # this function should take a character vector
topic <- lapply(topic, sample, 1)
topic <- as.integer(unlist(topic, use.names = FALSE))
return(topic)
}
# if the word is a seed, assign the appropriate (0 start) Z, else a random Z
topicvec <- 1:(info$total_k) - 1L
make_z <- function(s, topicvec) {
zz <- zs_assigner(s) # if it is a seed word, we already know the topic
zz[is.na(zz)] <- sample(topicvec, sum(is.na(zz)), replace = TRUE)
return(zz)
}
}
## ss indicates whether the word comes from a seed topic-word distribution or not
make_s <- function(s) {
key <- as.numeric(s %in% key_wdids) # 1 if they're a seed
# Use s structure
s[key == 0] <- 0L # non-keyword words have s = 0
s[key == 1] <- sample(0:1, length(s[key == 1]), prob = c(0.3, 0.7), replace = TRUE)
# keywords have x = 1 probabilistically
return(s)
}
if (info$parallel_init) {
S <- future.apply::future_lapply(W, make_s, future.seed = TRUE)
Z <- future.apply::future_lapply(W, make_z, topicvec, future.seed = TRUE)
} else {
S <- lapply(W, make_s)
Z <- lapply(W, make_z, topicvec)
}
return(list(S = S, Z = Z))
}
make_sz_lda <- function(W, info)
{
topicvec <- 1:(info$total_k) - 1L
make_z <- function(x, topicvec) {
zz <- sample(topicvec,
length(x),
replace = TRUE)
return(as.integer(zz))
}
if (info$parallel_init) {
Z <- future.apply::future_lapply(W, make_z, topicvec, future.seed = TRUE)
} else {
Z <- lapply(W, make_z, topicvec)
}
return(list(S = list(), Z = Z))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.