R/wmmTree.R

Defines functions wmmTree

Documented in wmmTree

#' @title wmmTree
#' @description Main function. Generate weighted estimates using the weighted
#'  multiplier method.
#' @param tree A makeTree object
#' @param sample_length An integer for number of samples
#' @param method Method specifying weighting. Only default compatible with 'mmEstimate'
#'  at this time
#' @param int.type A string specifying interval type. Default "quants" generates
#'  the interval using the quantiles giving the central 95\% of the samples.
#'  Alternatively, "var" can be used to generate a variance-weighted confidence
#'  interval, and "cox" generates a Cox interval.
#' @param single.source Set to TRUE if all data comes from single, fully informed
#'  source.  Default is FALSE.
#' @return Returns a makeTree object with branches and nodes now associated with
#'  estimates and samples generated with the weighted multiplier method
#' @examples \donttest{
#'   data(treeData1)
#'   tree <- makeTree(treeData1)
#'   Zhats <- wmmTree(tree, sample_length = 3)
#'
#'   message("Another example with a larger tree")
#'   message("note - longer run time example")
#'   data(treeData2)
#'   tree2 <- makeTree(treeData2)
#'   Zhats <- wmmTree(tree2, sample_length = 3)
#'   Zhats$estimates # print the estimates of the root node generated by the 15 iterations
#'   Zhats$weights # prints the weights of each branch
#'   Zhats$root # prints the final estimate of the root node by WMM
#'   Zhats$uncertainty # prints the final rounded estimate of the root with conf. int.
#'
#'   message(paste("show the average root estimate with 95% confidence interval,",
#'           "as well as average estimates with confidence interval for each parameter"))
#'   tree2$Get('uncertainty')
#'
#'   message("show the samples generated from each path which provides root estimates")
#'   tree2$Get('targetEst_samples')
#'
#'   message("show the probabilities sampled at each branch leading into the given node")
#'   tree2$Get('probability_samples')
#' }
#' @export
#' @import data.tree
#' @importFrom magrittr %>%
#' @importFrom tidyselect all_of
#' @importFrom rlang "is_empty"
#' @importFrom gtools "rdirichlet"

wmmTree <- function(tree, sample_length = 10, method ='mmEstimate',
                    int.type ='quants', single.source = FALSE){
  # choose which method to use - currently supports mmEstimate
  methodFunction <- NULL
  if(method =='mmEstimate'){
    message('using variance-weighted mean with multiplier method sampled path estimates')
    methodFunction <- mmEstimate
  }else{
    stop(paste(method,'is not a known method'))
  }

  tree$Set(targetEst_samples=numeric())
  tree$Set(probability_samples=numeric())
  tree$Set(imp_weights=numeric())

  if(single.source){
    # If all informative paths are obtained using a single, fully informative
    # source for all sibling data, closed form calculation can be used
    message('using closed-form expressions to generate estimates - be sure tree satisfies required assumptions')

    # Use parameters for each Dirichlet/Beta distributions from data table
    methodFunction <- ssEstimate

    # calculate target estimates from each leaf based on method above
    tree$Do(methodFunction, traversal = "post-order")

    # add targetEst to target estimates sample list
    tree$Do(function(node){
      node$targetEst_samples <- c(node$targetEst_samples,
                                  node$targetEst)
    })

    # set targetEst_samples to numeric() if the entire vector is NA
    # use for drawing functions
    tree$Do(function(node){
      if(all(is.na(node$targetEst_samples))){
        node$targetEst_samples <- numeric()
      }
    })

    # generate outputs
    # extract mean values for each path (column) (only leaves are required for
    # because that is where the estimates are stored)
    means <- tree$Get('targetEst', filterFun = function(node) node$isLeaf,
                      traversal = 'post-order')
    vars <- tree$Get('variance', filterFun = function(node) node$isLeaf,
                     traversal = 'post-order')

    ## incase the above is not in matrix form...
    if(is.list(means)){
      mat.means <- NULL
      for (i in 1:length(means)) {
        if (length(means[[i]]) > 0) {
          mat.means <- cbind(mat.means, means[[i]])
          colnames(mat.means)[dim(mat.means)[2]] <- names(means)[[i]]
        }
      }
      means <- mat.means
    }
    if(is.list(vars)){
      mat.vars <- NULL
      for (i in 1:length(vars)) {
        if (length(vars[[i]]) > 0) {
          mat.vars <- cbind(mat.vars, vars[[i]])
          colnames(mat.vars)[dim(mat.vars)[2]] <- names(vars)[[i]]
        }
      }
      vars <- mat.vars
    }

    # set as data frame
    means <- as.data.frame(means)
    vars <- as.data.frame(vars)

    ## select only leaves with marginal counts
    getleaves <- which(tree$Get('TerminalCount', filterFun = isLeaf,
                                traversal = 'post-order'))

    # if number of columns of x is greater than number of leaves, choose leaves only
    if(dim(means)[2]>length(getleaves)){
      means <- means %>%
        select(all_of(getleaves))
    }
    if(dim(vars)[2]>length(getleaves)){
      vars <- vars %>%
        select(all_of(getleaves))
    }

    # get mean estimate
    prec <- 1/vars
    w <- prec/sum(prec)
    rootEst <- as.matrix(means) %*% t(w)
    rootVar <- as.matrix(vars) %*% t(w^2)

    # final values of the root are retained in 'estimate' and 'variance' at root
    tree$Do(function(node){
      if(isRoot(node)){node$Estimate <- round(rootEst, 2)}
    })
    tree$Do(function(node){
      if(isRoot(node)){node$variance <- rootVar}
    })

    m <- log(tree$Get('Estimate', filterFun = isRoot))

    # add 95% confidence intervals
    tree$Do(function(node) {
      node$uncertainty <- ss.confInts(node)
    })

  }else{
    for(m in 1:sample_length){
      # create sample probabilities among sibling bunches, so that sampled
      # branching probabilites add to 1 within a sibling group
      tree$Do(function(node){
        if(!is.null(node$children)){
          # start with sibling group of branches
          siblist <- node$children
          k <- length(siblist)
          est.vec <- numeric(k) # create vector to save parameter values for dirichlet
          # or accepted beta draws
          total.vec <- numeric(k) # create vector to save 'Total' values for siblings

          # Get 'Estimate' and 'Total' from each sibling
          for(i in 1:k){
            childNode <- siblist[[i]]
            if(is.null(childNode$Estimate) | is.null(childNode$Total)){
              est.vec[i] <- NA
              total.vec[i] <- NA
            }else{
              est.vec[i] <- childNode$Estimate
              total.vec[i] <- childNode$Total
            }
          }

          # remove siblings from sib_list with NA 'estimate' (won't be sampled)
          # remove same entries from est.vec, total.vec
          na_index <- which(is.na(est.vec))
          est_childs <- which(!is.na(est.vec))
          sub.siblist <- siblist

          if(!rlang::is_empty(na_index)){
            sub.siblist <- siblist[-na_index]
            est.vec <- est.vec[-na_index]
            total.vec <- total.vec[-na_index]
          }

          # create sample_vec where samples will go
          sample.vec <- numeric(length(sub.siblist))
          names(sample.vec) <- names(sub.siblist)

          # FIRST CASE: IF any sibling 'estimate' in the group is NA,
          # then at least one branch is uninformed
          if(length(sub.siblist) < k){
            # while sib_list is non-empty:
            while(!rlang::is_empty(sub.siblist)){
              # start with first informed branch, current.branch<-sub.siblist[1]
              # find other branches informed by same study
              branch.samps <- which(total.vec == total.vec[1])

              # IF 'total' is the same for any other informed branches
              # && sum('estimate')<'total' for those branches
              if(length(branch.samps)>1 && sum(est.vec[branch.samps])<total.vec[1]){
                dir.params <- c(est.vec[branch.samps] + 1,
                                total.vec[1] - sum(est.vec[branch.samps]) + 1)
                probs <- rdirichlet(1, dir.params)

                # discard last probability (represents complement branches)
                probs <- probs[1:length(probs)-1]
              }

              else{
                probs <- rbeta(1, est.vec[1] + 1, total.vec[1] - est.vec[1] + 1)
              }

              # Now decide whether to accept or reject probs
              # IF sum(sample.vec) and new sampled probs < 1:
              if(sum(sample.vec) + sum(probs) < 1){
                # take branch number(s) off lists
                est.vec <- est.vec[-branch.samps]
                total.vec <- total.vec[-branch.samps]

                # set sample.vec equal to samples at the correct branch
                # positions
                sampled <- names(sub.siblist[branch.samps])
                sample.vec[names(sample.vec) %in% sampled] <- probs
                sub.siblist <- sub.siblist[-branch.samps]
              }
              # else: reject samples.  branches stay on sub.siblist and loop
              # re-runs to try sampling these branches again
            }

            # once we cycle through sampling for each informed sibling branch,
            # set 'probability' to be equal to the sample.vec
            t <- 1
            for(i in est_childs){
              node$children[[i]]$probability <- sample.vec[t]
              t <- t+1
            }
          }

          # SECOND CASE: ELSE, sub.siblist==siblist, so all branches are
          # informed. In addition to rejection sampling, we also use
          # importance sampling in this case
          else{
            # start with first branch (special case if Dirichlet)
            # find other branches informed by same study
            branch.samps <- which(total.vec == total.vec[1])

            # IF 'total' is the same for all branches &&
            # sum('estimate')='total' for those branches assume one
            # study informs all branches. use a
            # Dir(Estimate_sibling1 + 1, Estimate_sibling2 + 1, ...)
            # to sample those branches, save as sample.vec
            if(length(branch.samps)==length(sub.siblist) &&
               sum(est.vec[branch.samps])==total.vec[1]){
              dir.params <- c(est.vec[branch.samps] + 1)
              probs <- rdirichlet(1, dir.params)
              sample.vec[branch.samps] <- probs

              # set 'probability' value in tree
              for(i in 1:k){
                node$children[[i]]$probability <- sample.vec[i]
              }
            }

            # ELSE, a single study does not inform all branches
            # so while there is at least one branch left on sub.siblist
            # we sample accordingly (by Dirichlet if same 'total' among
            # some branches, with sum('estimate')< 'total' for those
            # branches, or by Beta if not).
            else{
              # must create a process here where sampling occurs many times over
              # to initiate importance sampling.
              # set new names for sample.vec, total.vec, est.vec so they can be reset
              num.samp <- 100
              samp.siblist <- sub.siblist
              group.samples <- matrix(nrow = num.samp, ncol = length(sub.siblist))

              sample.vec.imp <- sample.vec
              est.vec.imp <- est.vec
              total.vec.imp <- total.vec

              # first generate num.samp samples of sibling group probs from the
              # 'wrong' distribution (sample each branch, deterministically
              # set the last branch of the group)
              for(w in 1:num.samp){
                while(length(samp.siblist)>1){
                  # start with first informed branch, current.branch<-sub.siblist[1]
                  # find other branches informed by same study

                  # IF 'total' is the same for any other informed branches
                  # && sum('estimate')<1 for those branches
                  if(length(branch.samps)>1 && sum(est.vec[branch.samps])<total.vec.imp[1]){
                    dir.params <- c(est.vec.imp[branch.samps] + 1,
                                    total.vec.imp[1] - sum(est.vec.imp[branch.samps]) + 1)
                    probs <- rdirichlet(1, dir.params)

                    # discard last probability (represents complement branches)
                    probs <- probs[1:length(probs)-1]

                    # if length of probs is the same as the number of branches
                    # remaining to sample, take last branch prob off as it will
                    # be set deterministically
                    if(length(probs)==length(samp.siblist)){
                      probs <- probs[1:length(probs)-1]
                      branch.samps <- branch.samps[1:length(branch.samps)-1]
                    }
                  }

                  # ELSE, there are no other branches with same 'total'
                  # or doesn't sum in way that suggest same study, and we
                  # sample branch as Beta(Estimate + 1, Total - Estimate + 1)
                  else{
                    branch.samps <- branch.samps[1]
                    probs <- rbeta(1, est.vec.imp[1] + 1,
                                   total.vec.imp[1] - est.vec.imp[1] + 1)
                  }

                  # Now decide whether to accept or reject probs.  IF sum
                  # of existing sample.vec and sum of probs sampled < 1, accept
                  # samples and update est.vec, total.vec, samp.siblist so indices
                  # continue to match
                  if(sum(sample.vec.imp) + sum(probs) < 1){
                    # take branch number(s) off lists
                    est.vec.imp <- est.vec.imp[-branch.samps]
                    total.vec.imp <- total.vec.imp[-branch.samps]

                    # set sample.vec equal to samples at the correct branch
                    # positions
                    sampled <- names(samp.siblist[branch.samps])
                    sample.vec.imp[names(sample.vec.imp) %in% sampled] <- probs
                    samp.siblist <- samp.siblist[-branch.samps]
                  }
                  # ELSE, we reject the samples and loop through again
                }

                # now there should be one branch left, to be set deterministically
                # set last branch on list equal to 1-sum(sample.vec) in sample.vec
                sample.vec.imp[length(sample.vec.imp)] <- 1-sum(sample.vec.imp)

                # save samples by row in group.samples matrix
                group.samples[w,] <- sample.vec.imp

                # reset sibling list for next iterations in for loop
                samp.siblist <- sub.siblist

                # reset sample.vec.imp, est.vec.imp, total.vec.imp
                sample.vec.imp <- sample.vec
                est.vec.imp <- est.vec
                total.vec.imp <- total.vec

              }# end of sub-sampling for importance scheme

              # calculate the importance weights associated with each set
              p.last <- group.samples[,dim(group.samples)[2]]
              imp.weights <- ((p.last)^(est.vec[1])) * (1-p.last)^(total.vec[1]-est.vec[1])

              #browser()
              #log.imp.weights1 <- (est.vec[1])*log(p.last)
              #log.imp.weights2 <- (total.vec[1]-est.vec[1])*log(1-p.last)
              #logsum.imp.weights <- logSumExp(((p.last)^(est.vec[1])) * (1-p.last)^(total.vec[1]-est.vec[1]))
              #log.imp.weights <- log.imp.weights - logSumExp(((p.last)^(est.vec[1])) * (1-p.last)^(total.vec[1]-est.vec[1]))
              #imp.weights <- exp(log.imp.weights)

              # if sum of imp.weights is 0, computation error - use logged values
              if (sum(imp.weights)==0){
                log.imp.weights <- (est.vec[1])*log(p.last) + (total.vec[1]-est.vec[1])*log(1-p.last)
                log.imp.weights <- log.imp.weights - min(log.imp.weights)
                imp.weights <- exp(log.imp.weights)
              }
              imp.weights <- imp.weights/(sum(imp.weights))

              # if any imp.weights are NA, compute error resulting from large variance in p.last samples and extreme
              # values.  set imp.weights uniformly
              if (any(is.nan(imp.weights))){
                imp.weights <- rep(1/length(imp.weights), length(imp.weights))
              }

              # then create sampling scheme where we choose randomly from each of
              # the samples we generated, weighted by the normalized importance weights
              # use that sample as the 'right' one for this pass, and move on as usual
              sample.vec <- group.samples[sample(x = 1:length(imp.weights),
                                                 size = 1, prob = imp.weights),]

              # set 'probability' value within tree to be equal to the chosen sample
              # from importance weighting scheme
              for(i in 1:k){
                node$children[[i]]$probability <- sample.vec[i]
                node$children[[i]]$impweight <- imp.weights
              }
            }
          } # end of case 2
        } # end of process for one node with children
      }) # end of node and treeDo function

      tree$Do(function(node){
        if(node$isRoot | is.null(node$Estimate)){
          node$probability <- NA
        }
      })

      # calculate target estimates from each leaf based on method above
      tree$Do(methodFunction, traversal = "post-order")

      # add targetEst to target estimates sample list
      tree$Do(function(node){
        node$targetEst_samples <- c(node$targetEst_samples,
                                    node$targetEst)
        node$probability_samples <- c(node$probability_samples,
                                      node$probability)
        node$imp_weights <- c(node$imp_weights,
                              node$impweight)
      })
      m <- m + 1
    } # end of for loop up to sample_length

    # Importance weight is the same within a sibling group, on each pass of wmm
    # Calculate normalized imp_weights for each group of siblings, per pass
    tree$Do(function(node){
      node$imp_weights <- (node$imp_weights)/(sum(node$imp_weights))
    })

    # set targetEst_samples to numeric() if the entire vector is NA
    # use for drawing functions
    tree$Do(function(node){
      if(all(is.na(node$targetEst_samples))){
        node$targetEst_samples <- numeric()
      }
    })

    # after method is applied to each of the leaf node, targetEst_samples of each
    # node are used to calculate weights
    w <- ko.weights(tree)

    # weights are multiplied by mean estimates of each path to calculate a
    # final estimate of the root
    m <- logEstimates(tree)

    # final estimate of the root is retained in a value called 'estimate' at root
    tree$Do(function(node){
      if(isRoot(node)){node$Estimate <- round(exp(mean(m)), 2)}
    })

    # add 95% confidence intervals
    tree$Do(function(node) {
      node$uncertainty <- if(isRoot(node)){root.confInt(tree, int.type)}else{confInts(node$targetEst_samples)}
    })
  }
  # Finally, output the full list of Nhat estimates
  output <- list("root" = tree$Get("Estimate", filterFun = isRoot),
                 "uncertainty" = tree$Get("uncertainty", filterFun = isRoot),
                 "estimates" = round(exp(m),2),
                 "weights" = w)
  return(output)
}

Try the AutoWMM package in your browser

Any scripts or data that you put into this service are public.

AutoWMM documentation built on June 8, 2025, 11:10 a.m.