
ss_one_sample <- function(n = NULL, min_n = 30, max_n = 100, increments = 10) {
  if (is.null(n)) {
    n <- seq(min_n, max_n, increments)
    function(x) {
      list(list(n = x))

ss_two_sample <- function(n1 = NULL, n2 = NULL, min_n = 30,
                          max_n = 100, increments = 10) {
  # number of attempts at reaching power given sample size at certain increments
  if (is.null(n1)) {
    n1 <- seq(min_n, max_n, increments)
  if (is.null(n2)) {
    n2 <- seq(min_n, max_n, increments)
  n1_l <- length(n1)
  n2_l <- length(n2)

  # if number of sample size attempts is not equal for both groups then repeat
  # for the one with less tries
  if (n1_l != n2_l) {
    if (n1_l < n2_l) {
      n1 <- c(n1, rep(rev(n1), length.out = n2_l)[(n1_l + 1):n2_l])
    } else {
      n2 <- c(n2, rep(rev(n2), length.out = n1_l)[(n2_l + 1):n1_l])
  ss_attempts <- sapply(
    1:max(c(n1_l, n2_l)),
    function(x) {
      list(list(n1 = n1[x], n2 = n2[x]))

refit_stan_template <- function() {
  list(data = NULL, update_stan_args = NULL)

add_test_item <- function(test_name, passed, check = FALSE,
                          test_set = "pwr_test", store_value = NULL) {
  if (!is.null(store_value) & is.recursive(store_value)) {
    stop("store_value must be an atomic type as seen in is.atomic")

  if (!is.logical(passed)) {
      "passed must be a logical value indicating ",
      "if the test passed or not")

  if (!is.logical(check)) {
      "check must be a logical value indicating ",
      "if the test should be evaluated for stopping")

  slot <- data.table(
    set = test_set, test = test_name, passed = passed,
    check = check, value = store_value)


bayesian_power <- function(stanfit, draw_fun, make_fun, test_fun,
                           sample_sizes = NULL, stan_args = NULL,
                           power_goal = 0.8, max_draws = 250,
                           min_draws = 25, seed = 19851905,
                           print_progress = 25, fileprint = NULL) {
  mcmc_samples <- sum(unlist(lapply(stanfit@sim$permutation, length)))
  n_passes <- length(sample_sizes)

  if (is.null(min_draws) | is.na(min_draws)) {
    min_draws <- max_draws
  if (min_draws > max_draws) {
    stop("min_draws must be less than max_draws")

  draw_order <- sample(
    mcmc_samples, max_draws,
    ifelse(max_draws > mcmc_samples, TRUE, FALSE))
  pwr_record <- power_record()

  for (pass in 1:n_passes) {
    # seed is reset to append simulation values from last pass
    sample_size <- sample_sizes[[pass]]

    for (draw in 1:max_draws) {
      pars <- draw_fun(stanfit, draw_order[draw])
      stan_data <- make_fun(pars, sample_size)
      refitted <- power_stan(stan_data, stan_args)
      test_results <- test_fun(refitted)

      # save record of results from current draw
      pwr_record <- power_record(
        record = pwr_record, draw = draw, pass = pass,
        power_goal = power_goal, sample_size = sample_size,
        test_results = test_results,
        print_progress = print_progress

      # print updated status given print_progress
      if (!is.null(print_progress)) {
        if (pwr_record[[pass]]$counter %% print_progress == 0) {
          print_power_status(record, fileprint)

      # power reached early, return from function
      if (pwr_record[[pass]]$lower_reached) {
        print_power_status(pwr_record, fileprint)
          sprintf("Power stoping point reached early after %d draws."),

      # start checking if need to quit and move on to next sample size
      if (draw > min_draws) {
        upper_less_than_goal <- !pwr_record[[pass]]$upper_reached
        will_fail <- pwr_record[[pass]]$p_goal_met < .05
        if (upper_less_than_goal | will_fail) {
          print_power_status(pwr_record, fileprint)
            "Likelihood of reaching goal is <5%% after %d draws. Trying next.",

    # lower not reached but center power value reached
    if (pwr_record[[pass]]$center_reached) {
      print_power_status(pwr_record, fileprint)
      message("Power stoping point reached.")

  # all attempts failed
  print_power_status(pwr_record, fileprint)
  warning("Power was never reached!")

power_stan <- function(data, stan_args = NULL) {
  arg_updates <- data$update_stan_args

  if (is.null(stan_args)) {
    stan_args <- arg_updates
  } else {
    if (!is.null(arg_updates)) {
      which_updates <- names(arg_updates) %in% names(stan_args)
      # first overwrite stan_args from args_updates
      if (any(which_updates)) {
        to_update <- names(arg_updates)[which_updates]
        stan_args[names(stan_args) %in% to_update] <- arg_updates[
      # then append any additional args from args_updates
      if (any(!which_updates)) {
        stan_args <- c(stan_args, arg_updates[!which_updates])

  if (is.null(stan_args)) {
    stop("No stan arguments found in stan_args or data$update_stan_args. Need at least `file`` or `model_code`")

  stan_args$data <- data$data
  old_opt <- rstan::rstan_options(auto_write = TRUE)
  stanfit <- do.call(rstan::stan, stan_args)
  rstan::rstan_options(auto_write = old_opt)


power_record_blank <- function() {
    pass = 0, draw = 0, counter = 0, total = NA,
    sample_size = NA, power_goal = NA, data = list()))

power_record <- function(record = NULL, draw, pass, power_goal, sample_size,
                         test_results, print_progress = NULL) {
  if (is.null(record)) {
    record <- list()

  r <- length(record)

  if (r > 0) {
    counter <- record[[r]]$counter
  } else {
    counter <- 0

  if (draw == 1) {
    record <- c(record, power_record_blank())
    r <- length(record)
    record[[r]]$counter <- counter
    record[[r]]$pass <- pass
    record[[r]]$sample_size <- sample_size
    record[[r]]$power_goal <- power_goal

  if (missing(test_results)) {
    stop("results argument empty")

  if (is.data.table(test_results)) {
    # test function returned a single test
    test_results[, test_id := 1]
  } else {
    # test function returned a list of tests
    test_results <- data.table::rbindlist(test_results, idcol = "test_id")

  test_results[, value_id := 1:.N, .(test_id, set)]

  if (draw == 1) {
    total <- test_results[, numeric(length(unique(test_id)))]
  } else {
    total <- record[[r]]$total

  total <- total + test_results[, as.integer(passed[1]), test_id][[2]]
  successes <- total + 1
  failures <- 1 + draw - total

    , c(
      "power", "pow_lower",
      "pow_upper") := compute_power_interval(
      successes[test_id], failures[test_id],
      mass = 0.95

  power_goals_met <- power_results_tests(test_results, power_goal, draw)

  # update records
  record[[r]]$draw <- draw
  record[[r]]$counter <- record[[r]]$counter + 1
  record[[r]]$total <- total
  record[[r]]$data[[draw]] <- as.data.frame(test_results)
  record[[r]][names(power_goals_met)] <- power_goals_met


power_results_tests <- function(results, power_goal, draw) {
  max_set_pwr <- results[
    check == TRUE, .(row = .I[which.max(pow_lower)]), .(set)
  last_best <- results[max_set_pwr$row, .(test, power, pow_lower, pow_upper)]
  lower_reached <- last_best[, all(pow_lower >= power_goal)]
  center_reached <- last_best[, all(power >= power_goal)]
  upper_reached <- last_best[, min(pow_upper) > power_goal]

  p_goal_met <- last_best[, pbeta(
    power_goal, pow_upper * draw + 1,
    (1 - pow_upper) * draw + 1)]

  p_goal_met <- max(1 - p_goal_met)

  nlist(last_best, lower_reached, center_reached, upper_reached, p_goal_met)

compute_power_interval <- function(successes, failures, mass = 0.95) {
  if (mass <= 0 | mass >= 1) {
    stop("HDI interval must have a mass between 0 and 1")

  opt_result <- optimize(beta_optimization_fn, c(0, 1 - mass),
    a = successes,
    b = failures, mass = mass, tol = 1e-09)

  left_tail_p <- opt_result$minimum

    power = successes / (successes + failures),
    pow_lb = qbeta(left_tail_p, successes, failures),
    pow_ub = qbeta(mass + left_tail_p, successes, failures))

beta_optimization_fn <- function(x, a, b, mass = 0.95) {
  qbeta(x + mass, a, b) - qbeta(x, a, b)

print_power_status <- function(record, filename = NULL) {
  if (!is.null(filename)) {
    plog_file <- file(filename, open = "wt")
    sink(plog_file, type = "message")

  n_recs <- length(record)
  sample_size <- paste(unlist(record[[n_recs]]$sample_size), collapse = ",")
  pass <- record[[n_recs]]$pass
  draw <- record[[n_recs]]$draw
  counter <- record[[n_recs]]$counter
  last_best <- record[[n_recs]]$last_best
  power_goal <- record[[n_recs]]$power_goal
  p_goal_met <- record[[n_recs]]$p_goal_met

      "Status:\n  Attempt %d, Draw %d, Iter %d, Sample Size %s\n",
      "  prob. of achieving power of %.2f is currently %.2f"),
    pass, draw, counter, sample_size, power_goal, p_goal_met))

  last_best[, message(sprintf(
    "  Power: %.2f [%.2f, %.2f]\t%s",
    power, pow_lower, pow_upper, test)), test]

  if (!is.null(filename)) {
    sink(type = "message")


# power tests ---------------------------------------------------------------------------------
test_interval_type <- function(op, op_list = FALSE) {
  if (op_list) {
      "!", "!=", "neq", "l>", "lgt", "l>=", "lgeq",
      "l<", "llt", "l<=", "lleq", "r>", "rgt", "r>=",
      "rgeq", "r<", "rlt", "r<=", "rleq", "==", "eq"))

  # not equal to
  if (op %in% c("!", "!=", "neq")) {

  # left greater than
  if (op %in% c("l>", "lgt")) {

  # left greater than or equal to
  if (op %in% c("l>=", "lgeq")) {

  # left less than
  if (op %in% c("l<", "llt")) {

  # left less than or equal to
  if (op %in% c("l<=", "lleq")) {

  # right greater than
  if (op %in% c("r>", "rgt")) {

  # right greater than or equal to
  if (op %in% c("r>=", "rgeq")) {

  # right less than
  if (op %in% c("r<", "rlt")) {

  # right less than or equal to
  if (op %in% c("r<=", "rleq")) {

  # is equal to
  if (op %in% c("==", "eq")) {

  stop("unknown interval test operation: ", op)

test_interval <- function(x, type, null = 0) {
  t <- test_interval_type(type)

  # if (is.vector(x)) {
  # x <- matrix(x, ncol = 2)
  # }

  # make sure lower and upper are sorted
  if (!all(apply(
    function(i) {
      i[2] >= i[1]
    }))) {
    stop("x must be an interval with the first value <= the second value")

  lhs <- x[, 1]
  rhs <- x[, 2]

    "neq" = {
      rhs < null | lhs > null
    "lgt" = {
      lhs > null
    "lgeq" = {
      lhs >= null
    "llt" = {
      lhs < null
    "lleq" = {
      lhs <= null
    "rgt" = {
      rhs > null
    "rgeq" = {
      rhs >= null
    "rlt" = {
      rhs < null
    "rleq" = {
      rhs <= null
    "eq" = {
      # ROPE
      if (length(null) != 2) {
        stop("null_value must be a lower and upper region for this option")
      lhs >= null[1] & rhs <= null[2]

# stats ---------------------------------------------------------------------------------------

bayes_p <- function(x, null = 0) {
  n <- length(x)
  p <- sum(x > null) / n
  return(c(leq_null = 1 - p, gt_null = p))

#' Effect size using Hedges' *g*
#' @export
cohens_d <- function(mu_x, sd_x, n_x, mu_y = NULL, sd_y = NULL,
                     n_y = NULL, one_sample_test_val = 0.0) {
  if (missing(mu_x) || missing(sd_x) || missing(n_x)) {
    stop("Missing mean, sd, or sample size for `x`")

  if (is.null(n_x)) {
    stop("Must provide sample size value for `x`")

  use_y <- !is.null(mu_y)

  if (use_y && is.null(sd_y)) {
      stop("Must provide standard deviation value for `y`")

  if (is.null(n_y)) {
    if (use_y && is.null(n_x)) {
      stop("Sample sizes for `x` and `y` cannot both be `NULL`")
    n_y <- n_x

  if (use_y) {
    (mu_x - one_sample_test_val) / sd_x
  } else {
    sd_vec <- cbind(sd_x, sd_y)
    n_vec <- cbind(n_x, n_y)
    sd_p <- pooled_sd(sd_vec, n_vec)
    (mu_x - mu_y) / sd_p

# example functions ---------------------------------------------------------------------------

one_sample_mvt_draw <- function(stanfit, i) {
  post <- rstan::extract(stanfit, pars = c("mu", "Sigma", "nu"))
  pars <- list(mu = post$mu[i, ], Sigma = post$Sigma[i, , ], nu = post$nu[i])

#' One sample multivariate-t
#' Simulation function for one sample multivariate t
#' @param pars list of parameters needed for simulation and to create objects to pass to stan
#' @return sim_fit_template
one_sample_mvt_make <- function(pars, sample_size_list = NULL) {
  if (is.null(sample_size_list)) {
    # use sample size from data if only determining current power
    N <- pars$N
  } else {
    # since this is a one sample, it'll have a single slot `n`
    N <- sample_size_list$n
  # the rest of the code uses parameter values to simulate a new stan data list
  Sigma <- pars$Sigma
  K <- ncol(Sigma)
  nu <- pars$nu
  mu <- pars$mu
  # i.i.d student-t values
  mu_t_zero <- array(rt(N * K, df = nu), c(N, K))
  # reshape t-values according to Sigma and rescale by mu
  Y <- t(apply(
    function(x) {
      mu + x %*% chol(Sigma)
  # use the template to check for some slots not used, such as initialization values
  stan_data <- refit_stan_template()
  stan_data$data <- nlist(Y, K, N)
  inits <- function(chain_id) {
      mu_adjustment = c(-0.25, 0.25, 0.25), nu_minus_one = 36,
      sigma = c(0.7, 0.88, 0.85), chol_corr = diag(3))
  stan_data$update_stan_args <- list(init = inits)

one_sample_mvt_tests <- function(refitted) {
  # extract paramters
  post <- rstan::extract(refitted, pars = c("mu", "sigma"))
  mu <- post$mu
  sigma <- post$sigma

  # mean differences
  emb_fam <- mu[, 2] - mu[, 1]
  par_fam <- mu[, 3] - mu[, 1]
  par_emb <- mu[, 3] - mu[, 2]

  # effect sizes
  #fx_size_EF <- cohens_d(mu[, 2], sigma[, 2], mu[, 1], sigma[, 1]) fix
  #fx_size_PF <- cohens_d(mu[, 3], sigma[, 3], mu[, 1], sigma[, 1])

  # test values
  hdi_EF <- rbaes::hdi(emb_fam, 0.9)
  p_EF <- bayes_p(emb_fam)[1]
  hdi_D_EF <- rbaes::hdi(fx_size_EF, 0.9)
  hdi_PF <- rbaes::hdi(par_fam, 0.9)
  p_PF <- bayes_p(par_fam)[1]
  hdi_D_PF <- rbaes::hdi(fx_size_PF, 0.9)
  hdi_PE <- rbaes::hdi(par_emb, 0.9)
  p_PE <- bayes_p(par_emb)
  p_PE_test <- !as.logical(p_PE[1] < 0.05 | p_PE[2] < 0.05)

      test_name = "HDI E-F > 0",
      passed = test_interval(hdi_EF), check = TRUE,
      test_set = "EvF", store_value = hdi_EF),
    add_test_item("p E-F < .05", p_EF < 0.05,
      check = TRUE, "EvF", p_EF),
      "HDI D > 0.1 | E,F",
      test_interval(hdi_D_EF, null = 0.1, type = "l>"),
      FALSE, "EvF", hdi_D_EF),
    add_test_item("HDI P-F > 0", test_interval(hdi_PF),
      check = TRUE, "PvF", hdi_PF),
    add_test_item("p P-F < .05", p_PF < 0.05,
      check = TRUE, "PvF", p_PF),
      "HDI D > 0.1 | P,F",
      test_interval(hdi_D_PF, null = 0.1, type = "l>"),
      FALSE, "PvF", hdi_D_PF),
      "HDI P-E == 0", !test_interval(hdi_PE),
      FALSE, "PvE", hdi_PE),
    add_test_item("p P-E > .05", p_PE_test, FALSE, "PvE", p_PE)))
iamamutt/rbaes documentation built on May 18, 2019, 1:27 a.m.