#' @title Fit a negative binomial
#'
#' @description Fit a negative binomial distribution to a set of counts
#'
#' @param input_dat a data frame with one column called "counts"
#' @param nb_init optional vector of mean, dispersion initialization point
fit_nb = function(input_dat,
nb_init = c(10, 1)) {
input_vec = input_dat$counts # feeding it a data frame with one column of counts
fn_to_min = function(param_vec){
# param_vec[1] nb mean
# param_vec[2] nb size
-sum(stats::dnbinom(input_vec,
mu = param_vec[1],
size = param_vec[2],
log = TRUE))
}
stats::nlminb(start = nb_init,
objective = fn_to_min,
lower = rep(.Machine$double.xmin, 2))
}
fit_gamma = function(input_vec,
weights = NULL,
gamma_init = c(1, 1)) {
# if fails to fit, stop() with error message suggesting different inits
if(missing(weights)){
fn_to_min = function(ab_vec){
-sum(dgamma(input_vec,
shape = ab_vec[1],
rate = ab_vec[2],
log = TRUE))
}
} else {
fn_to_min = function(ab_vec){
-sum(weights*dgamma(input_vec,
shape = ab_vec[1],
rate = ab_vec[2],
log = TRUE))
}
}
stats::nlminb(start = gamma_init,
objective = fn_to_min,
lower = rep(.Machine$double.xmin, 2))
}
#' Fit a Bayesian MPRA model
#'
#' @description This function fits a negative-binomial based Bayesian model to
#' MPRA data. Optional annotations can be included to allow for more
#' informative conditional priors.
#'
#' @param mpra_data a data frame of MPRA data with 1 column called variant_id,
#' an allele column, a barcode column, and additional columns per sequencing
#' sample. Each row is for a single barcode.
#' @param annotations an optional data frame of annotations with identical
#' variant_ids and an arbitrary number of functional annotations. If omitted,
#' the prior for a given variant is influenced by all other variants in the
#' assay equally.
#' @param group_df an optional data frame giving group identity by variant_id in
#' mpra_data
#' @param priors optional objects provided by either fit_marg_prior() or
#' fit_cond_prior.
#' @param out_dir path to output directory
#' @param save_nonfunctional logical indicating whether or not to save the
#' sampler results for variants identified as non-functional
#' @param n_cores number of cores across which to parallelize variant MPRA
#' samplers
#' @param n_chains number of MCMC chains to run in each sampler
#' @param tot_samp total number of MCMC draws to take, spread evenly across
#' chains
#' @param n_warmup total number of warmup draws to take from each MCMC chain
#' @param vb_pass logical indicating whether to use a variational first pass
#' @param vb_prob numeric 0 - 1 indicating probability mass to use as a TS HDI
#' for identifying "promising" candidates for MCMC followup
#' @param ts_hdi_prob probability mass to include in the highest density
#' interval on transcription shift to call MPRA-functional variants.
#' @param ts_rope length 2 numeric vector describing the boundaries of the
#' transcription shift region of practical equivalence (ROPE), defaulting to
#' +/- log(3/2)
#' @param rep_cutoff a representation cutoff quantile (0 to 1)
#' @param adaptive_precision logical indicating whether to adaptively adjust the
#' length of the posterior MCMC chain for borderline functional variants
#' @param verbose logical indicating whether to print messages
#' @details \code{mpra_data} must contain the following groups of columns:
#' \itemize{ \item{variant_id} \item{allele - either 'ref' or 'alt'}
#' \item{barcode - a unique index sequence for that row (ideally the same
#' barcode used in the assay)} \item{at least one column of MPRA counts whose
#' column name(s) matches 'DNA'} \item{at least one column of MPRA counts
#' whose column name(s) matches 'RNA'} }
#'
#' \code{annotations} must contain the same variant_id's used in mpra_data.
#' Additional columns are used as informative predictors: when estimating the
#' priors for one variant, other variants with similar annotations will be
#' upweighted in the prior-fitting process.
#'
#' If \code{priors} is provided, any annotations input will be ignored. This
#' can be useful when you want to fit models again without having to spend
#' time re-fitting the priors.
#'
#' Sampler results will be saved to out_dir. By default, only the sampler
#' results for variants identified as MPRA-functional will be saved. This
#' behavior can be changed by setting \code{save_nonfunctional} to TRUE.
#'
#' We've set the sampler parameters (n_chains to n_warmup) to values that work
#' reasonably well at reasonable speeds for typical MPRA data on typical
#' hardware. Final analyses and/or models fit to larger MPRA experiments will
#' likely want to increase n_chains and tot_samp considerably to ensure
#' precise convergence.
#'
#' \code{vb_pass} indicates whether to use a first pass variational check to
#' see if a given variant is worth running the MCMC sampler. It does this by
#' checking if a 40% HDI on the variational transcription shift posterior
#' excludes 0. This speeds up posterior evaluation considerably, but gives
#' approximate results. If \code{vb_pass} is set to FALSE, all variants get
#' MCMC.
#'
#' \code{ts_rope} can be used to define a "Region Of Practical Equivalence"
#' for transcription shift. This is some small-ish region around 0 where
#' observed posterior samples are "practically equivalent" to 0. The output
#' column \code{ts_rope_mass} returns the fraction of transcription shift
#' posterior samples that fall within the defined ROPE along with the usual
#' model outputs. If this fraction is small, one can say that there is very
#' little posterior belief that the variant's transcription shift is
#' practically equivalent to 0. The user must be cognizant of defining the
#' region in accordance with observed noise and effect size levels. Note that
#' the output ROPE fractions are NOT p-values.
#'
#' Barcodes below the \code{rep_cutoff} quantile of representation in the DNA
#' pools are discarded.
#'
#' \code{adaptive_precision} indicates whether to adaptively increase the
#' length of the MCMC chains for borderline functional variants. The edges of
#' a 95% HDI have high variance with a small number of posterior samples. If
#' one edge of a variant's HDI interval is close to 0, this setting will tell
#' the sampler to double the length of the MCMC chain. Using the default 95%
#' HDI, this means that if a 92.5% HDI excludes 0, and a 97.5% HDI includes 0,
#' the sampler will double the length of the MCMC chain. This argument should
#' be turned off if tot_samp is already set to a high value.
#'
#' @return a data frame with a row for each variant_id that specifies the
#' posterior mean TS, upper and lower HDI bounds, a binary call of functional
#' or non-functional, and other appropriate outputs. The output column
#' \code{is_functional} is defined by the TS HDI excluding 0.
#' @note Sampler results for individual variants will be saved to the specified
#' out_dir as they can be several megabytes each. The table of summary
#' statistics that this function returns will also be saved into this
#' directory into an object called "analysis_res.RData".
#' @examples
#' # This example fits the malacoda model on 3 variants with too-short MCMC chains
#'
#' example_variants = c("1_205247315_2-3", "10_101274365_1-3", "10_45966422_2-3")
#'
#' examples_to_evaluate = umpra_example[umpra_example$variant_id %in% example_variants,]
#'
#' # tot_samp should be set to >50,000 to ensure the posterior chains converge
#' example_result = fit_mpra_model(mpra_data = examples_to_evaluate,
#' priors = marg_prior_example,
#' vb_pass = FALSE,
#' tot_samp = 20,
#' n_warmup = 10,
#' adaptive_precision = FALSE)
#'
#' print(example_result)
#' @export
fit_mpra_model = function(mpra_data,
annotations = NULL,
group_df = NULL,
out_dir = NULL,
save_nonfunctional = FALSE,
priors = NULL,
n_cores = 1,
n_chains = 4,
tot_samp = 2e3,
n_warmup = 200,
vb_pass = TRUE,
vb_prob = .8,
ts_hdi_prob = .95,
ts_rope = c(-.405, .405),
rep_cutoff = .15,
adaptive_precision = TRUE,
verbose = TRUE) {
start_time = Sys.time()
#### Input checks ----
if(missing(ts_rope)){
ts_rope = NULL
}
if (missing(mpra_data)) {
stop('mpra_data is missing: You must provide MPRA data to fit a MPRA model!')
}
if (is.null(out_dir) & verbose) {
message('out_dir is missing: Results will not be saved')
}
if (!is.null(out_dir) && !dir.exists(out_dir)) {
stop('specified out_dir does not exist')
}
if (!missing(annotations)) {
if (!all(mpra_data$variant_id %in% annotations$variant_id)) {
stop('Some mpra_data$variant_id\'s missing from annotations')
}
}
if (ts_hdi_prob < 0 | ts_hdi_prob > 1) {
stop('ts_hdi_prob must be between 0 and 1!')
}
if (vb_prob > ts_hdi_prob){
stop('vb_prob must be less than ts_hdi_prob')
}
if (!is.null(out_dir)) {
# make sure the out_directory ends in a slash, if not, add it
dir_ends_in_slash = grepl('/$', out_dir)
if (!dir_ends_in_slash){
out_dir = paste0(out_dir, '/')
}
}
correct_columns = all(grepl('variant_id|allele|barcode|[DR]NA', names(mpra_data)))
if (!correct_columns){
stop('mpra_data columns must be: variant_id, allele, barcode, and DNA/RNA columns.\ndplyr::rename(), dplyr::select(), malacoda::count_barcodes(), and the tidyr package might be helpful for preparing your input.')
}
# Check that there are 2 alleles for each variant
variant_allele_counts = mpra_data %>%
select(.data$variant_id, .data$allele) %>%
unique %>%
dplyr::count(.data$variant_id)
if (!all(variant_allele_counts$n == 2)){
stop('Non-biallelic variants detected. This may be due to an ill-formatted variant_id column. Note that the variant_id column should be the same for both alleles of a given variant.')
}
if (!any(tolower(mpra_data$allele) == 'ref')) {
stop('Cannot identify which alleles are reference or alternate. Map existing values onto "ref" or "alt" and try again. The function stringr::str_replace_all might help with this. Try str_replace_all(allele, c("A" = "ref", "B" = "alt")) where allele is the existing allele column and "A" and "B" are the current allele indicators.')
}
if (vb_pass & verbose) {
message('Using variantional approximation first pass. Set vb_pass = FALSE for publication quality analyses.')
}
if (tot_samp < 5e4 & verbose){
message('Using less than 50,000 MCMC samples is not recommended for publication quality analyses. Inspect convergence metrics in any case.')
}
#### Initial cleanup ----
mpra_data %<>%
arrange(.data$variant_id)
if (!missing(annotations)) {
annotations %<>%
arrange(.data$variant_id)
}
#### Establish priors ----
if (missing(priors)) { # No prior given --> Auto-priors
annotations_given = !is.null(annotations)
# Fit priors
if (verbose) {
message('No annotations provided, fitting marginal priors...')
}
if (!annotations_given) {
if (verbose) {
message('Fitting MARGINAL priors...')
}
priors = fit_marg_prior(mpra_data,
n_cores = n_cores,
rep_cutoff = rep_cutoff,
plot_rep_cutoff = TRUE,
verbose = verbose)
} else if (!is.null(group_df)) {
if (verbose) {
message('Fitting group-wise priors...')
}
priors = fit_grouped_prior(mpra_data,
group_df = group_df,
n_cores = n_cores,
plot_rep_cutoff = TRUE,
rep_cutoff = rep_cutoff,
verbose = verbose)
} else {
if (verbose) {
message('Fitting annotation-based conditional priors...')
message('Defaulting to min_neighbors = 50 and kernel_fold_increase = 1.3 . See ?fit_cond_prior for alternatives.')
}
priors = fit_cond_prior(mpra_data,
annotations,
n_cores = n_cores,
plot_rep_cutoff = TRUE,
rep_cutoff = rep_cutoff,
min_neighbors = 50,
kernel_fold_increase = 1.3,
verbose = verbose)
if (verbose) {
message('Conditional prior fitting done...')
}
if(!is.null(out_dir)){
save(priors,
file = paste0(out_dir, 'conditional_prior.RData'))
}
}
} else {
# Prior given --> use that
if (all(class(priors) == 'list')){
if (verbose) {
message('Input prior class is list, interpreting as conditional priors.')
}
annotations_given = TRUE
} else if ('group_prior' %in% names(priors)) {
if (verbose) {
message('Interpreting input prior as a grouped prior.')
}
annotations_given = FALSE
} else {
if (verbose) {
message('Input prior class is not list, interpreting as marginal priors.')
}
annotations_given = FALSE
}
}
#### Prepare to run samplers ----
n_rna = mpra_data %>% select(matches('RNA')) %>% ncol
n_dna = mpra_data %>% select(matches('DNA')) %>% ncol
sample_depths = get_sample_depths(mpra_data)
well_represented = get_well_represented(mpra_data,
sample_depths,
rep_cutoff = rep_cutoff,
plot_rep_cutoff = FALSE, # this will have been plotted in the prior fitting already if necessary
verbose = verbose)
# TODO, make the user aware that this step is happening
wr_counts = mpra_data %>%
filter(.data$barcode %in% well_represented$barcode) %>%
select(.data$variant_id, .data$allele, .data$barcode) %>%
group_by(.data$variant_id) %>%
mutate(n_alleles = n_distinct(.data$allele))
if (any(wr_counts$n_alleles != 2)) {
biallelic_wr = wr_counts %>%
filter(.data$n_alleles == 2) %>%
ungroup
message('Non-biallelic variants detected after filtering to well-represented barcodes. See mono-allelic MPRA model: malacoda/src/stan_files/monoallelic_model.stan')
} else {
biallelic_wr = wr_counts
}
well_represented = well_represented %>%
filter(.data$barcode %in% biallelic_wr$barcode)
arg_list = list(n_chains = n_chains, # To pass to run_mpra_sampler()
n_warmup = n_warmup,
tot_samp = tot_samp,
n_rna = n_rna,
n_dna = n_dna,
depth_factors = sample_depths,
out_dir = out_dir,
save_nonfunctional = save_nonfunctional,
ts_hdi_prob = ts_hdi_prob,
ts_rope = ts_rope,
vb_pass = vb_pass,
vb_prob = vb_prob,
adaptive_precision = adaptive_precision,
verbose = verbose)
#### Run samplers ----
if (verbose) {
message('Running model samplers...')
}
if (annotations_given) {
# attach the conditional priors in the form expected by run_mpra_sampler
sampler_input = mpra_data %>%
filter(.data$barcode %in% well_represented$barcode) %>%
group_by(.data$variant_id) %>%
nest() %>%
dplyr::rename('variant_data' = 'data') %>%
ungroup %>%
mutate(variant_prior = map(.data$variant_id,
format_conditional_prior,
cond_priors = priors))
analysis_res = sampler_input %>%
mutate(sampler_stats = parallel::mcmapply(run_mpra_sampler,
.data$variant_id, .data$variant_data, .data$variant_prior,
MoreArgs = arg_list,
mc.cores = n_cores,
mc.preschedule = FALSE,
SIMPLIFY = FALSE)) %>%
unnest(.data$sampler_stats) %>%
arrange(desc(abs(.data$ts_post_mean)))
} else if ('group_prior' %in% names(priors)) {
# This block uses a grouped prior.
analysis_res = mpra_data %>%
filter(.data$barcode %in% well_represented$barcode) %>%
group_by(.data$variant_id) %>%
nest() %>%
dplyr::rename('variant_data' = 'data') %>%
ungroup %>%
left_join(group_df, by = 'variant_id') %>%
left_join(priors, by = 'group_id') %>% # give the grouped_prior by variant_id
dplyr::rename('variant_prior' = 'group_prior') %>%
mutate(sampler_stats = parallel::mcmapply(run_mpra_sampler,
.data$variant_id, .data$variant_data, .data$variant_prior,
MoreArgs = arg_list,
mc.cores = n_cores,
mc.preschedule = FALSE,
SIMPLIFY = FALSE)) %>%
unnest(.data$sampler_stats) %>%
arrange(desc(abs(.data$ts_post_mean)))
} else {
# This block uses the marg priors
analysis_res = mpra_data %>%
filter(.data$barcode %in% well_represented$barcode) %>%
group_by(.data$variant_id) %>%
nest() %>%
dplyr::rename('variant_data' = 'data') %>%
ungroup %>%
mutate(variant_prior = list(priors)) %>% # give the same marg prior to every variant
mutate(sampler_stats = parallel::mcmapply(run_mpra_sampler,
.data$variant_id, .data$variant_data, .data$variant_prior,
MoreArgs = arg_list,
mc.cores = n_cores,
mc.preschedule = FALSE,
SIMPLIFY = FALSE)) %>%
unnest(.data$sampler_stats) %>%
arrange(desc(abs(.data$ts_post_mean)))
}
if(!is.null(out_dir)){
save(analysis_res,
file = paste0(out_dir, 'analysis_res.RData'))
}
end_time = Sys.time()
time_diff = end_time - start_time
if (verbose) {
message(paste0('MPRA data for ', n_distinct(mpra_data$variant_id), ' variants analyzed in ',
round(digits = 3, end_time - start_time), ' ', attr(time_diff, 'units')))
}
return(analysis_res)
}
#' Fit a Bayesian model of dropout CRISPR screen data
#'
#' @description This function fits a Bayesian model of survival/dropout CRISPR
#' screen data. It uses a negative binomial to model the input and output
#' counts of gRNAs, adjusting the results appropriately to account for
#' sequencing depth.
#' @param dropout_data a data frame of dropout data. See details for column
#' requirements.
#' @param n_cores number of cores to utilize
#' @inheritParams fit_mpra_model
#' @param plot_rep_cutoff logical indicating whether to plot the representation
#' cutoff histogram
#' @details \code{dropout_data} requires the following columns:
#' \itemize{\item{gene_id - character column giving a unique identifier for
#' each gene} \item{gRNA - character column giving identifiers for individual
#' gRNAs (usually the gRNA sequence itself)} \item{input count columns -
#' columns of sequencing counts of the input gRNA library. Multiple columns
#' for sequencing replicates are allowed (which require unique identifiers).
#' Column names must contain the string "input".} \item{output count columns -
#' columns of sequencing counts of gRNAs in the output libraries. Multiple
#' columns allowed (which in turn require unique names). Column name must
#' contain the string "output".}}
#' @note Currently this function only supports marginal priors. If you want to
#' use grouped/conditional priors, contact the malacoda developers.
#'
#' The \code{gene_data} column in the output contains only the gRNAs that
#' passed the representation cutoff.
#' @return a data frame of input counts, fit and model statistics for the
#' log-fold-change for each input gene.
#' @examples
#' # This example uses too-few MCMC samples for the sake of run time. Convergence will be poor.
#'
#' fit_dropout_model(dropout_data = dropout_example,
#' n_cores = 1,
#' tot_samp = 20,
#' n_warmup = 5)
#' @export
fit_dropout_model = function(dropout_data,
out_dir = NULL,
n_cores = 1,
tot_samp = 1e4,
n_warmup = 500,
n_chains = 4,
rep_cutoff = .1,
plot_rep_cutoff = TRUE,
verbose = TRUE) {
#### input checks ----
input_names = names(dropout_data)
if (!('gene_id' %in% input_names)){
stop('No gene_id column found!')
}
if (!('gRNA' %in% input_names)) {
stop('No gRNA column found!')
}
if(!any(grepl('input', input_names))) {
stop('No input count columns found!')
}
if(!any(grepl('output', input_names))){
stop('No output count columns found!')
}
if (tot_samp < 5e4){
warning('Using less than 50,000 MCMC samples is not recommended for publication quality analyses. Inspect convergence metrics in any case.')
}
if (!is.null(out_dir)) {
dir_ends_in_slash = grepl('/$', out_dir)
if (!dir_ends_in_slash){
out_dir = paste0(out_dir, '/')
}
}
#### Clean up input ----
if (verbose) {
message('Determining input gRNA representation parameters...')
}
sample_depths = dropout_data %>%
gather('sample_id', 'gRNA_count', matches('input|output')) %>%
group_by(.data$sample_id) %>%
summarise(depth_factor = sum(.data$gRNA_count) / 1e6)
# find well represented gRNAs in depth-adjusted input sequencing samples
depth_adj_input = dropout_data %>%
dplyr::select(.data$gRNA, matches('input')) %>%
gather('sample_id', 'gRNA_count', -.data$gRNA) %>%
left_join(sample_depths, by = 'sample_id') %>%
mutate(depth_adj_count = .data$gRNA_count / .data$depth_factor) %>%
dplyr::select(-.data$gRNA_count, -.data$depth_factor)
mean_input = depth_adj_input %>%
group_by(.data$gRNA) %>%
summarise(mean_depth_adj_count = mean(.data$depth_adj_count))
cutoff_point = quantile(mean_input$mean_depth_adj_count,
probs = rep_cutoff)
well_rep = mean_input %>%
filter(.data$mean_depth_adj_count > cutoff_point)
if (plot_rep_cutoff) {
if (verbose) {
message('Plotting representation cutoff. Stop and adjust rep_cutoff if necessary.')
}
depth_adj_input %>%
ggplot(mapping = aes(x = .data$depth_adj_count)) +
geom_histogram(aes(y = .data$..density..),
boundary = 0) +
geom_density(aes(color = .data$sample_id)) +
geom_vline(xintercept = cutoff_point,
lty = 2,
color = 'grey20') +
labs(x = 'Depth adjusted counts in input library',
color = 'Sample ID',
title = 'gRNA representation in input libraries',
subtitle = paste0('Data from gRNAs with average representation below ', round(cutoff_point, digits = 3),
' are discarded')) +
theme_light()
}
#### Estimate prior ----
# This should go into its own function eventually... TODO
if (verbose) {
message('Estimating marginal prior...')
}
multiple_gRNA = dropout_data %>%
filter(.data$gRNA %in% well_rep$gRNA) %>%
dplyr::count(.data$gene_id) %>%
filter(.data$n > 2)
gamma_mle = dropout_data %>%
filter(.data$gRNA %in% well_rep$gRNA,
.data$gene_id %in% multiple_gRNA$gene_id) %>%
gather('sample_id', 'gRNA_count', matches('input|output')) %>%
left_join(sample_depths, by = 'sample_id') %>%
group_by(.data$gene_id, .data$sample_id) %>%
summarise(mean_mle = mean(.data$gRNA_count / .data$depth_factor), # I think these are biased estimators... TODO
size_mle = mean(.data$gRNA_count) ^ 2 /
(var(.data$gRNA_count) - mean(.data$gRNA_count))) %>%
ungroup %>%
filter(.data$size_mle > 0, is.finite(.data$size_mle))
# Fit gamma priors off those
gamma_priors = gamma_mle %>%
mutate(param_type = stringr::str_extract(.data$sample_id, pattern = 'input|output')) %>%
group_by(.data$param_type) %>%
summarise(mean_prior = list(fit_dropout_gamma(.data$mean_mle)),
size_prior = list(fit_dropout_gamma(.data$size_mle))) %>%
mutate(mean_alpha = map_dbl(.data$mean_prior, ~.x$par[1]),
mean_beta = map_dbl(.data$mean_prior, ~.x$par[2]),
size_alpha = map_dbl(.data$size_prior, ~.x$par[1]),
size_beta = map_dbl(.data$size_prior, ~.x$par[2]))
#### Evaluate models ----
if (verbose) {
message('Running model samplers...')
}
start_time = Sys.time()
fit_summary = dropout_data %>%
filter(.data$gRNA %in% well_rep$gRNA) %>%
group_by(.data$gene_id) %>%
nest() %>%
dplyr::rename('gene_data' = 'data') %>%
ungroup %>%
dplyr::mutate(fit_statistics = parallel::mcmapply(run_dropout_sampler,
.data$gene_id, .data$gene_data,
mc.cores = n_cores,
MoreArgs = list(gene_prior = gamma_priors,
tot_samp = tot_samp,
n_warmup = n_warmup,
n_chains = n_chains,
depth_factors = sample_depths,
out_dir = out_dir),
SIMPLIFY = FALSE))
#### Compile results and return summary data frame ----
fit_summary %<>% mutate(lfc_post_mean = map_dbl(.data$fit_statistics,
~.x$mean[.x$parameter == 'log_fold_change']),
`lfc_2.5%` = map_dbl(.data$fit_statistics,
~.x$`2.5%`[.x$parameter == 'log_fold_change']),
`lfc_97.5%` = map_dbl(.data$fit_statistics,
~.x$`97.5%`[.x$parameter == 'log_fold_change'])) %>%
arrange(-abs(.data$lfc_post_mean))
end_time = Sys.time()
time_diff = end_time - start_time
message(paste0('Data for ', n_distinct(dropout_data$gene_id), ' genes analyzed in ',
round(digits = 3, end_time - start_time), ' ', attr(time_diff, 'units')))
return(fit_summary)
}
fit_mpra_mle = function(variant_data,
n_dna,
n_rna,
n_ref,
n_alt,
depth_factors){
data_list = list(n_rna_samples = n_rna,
n_dna_samples = n_dna,
n_ref = n_ref,
n_alt = n_alt,
ref_dna_counts = variant_data %>% filter(tolower(.data$allele) == 'ref') %>% select(matches('DNA')) %>% as.matrix,
alt_dna_counts = variant_data %>% filter(tolower(.data$allele) != 'ref') %>% select(matches('DNA')) %>% as.matrix,
ref_rna_counts = variant_data %>% filter(tolower(.data$allele) == 'ref') %>% select(matches('RNA')) %>% as.matrix,
alt_rna_counts = variant_data %>% filter(tolower(.data$allele) != 'ref') %>% select(matches('RNA')) %>% as.matrix,
rna_depths = depth_factors %>% filter(grepl('RNA', .data$sample_id)) %>% pull(.data$depth_factor),
dna_depths = depth_factors %>% filter(grepl('DNA', .data$sample_id)) %>% pull(.data$depth_factor))
# TODO - make the initializations more user-accessible
variant_mle_fit = rstan::optimizing(object = stanmodels$bc_mpra_model_mle,
data = data_list,
init = list(dna_m = 100,
dna_p = 1,
rna_m = c(1,1),
rna_p = c(1,1)),
hessian = TRUE)
return(variant_mle_fit)
}
fit_all_nb_mle = function(mpra_data,
well_represented,
sample_depths,
n_cores){
n_dna = sum(grepl('DNA', names(mpra_data)))
n_rna = sum(grepl('RNA', names(mpra_data)))
all_nb_mle = mpra_data %>%
filter(.data$barcode %in% well_represented$barcode) %>%
group_by(.data$variant_id) %>%
nest() %>%
dplyr::rename('variant_data' = 'data') %>%
ungroup %>%
mutate(n_ref = map_dbl(.data$variant_data, ~sum(tolower(.x$allele) == 'ref')),
n_alt = map_dbl(.data$variant_data, ~sum(tolower(.x$allele) != 'ref'))) %>%
filter(.data$n_ref > 2 & .data$n_alt > 2) %>%
mutate(mle_fit = parallel::mcmapply(fit_mpra_mle,
.data$variant_data, .data$n_ref, .data$n_alt,
MoreArgs = list(n_dna = n_dna,
n_rna = n_rna,
depth_factors = sample_depths),
SIMPLIFY = FALSE,
mc.cores = n_cores)) %>%
mutate(converged = map_lgl(.data$mle_fit,
~.x$return_code == 0),
ml_estimates = map(.data$mle_fit,
~spread(tibble(par = names(.x$par),
val = .x$par),
par, val))) %>%
unnest(cols = .data$ml_estimates) %>%
ungroup
return(all_nb_mle)
}
fit_gamma_stan = function(mles){
data_list = list(N = length(mles),
mles = mles)
rstan::optimizing(object = stanmodels$bc_mpra_fit_gamma,
data = data_list)
}
fit_weighted_gammas_stan = function(anno_weight_df,
estimates_to_weight) {
anno_weight_df %<>% inner_join(estimates_to_weight, by = 'variant_id')
# converged = map_lgl(gamma_fit, ~.x$return_code == 0)
# ^ you can add this to the first mutate to check for convergence
g_priors = anno_weight_df %>%
group_by(.data$par) %>%
summarise(gamma_fit = list(fit_weighted_gamma_stan(.data$value, .data$weight))) %>%
mutate(alpha_est = map_dbl(.data$gamma_fit, ~.x$par[1]),
beta_est = map_dbl(.data$gamma_fit, ~.x$par[2])) %>% #
separate(col = 'par',
into = c('acid_type', 'prior_type', 'allele'),
sep = '_|\\[') %>%
mutate(acid_type = toupper(.data$acid_type),
prior_type = str_replace_all(.data$prior_type,
c('p' = 'phi_prior',
'm' = 'mu_prior')),
allele = str_replace_all(.data$allele,
c('1\\]' = 'ref',
'2\\]' = 'alt'))) %>%
select(.data$allele, .data$prior_type, 'prior' = .data$gamma_fit, .data$alpha_est, .data$beta_est, .data$acid_type)
split_priors = split(g_priors, # this is purely to make it work with the old, pre-Stan code
f = g_priors$prior_type)
return(split_priors)
}
fit_weighted_gamma_stan = function(mles, weights){
data_list = list(N = length(mles),
mles = mles,
weights = weights)
rstan::optimizing(object = stanmodels$weighted_gamma,
data = data_list)
}
fit_dropout_gamma = function(values){
# This should be deprecated in favor of fit_gamma eventually... TODO
fn_to_min = function(par_vec) {
-sum(dgamma(values,
shape = par_vec[1],
rate = par_vec[2],
log = TRUE))
}
optim_res = stats::optim(par = c(1,1),
fn = fn_to_min)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.