nlcm_doubletree: wrapper function for fitting and summaries

View source: R/nlcm_doubletree.R

nlcm_doubletreeR Documentation

wrapper function for fitting and summaries

Description

wrapper function for fitting and summaries

Usage

nlcm_doubletree(
  Y,
  leaf_ids,
  mytrees,
  weighted_edges = c(TRUE, TRUE),
  ci_level = 0.95,
  get_lcm_by_group = FALSE,
  update_hyper_freq = 50,
  print_freq = 10,
  quiet = FALSE,
  plot_fig = FALSE,
  hyper_fixed = list(K = 2, LD = TRUE, do_tree1_update = TRUE),
  tol = 1e-08,
  tol_hyper = 1e-04,
  max_iter = 5000,
  nrestarts = 3,
  keep_restarts = TRUE,
  parallel = TRUE,
  log_restarts = FALSE,
  log_dir = ".",
  vi_params_init = list(),
  hyperparams_init = list(),
  random_init = FALSE,
  random_init_vals = list(mu_gamma_sd_frac = 0.2, mu_alpha_sd_frac = 0.2, tau1_lims =
    c(0.5, 1.5), tau2_lims = c(0.5, 1.5), u_sd_frac = 0.2, psi_sd_frac = 0.2, phi_sd_frac
    = 0.2)
)

Arguments

Y

N by J binary data matrix; rows for subjects, columns for features; missing entries, if present, are encoded by NAs. Note that the rows of Y will be reordered twice, once according to leaf nodes/missing in tree1, the second time according to the leaf nodes in tree2.

leaf_ids

A list containing two elements. For example, the first can be a vector of character strings for leaf nodes for each observation, representing the leaf nodes in tree1; similarly for the second tree; NA represents for missing leaf info. For each observation, the pair of labels indicates the leaf memberships in the two trees contained in mytrees, respectively. For example, in verbal autopsy (VA) applications, for data in the source domains, we must have both leaf ids observed; for data in the target domain, we can have the leaf id in the first tree (i.e., cause tree) NA, indicating an unobserved cause of death (hence unknown to analysts which leaf in the cause tree should an observation be placed). In this package, we only allow NA for leaf labels in tree1; tree2's leaves represent domains, which must be known when doing domain adaptation. Currently the NA in tree1, if present, can only contain ALL subjects from a single leaf node in tree2. NB: Extensions to deal with CODs that are partially observed in the target domain need additional work...

mytrees

A list of two elements: tree1,tree2; both are igraph objects. They may contain attributes, such as node, edge, edge lengths. (NB: need refinement)

weighted_edges

a vector of logical values, indicating whether to use weighted edges in the two trees; default to c(FALSE,FALSE), i.e., not using weighted edges and assuming the edges in the trees are unit lengths.

ci_level

A number between 0 and 1 giving the desired credible interval. For example, ci_level = 0.95 (the default) returns a 95% credible interval

get_lcm_by_group

If TRUE, doubletree will also return the maximum likelihood estimates of the coefficients for each leaf_ids group discovered by the model. Default is TRUE.

update_hyper_freq

How frequently to update hyperparameters. Default = every 50 iterations.

print_freq

How often to print out iteration number and current value of epsilon (the difference in objective function value for the two most recent iterations).

quiet

default to FALSE, which prints empirical class probabilities and updates on tau's

plot_fig

plot figure about prob (the probability of each node diffuse from the parent node, i.e., s_u=1 for using the slab component) and response profile (1st node)

hyper_fixed

Fixed values of hyperprior parameters.

tol

Convergence tolerance for the objective function. Default is 1E-8.

tol_hyper

The convergence tolerance for the objective function between subsequent hyperparameter updates. Typically it is a more generous tolerance than tol. Default is 1E-4.

max_iter

Maximum number of iterations of the VI algorithm. Default is 5000. NB: check this number before package submission.

nrestarts

Number of random re-starts of the VI algorithm. The restart that gives the highest value of the objective function will be returned. It is recommended to choose nrestarts > 1; The default is 3.

keep_restarts

If TRUE, the results from all random restarts will be returned. If FALSE, only the restart with the highest objective function is returned. ' Default is TRUE.

parallel

If TRUE, the random restarts will be run in parallel. It is recommended to first set the number of cores using doParallel::registerDoParallel(). Otherwise, the default number of cores specified by the doParallel package will be used. Default is TRUE.

log_restarts

If TRUE, when nrestarts > 1 progress of each random restart will be logged to a text file in log_dir. If FALSE and nrestarts > 1, progress will not be shown. If nrestarts = 1, progress will always be printed to the console. Default is FALSE.

log_dir

Directory for logging progress of random restarts. Default is the working directory.

vi_params_init, hyperparams_init

Named lists containing initial values for the variational parameters and hyperparameters. Supplying good initial values can be challenging, and lotR() provides a way to guess initial values based on transformations of latent class model estimates for each individual leaf_ids (see initialize_tree_lcm()). The most common use for vi_params_init and hyperparams_init is to supply starting values based on previous output from lotR(); see the vignette('lotR') for examples. The user can provide initial values for all parameters or a subset. When initial values for one or more parameters are not supplied, the missing values will be filled in by initialize_nlcm_doubletree().

random_init

If TRUE, some random variability will be added to the initial values. The default is FALSE, unless nrestarts > 1, in which case random_init will be set to TRUE and a warning message will be printed. The amount of variability is determined by random_init_vals.

random_init_vals

If random_init = TRUE, this is a list containing the following parameters for randomly permuting the initial values. NB: The following are copied from lotR; so need edits!!!!!!!

tau_lims

a vector of length 2, where tau_lims[1] is between 0 and 1, and tau_lims[2] > 1. The initial values for the hyperparameter tau will be chosen uniformly at random in the range ⁠(tau_init * tau_lims[1], tau_init * tau_lims[2])⁠, where tau_init is the initial value for tau either supplied in hyperparams_init or guessed using initialize_nlcm_doubletree().

psi_sd_frac

a value between 0 and 1. The initial values for the auxiliary parameters psi will have a normal random variate added to them with standard deviation equal to psi_sd_frac multiplied by the initial value for eta either supplied in hyperparams_init or guessed using initialize_nlcm_doubletree(). Absolute values are then taken for any values of psi that are ⁠< 0⁠.

phi_sd_frac

same as above

.

mu_gamma_sd_frac

a value between 0 and 1. The initial values for mu will have a normal random variate added to them with standard deviation equal to mu_sd_frac multiplied by the absolute value of the initial value for mu_gamma_sd_frac either supplied in vi_params_init or guessed using initialize_nlcm_doubletree().

mu_alpha_sd_frac

same as above.

u_sd_frac

a value between 0 and 1. The initial value for the node inclusion probabilities will first be transformed to the log odds scale to obtain u. A normal random variate will be added to u with standard deviation equal to u_sd_frac multiplied by the absolute value of the initial value for u either supplied in vi_params_init or guessed using moretrees_init_logistic(). u will then be transformed back to the probability scale.

Value

a list also of class "nlcm_doubletree"; NB: need to create a simulated example that uses this function!

res <- make_list(mod,mod_restarts,mytrees,dsgn,prob_est,est_ad_hoc) class(res) <- c("nlcm_doubletree","list")

Examples

rm(list=ls())

library(igraph)
library(doubletree)
library(MASS)
library(poLCA)
library(BayesLCA)

data("example_data_doubletree")

# second tree - over domains:
data("example_domain_edges")
domain_tree <- graph_from_edgelist(example_domain_edges, directed = TRUE)

# set the levels l*_u for nodes in the cause tree; nodes
# in the same level will share a slab variance multiplier tau_l* (times the edge length
# eminating from the parent).
igraph::V(domain_tree)$levels <- c(1,rep(2,length(igraph::V(domain_tree))-1))
igraph::E(domain_tree)$weight <- rep(1,length(E(domain_tree)))
nodes2  <- names(igraph::V(domain_tree))
leaves2 <- names(igraph::V(domain_tree)[igraph::degree(domain_tree, mode = "out") == 0])
rootnode2 <- names(igraph::V(domain_tree)[igraph::degree(domain_tree, mode = "in") == 0])
pL2 <- length(leaves2)
p2  <- length(nodes2)

# # first tree - over causes:
cause_tree <- domain_tree # for simplicity, make then the same for illustration.

igraph::V(cause_tree)$levels <- c(1,rep(2,length(igraph::V(cause_tree))-1))
igraph::E(cause_tree)$weight <- rep(1,length(E(cause_tree)))

nodes1  <- names(igraph::V(cause_tree))
leaves1 <- names(igraph::V(cause_tree)[igraph::degree(cause_tree, mode = "out") == 0])
rootnode1 <- names(igraph::V(cause_tree)[igraph::degree(cause_tree, mode = "in") == 0])
pL1 <- length(leaves1)
p1  <- length(nodes1)

# create a new doubletree list for potentially modifying the levels for the nodes
# in the two trees:
working_mytrees <- list(tree1 = cause_tree, tree2 = domain_tree)

# get lists of ancestors for each leaf_ids:
d1 <- igraph::diameter(working_mytrees[[1]],weights=NA)
# need to set weight=NA to prevent the use of edge lengths in determining the diameter.
ancestors1 <- igraph::ego(working_mytrees[[1]], order = d1 + 1, nodes = leaves1, mode = "in")
ancestors1 <- sapply(ancestors1, names, simplify = FALSE)
ancestors1 <- sapply(ancestors1, function(a, nodes) which(nodes %in% a),
                     nodes = nodes1, simplify = FALSE)
names(ancestors1) <- leaves1

# get lists of ancestors for each leaf_ids:
d2 <- igraph::diameter(working_mytrees[[2]],weights=NA)
# need to set weight=NA to prevent the use of edge lengths in determining the diameter.
ancestors2 <- igraph::ego(working_mytrees[[2]], order = d2 + 1, nodes = leaves2, mode = "in")
ancestors2 <- sapply(ancestors2, names, simplify = FALSE)
ancestors2 <- sapply(ancestors2, function(a, nodes) which(nodes %in% a),
                     nodes = nodes2, simplify = FALSE)
names(ancestors2) <- leaves2

# print("counts in simulation:")
example_data_doubletree$N_sim_mat

# FITTING MODELS:
working_leaf_ids <- vector("list",2)
working_leaf_ids[[1]] <- leaves1[example_data_doubletree$truth$true_leaf_ids[,1]]
working_leaf_ids[[1]][example_data_doubletree$truth$true_leaf_ids[,2]==1] <- NA
# working_leaf_ids[[1]][example_data_doubletree$truth$true_leaf_ids[,2]==2] <- NA
working_leaf_ids[[2]] <- leaves2[example_data_doubletree$truth$true_leaf_ids[,2]]

nrestarts <- 3 # the number of random initializations.

# doParallel::registerDoParallel(cores = nrestarts)
# log_dir <- tempdir()
# dir.create(log_dir)

## Not run:------------------------------------------------
# mod <- nlcm_doubletree(
#   example_data_doubletree$Y,
#   working_leaf_ids,
#   working_mytrees,
#   weighted_edges = c(FALSE,FALSE),
#   ci_level = 0.95,
#   get_lcm_by_group = FALSE,
#   update_hyper_freq = 20,
#   print_freq = 20,
#   quiet      = FALSE,
#   plot_fig   = FALSE, # <-- used?
#   tol        = 1E-8,
#   tol_hyper = 1E-4,
#   max_iter = 1000,
#   nrestarts = nrestarts,
#   keep_restarts = TRUE,
#   parallel = TRUE,
#   # log_restarts = TRUE,
#   # log_dir = log_dir,
#   #log_restarts = FALSE,
#   #log_dir = ".",
#   vi_params_init = list(),
#   hyperparams_init = list(),
#   random_init = FALSE,
#   hyper_fixed = list(
#     K=2, LD=TRUE,# number of latent classes.
#     a1 = rep(20,max(igraph::V(cause_tree)$levels)),
#     b1 = rep(1,max(igraph::V(cause_tree)$levels)),
#     a2=matrix(1,nrow=length(ancestors1),ncol=max(igraph::V(domain_tree)$levels)),
#     # <-- NB: where do we specify levels? in the tree.
#     b2=matrix(10,nrow=length(ancestors1),ncol=max(igraph::V(domain_tree)$levels)),
#     # both (a1,b1),(a2,b2) can encourage shrinkage towards the parent.
#     dmat = matrix(1,nrow=length(ancestors1),ncol=length(ancestors2)), # (cause,domain).
#     s1_u_zeroset = NULL,
#     #s1_u_oneset = NULL, # not force diffusion.
#     s1_u_oneset = 1:p1, # force diffusion.
#     #s2_cu_zeroset = rep(list(2:p2),pL1), # force NO diffusion in non-roots tree2.
#     s2_cu_zeroset = NULL,
#     s2_cu_oneset = rep(list(1),pL1), # no force diffusion tree2.
#     hyperparams_init = list(tau_1=1.5^2,
#                             tau_2=1.5^2),
#     tau_update_levels = list(c(1,2),c(1,2)))
# )
#
#
# # get the design output; this is needed because the function reorders the data rows:
# dsgn0 <- design_doubletree(example_data_doubletree$Y,working_leaf_ids,
#                            working_mytrees)
#
# # for each tree1 leaf, look at shrinkage structure across tree2:
# par(mfrow=c(ceiling(sqrt(pL1+1)),ceiling(sqrt(pL1+1))),
#     mar=c(1,1,1,1))
# for (u in 1:pL1){
#   plot(mod$mod$vi_params$prob2[[u]],type="h",ylim=c(0,1));abline(h=0.5)
# }
# # look at shrinkage structure across tree1:
# plot(mod$mod$vi_params$prob1,type="h",ylim=c(0,1),col="blue");abline(h=0.5)
#
# do.call("rbind",mod$mod$vi_params$prob2)
#
# heatmap(mod$mod$vi_params$emat[is.na(dsgn0$leaf_ids[[1]]),],Rowv=NA,Colv=NA)
# # heatmap(mod$mod$vi_params$emat[!is.na(dsgn0$leaf_ids[[1]]),],Rowv=NA,Colv=NA)
# # heatmap(mod$mod$vi_params$emat,Rowv=NA,Colv=NA)
#
# # posterior means of CSMFs:
# sweep(mod$mod$vi_params$dirich_mat,MARGIN = 2,colSums(mod$mod$vi_params$dirich_mat),"/")
#
# # visualize tree2 root node class probabilities for each tree1 leaf; can change to
# # nodes other than tree2 root node:
# heatmap(apply(mod$mod$vi_params$mu_alpha[[1]],1,function(v) tsb(c(expit(v),1))),Rowv=NA,Colv=NA)
#
# #
# # CLASSIFICATION:
# #
# # MAP cause assignment:
# xx <- mod$mod$vi_params$emat[is.na(dsgn0$leaf_ids[[1]]),]
# apply(xx,1,which.max)
#
# # true causes:
# na_index <- which(is.na(dsgn0$leaf_ids[[1]]))
# example_data_doubletree$truth$true_leaf_ids[dsgn0$all_ord[na_index],1]
#
#
# # domains of the observations missing tree1 leaf label:
# truth <- example_data_doubletree$truth$true_leaf_ids[dsgn0$all_ord[na_index],1]
# map_nlcm <- apply(xx,1,which.max)
#
# table(map_nlcm,truth)
# sum(map_nlcm!=truth)/length(map_nlcm)
#
#
# #
# # CSMF accuracy (the function is obtained from openVA):
# # check how to get top1, top3 cause classification accuracy.
# #
# acc_CSMF <- rep(NA,pL2)
# for (g in 1:pL2){
#   acc_CSMF[g] <- openVA::getCSMF_accuracy(
#     sweep(mod$mod$vi_params$dirich_mat,MARGIN = 2,
#           colSums(mod$mod$vi_params$dirich_mat),"/")[,g],
#     example_data_doubletree$truth$pi_mat[,g])
# }
# print(acc_CSMF)
#
# #
# # top k accuracy:
# #
# k = 1
# pred_top <- get_topk_COD(xx,1)
# acc_topk(pred_top,truth)
#
# #
# # RESPONSE PROBABILITIES:
# #
# itemprob_list_est <-list()
# for (v1 in 1:pL1){
#   itemprob_list_est[[v1]] <- t(expit(Reduce("+",mod$mod$vi_params$mu_gamma[ancestors1[[v1]]])))
# }
#
# par(mfrow=c(ceiling(sqrt(pL1+1)),ceiling(sqrt(pL1+1))),
#     mar=c(1,1,1,1))
# for (v1 in 1:pL1){
#   image(itemprob_list_est[[v1]],main=v1)
# }
#
# #
# # LATENT CLASS PROBABILITIES:
# #
#
# # mixture probabilities:
# mixprob_list_est <-list()
# for (v2 in 1:pL2){
#   tmp <- t(expit(Reduce("+",mod$mod$vi_params$mu_alpha[ancestors2[[v2]]])))
#   mixprob_list_est[[v2]] <- apply(tmp,2,function(v){tsb(c(v,1))})
# }
#
# par(mfrow=c(ceiling(sqrt(pL2+1)),ceiling(sqrt(pL2+1))),
#     mar=c(1,1,1,1))
# for (v2 in 1:pL2){
#   image(mixprob_list_est[[v2]],main=v2)
# }
#
# # individual-level:
# heatmap(mod$mod$vi_params$rmat,Rowv=NA,Colv=NA)
#
# #
# # ELBO trajectory
# #
# plot(mod$mod$ELBO_track,type="o",main="ELBO trajectory")
#
# ## show the ELBO trajectory:
# # library(plotly)
# # tmp_df <- data.frame(iteration = 1:length(mod$mod$ELBO_track),
# #                      ELBO=mod$mod$ELBO_track)
# # plot_ly(tmp_df,x=~iteration,y = ~ELBO,type = 'scatter', mode = 'lines')
#
#
# # need to write a function to determine the misclassification rates.
#
# # how do we avoid tree1 leaf specific class relabeling.
# # perhaps need to consider the same set of alphas
# # no matter which tree1 leaf it is.
#
# ## need to confirm the classification performance; here we have very collapsed
# ## tree for the causes, so the classification performance likely would not be great.
#
# # window <- -c(1:400)
# # plot(mod$mod$line_track[window,1],type="o")
# # for (l in 2:17){
# #   plot(mod$mod$line_track[window,l],type="o",pch=l)
# # }


zhenkewu/doubletree documentation built on Oct. 21, 2023, 7:04 a.m.