R/update_priors.R

#' @export
update_prior.beta_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
    if (!is.null(evidence)) {
        if (!is.numeric(evidence) || any(!(evidence %in% c(0, 1)))) {
            stop('data passed to update_prior for beta distribution must be a binary vector')
        }
        num_success <- sum(evidence == 1)
        num_failure <- sum(evidence == 0)
    } else {
        if (is.null(stats)) {
            stop('If evidence is NULL, stats cannot be NULL')
        }
        num_success <- stats[['num_obs']] * stats[['observed_rate']]
        num_failure <- stats[['num_obs']] - num_success
    }
    return(beta_dist(alpha = prior[['alpha']] + num_success
                     , beta = prior[['beta']] + num_failure))
}

#' @export
update_prior.normal_gamma_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
    if (!is.null(evidence)) {
        if (!is.numeric(evidence)) {
            stop('evidence must be a numeric vector')
        }
        num_obs <- length(evidence)
        avg <- mean(evidence)
        sse <- sum((evidence - avg) ^ 2)
    } else {
        if (is.null(stats)) {
            stop('If evidence is NULL, stats cannot be NULL')
        }
        num_obs <- stats[['num_obs']]
        avg <- stats[['avg']]
        std_dev <- stats[['std_dev']]
        sse <- (std_dev ^ 2) * (num_obs - 1)
    }
    
    beta_term <- (prior[['lambda']] * num_obs * (avg - prior[['mu']]) ^ 2) / (prior[['lambda']] + num_obs)
    return(normal_gamma_dist(mu = (prior[['lambda']] * prior[['mu']] + num_obs * avg) / (prior[['lambda']] + num_obs)
                             , lambda = prior[['lambda']] + num_obs
                             , alpha = prior[['alpha']] + num_obs / 2
                             , beta = prior[['beta']] + (sse + beta_term) / 2))
}

#' @export
update_prior.gamma_dist <- function(prior, evidence = NULL, stats = NULL, ...) {
    if (!is.null(evidence)) {
        if (any(evidence < 0)) {
            stop('evidence must be a positive integer vector')
        }
        num_sessions <- length(evidence)
        observed_count <- sum(evidence)
    } else {
        if (is.null(stats)) {
            stop('If evidence is NULL, stats cannot be NULL')
        }
        num_sessions <- stats[['num_sessions']]
        observed_count <- stats[['observed_count']]
    }
    return(gamma_dist(alpha = prior[['alpha']] + observed_count
                      , beta = prior[['beta']] + num_sessions))
}

#' @title Update Prior Parameters
#' @name update_prior
#' @description Use observed data to update parameters of prior distribution
#' @export
#' @importFrom purrr map
#' @param prior An object of class \code{'beta_dist'}, \code{'normal_gamma_dist'}
#'             , or \code{'gamma_dist'} that specifies the parameters of some distribution
#' @param evidence A numeric vector that contains observed data. Default is \code{NULL},
#'                 will override \code{stats} if specified.
#' @param stats An object of class \code{'beta_stats'}, \code{'normal_gamma_stats'}, or
#'              \code{'gamma_stats'} that contains sufficient statistics for the
#'              update. Default is \code{NULL}, will be ignored if \code{evidence}
#'              is specified.
#' @param ... Arguments to be passed onto other methods
update_prior <- function(prior, evidence = NULL, stats = NULL, ...) {
    UseMethod('update_prior')
}

#' @title Update the hyper-parameters of prior distributions
#' @name update_priors
#' @description This function updates the hyper-parameters of our prior distributions
#' @inheritParams ab_arguments
#' @param evidence_dt A data.table containing the raw data generated by each variant.
#'                    Column names must be identical to variant names in \code{priors}
#' @param stats_dt A data.table containing statistics about the raw data generated by each
#'                 variant.
#' @export
#' @importFrom purrr map
#' @return A list of distribution objects that represented the updated prior distributions
update_priors <- function(priors, evidence_dt = NULL, stats_dt = NULL) {
    variants <- names(priors)
    
    if (!is.null(evidence_dt)) {
        validate_dt(evidence_dt, expected_cols = variants)
        posteriors <- purrr::map(variants, function(x) {
            update_prior(prior = priors[[x]]
                         , evidence = evidence_dt[[x]])
        })
    } else {
        validate_stats_dt(stats_dt, class(priors[['a']]))
        posteriors <- purrr::map(variants, function(x) {
            update_prior(prior = priors[[x]]
                         , stats = as.list(stats_dt[variant == x, !'variant']))
        })
    }
    
    names(posteriors) <- names(priors)
    return(posteriors)
}

#' @title Check That Data Table Matches Expectations
#' @name validate_dt
#' @description Use this function to assert certain properties of a data.table
#' @export
#' @param dt A data.table
#' @param expected_cols A vector of column names that \code{dt} must have
#' @return \code{NULL} if successful, else a fatal error
validate_dt <- function(dt, expected_cols) {
    if (!inherits(dt, 'data.table')) {
        stop('dt must be a data.table')
    }
    
    cols <- names(dt)
    missing_cols <- setdiff(expected_cols, cols)
    if (length(missing_cols) > 0) {
        stop(paste('missing columns:', paste(missing_cols, collapse = ', ')))
    }
    return(NULL)
}

validate_stats_dt <- function(dt, dist_name) {
    if (!inherits(dt, 'data.table')) {
        stop('stats_dt must be a data.table')
    }
    
    cols <- names(dt)
    if (identical(dist_name, 'beta_dist')) {
        expected <- c('variant', 'num_obs', 'observed_rate')
    } else if (identical(dist_name, 'normal_gamma_dist')) {
        expected <- c('variant', 'num_obs', 'avg', 'std_dev')
    } else if (identical(dist_name, 'gamma_dist')) {
        expected <- c('variant', 'num_sessions', 'observed_count')
    } else {
        stop('Unsupported prior distribution type')
    }
    if (!identical(sort(cols), sort(expected))) {
        stop(paste0('bad column names in stats_dt provided to update_priors. expected:'
                    , paste(expected, collapse = ', '), 'received:'
                    , paste(cols, collapse = ', ')))
    }
}
convoyinc/abayes documentation built on May 12, 2019, 1:34 a.m.