#' @export
update_prior.beta_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
if (!is.null(evidence)) {
if (!is.numeric(evidence) || any(!(evidence %in% c(0, 1)))) {
stop('data passed to update_prior for beta distribution must be a binary vector')
}
num_success <- sum(evidence == 1)
num_failure <- sum(evidence == 0)
} else {
if (is.null(stats)) {
stop('If evidence is NULL, stats cannot be NULL')
}
num_success <- stats[['num_obs']] * stats[['observed_rate']]
num_failure <- stats[['num_obs']] - num_success
}
return(beta_dist(alpha = prior[['alpha']] + num_success
, beta = prior[['beta']] + num_failure))
}
#' @export
update_prior.normal_gamma_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
if (!is.null(evidence)) {
if (!is.numeric(evidence)) {
stop('evidence must be a numeric vector')
}
num_obs <- length(evidence)
avg <- mean(evidence)
sse <- sum((evidence - avg) ^ 2)
} else {
if (is.null(stats)) {
stop('If evidence is NULL, stats cannot be NULL')
}
num_obs <- stats[['num_obs']]
avg <- stats[['avg']]
std_dev <- stats[['std_dev']]
sse <- (std_dev ^ 2) * (num_obs - 1)
}
beta_term <- (prior[['lambda']] * num_obs * (avg - prior[['mu']]) ^ 2) / (prior[['lambda']] + num_obs)
return(normal_gamma_dist(mu = (prior[['lambda']] * prior[['mu']] + num_obs * avg) / (prior[['lambda']] + num_obs)
, lambda = prior[['lambda']] + num_obs
, alpha = prior[['alpha']] + num_obs / 2
, beta = prior[['beta']] + (sse + beta_term) / 2))
}
#' @export
update_prior.gamma_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
if (!is.null(evidence)) {
if (any(evidence < 0)) {
stop('evidence must be a positive integer vector')
}
num_sessions <- length(evidence)
observed_count <- sum(evidence)
} else {
if (is.null(stats)) {
stop('If evidence is NULL, stats cannot be NULL')
}
num_sessions <- stats[['num_sessions']]
observed_count <- stats[['observed_count']]
}
return(gamma_dist(alpha = prior[['alpha']] + observed_count
, beta = prior[['beta']] + num_sessions))
}
#' @title Update Prior Parameters
#' @name update_prior
#' @description Use observed data to update parameters of prior distribution
#' @export
#' @importFrom purrr map
#' @param prior An object of class \code{'beta_dist'}, \code{'normal_gamma_dist'}
#' , or \code{'gamma_dist'} that specifies the parameters of some distribution
#' @param evidence A numeric vector that contains observed data. Default is \code{NULL},
#' will override \code{stats} if specified.
#' @param stats An object of class \code{'beta_stats'}, \code{'normal_gamma_stats'}, or
#' \code{'gamma_stats'} that contains sufficient statistics for the
#' update. Default is \code{NULL}, will be ignored if \code{evidence}
#' is specified.
#' @param ... Arguments to be passed onto other methods
update_prior <- function(prior, evidence = NULL, stats = NULL, ...) {
UseMethod('update_prior')
}
#' @title Update the hyper-parameters of prior distributions
#' @name update_priors
#' @description This function updates the hyper-parameters of our prior distributions
#' @inheritParams ab_arguments
#' @param evidence_dt A data.table containing the raw data generated by each variant.
#' Column names must be identical to variant names in \code{priors}
#' @param stats_dt A data.table containing statistics about the raw data generated by each
#' variant.
#' @export
#' @importFrom purrr map
#' @return A list of distribution objects that represented the updated prior distributions
update_priors <- function(priors, evidence_dt = NULL, stats_dt = NULL) {
variants <- names(priors)
if (!is.null(evidence_dt)) {
validate_dt(evidence_dt, expected_cols = variants)
posteriors <- purrr::map(variants, function(x) {
update_prior(prior = priors[[x]]
, evidence = evidence_dt[[x]])
})
} else {
validate_stats_dt(stats_dt, class(priors[['a']]))
posteriors <- purrr::map(variants, function(x) {
update_prior(prior = priors[[x]]
, stats = as.list(stats_dt[variant == x, !'variant']))
})
}
names(posteriors) <- names(priors)
return(posteriors)
}
#' @title Check That Data Table Matches Expectations
#' @name validate_dt
#' @description Use this function to assert certain properties of a data.table
#' @export
#' @param dt A data.table
#' @param expected_cols A vector of column names that \code{dt} must have
#' @return \code{NULL} if successful, else a fatal error
validate_dt <- function(dt, expected_cols) {
if (!inherits(dt, 'data.table')) {
stop('dt must be a data.table')
}
cols <- names(dt)
missing_cols <- setdiff(expected_cols, cols)
if (length(missing_cols) > 0) {
stop(paste('missing columns:', paste(missing_cols, collapse = ', ')))
}
return(NULL)
}
validate_stats_dt <- function(dt, dist_name) {
if (!inherits(dt, 'data.table')) {
stop('stats_dt must be a data.table')
}
cols <- names(dt)
if (identical(dist_name, 'beta_dist')) {
expected <- c('variant', 'num_obs', 'observed_rate')
} else if (identical(dist_name, 'normal_gamma_dist')) {
expected <- c('variant', 'num_obs', 'avg', 'std_dev')
} else if (identical(dist_name, 'gamma_dist')) {
expected <- c('variant', 'num_sessions', 'observed_count')
} else {
stop('Unsupported prior distribution type')
}
if (!identical(sort(cols), sort(expected))) {
stop(paste0('bad column names in stats_dt provided to update_priors. expected:'
, paste(expected, collapse = ', '), 'received:'
, paste(cols, collapse = ', ')))
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.