# print -------------------------------------------------------------------
sentopics_print_extend <- function(extended = FALSE) {
methods = c("fit", "topics", "topWords", "plot")
explain <- c("Estimate the model using Gibbs sampling",
"Return the most important topic of each document",
"Return a data.table with the top words of each topic/sentiment",
"Plot a sunburst chart representing the estimated mixtures")
cat("------------------Useful methods------------------\n")
cat(sprintf("%-10s:%s", methods, explain), sep = "\n")
if (!extended) cat("This message is displayed once per session, unless calling `print(x, extended = TRUE)`\n")
}
#' Print method for sentopics models
#'
#' @description Print methods for **sentopics** models. Once per session (or
#' forced by using `extended = TRUE`), it lists the most important function
#' related to **sentopics** models.
#'
#' @rdname print.sentopicmodel
#'
#' @param x the model to be printed
#' @param extended if `TRUE`, extends the print to include some helpful related
#' functions. Automatically displayed once per session.
#' @param ... not used
#'
#' @return No return value, called for side effects (printing).
#'
#' @export
print.sentopicmodel <- function(x, extended = FALSE, ...) {
cat("A sentopicmodel topic model with", x$L1, "topics and", x$L2, "sentiments. Currently fitted by",
x$it, "Gibbs sampling iterations.\n")
if (getOption("sentopics_print_extended", TRUE) | extended) {
sentopics_print_extend(extended)
options(sentopics_print_extended = FALSE)
}
##TODO: add models paremeters? ex: # cat("Model parameters: alpha = ")
}
#' @rdname print.sentopicmodel
#' @export
print.rJST <- function(x, extended = FALSE, ...) {
cat("A reversed-JST model with", x$K, "topics and", x$S, "sentiments. Currently fitted by",
x$it, "Gibbs sampling iterations.\n")
if (getOption("sentopics_print_extended", TRUE) | extended) {
sentopics_print_extend(extended)
options(sentopics_print_extended = FALSE)
}
}
#' @rdname print.sentopicmodel
#' @export
print.LDA <- function(x, extended = FALSE, ...) {
cat("An LDA model with", x$K, "topics. Currently fitted by",
x$it, "Gibbs sampling iterations.\n")
if (getOption("sentopics_print_extended", TRUE) | extended) {
sentopics_print_extend(extended)
options(sentopics_print_extended = FALSE)
}
}
#' @rdname print.sentopicmodel
#' @export
print.JST <- function(x, extended = FALSE, ...) {
cat("A JST model with", x$S, "sentiments and", x$K, "topics. Currently fitted by",
x$it, "Gibbs sampling iterations.\n")
if (getOption("sentopics_print_extended", TRUE) | extended) {
sentopics_print_extend(extended)
options(sentopics_print_extended = FALSE)
}
}
#' @export
print.multiChains <- function(x, ...) {
x <- lapply(x, identity)
NextMethod()
}
#' @export
print.topWords <- function(x, ...) {
if (!is.null(attr(x, "method"))) colnames(x)[colnames(x) == "value"] <-
paste0("value[", attr(x, "method"), "]")
NextMethod()
}
# plot --------------------------------------------------------------------
#' Plot a topic model using Plotly
#'
#' @description Summarize and plot a **sentopics** model using a sunburst chart
#' from the [plotly::plotly] library.
#'
#' @param x a model created from the [LDA()], [JST()] or [rJST()] function and
#' estimated with [fit()]
#' @param nWords the number of words per topic/sentiment to display in the outer
#' layer of the plot
#' @param layers specifies the number of layers for the sunburst chart. This
#' will restrict the output to the `layers` uppermost levels of the chart. For
#' example, setting `layers = 1` will only display the top level of the
#' hierarchy (topics for an LDA model).
#' @param sort if `TRUE`, sorts the plotted topics in a decreasing frequency.
#' @param ... not used
#'
#' @return A `plotly` sunburst chart.
#'
#' @export
#' @seealso [topWords()] [LDAvis()]
#' @examples
#' lda <- LDA(ECB_press_conferences_tokens)
#' lda <- fit(lda, 100)
#' plot(lda, nWords = 5)
#'
#' # only displays the topic proportions
#' plot(lda, layers = 1)
## TODO: add other methods than probability?
plot.sentopicmodel <- function(x, nWords = 15, layers = 3, sort = FALSE, ...) {
mis <- missingSuggets(c("plotly", "RColorBrewer"))
if (length(mis) > 0) stop("Suggested packages are missing for the plot.sentopicmodel function.\n",
"Please install first the following packages: ",
paste0(mis, collapse = ", "),".\n",
"Install command: install.packages(",
paste0("'", mis, "'", collapse = ", "),")" )
class <- class(x)[1]
x <- as.sentopicmodel(x)
real <- prob <- L1 <- L2 <- id <- value <- parent <- color <- NULL # due to NSE notes in R CMD check
l1 <- l2 <- l3 <- NULL ## in case they exist in environment
mixtureStats <- melt(x)
L1_name <- colnames(mixtureStats)[2]
L2_name <- colnames(mixtureStats)[1]
# if (topicsOnly) method <- "topics" else method <- "probability"
if (layers > 0) {
l1 <- mixtureStats[, list(value = sum(prob) / length(x$tokens)), by = L1_name]
# l1$name <- l1[[L1_name]]
l1$name <- create_labels(x, class, flat = FALSE)[["L1"]]
l1$parent <- ""
l1$id <- paste0("l1_", as.numeric(l1[[L1_name]]))
if (x$L1 < 10) {
if (x$L1 < 3) l1$color <- RColorBrewer::brewer.pal(3, "Set1")[seq(nrow(l1))]
else l1$color <- RColorBrewer::brewer.pal(nrow(l1), "Set1")
}
else l1$color <- grDevices::colorRampPalette(
RColorBrewer::brewer.pal(7, "Set1"))(x$L1)
l1$real <- l1$value
}
if (layers > 1) {
l2 <- mixtureStats[, list(value = sum(prob) / length(x$tokens)), by = c(L1_name, L2_name)]
l2$name <- l2[[L2_name]]
l2$parent <- paste0("l1_", as.numeric(l2[[L1_name]]))
l2$id <- paste0(l2$parent, paste0("l2_", as.numeric(l2[[L2_name]])))
l2$real <- l2$value
## Not working. Neet update
# if (topicsOnly) l2 <- l2[, list(value = sum(value)), by = list(L1)][, id := paste0("l1_", L1, "l2_1")]
for (i in l1[[L1_name]]) {
cols <- spreadColor(l1[l1[[L1_name]] == i, color], x$L2)
l2[l2[[L1_name]] == i, color := cols]
l2[l2[[L1_name]] == i, real := (real / l1[l1[[L1_name]] == i, real])]
}
}
if (layers > 2 | (layers > 1 & class == "LDA")) {
l3 <- topWords_dt(x, nWords, "probability")
l3$name <- l3$word
l3$parent <- paste0("l1_", l3$L1, "l2_", l3$L2)
l3$id <- paste0(l3$parent, l3$name)
l3$real <- l3$value
## need to rescale l3 to sum up to l2. multiplying by 0.9 for a small space between groups
## It's okay as real is used to show the accurate probabilities
for (i in unique(l3$parent)) {
sub_l2 <- l2[id == i, value]
sub_l3 <- l3[parent == i, value]
l3[parent == i, "value"] <- (sub_l2 / sum(sub_l3) * 0.9) * sub_l3
l3[parent == i, "color" := l2[id == i, "color"]]
}
}
LIST <- list(l1[, -1], l2[, -(1:2)], l3[, -(1:3)])
data <- data.table::rbindlist(LIST, use.names = TRUE)
if (class == "LDA") {
data <- data[!grepl("^l1_[0-9]+$", parent)]
data$parent <- sub("l2_[0-9]+$", "", data$parent)
}
# reorder data clockwise
if (sort)
data <- data[order(value)]
else
data <- data[.N:1, ]
fig <- plotly::plot_ly(
ids = data$id,
labels = data$name,
parents = data$parent,
values = data$value,
type = "sunburst",
rotation = "180",
branchvalues = "total",
leaf = list(opacity = 1),
marker = list(colors = data$color),
sort = FALSE,
hovertemplate = "%{label}\n%{customdata:.2%}<extra></extra>",
customdata = data$real,
hoverlabel = list(font = list(size = 20))
)
fig
}
#' Plot the distances between topic models (chains)
#'
#' @description Plot the results of `chainsDistance(x)` using multidimensional
#' scaling. See [chainsDistances()] for details on the distance computation
#' and [stats::cmdscale()] for the implementation of the multidimensional
#' scaling.
#'
#' @inheritParams chainsDistances
#'
#' @param ... not used
#' @seealso [chainsDistances()] [cmdscale()]
#'
#' @return Invisibly, the coordinates of each topic model resulting from the
#' multidimensional scaling.
#'
#' @examples
#' models <- LDA(ECB_press_conferences_tokens)
#' models <- fit(models, 10, nChains = 5)
#' plot(models)
#' @export
plot.multiChains <- function(x, ..., method = c("euclidean", "hellinger", "cosine", "minMax", "naiveEuclidean", "invariantEuclidean")) {
if (attr(x, "nChains") < 3) stop("At least 3 chains are required for proper plotting.")
method <- match.arg(method)
d <- stats::as.dist(chainsDistances(x, method))
coord <- stats::cmdscale(d)
## Possible ggplot2 way of doing it
# local({
# coord <- data.table::as.data.table(coord)
# coord$maxLik <- sapply(x, function(xx) max(xx$logLikelihood))
# coord$lastLik <- sapply(x, function(xx) tail(xx$logLikelihood, 1))
# coord
# print(ggplot(coord, aes(x = V1, y = V2, colour = lastLik)) + geom_point(size = 4))
# })
plot(coord[, 1], coord[, 2], type = "n", xlab = "Coordinate 1", ylab = "Coordinate 2", main = paste0("Chains multidimensional scaling of ", deparse(substitute(x))))
graphics::text(coord[, 1], coord[, 2], rownames(coord), cex = 0.8)
graphics::abline(h = 0, v = 0, col = "gray75")
invisible(coord)
}
# fit --------------------------------------------------------------------
#' @importFrom generics fit
#' @export
generics::fit
#' @name fit.sentopicmodel
#' @aliases grow
#' @title Estimate a topic model
#'
#' @description This function is used to estimate a topic model created by
#' [LDA()], [JST()] or [rJST()]. In essence, this function iterates a Gibbs
#' sampler MCMC.
#'
#' @param object a model created with the [LDA()], [JST()] or [rJST()] function.
#' @param iterations the number of iterations by which the model should be
#' fitted.
#' @param nChains if set above 1, the model will be fitted multiple times
#' from various starting positions. Latent variables will be re-initialized if
#' `object` has not been fitted before.
#' @param displayProgress if `TRUE`, a progress bar will be displayed indicating
#' the progress of the computation. When `nChains` is greater than 1, this
#' requires the package \pkg{progressr} and optionally \pkg{progress}.
#' @param computeLikelihood if set to `FALSE`, does not compute the likelihood
#' at each iteration. This can slightly decrease the computing time.
#' @param seed for reproducibility, a seed can be provided.
#' @param ... arguments passed to other methods. Not used.
#'
#' @return a `sentopicmodel` of the relevant model class if `nChains` is
#' unspecified or equal to 1. A `multiChains` object if `nChains` is greater
#' than 1.
#'
#' @section Parallelism: When `nChains > 1`, the function can take advantage of
#' [future.apply::future_lapply] (if installed) to spread the computation over
#' multiple processes. This requires the specification of a parallel strategy
#' using [future::plan()]. See the examples below.
#'
#' @seealso [LDA()], [JST()], [rJST()], [reset()]
#' @export
#' @examples
#' model <- rJST(ECB_press_conferences_tokens)
#' fit(model, 10)
#'
#' # -- Parallel computation --
#' require(future.apply)
#' future::plan("multisession", workers = 2) # Set up 2 workers
#' fit(model, 10, nChains = 2)
#'
#' future::plan("sequential") # Shut down workers
NULL
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
fit.LDA <- function(object, iterations = 100, nChains = 1, displayProgress = TRUE, computeLikelihood = TRUE, seed = NULL, ...) {
if (isTRUE(attr(object, "approx"))) stop("Not possible for approximated models")
as.LDA(fit(as.sentopicmodel(object),
iterations = iterations,
nChains = nChains,
displayProgress = displayProgress,
computeLikelihood = computeLikelihood,
seed = seed))
}
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
fit.rJST <- function(object, iterations = 100, nChains = 1, displayProgress = TRUE, computeLikelihood = TRUE, seed = NULL, ...) {
as.rJST(fit(as.sentopicmodel(object),
iterations = iterations,
nChains = nChains,
displayProgress = displayProgress,
computeLikelihood = computeLikelihood,
seed = seed))
}
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
fit.JST <- function(object, iterations = 100, nChains = 1, displayProgress = TRUE, computeLikelihood = TRUE, seed = NULL, ...) {
as.JST(fit(as.sentopicmodel(object),
iterations = iterations,
nChains = nChains,
displayProgress = displayProgress,
computeLikelihood = computeLikelihood,
seed = seed))
}
#' @rdname fit.sentopicmodel
#' @export
fit.sentopicmodel <- function(object, iterations = 100, nChains = 1,
displayProgress = TRUE, computeLikelihood = TRUE,
seed = NULL, ...) {
start_time <- Sys.time()
## force deep copy
object <- data.table::copy(object)
if (nChains == 1) {
if (!is.null(seed)) set.seed(seed * (object$it + 1))
## rebuild c++ model
base <- core(object) ## need to protect base$cleaned from gc
cpp_model <- rebuild_cppModel(object, base)
cpp_model$iterate(iterations, displayProgress, computeLikelihood)
tmp <- extract_cppModel(cpp_model, base)
object[names(tmp)] <- tmp
reorder_sentopicmodel(object)
} else if (nChains > 1) {
## CMD check
chains <- NULL
base <- core(object)
core(object) <- NULL
if (!is.null(seed)) seed <- seed * (object$it + 1)
FUN <- function(i) {
report_time <- Sys.time()
report_processed <- 0L
## need to duplicate memory location of x$za. Otherwise, all chains
## share the same memory location
object$za <- data.table::copy(object$za)
rebuild_cppModel <- get("rebuild_cppModel",
envir = getNamespace("sentopics"))
cpp_model <- rebuild_cppModel(object, base)
## generate different initial assignment for each chain
if (object$it == 0 & i > 1) cpp_model$initAssignments()
if (iterations > 0) {
for (chunk in c(rep(chunkProgress, iterations %/% chunkProgress),
iterations %% chunkProgress)) {
cpp_model$iterate(chunk, FALSE, computeLikelihood)
report_processed <- report_processed + chunk
if ((Sys.time() - report_time) > 1) { # not too often
difftime <- difftime(Sys.time(), start_time)
p(amount = report_processed,
message = sprintf("Elapsed: %.2f %s", difftime, units(difftime)))
report_time <- Sys.time()
report_processed <- 0L
}
}
}
if (report_processed > 0L) { # if remaining update
difftime <- difftime(Sys.time(), start_time)
p(amount = report_processed,
message = sprintf("Elapsed: %.2f %s", difftime, units(difftime)))
}
extract_cppModel <- get("extract_cppModel",
envir = getNamespace("sentopics"))
tmp <- extract_cppModel(cpp_model, base)
object[names(tmp)] <- tmp
object
}
expr <- quote({
if (requireNamespace("future.apply", quietly = TRUE)) {
if (is.null(seed)) seed <- TRUE
environment(FUN) <- globalenv()
chains <- future.apply::future_lapply(
1:nChains, FUN, future.seed = seed, future.globals = list(
object = object, base = base, iterations = iterations,
chunkProgress = chunkProgress,
computeLikelihood = computeLikelihood,
p = p, start_time = start_time))
} else {
if (requireNamespace("future", quietly = TRUE)) {
if (!inherits(class(future::plan()), "sequential"))
message("It seems that a parallel setup is registered, but `future.apply` is not installed. Did you omit installing it? Proceeding sequential computation...")
}
## TODO: align RNGs?
if (!is.null(seed)) set.seed(seed)
chains <- lapply(1:nChains, FUN)
}
names(chains) <- paste0("chain", 1:nChains)
class(chains) <- "multiChains"
attr(chains, "nChains") <- nChains
attr(chains, "base") <- base
attr(chains, "containedClass") <- "sentopicmodel"
})
if (displayProgress & !requireNamespace("progressr", quietly = TRUE)) {
message("The `progressr` package is required to track progress of multiple chains.")
displayProgress <- FALSE
}
if (displayProgress) {
## determine how often the progress bar is refreshed (too high refresh rate can negatively affect performance)
# chunkProgress <- min(max((iterations * nChains) %/% 10, 10), 1000, iterations)
chunkProgress <- min(100, iterations)
progressr::with_progress({
p <- progressr::progressor(steps = nChains * iterations + 0.001, enable = displayProgress)
eval(expr)
difftime <- difftime(Sys.time(), start_time)
p(message = sprintf("Done! Elapsed: %.2f %s",
difftime, units(difftime)),
amount = 0.001)
},
handlers = custom_handler())
} else {
chunkProgress <- iterations
p <- function(...) {}
environment(p) <- globalenv()
eval(expr)
}
chains
} else {
stop("Incorrect number of chains specified.")
}
}
#' @rdname fit.sentopicmodel
#' @export
fit.multiChains <- function(object, iterations = 100, nChains = NULL,
displayProgress = TRUE, computeLikelihood = TRUE,
seed = NULL, ...) {
start_time <- Sys.time()
## CMD check
chains <- NULL
nChains <- attr(object, "nChains")
base <- attr(object, "base")
containedClass <- attr(object, "containedClass")
object <- unclass(object) ## unclass to prevent usage of `[[.multiChains`
object <- lapply(object, as.sentopicmodel) ## return to sentopicmodel objects
## erase posterior in each chain to limit memory transfers
for (i in seq_along(object)) {
object[[i]][["L1post"]] <- object[[i]][["L2post"]] <- object[[i]][["phi"]] <- NULL
}
if (!is.null(seed)) seed <- seed * (object[[1]]$it + 1)
FUN <- function(object) {
report_time <- Sys.time()
report_processed <- 0L
rebuild_cppModel <- get("rebuild_cppModel",
envir = getNamespace("sentopics"))
cpp_model <- rebuild_cppModel(object, base)
for (chunk in c(rep(chunkProgress, iterations %/% chunkProgress),
iterations %% chunkProgress)) {
cpp_model$iterate(chunk, FALSE, computeLikelihood)
report_processed <- report_processed + chunk
if ((Sys.time() - report_time) > 1) { # not too often
difftime <- difftime(Sys.time(), start_time)
p(amount = report_processed,
message = sprintf("Elapsed: %.2f %s", difftime, units(difftime)))
report_time <- Sys.time()
report_processed <- 0L
}
}
if (report_processed > 0L) { # if remaining update
difftime <- difftime(Sys.time(), start_time)
p(amount = report_processed,
message = sprintf("Elapsed: %.2f %s", difftime, units(difftime)))
}
extract_cppModel <- get("extract_cppModel",
envir = getNamespace("sentopics"))
tmp <- extract_cppModel(cpp_model, base)
object[names(tmp)] <- tmp
object
}
expr <- quote({
if (requireNamespace("future.apply", quietly = TRUE)) {
if (is.null(seed)) seed <- TRUE
environment(FUN) <- globalenv()
chains <- future.apply::future_lapply(
object, FUN, future.seed = seed, future.globals = list(
object = object, base = base, iterations = iterations,
chunkProgress = chunkProgress,
computeLikelihood = computeLikelihood,
p = p, start_time = start_time))
} else {
if (requireNamespace("future", quietly = TRUE)) {
if (!inherits(class(future::plan()), "sequential"))
message("It seems that a parallel setup is registered, but `future.apply` is not installed. Did you omit installing it? Proceeding sequential computation...")
}
## TODO: align RNGs?
if (!is.null(seed)) set.seed(seed)
chains <- lapply(object, FUN)
}
class(chains) <- "multiChains"
attr(chains, "nChains") <- nChains
attr(chains, "base") <- base
attr(chains, "containedClass") <- containedClass
})
if (displayProgress & !requireNamespace("progressr", quietly = TRUE)) {
message("The `progressr` package is required to track progress of multiple chains.")
displayProgress <- FALSE
}
if (displayProgress) {
## determine how often the progress bar is refreshed (too high refresh rate can negatively affect performance)
# chunkProgress <- min(max((iterations * nChains) %/% 10, 10), 1000, iterations)
chunkProgress <- min(100, iterations)
progressr::with_progress({
p <- progressr::progressor(steps = nChains * iterations + 0.001, enable = displayProgress)
eval(expr)
difftime <- difftime(Sys.time(), start_time)
p(message = sprintf("Done! Elapsed: %.2f %s",
difftime, units(difftime)),
amount = 0.001)
},
handlers = custom_handler())
} else {
chunkProgress <- iterations
p <- function(...) {}
environment(p) <- globalenv()
eval(expr)
}
chains
}
# Retain `grow` methods for compatibility
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow <- function(object, iterations = 100, nChains = 1, displayProgress = TRUE, computeLikelihood = TRUE, seed = NULL, ...) {
UseMethod("grow")
}
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow.LDA <- fit.LDA
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow.rJST <- fit.rJST
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow.JST <- fit.JST
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow.sentopicmodel <- fit.sentopicmodel
#' @rdname fit.sentopicmodel
#' @export
#' @usage NULL
grow.multiChains <- fit.multiChains
#' Re-initialize a topic model
#'
#' @author Olivier Delmarcelle
#'
#' @description This function is used re-initialize a topic model, as if it was
#' created from [LDA()], [JST()] or another model. The re-initialized model
#' retains its original parameter specification.
#'
#' @param object a model created from the [LDA()], [JST()] or [rJST()] function and
#' estimated with [fit()]
#'
#' @return a `sentopicmodel` of the relevant model class, with the iteration count
#' reset to zero and re-initialized assignment of latent variables.
#'
#' @seealso [fit()]
#' @export
#' @examples
#' model <- LDA(ECB_press_conferences_tokens)
#' model <- fit(model, 10)
#' reset(model)
reset <- function(object) {
## make copy or assignment are reset by reference
object$za <- data.table::copy(object$za)
class <- class(object)[1]
object <- as.sentopicmodel(object)
object$it <- 0
object$logLikelihood <- NULL
object$phi <- object$L2post <- object$L1post <- NULL
base <- core(object) ## need to protect base$cleaned from gc
cpp_model <- rebuild_cppModel(object, base)
cpp_model$initAssignments()
# object <- utils::modifyList(object, extract_cppModel(cpp_model, core(object)))
tmp <- extract_cppModel(cpp_model, base)
object[names(tmp)] <- tmp
fun <- get(paste0("as.", class))
fun(reorder_sentopicmodel(object))
}
# melt --------------------------------------------------------------------
#' Replacement generic for [data.table::melt()]
#'
#' As of the CRAN release of the 1.14.8 version of **data.table**, the
#' [data.table::melt()] function is not a generic. This function aims to
#' temporary provide a generic to this function, so that [melt.sentopicmodel()]
#' can be effectively dispatched when used. Expect this function to disappear
#' shortly after the release of **data.table** 1.14.9.
#'
#' @param data an object to melt
#' @param ... arguments passed to other methods
#'
#' @return An unkeyed `data.table` containing the molten data.
#'
#' @seealso [data.table::melt()], [melt.sentopicmodel()]
#' @export
melt <- function(data, ...) {
UseMethod("melt")
}
#' @export
melt.default <- function(data, ...) {
data.table::melt(data, ...)
}
# TODO: re-activate once data.table 1.14.9 is released.
# #' @importFrom data.table melt
# #' @export
# data.table::melt
#' Melt for sentopicmodels
#'
#' @author Olivier Delmarcelle
#'
#' @description This function extracts the estimated document mixtures from a
#' topic model and returns them in a long [data.table::data.table] format.
#'
#' @param data a model created from the [LDA()], [JST()] or [rJST()] function
#' and estimated with [fit()]
#' @param ... not used
#' @param include_docvars if `TRUE`, the melted result will also include the
#' *docvars* stored in the [tokens] object provided at model initialization
#'
#' @seealso [topWords()] for extracting representative words,
#' [data.table::melt()] and [data.table::dcast()]
#'
#' @return A [data.table::data.table] in the long format, where each line is the estimated
#' proportion of a single topic/sentiment for a document. For JST and rJST
#' models, the probability is also decomposed into 'L1' and 'L2' layers,
#' representing the probability at each layer of the topic-sentiment
#' hierarchy.
#'
#' @export
#' @examples
#' # only returns topic proportion for LDA models
#' lda <- LDA(ECB_press_conferences_tokens)
#' lda <- fit(lda, 10)
#' melt(lda)
#'
#' # includes sentiment for JST and rJST models
#' jst <- JST(ECB_press_conferences_tokens)
#' jst <- fit(jst, 10)
#' melt(jst)
melt.sentopicmodel <- function(data, ..., include_docvars = FALSE) {
if (data$it <= 0) stop("Nothing to melt. Estimate the model with fit() first.")
class <- class(data)[1]
data <- as.sentopicmodel(data)
.id <- topic <- sentiment <- prob <- L1_prob <- L2_prob <- NULL # due to NSE notes in R CMD check
L1stats <- data.table::as.data.table(data$L1post, sorted = FALSE, keep.rownames = ".id")
L1_name <- names(dimnames(data$L1post))[2]
L1stats <- melt(L1stats, id = ".id", variable.factor = TRUE,
variable.name = L1_name, value.name = "L1_prob")
if (attr(data, "Sdim") == "L1") stopifnot(identical(levels(L1stats[[L1_name]]), levels(data$vocabulary$lexicon)))
if (class != "LDA") {
L2stats <- data$L2post
names(dimnames(L2stats))[3] <- ".id"
L2_name <- names(dimnames(L2stats))[1]
L2stats <- data.table::as.data.table(L2stats, sorted = FALSE, value.name = "L2_prob")
for (i in 1:2) {
L2stats[[i]] <- factor(L2stats[[i]])
}
if (attr(data, "Sdim") == "L1") stopifnot(identical(levels(L1stats[[L1_name]]), levels(L1stats[[L1_name]])))
if (attr(data, "Sdim") == "L2") stopifnot(identical(levels(L2stats[[L2_name]]), levels(data$vocabulary$lexicon)))
mixtureStats <- L2stats[L1stats, on = c(".id", L1_name)]
mixtureStats[, prob := L2_prob * L1_prob]
mixtureStats$`sentopic` <- interaction(mixtureStats[[L1_name]], mixtureStats[[L2_name]], sep = "_", lex.order = TRUE)
data.table::setcolorder(mixtureStats, c(colnames(mixtureStats)[1:2], "sentopic"))
## quick integrity check
if (!all(isTRUE(all.equal(sum(L2stats$L2_prob), data$L1 * length(data$tokens))),
isTRUE(all.equal(sum(L1stats$L1_prob), length(data$tokens))),
isTRUE(all.equal(sum(mixtureStats$prob), length(data$tokens))))) {
stop("Issue found in the computation. Please verify that the input object is correct.")
}
} else {
mixtureStats <- L1stats
data.table::setnames(mixtureStats, "L1_prob", "prob")
mixtureStats <- mixtureStats[, list(topic, .id, prob)]
}
if (include_docvars) {
docvars <- quanteda::docvars(data$tokens)
docvars$.id <- names(data$tokens)
mixtureStats <- merge(mixtureStats, docvars, sort = FALSE)
data.table::setcolorder(mixtureStats, c(colnames(mixtureStats)[2:4], ".id"))
}
mixtureStats[]
}
# Accessors ---------------------------------------------------------------
## TODO: structure() for attributes ie structure(NextMethod(), class="foo")?
#' @export
`[[.multiChains` <- function(x, i, ...) {
base <- attr(x, "base")
containedClass <- attr(x, "containedClass")
x <- NextMethod()
core(x) <- base
fun <- get(paste0("as.", containedClass))
fun(reorder_sentopicmodel(x))
}
##TODO: add test for this?
#' @export
`$.multiChains` <- function(x, name, ...) {
if (!name %in% names(x)) {
#### This wont work with "alpha" "gamma" because they are under sentopicmodel form
x <- attr(x, "base")
x <- NextMethod()
} else {
base <- attr(x, "base")
containedClass <- attr(x, "containedClass")
x <- NextMethod()
core(x) <- base
fun <- get(paste0("as.", containedClass))
x <- fun(reorder_sentopicmodel(x))
}
x
}
#' @export
`[.multiChains` <- function(x, i, ...) {
base <- attr(x, "base")
nChains <- length(i)
rng <- attr(x, "rng")[i]
containedClass <- attr(x, "containedClass")
x <- NextMethod()
attr(x, "base") <- base
attr(x, "nChains") <- nChains
attr(x, "rng") <- rng
attr(x, "containedClass") <- containedClass
# TODO : check attributes?
class(x) <- "multiChains"
x
}
#' @export
`as.list.multiChains` <- function(x, copy = TRUE, ...) {
x <- unclass(x)
for (i in seq_along(x)) {
if (copy) core(x[[i]]) <- data.table::copy(attr(x, "base"))
else core(x[[i]]) <- attr(x, "base")
}
fun <- get(paste0("as.", attr(x, "containedClass")))
x[] <- lapply(x, function(xi) fun(reorder_sentopicmodel(xi)))
attr(x, "base") <- NULL
x
# res <- list()
# for (i in seq_along(x)) {
# res[[i]] <- data.table::copy(x[[i]])
# }
# names(res) <- names(x)
# res
}
# quanteda ----------------------------------------------------------------
#' @importFrom quanteda docvars
#' @export
quanteda::docvars
#' @method docvars sentopicmodel
#' @export
docvars.sentopicmodel <- function(x, field = NULL) {
quanteda::docvars(x$tokens)
}
#' @importFrom quanteda tokens
#' @export
quanteda::tokens
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.