R/set_up_general.R

Defines functions union_terms

#' set_up_general
#'
#' compute similarities between topics generated by new models and topics in original models
#'
#' @include rlda_general.R
#' @param r a rlda_general object
#' @exportMethod set_up_general
#'

setGeneric("set_up_general", function(r)standardGeneric("set_up_general"))
setMethod("set_up_general",
          signature(r = "rlda_general"),
          function (r) {
            if(length(r@beta_list) == 0)
            {
              stop("must have a nonempty beta list")
            }

            if(length(r@model_type) == 0)
            {
              stop("must have a nonempty model_type list")
            }

            k_list = r@K
            beta_list = list()
            gamma_list = list()
            feature_list = list()
            or_dtm = r@dtm
            other_dtms = r@other_dtms
            terms_u = or_dtm$dimnames$Terms
            model_type = r@model_type


            feature_list = list()
            j = 1
            feature_list[[1]] = apply(r@beta_list[[r@idx]], 1, function(x){terms_u[order(x, decreasing = TRUE)][1:10]})
            for( i in 1:length(r@K) )
            {
              j=j+1
              if (model_type[i] == "or")
              {
                j = j-1
                next
              }
              # similarity matrix result, each row is similarity between A_i with B
              ot_dtm_ct = 1
              if(model_type[i] == "diff_dtm")
              {
                diff_term = r@other_dtms[[ot_dtm_ct]]$dimnames$Terms
                feature_list[[j]] = apply(r@beta_list[[i]], 1, function(x){diff_term[order(x, decreasing = TRUE)][1:10]})
                ot_dtm_ct = ot_dtm_ct+1
              }
              else
              {
                feature_list[[j]] = apply(r@beta_list[[i]], 1, function(x){terms_u[order(x, decreasing = TRUE)][1:10]})
              }
            }

            if (length(other_dtms) > 0)
            {
              new_beta_tuple = union_terms(terms_u, r@idx, other_dtms, beta_list, model_type)
              r@beta_list = new_beta_tuple[[1]]
              r@terms = new_beta_tuple[[2]]
            }

            r@key_features = feature_list
            return(r)

          })



# utility functions for other dtms
union_terms <- function(dtm_terms,or_idx, list_of_dtms, beta_list, mod_type)
{
  # get union of terms
  list_of_dtm_terms = lapply(list_of_dtms, function(x) x$dimnames$Terms)
  list_of_dtm_terms[[length(list_of_dtms) + 1]] = dtm_terms
  all_terms = purrr::reduce(list_of_dtm_terms, function(x,y) union(x,y))

  additional_cols = matrix(0, nrow(beta_list[[or_idx]]), length(all_terms) - length(dtm_terms))
  term_order = c(dtm_terms, setdiff(all_terms, dtm_terms))
  dtm_ct = 1
  new_beta_list = list()
  for(i in 1:length(beta_list))
  {
    if(mod_type[i] == "diff_dtm")
    {
      alt_dtm_terms = list_of_dtm_terms[[dtm_ct]]
      new_words = setdiff(all_terms, alt_dtm_terms)
      sort_idx = match(term_order, c(alt_dtm_terms, new_words))
      additional_col_dtm = matrix(0, nrow(beta_list[[i]]), length(new_words))
      new_beta = cbind(beta_list[[i]], additional_col_dtm)[,sort_idx]
      new_beta_list[[i]] = new_beta
      dtm_ct = dtm_ct+1
    }
    else
    {
      additional_col_mat = matrix(0, nrow(beta_list[[i]]), length(all_terms) - length(dtm_terms))
      new_beta_list[[i]] = cbind(beta_list[[i]], additional_col_mat)
    }
  }
  return(list(new_beta_list,term_order))
}
CasAndreu/ldaRobust documentation built on May 29, 2019, 3 p.m.