R/mallet_lda.R

Defines functions mallet_lda

Documented in mallet_lda

#' A wrapper function for LDA using the MALLET machine learning toolkit -- an incredibly efficient, fast and well tested implementation of LDA. See http://mallet.cs.umass.edu/ and https://github.com/mimno/Mallet for much more information on this amazing set of libraries.
#'
#' @param documents Optional argument for providing the documents we wish to run
#' LDA on. Can be either a character vector with one string per document, a list
#' object where each entry is an (ordered) document-term vector with one list
#' entry per document, a dense document-term matrix where each row represents a
#' document, each column represents a term in the vocabulary, and entries are
#' document-term counts, or a sparse document term matrix (simple triplet matrix
#' from the slam library) -- preferably generated by quanteda::dfm and then
#' converted using convert_quanteda_to_slam().
#' @param document_directory Optional argumnt specifying a directory containing
#' .txt files (one per document) to be used for LDA. May only be used if
#' documents is NULL.
#' @param document_csv Optional argument specifying the path to a csv file
#' containing one document per line. MAy only be used if documents and
#' document_directory are NULL.
#' @param vocabulary An optional character vector (required if the user wishes
#' to not use hyper parameter optimization) specifying the vocabulary. If a
#' (sparse) document term matrix is provided, then this must be the same length
#' as the number of columns in the matrix, and should correspond to those columns.
#' @param topics The number of topics the user wishes to specify for LDA.
#' Defaults to 10.
#' @param iterations The number of collapsed Gibbs sampling iterations the user
#' wishes to specify. Defaults to 1000.
#' @param burnin The number of iterations to be discarded before assesing topic
#' model convergences via a Geweke test. Must be less than iterations. Not a
#' parameter passed to MALLET, only used for post-hoc convergence checking.
#' Defualts to 100.
#' @param alpha The alpha LDA hyperparameter. Defaults to 1.
#' @param beta The beta LDA hyperparameter. This value is multiplied by the size
#' of the vocabulary. Defaults to 0.01 which has worked well for the author in
#' the past.
#' @param hyperparameter_optimization_interval The interval (number of
#' iterations) at which LDA hyper-parameters should be optimized. Defaults to 0
#' -- meaning no hyper parameter optimization will be performed. If greater than
#' zero, the beta term need not be specified as it will be optimized regardless.
#' Generally a value of 5-10 works well and hyper parameter optimization will
#' often provide much better quality topics.
#' @param num_top_words The number of topic top-words returned in the model
#' output. Defaults to 20.
#' @param optional_arguments Allows the user to specify a string with additional
#' arguments for MALLET.
#' @param tokenization_regex Regular expression used for tokenization by MALLET.
#' Defaults to '[\\p{L}\\p{N}\\p{P}]+' meaning that all letters, numbers and
#' punctuation will be counted as tokens. May be adapted by the user, but double
#' escaping (\\) must be used by the user due to the way that escaping is
#' removed by R when piping to the console. Another perfectly reasonable choice
#' is '[\\p{L}]+', which only counts letters in tokens.
#' @param stopword_list Defaults to NULL. If not NULL, then a vector of terms
#' to be removed from the input text should be provided. Only implmeneted when
#' supplying the documents argument.
#' @param cores Number of cores to be used to train the topic model. Defualts
#' to 1.
#' @param delete_intermediate_files Defaults to TRUE. If FALSE, then all raw
#' ouput from MALLET will be left in a "./mallet_intermediate_files"
#' subdirectory of the current working directory.
#' @param memory The amount of Java heap space to be allocated to MALLET.
#' Defaults to '-Xmx10g', indicating 10GB of RAM will be allocated (at maximum).
#' Users may increase this limit if they are working with an exceptionally large
#' corpus.
#' @param only_read_in Defaults to FALSE. If TRUE, then the function only
#' attempts to read back in files from the completed MALLET run. This can be
#' useful if there was an error reading back in the topic reports (usually
#' due to some sort of weird symbols getting in).
#' @param unzip_command Defaults to "gunzip -k", which should work on a mac. This
#' command should be able to unzip a .txt.gz file and keep the original input
#' as a backup, which is what the "-k" option does here.
#' @param return_predictive_distribution Defaults to TRUE, but can be set to
#' FALSE if using a large coprus on a computer with relatively less RAM.
#' @param use_phrases Defaults to TRUE. When TRUE, the topic phrase reports are
#' returned. If FALSE, they are excluded.
#' @return Returns a list object with the following fields: lda_trace_stats is a
#' data frame reporting the beta hyperparameter value and model log likelihood
#' per token every ten iterations, can be useful for assesing convergence;
#' document_topic_proportions reports the document topic proportions for all
#' topics; topic_metadata reports the alpha x basemeasure values for all topics,
#' along with the total number of tokens assigned to each topic; topic_top_words
#' reports the 'num_top_words' top words for each topic (in descending order);
#' topic_top_word_counts reports the count of each top word in their respective
#' topics; topic_top_phrases reports top phrases (as found post-hoc by MALLET)
#' asscoiated with each topic; topic_top_phrase_counts reports the counts of
#' these phrases in each topic.
#' @examples
#' \dontrun{
#'files <- get_file_paths(source = "test sparse doc-term")
#'
#'sdtm <- generate_sparse_large_document_term_matrix(
#'    file_list = files,
#'    maximum_vocabulary_size = -1,
#'    using_document_term_counts = TRUE)
#'
#'test <- mallet_lda(documents = sdtm,
#'                   topics = 10,
#'                   iterations = 1000,
#'                   burnin = 100,
#'                   alpha = 1,
#'                   beta = 0.01,
#'                   hyperparameter_optimization_interval = 5,
#'                   cores = 1)
#' }
#' @export
mallet_lda <- function(documents = NULL,
                       document_directory = NULL,
                       document_csv = NULL,
                       vocabulary = NULL,
                       topics = 10,
                       iterations = 1000,
                       burnin = 100,
                       alpha = 1,
                       beta = 0.01,
                       hyperparameter_optimization_interval = 0,
                       num_top_words = 20,
                       optional_arguments = "",
                       tokenization_regex = '[\\p{L}\\p{N}\\p{P}]+',
                       stopword_list = NULL,
                       cores = 1,
                       delete_intermediate_files = TRUE,
                       memory = "-Xmx10g",
                       only_read_in = FALSE,
                       unzip_command = "gunzip -k",
                       return_predictive_distribution = TRUE,
                       use_phrases = TRUE){

    docnames <- NULL

    if (!only_read_in) {
        ###############################
        #### Step 0: Preliminaries ####
        ###############################

        #check to see that we have the selected version of corenlp installed
        test1 <- system.file("extdata","mallet.jar", package = "SpeedReader")[1]
        test2 <- system.file("extdata","mallet-deps.jar", package = "SpeedReader")[1]

        if(test1 != "" & test2 != ""){
            cat("Found MALLET JAR files...\n")
        }else{
            cat("MALLET Jar files not found, downloading...\n")
            download_mallet()
        }

        if(hyperparameter_optimization_interval == 0 & is.null(vocabulary)){
            stop("You must provide the vocabulary_size if you are not using hyperparameter optimization.")
        }

        if(hyperparameter_optimization_interval == 0){
            beta <- beta * length(vocabulary)
        }

        if(burnin >= iterations){
            burnin <- ceiling(iterations/2)
            cat("Burnin selected was too large, setting burnin to:",burnin,"...\n")
        }

        # save the current working directory
        currentwd <- getwd()

        USING_EXTERNAL_FILES <- FALSE
        USING_CSV <- FALSE
        #check to make sure that we have the right kind of input
        if(!is.null(documents) & is.null(document_directory) & is.null(document_csv)){
            if(class(documents) == "list" | class(documents) == "character"){

                if(class(documents) == "list"){

                    # get the rownames
                    if (!is.null(names(documents))) {
                        docnames <- names(documents)
                    }

                    # deal with the case where we got a list of term vectors
                    temp <- rep("", length(documents))
                    for(i in 1:length(temp)){
                        temp[i] <- paste0(documents[[i]],collapse = " ")
                    }
                    documents <- temp

                }
            }else if(class(documents) == "matrix"){
                if(is.null(vocabulary)){
                    vocabulary <- colnames(documents)
                    cat("No vocabulary supplied, using column names of document term matrix...\n")
                }
                if(length(vocabulary) != ncol(documents)){
                    stop(paste("Length of vocabulary:",length(vocabulary),"is not equal to the number of columns in the document term matrix:",ncol(documents)))
                }
                vocabulary <- stringr::str_replace_all(vocabulary," ","_")

                # optionally remove stopwords
                if (!is.null(stopword_list)) {
                    remove <- which(vocabulary %in% stopword_list)
                    if (length(remove) > 0) {
                        vocabulary <- vocabulary[-remove]
                        documents <- documents[,-remove]
                    }
                }

                # get the rownames
                if (!is.null(rownames(documents))) {
                    docnames <- rownames(documents)
                }

                #populate a string vector of documents from dtm
                cat("Populating document vector from document term matrix...\n")
                printseq_counter <- 1
                if(nrow(documents) > 9999){
                    printseq <- round(seq(1,nrow(documents), length.out = 1001)[2:1001],0)
                }else if(nrow(documents) > 199){
                    printseq <- round(seq(1,nrow(documents), length.out = 101)[2:101],0)
                }else{
                    printseq <- 1:nrow(documents)
                }
                temp_docs <- rep("",nrow(documents))
                for(i in 1:length(temp_docs)){
                    if(printseq[printseq_counter] == i){
                        cat(printseq_counter,"/",length(printseq)," complete...\n",sep = "")
                        printseq_counter <- printseq_counter +1
                    }
                    str <- NULL
                    colindexes <- which(documents[i,] > 0)
                    if(length(colindexes) > 0){
                        for(k in 1:length(colindexes)){
                            str <- c(str,
                                     rep(vocabulary[colindexes[k]],
                                         documents[i,colindexes[k]]))
                        }
                        temp  <- paste0(str,collapse = " ")
                    }else{
                        temp <- ""
                    }
                    temp_docs[i] <- temp
                }
                documents <- temp_docs


            } else if (class(documents) == "simple_triplet_matrix"){
                if (is.null(vocabulary)){
                    vocabulary <- colnames(documents)
                    cat("No vocabulary supplied, using column names of document term matrix...\n")
                }
                if(length(vocabulary) != ncol(documents)){
                    stop(paste("Length of vocabulary:",length(vocabulary),"is not equal to the number of columns in the document term matrix:",ncol(documents)))
                }
                vocabulary <- stringr::str_replace_all(vocabulary," ","_")

                # get the rownames
                if (!is.null(rownames(documents))) {
                    docnames <- rownames(documents)
                }

                # optionally remove stopwords
                if (!is.null(stopword_list)) {
                    remove <- which(vocabulary %in% stopword_list)
                    if (length(remove) > 0) {
                        vocabulary <- vocabulary[-remove]
                        documents <- documents[,-remove]
                    }
                }

                cat("Making sure rows are properly ordered...\n")
                ord <- order(documents$i,decreasing = FALSE)
                documents$i <- documents$i[ord]
                documents$j <- documents$j[ord]
                documents$v <- documents$v[ord]

                #populate a string vector of documents from dtm
                cat("Populating document vector from document term matrix...\n")
                printseq_counter <- 1
                if (nrow(documents) > 9999){
                    printseq <- round(seq(1,nrow(documents), length.out = 1001)[2:1001],0)
                }else if(nrow(documents) > 199){
                    printseq <- round(seq(1,nrow(documents), length.out = 101)[2:101],0)
                }else{
                    printseq <- 1:nrow(documents)
                }
                temp_docs <- rep("",nrow(documents))

                # we are going to loop through the document indices which we expect
                # go up one at a time
                start <- 1
                stop <- 1
                for(i in 1:length(temp_docs)){
                    if(printseq[printseq_counter] == i){
                        cat(i,"/",length(temp_docs)," complete...\n",sep = "")
                        printseq_counter <- printseq_counter +1
                    }
                    # if we are at less than the last document, do this
                    #
                    if (documents$i[stop] > i) {
                        indexes <- NULL
                    } else {
                        if (i < length(temp_docs)) {
                            # precheck to deal with documents of length zero
                            while (documents$i[stop] == i) {
                                stop <- stop + 1
                            }
                            indexes <- start:(stop - 1)
                            start <- stop
                        } else {
                            indexes <- start:length(documents$i)
                        }
                    }
                    if (length(indexes) > 0) {
                        colindexes <- documents$j[indexes]
                        repeats <- documents$v[indexes]
                        cur_vocab <- vocabulary[colindexes]
                        # now allocate a vector to fill so we do not need to keep
                        # concatenating
                        this_doc <- rep("", sum(repeats))
                        cur_counter <- 1
                        for (k in 1:length(colindexes)) {
                            # number of times to repeat the word
                            for (l in 1:repeats[k]) {
                                this_doc[cur_counter] <- cur_vocab[k]
                                cur_counter <- cur_counter + 1
                            }
                        }

                        # for(k in 1:length(colindexes)){
                        #     str <- c(str,
                        #              rep(vocabulary[colindexes[k]],
                        #                  repeats[k]))
                        # }
                        temp  <- paste0(this_doc,collapse = " ")
                    }else{
                        temp <- ""
                    }
                    temp_docs[i] <- temp
                }
                documents <- temp_docs


            }else{
                stop("You must provide a 'documents' object as either a vector of strings (one per document),a list of string vectors (one entry per document), or a dense (or sparse) document-term matrix...")
            }
        }else if(is.null(documents) & !is.null(document_directory) & is.null(document_csv)){
            USING_EXTERNAL_FILES <- TRUE
            substrRight <- function(x, n){
                substr(x, nchar(x)-n+1, nchar(x))
            }

            # prepare text to be used
            documents <- paste(check_directory_name(document_directory),
                               list.files(path = document_directory), sep = "")
            #only use files with a .txt ending
            endings <- as.character(sapply(documents,substrRight,4))
            txtfiles <- which(endings == ".txt")
            if (length(txtfiles) > 0) {
                documents <- documents[txtfiles]
                docnames <- documents
            }else{
                stop("Did not find any valid .txt files in the specified directory...")
            }
            #read in documents
            temp_docs <- rep("",length(documents))
            for (i in 1:length(documents)) {
                temp_docs[i] <- paste0(readLines(documents[i], warn = F),
                                       collapse = " ")
            }
            documents <- temp_docs
        } else if (is.null(documents) & is.null(document_directory) & !is.null(document_csv)) {
            USING_CSV <- TRUE
        } else {
            stop("You must specify either a valid documents object or a valid document_directory directory path (but not both)...")
        }

        directory <- system.file("extdata", package = "SpeedReader")[1]

        ##############################################
        #### Step 1: Output documents to tsv file ####
        ##############################################
        cat("Outputing documents in correct format for MALLET...\n")

        # create an intermediate directory
        success <- dir.create("mallet_intermediate_files",showWarnings = FALSE)
        if (!success) {
            file.remove("./mallet_intermediate_files")
            success <- dir.create("mallet_intermediate_files")
        }
        if (!success) {
            stop("Could not create the intermdiate file directory necessary to use coreNLP. This is likely due to a file premission error. Make usre you have permission to create files or run your R session as root.")
        }
        setwd("./mallet_intermediate_files")

        if (!USING_CSV) {

            # make sure to remove all &:
            num_docs <- length(documents)
            for (i in 1:num_docs) {
                documents[i] <- stringr::str_replace_all(documents[i],"&","and")
            }


            # CSV format -- 1 line per document:
            # doc_id\t\tdoc_text
            data <- matrix("",nrow = num_docs,ncol = 3)
            for (i in 1:num_docs) {

                # set the document names
                if (!is.null(docnames)) {
                    data[i,1] <- docnames[i]
                } else {
                    data[i,1] <- i
                }

                data[i,2] <- ""
                data[i,3] <- documents[i]
            }
            cat("Writing corpus to file...")

            # write the data to file:
            write.table(data, file = "mallet_input_corpus.csv", quote = FALSE,
                        row.names = F,col.names = F, sep = "\t" )
        }


        ############################################
        #### Step 2: Preprocess data for MALLET ####
        ############################################

        # ," --print-output > stdout_intake.txt" should add this as an option in the future
        cat("Converting input to MALLET format...\n")
        # prepare the data for use with Mallet's LDA routine
        if (USING_CSV) {
            prepare_data <- paste("java -server ",memory," -XX:-UseConcMarkSweepGC -XX:-UseGCOverheadLimit -classpath ",directory,"/mallet.jar:",directory,"/mallet-deps.jar cc.mallet.classify.tui.Csv2Vectors --keep-sequence --token-regex '",tokenization_regex,"' --output mallet_corpus.dat --input ",document_csv, sep = "")
        } else {
            prepare_data <- paste("java -server ",memory," -XX:-UseConcMarkSweepGC -XX:-UseGCOverheadLimit -classpath ",directory,"/mallet.jar:",directory,"/mallet-deps.jar cc.mallet.classify.tui.Csv2Vectors --keep-sequence --token-regex '",tokenization_regex,"' --output mallet_corpus.dat --input mallet_input_corpus.csv", sep = "")
        }


        #2>&1
        #print(prepare_data)
        p <- pipe(prepare_data,"r")
        close(p)

        ####################################
        #### Step 3: Run LDA via MALLET ####
        ####################################

        cat("Fitting topic model. This may take anywhere from seconds to days depending on the size of your corpus. Check: ",getwd(),"/stdout.txt for estimation progress...\n", sep = "")
        # Now run LDA
        if (hyperparameter_optimization_interval != 0) {
            run_mallet <- paste("java -server ",memory," -XX:-UseConcMarkSweepGC -XX:-UseGCOverheadLimit -classpath ",directory,"/mallet.jar:",directory,"/mallet-deps.jar cc.mallet.topics.tui.Vectors2Topics --input mallet_corpus.dat --output-state output_state.txt.gz --output-topic-keys topic-keys.txt --xml-topic-report topic-report.xml --xml-topic-phrase-report topic-phrase-report.xml --inferencer-filename inferencer.mallet --output-doc-topics doc-topics.txt --num-topics ",topics," --num-iterations ",iterations," --num-top-words ",num_top_words," --num-threads ",cores," --optimize-interval ",hyperparameter_optimization_interval," --optimize-burn-in ",hyperparameter_optimization_interval," ",optional_arguments," > stdout.txt 2>&1", sep = "")
            # 2>&1&
        } else {
            run_mallet <- paste("java -server ",memory," -XX:-UseConcMarkSweepGC -XX:-UseGCOverheadLimit -classpath ",directory,"/mallet.jar:",directory,"/mallet-deps.jar cc.mallet.topics.tui.Vectors2Topics --input mallet_corpus.dat --output-state output_state.txt.gz --output-topic-keys topic-keys.txt --xml-topic-report topic-report.xml --xml-topic-phrase-report topic-phrase-report.xml --inferencer-filename inferencer.mallet --output-doc-topics doc-topics.txt --num-topics ",topics," --num-iterations ",iterations," --num-top-words ",num_top_words, " --num-threads ",cores," --beta ",beta," ",optional_arguments," > stdout.txt 2>&1", sep = "")
            #  2>&1&
        }
        #print(run_mallet)
        p <- pipe(run_mallet,"r")
        close(p)
    } else {
        # save the current working directory
        currentwd <- getwd()
        # if we are only reading back in the files, then just set the working
        # directory:
        setwd("./mallet_intermediate_files")
    }



    #######################################
    #### Step 4: Read the data back in ####
    #######################################
    cat("Reading MALLET output back into R...\n")
    topic_report <- XML::xmlParse("topic-report.xml")
    topic_report <- XML::xmlToList(topic_report)
    if (use_phrases) {
        topic_phrase_report <- XML::xmlParse("topic-phrase-report.xml")
        topic_phrase_report <- XML::xmlToList(topic_phrase_report)
    }
    stdout <- readLines("stdout.txt",warn = FALSE)
    document_topics <- read.table(file = "doc-topics.txt",sep = "\t",header = FALSE)

    #####################################################
    # 4.1 Turn document-topic table into a useable form #
    #####################################################

    # first get rid of bad columns
    vector_is_NA <- function(vec){
        if(length(which(is.na(vec))) == length(vec)){
            return(1)
        }else{
            return(0)
        }
    }
    NA_columns <- apply(document_topics,2,vector_is_NA)
    NA_columns <- which(NA_columns == 1)
    if(length(NA_columns) > 0){
        document_topics <- document_topics[,-NA_columns]
    }

    document_topics <- document_topics[,-c(1,2)]
    #create a populate a clean doc-topics table
    temp <- matrix(0,nrow = nrow(document_topics),ncol = topics)
    col_ind_L <- TRUE
    col_ind <- 0
    for(i in 1:nrow(document_topics)){
        for(j in 1:ncol(document_topics)){
            if(col_ind_L){
                col_ind_L <- FALSE
                col_ind <- document_topics[i,j] + 1
            }else{
                col_ind_L <- TRUE
                temp[i,col_ind] <- document_topics[i,j]
            }
        }
    }
    document_topics <- temp

    if (!is.null(docnames)) {
        rownames(document_topics) <- docnames
    } else {
        rownames(document_topics) <- paste("document_",
                                           1:nrow(document_topics), sep = "")
    }

    colnames(document_topics) <- paste("topic_",
                                       1:topics, sep = "")

    document_topics <- as.data.frame(document_topics)

    ####################################################
    # 4.2 Extract useful trace information from stdout #
    ####################################################

    # assumes that LL/token is printed out every 10 iterations
    # I am only going to extract [beta: ] values on every 10.

    trace_stats <- data.frame(iteration = seq(10,iterations,by = 10),
                              beta = rep(NA,ceiling(iterations/10)),
                              LL_Token = rep(NA,ceiling(iterations/10)))
    counter <- 0
    for(i in 1:length(stdout)){
        if(grepl("LL/token",stdout[i])){
            counter <- counter + 1
            trace_stats$LL_Token[counter] <- as.numeric(stringr::str_split(stdout[i],":")[[1]][2])
            next_beta <- TRUE
        }
        if(grepl("\\[beta:",stdout[i])){
            temp <- stringr::str_split(stdout[i],":")[[1]][2]
            trace_stats$beta[counter] <- as.numeric(stringr::str_replace(temp,"\\]",""))
            next_beta <- FALSE
        }
    }

    ########################################################
    # 4.3 Read in topic report and put it in a nice format #
    ########################################################

    top_words <- data.frame(matrix("",nrow = topics, ncol = num_top_words),
                            stringsAsFactors = F)
    rownames(top_words) <- paste("topic_",1:topics, sep = "")
    colnames(top_words) <- paste("top_word_",1:num_top_words, sep = "")
    top_word_counts <- data.frame(matrix(NA,nrow = topics, ncol = num_top_words),
                            stringsAsFactors = F)
    rownames(top_word_counts) <- paste("topic_",1:topics, sep = "")
    colnames(top_word_counts) <- paste("top_word_",1:num_top_words, sep = "")
    topic_data <- data.frame(matrix(NA,nrow = topics, ncol = 2),
                                  stringsAsFactors = F)
    rownames(topic_data) <- paste("topic_",1:topics, sep = "")
    colnames(topic_data) <- c("alpha","total_tokens")

    for(i in 1:length(topic_report)){
        cur <- topic_report[[i]]
        if(length(cur) == num_top_words+1){
            for(j in 1:(length(cur)-1)){
                top_words[i,j] <- cur[[j]]$text
                top_word_counts[i,j] <- as.numeric(cur[[j]]$.attrs[2])
            }
        }
        topic_data[i,1] <- as.numeric(cur[[length(cur)]][2])
        topic_data[i,2] <- as.numeric(cur[[length(cur)]][3])
    }

    if (use_phrases) {
        ###############################################################
        # 4.4 Read in topic phrase report and put it in a nice format #
        ###############################################################

        top_phrases <- data.frame(matrix(NA,nrow = topics, ncol = num_top_words),
                                  stringsAsFactors = F)
        rownames(top_phrases) <- paste("topic_",1:topics, sep = "")
        colnames(top_phrases) <- paste("top_phrase_",1:num_top_words, sep = "")
        top_phrase_counts <- data.frame(matrix(NA,nrow = topics, ncol = num_top_words),
                                        stringsAsFactors = F)
        rownames(top_phrase_counts) <- paste("topic_",1:topics, sep = "")
        colnames(top_phrase_counts) <- paste("top_phrase_",1:num_top_words, sep = "")


        for(i in 1:length(topic_phrase_report)){
            cur2 <- topic_phrase_report[[i]]
            if(length(cur2) == (2*num_top_words)+1){
                for(j in 1:num_top_words){
                    top_phrases[i,j] <- cur2[[j+num_top_words]]$text
                    top_phrase_counts[i,j] <- as.numeric(cur2[[j+num_top_words]]$.attrs[2])
                }
            }
        }

        LDA_Results <- list(lda_trace_stats = trace_stats,
                            document_topic_proportions = document_topics,
                            topic_metadata = topic_data,
                            topic_top_words = top_words,
                            topic_top_word_counts = top_word_counts,
                            topic_top_phrases = top_phrases,
                            topic_top_phrase_counts = top_phrase_counts)
    } else {

        LDA_Results <- list(lda_trace_stats = trace_stats,
                            document_topic_proportions = document_topics,
                            topic_metadata = topic_data,
                            topic_top_words = top_words,
                            topic_top_word_counts = top_word_counts)
    }



    #####################################################
    # 4.5 Read in topic word type counts, alpha_m, beta #
    #####################################################

    if (return_predictive_distribution) {
        temp <- get_term_topics(num_topics = topics,
                                unzip_command = unzip_command)

        LDA_Results <- append(LDA_Results,temp)
    }


    ###############################################
    #### Step 5: Cleanup and Return Everything ####
    ###############################################

    UMASS_BLUE <- rgb(51,51,153,255,maxColorValue = 255)
    cat("Assessing topic model convergence...\n")
    try({
        plot( y = trace_stats$LL_Token[ceiling(burnin/10):length(trace_stats$LL_Token)],
              x = trace_stats$iteration[ceiling(burnin/10):length(trace_stats$LL_Token)],
              pch = 19, col = UMASS_BLUE,
              main = paste(
                  "Un-Normalized Topic Model Log Likelihood \n",
                  " Geweke Statistic for Last",
                  length(ceiling(burnin/10):length(trace_stats$LL_Token)),
                  "Iterations:",
                  round(coda::geweke.diag(
                      trace_stats$LL_Token[ceiling(burnin/10):length(trace_stats$LL_Token)])$z,
                      2)),
              xlab = "Iteration", ylab = "Log Likelihood",
              cex.lab = 2, cex.axis = 1.4, cex.main = 1.4)
    })
    # remove
    if (delete_intermediate_files) {
        setwd("..")
        unlink("./mallet_intermediate_files",recursive = T)
    }

    #reset the working directory
    setwd(currentwd)

    return(LDA_Results)
}
matthewjdenny/SpeedReader documentation built on March 25, 2020, 5:32 p.m.