R/sample-trees-fast.R

Defines functions sample_connections sample_tree_perm get_in0x0 sample_in0x0 sample_unique_perms sample_uniform_trees_nx sample_uniform_trees

Documented in sample_connections sample_tree_perm sample_uniform_trees sample_uniform_trees_nx sample_unique_perms

## SKG
## March 9, 2020
## faster/less memory tree sampling




#' Sample uniform trees
#'
#' @param n_vec total sizes
#' @param x_vec total numbers of smear pos
#' @param B number of trees to sample
#' @param use_one_index Do we sample assuming 1 root node?  Default is TRUE
#' @return data frame with the following columns
#' \describe{
#' \item{n}{total size}
#' \item{n_pos}{numer of smear pos}
#' \item{i_pos}{number of i-positive transmissions}
#' \item{i_neg}{number of i-negative transmissions}
#' \item{freq}{frequency of (n,x,i)}
#' }
#' @export
#' @details assumes one root node
sample_uniform_trees <- function(n_vec, x_vec, B,
                                 use_one_index = TRUE){
  if(use_one_index){
    tree_sampler <- sample_uniform_trees_nx
  } else{
    tree_sampler <- sample_uniform_trees_nxo
  }
  K <- length(n_vec)
  tree_list <- vector(mode = "list", length = K)
  for(ii in 1:K){
    tree_list[[ii]] <- tree_sampler(n = n_vec[ii],
                                               x = x_vec[ii],
                                    B = B,
                                    summarize_generator = FALSE)
  }
  trees <- dplyr::bind_rows(tree_list) %>%
      dplyr::filter(.data$freq != 0)
  return(trees)
}



#' Sample uniform trees
#'
#' @param n total size
#' @param x total number of smear pos
#' @param B number of trees to sample
#' @param summarize_generator Do we have summarize the generator infections?  Default is FALSE
#' @return data frame with the following columns
#' \describe{
#' \item{n}{total size}
#' \item{n_pos}{numer of smear pos}
#' \item{i_pos}{number of i-positive transmissions}
#' \item{i_neg}{number of i-negative transmissions}
#' \item{n0}{number of infections by generator (if summarize_generator is TRUE)}
#' \item{x0}{smear status of generator (if summarize_generator is TRUE)}
#' \item{freq}{frequency of (n,x,i)}
#' }
#' @export
#' @details assumes one root node
sample_uniform_trees_nx <- function(n, x, B,
                                    summarize_generator = FALSE){

    ## If size 1, have trivial samples
    if(n == 1){
        df <- data.frame(n = n,
                         n_pos = x,
                         i_pos = 0,
                         i_neg = 0,
                         n0 = 0,
                         x0 = ifelse(x == 1, 1, 0),
                         freq = B,
                         stringsAsFactors = FALSE)
        return(df)
    }

    ## Sample generation sizes
                                        # only have 1 generation if there is only one person, otherwise there are at least 2
    g_vec <- 2 + rbinom(n = B, size = n-2, prob = .5) # number of generations where first generation is fixe

    g_tab <- table(g_vec)
    unique_g <- sort(unique(g_vec))
    i_list <- vector(mode = "list", length = length(unique_g))

    for(ii in 1:length(unique_g)){
        ## Sample unique permutations of generation sizes given g
        g <- unique_g[ii]
        perm_mat <- sample_unique_perms(g = g, n = n,
                                        B = as.numeric(g_tab[ii]))
        ## TODO: Fill in sampleD_in0x0
        ## Summarize i_pos/i_neg/n0/x0
        i_list[[ii]] <- sample_in0x0(perm_mat, x,
                                     summarize_generator = summarize_generator)  # generate whole tree

    }

    i_mat <- do.call('rbind', i_list)
                                        #  browser()
    if(!summarize_generator){
        df <- data.frame(table(i_mat[,1], i_mat[,2]))
        colnames(df) <- c("i_pos", "i_neg", "freq")
    } else {
        df <- data.frame(table(i_mat[,1], i_mat[,2], i_mat[,3], i_mat[,4]))
        colnames(df) <- c("i_pos", "i_neg", "x0", "n0", "freq")
    }
    df$n <- n
    df$x <- x
    df$i_pos <- as.numeric(as.character(df$i_pos))
    df$i_neg <- as.numeric(as.character(df$i_neg))
    if(!summarize_generator){
        df <- df %>%  dplyr::rename(n_pos = "x") %>%
        dplyr::select(.data$n, .data$n_pos, 
                      .data$i_pos, .data$i_neg,
                      .data$freq) %>%
        dplyr::filter(.data$freq > 0)
    } else {
        df$n0 <- as.numeric(as.character(df$n0))
        df$x0 <- as.numeric(as.character(df$x0))
        df <- df %>% dplyr::rename(n_pos = "x") %>%
            dplyr::select(.data$n, .data$n_pos, 
                          .data$i_pos, .data$i_neg, 
                          .data$n0, .data$x0,
                          .data$freq) %>%
            dplyr::filter(.data$freq > 0)
    }

    return(df)


}


#' Sample unique permutations from a partition space
#'
#' @param g number of generations
#' @param n total size of cluster
#' @param B number of permutations to draw
#' @return matrix of size g x B and the first row is all 1.  Each column is a unique permutation that sums to n and all are positive entries
sample_unique_perms <- function(g, n, B){

  if(n == 1) {
    return(matrix(1, ncol = B, nrow = 1))
  } else if(g == n){
    return(matrix(1, ncol = B, nrow = g))
  }

  parts <- partitions::restrictedparts(n = n-1, m = g-1,
                                       include.zero = FALSE)
  wts <- apply(parts, 2, function(x){
    if(length(unique(x)) == 1) return(1)
    RcppAlgos::permuteCount(sort(unique(x)), freqs = table(x))
  })
  length(wts) == ncol(parts)
  part_inds <- sample(1:ncol(parts), size = B,
                      replace = TRUE, prob = wts / sum(wts))
  drawn_parts <- parts[, part_inds, drop = FALSE]
  drawn_perms <- plyr::alply(.data = drawn_parts,
                             .margins = 2,
                             .fun = function(x){
                               if(length(sort(unique(x))) == 1){
                                 out <- matrix(rep(x[1], g-1), ncol = 1, nrow = g-1)
                               } else{
                                  out <- matrix(RcppAlgos::permuteSample(v = sort(unique(x)),
                                                                  freqs = table(x), n = 1),
                                                ncol = 1)

                               }
                               return(out)
                             })
  drawn_perms <- do.call('cbind', drawn_perms)
  drawn_perms <- rbind(rep(1, ncol(drawn_perms)), drawn_perms)
  return(drawn_perms)

}


#' Sample and summarize a tree for given generation sizes
#'
#' @param perm_mat g x B matrix where each column is a unique permutation of generation sizes
#' @param x number of smear positives
#' @param summarize_generator logical.  Do we summarize the generator with its smear status and number of infected as well?  Default is FALSE.
#' @return matrix of four columns of dimension  Bx4
#' where the first column is the number of i positive transmissions in the cluster and the second is i negative
#' third column is n0 number of infection by generator and x0 is smear status of generator
sample_in0x0 <- function(perm_mat, x,
                         summarize_generator = FALSE){

    i_mat <- t(apply(perm_mat, 2, function(perm){
          tree <- sample_tree_perm(perm, x)
          i <- get_in0x0(tree,
                         summarize_generator)
          return(i)
  }))
  if(class(i_mat) != "matrix") browser()

  return(i_mat)

}

#' Count the number of smear positive transmissions
#'
#' @param df data frame corresponding to a single tree with columns
#' \describe{
#' \item{id}{gen-id_num}
#' \item{inf_id}{infector ID or NA}
#' \item{smear}{0/1 for -/+ smear}
#' }
#' @param summarize_generator logical.  Do we summarize the generator with its smear status and number of infected as well?  Default is FALSE.
#' @return number of i_positive smears
get_in0x0 <- function(df, summarize_generator = FALSE){
  inf_id_inds <- sapply(df$inf_id, function(id){
    if(is.na(id)) return(NA)
    which(df$id == id)
  })
#  if(class(inf_id_inds) == "list") browser()
  i_pos <- sum(df$smear[inf_id_inds], na.rm = TRUE)
  i_neg <- sum(df$smear[inf_id_inds] == 0, na.rm = TRUE)
  if(!summarize_generator){
      return(c(i_pos, i_neg))
  }
  generator_ind <- which.min(df$gen)
  if(df$gen[generator_ind] != 1) browser()
  x0 <- df$smear[generator_ind]
  n0 <- sum(df$inf_id == df$id[generator_ind], na.rm = TRUE)
  i <- c(i_pos, i_neg, n0, x0)
  return(i)
  ## TODO FIX ALL THE TESTS
  ## THEN MAKE A NEW LIKELIHOOD SAMPLER
}


#' Sample the actual tree for given permutation
#'
#' @param gen_sizes vector of generation sizes
#' @param x number of smear positives
#' @return sampled tree given the generation sizes
sample_tree_perm <- function(gen_sizes, x){
  tree <- sample_connections(gen_sizes)
  n <- sum(gen_sizes)
  tree$smear <- sample(c(rep(0, n-x), rep(1, x)))
  return(tree)
}


#' Sample connections and form a tree
#'
#' @param gen_sizes vector of generation sizes
#' @return data frame with the following columns
#' \describe{
#' \item{gen}
#' \item{n_in_gen}
#' \item{inf_id}
#' \item{id}
#' }
sample_connections <- function(gen_sizes){

  gen <- as.numeric(unlist(sapply(1:length(gen_sizes), function(ii) rep(ii, gen_sizes[ii]))))
  n_in_gen <- as.numeric(unlist(sapply(gen_sizes, function(x) 1:x)))
  inf_id <- sapply(1:length(gen), function(ii){
    cur_gen <- gen[ii]
    if(cur_gen == 1) return(NA)
    id <- sample(1:gen_sizes[cur_gen-1], size = 1)
    return(paste0((cur_gen-1), "-",id))
  })
  df <- data.frame(gen = gen, n_in_gen = n_in_gen, inf_id = inf_id,
                   stringsAsFactors = FALSE)
  df$id <- paste0(df$gen, "-", df$n_in_gen)
#  if("gen.1" %in% colnames(df)) browser()
  return(df)


}
skgallagher/TBornotTB documentation built on April 21, 2020, 1:19 p.m.