## SKG
## June 22, 2021
## JCGS revision to reproduce code for most likely trees
#################################################
## A couple helper functions
#' Assign x-coordinates based on generation and index for a single cluster
#'
#' @param gen generation number
#' @return generation number
assign_coords_x <- function(gen){
return(gen)
}
#' Assign x-coordinates based on generation and index for a single cluster
#'
#' @param index index ONLY for those in a single generation
#' @param ymin default is -1
#' @param ymax default is 1
#' @return index space, evenly spaced by number in generation
assign_coords_y <- function(index,
ymin = -1, ymax = 1){
gen_ranks <- rank(index)
n_in_gen <- length(gen_ranks)
if(n_in_gen == 1){
return(0)
}
y <- -1 + (gen_ranks - 1) /(.5 * (length(gen_ranks)-1))
return(y)
}
###########################################################
rep_from_lib <- TRUE
if(rep_from_lib){
library(InfectionTrees)
} else {
devtools::load_all()
}
library(ggplot2)
library(data.table)
library(dplyr)
library(stringr)
## Subset data to cluster 27, it's cool
tb_ex <- tb_clean %>%
filter(group == 27) %>%
mutate(rel_time = as.numeric(rel_time) / 365) %>%
arrange(rel_time) %>%
mutate(order = 1:n()) %>%
dplyr::mutate(smear = ifelse(spsmear == "Positive",
1, 0),
cluster_id = group,
hiv_f = ifelse(hivstatus == "Negative", "neg",
ifelse(hivstatus == "Positive", "pos",
"unk"))) %>%
dplyr::mutate(hiv_neg_pos = ifelse(hiv_f == "neg", 1, 0),
hiv_unk_pos = ifelse(hiv_f == "unk", 1, 0)) %>%
ungroup()
## Load in best results from base model
fns <- list.files()
fns_rds <- grep(".RDS", fns, value = TRUE)
base_model_results <- readRDS(grep("base", fns_rds, value = TRUE))
inf_params <- base_model_results$beta_list[[4]][,1]
## Sample a bunch of MC trees
B <- 100000
set.seed(622021)
covariate_names <- c("smear",
"hiv_neg_pos", "hiv_unk_pos",
"rel_time")
mc_trees <- sample_mc_trees(tb_ex,
B = B,
multiple_outside_transmissions = FALSE,
covariate_names = c(covariate_names,
"order"))
## get the covariate matrix
cov_mat <- covariate_df_to_mat(mc_trees,
cov_names = covariate_names)
mc_trees$prob_inf <- 1 / (1 + exp(-(cov_mat %*% inf_params)))
cluster_id <- n_inf <- orig_id <- prob_inf <- NULL
my_prob_inf <- 1 / (1 + exp(-(cov_mat %*% inf_params)))
## Turning mc_trees to a data table
mc_trees.dt <- data.table::as.data.table(mc_trees)
mc_trees.dt <- mc_trees.dt[, prob_inf := my_prob_inf]
cluster_id <- NULL
## Getting the likelihood for each sampled cluster
like_df <- mc_trees.dt[,
.(like = general_cluster_like.dt(prob_inf, n_inf)),
by = .(orig_id, cluster_id)]
mc_trees_like <- left_join(mc_trees, like_df,
by = c(orig_id, cluster_id)) %>%
mutate(prob = like / sum(like))
df <- mc_trees_like %>% group_by(cluster_id) %>%
mutate(x = assign_coords_x(gen)) %>%
group_by(cluster_id, gen) %>%
mutate(y = assign_coords_y(n_in_gen)) %>%
ungroup() %>%
group_by(cluster_id) %>%
mutate(gen_size = max(gen)) %>%
arrange(desc(prob))
top_groups_by_gen <- df %>%
ungroup() %>%
group_by(gen_size, cluster_id) %>%
summarize(prob = min(prob),
.groups = "drop_last") %>%
slice_max(order_by = prob, n = 3, with_ties = FALSE)
top_clusts <- top_groups_by_gen$cluster_id
df_sub <- df %>% filter(cluster_id %in% top_clusts)
df_sub <- df_sub %>%
mutate(facet = paste(gen_size, cluster_id, sep = "-"))
inf_df <- df_sub %>%
select(id, cluster_id, x, y) %>%
rename(x_to = x, y_to = y,
inf_id = id)
jsub <- left_join(df_sub, inf_df, by = c("cluster_id", "inf_id"))
fctr_sub <- df_sub %>% group_by(cluster_id) %>%
summarize(like = prob[1])
jsub$factor_id <- factor(jsub$cluster_id,
labels = paste0("Sample ",
fctr_sub$cluster_id,
": P(C) = ",
formatC(fctr_sub$like,
format = "e",
digits = 2)))
jsub$factor_gen <- factor(jsub$gen_size,
levels = 1:max(jsub$gen_size),
labels = paste0("Gen. ", 1:max(jsub$gen_size)))
fctr_sub2 <- df_sub %>% group_by(facet, gen_size, cluster_id) %>%
summarize(like = prob[1]) %>%
mutate(id = stringr::str_sub(cluster_id, start = -3))
jsub$facet_lab <- factor(jsub$facet,
levels = unique(fctr_sub2$facet),
labels = paste0("# Gens: ",
fctr_sub2$gen_size,
"; ID: ",
fctr_sub2$id,
"; P(C) = ",
formatC(fctr_sub2$like,
format = "e",
digits = 2)))
my_orig_id <- 27
ggplot(data = jsub,
aes(x = x, y = y,
group = cluster_id,
col = factor(order),
shape = factor(hiv_neg_pos)
)) +
geom_curve(aes(xend = x_to, yend = y_to),
size = 2, curvature = -0,
col = "black") +
geom_point(size = 6, stroke = 4, col = "darkgray") +
geom_point(size = 5, stroke = 4) +
facet_wrap(~facet_lab, ncol = 3) +
scale_color_brewer(palette = "Set1", name = "Detection Order") +
scale_shape_manual(values = c(3, 16),
labels = c("Pos./Unk.", "Neg."),
name = "HIV Status") +
xlim(.8, 7.2) +
ylim(-1.2, 1.2) +
theme_bw(base_size = 18) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()) +
labs(x = "Generation in transmission tree",
y = latex2exp::TeX("Order of infection in gen. among 'siblings' $\\rightarrow$"),
title = "Most likely sampled trees by number of generations",
subtitle = paste0("Cluster ", my_orig_id)) +
theme(legend.position = "bottom", legend.box = "horizontal") +
guides(color = guide_legend(nrow = 2, byrow = TRUE))
ggsave("trees-7.pdf", height = 15, width = 13)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.