ss_one_sample <- function(n = NULL, min_n = 30, max_n = 100, increments = 10) {
if (is.null(n)) {
n <- seq(min_n, max_n, increments)
}
return(sapply(
n,
function(x) {
list(list(n = x))
}))
}
ss_two_sample <- function(n1 = NULL, n2 = NULL, min_n = 30,
max_n = 100, increments = 10) {
# number of attempts at reaching power given sample size at certain increments
if (is.null(n1)) {
n1 <- seq(min_n, max_n, increments)
}
if (is.null(n2)) {
n2 <- seq(min_n, max_n, increments)
}
n1_l <- length(n1)
n2_l <- length(n2)
# if number of sample size attempts is not equal for both groups then repeat
# for the one with less tries
if (n1_l != n2_l) {
if (n1_l < n2_l) {
n1 <- c(n1, rep(rev(n1), length.out = n2_l)[(n1_l + 1):n2_l])
} else {
n2 <- c(n2, rep(rev(n2), length.out = n1_l)[(n2_l + 1):n1_l])
}
}
ss_attempts <- sapply(
1:max(c(n1_l, n2_l)),
function(x) {
list(list(n1 = n1[x], n2 = n2[x]))
})
return(ss_attempts)
}
refit_stan_template <- function() {
list(data = NULL, update_stan_args = NULL)
}
add_test_item <- function(test_name, passed, check = FALSE,
test_set = "pwr_test", store_value = NULL) {
if (!is.null(store_value) & is.recursive(store_value)) {
stop("store_value must be an atomic type as seen in is.atomic")
}
if (!is.logical(passed)) {
stop(
"passed must be a logical value indicating ",
"if the test passed or not")
}
if (!is.logical(check)) {
stop(
"check must be a logical value indicating ",
"if the test should be evaluated for stopping")
}
slot <- data.table(
set = test_set, test = test_name, passed = passed,
check = check, value = store_value)
return(slot)
}
bayesian_power <- function(stanfit, draw_fun, make_fun, test_fun,
sample_sizes = NULL, stan_args = NULL,
power_goal = 0.8, max_draws = 250,
min_draws = 25, seed = 19851905,
print_progress = 25, fileprint = NULL) {
set.seed(seed)
mcmc_samples <- sum(unlist(lapply(stanfit@sim$permutation, length)))
n_passes <- length(sample_sizes)
if (is.null(min_draws) | is.na(min_draws)) {
min_draws <- max_draws
}
if (min_draws > max_draws) {
stop("min_draws must be less than max_draws")
}
draw_order <- sample(
mcmc_samples, max_draws,
ifelse(max_draws > mcmc_samples, TRUE, FALSE))
pwr_record <- power_record()
for (pass in 1:n_passes) {
# seed is reset to append simulation values from last pass
set.seed(seed)
sample_size <- sample_sizes[[pass]]
for (draw in 1:max_draws) {
pars <- draw_fun(stanfit, draw_order[draw])
stan_data <- make_fun(pars, sample_size)
refitted <- power_stan(stan_data, stan_args)
test_results <- test_fun(refitted)
# save record of results from current draw
pwr_record <- power_record(
record = pwr_record, draw = draw, pass = pass,
power_goal = power_goal, sample_size = sample_size,
test_results = test_results,
print_progress = print_progress
)
# print updated status given print_progress
if (!is.null(print_progress)) {
if (pwr_record[[pass]]$counter %% print_progress == 0) {
print_power_status(record, fileprint)
}
}
# power reached early, return from function
if (pwr_record[[pass]]$lower_reached) {
print_power_status(pwr_record, fileprint)
message(
sprintf("Power stoping point reached early after %d draws."),
draw)
return(pwr_record)
}
# start checking if need to quit and move on to next sample size
if (draw > min_draws) {
upper_less_than_goal <- !pwr_record[[pass]]$upper_reached
will_fail <- pwr_record[[pass]]$p_goal_met < .05
if (upper_less_than_goal | will_fail) {
print_power_status(pwr_record, fileprint)
message(sprintf(
"Likelihood of reaching goal is <5%% after %d draws. Trying next.",
draw))
break
}
}
}
# lower not reached but center power value reached
if (pwr_record[[pass]]$center_reached) {
print_power_status(pwr_record, fileprint)
message("Power stoping point reached.")
return(pwr_record)
}
}
# all attempts failed
print_power_status(pwr_record, fileprint)
warning("Power was never reached!")
return(pwr_record)
}
power_stan <- function(data, stan_args = NULL) {
arg_updates <- data$update_stan_args
if (is.null(stan_args)) {
stan_args <- arg_updates
} else {
if (!is.null(arg_updates)) {
which_updates <- names(arg_updates) %in% names(stan_args)
# first overwrite stan_args from args_updates
if (any(which_updates)) {
to_update <- names(arg_updates)[which_updates]
stan_args[names(stan_args) %in% to_update] <- arg_updates[
which_updates
]
}
# then append any additional args from args_updates
if (any(!which_updates)) {
stan_args <- c(stan_args, arg_updates[!which_updates])
}
}
}
if (is.null(stan_args)) {
stop("No stan arguments found in stan_args or data$update_stan_args. Need at least `file`` or `model_code`")
}
stan_args$data <- data$data
old_opt <- rstan::rstan_options(auto_write = TRUE)
stanfit <- do.call(rstan::stan, stan_args)
rstan::rstan_options(auto_write = old_opt)
return(stanfit)
}
power_record_blank <- function() {
list(list(
pass = 0, draw = 0, counter = 0, total = NA,
sample_size = NA, power_goal = NA, data = list()))
}
power_record <- function(record = NULL, draw, pass, power_goal, sample_size,
test_results, print_progress = NULL) {
if (is.null(record)) {
record <- list()
return(record)
}
r <- length(record)
if (r > 0) {
counter <- record[[r]]$counter
} else {
counter <- 0
}
if (draw == 1) {
record <- c(record, power_record_blank())
r <- length(record)
record[[r]]$counter <- counter
record[[r]]$pass <- pass
record[[r]]$sample_size <- sample_size
record[[r]]$power_goal <- power_goal
}
if (missing(test_results)) {
stop("results argument empty")
}
if (is.data.table(test_results)) {
# test function returned a single test
test_results[, test_id := 1]
} else {
# test function returned a list of tests
test_results <- data.table::rbindlist(test_results, idcol = "test_id")
}
test_results[, value_id := 1:.N, .(test_id, set)]
if (draw == 1) {
total <- test_results[, numeric(length(unique(test_id)))]
} else {
total <- record[[r]]$total
}
total <- total + test_results[, as.integer(passed[1]), test_id][[2]]
successes <- total + 1
failures <- 1 + draw - total
test_results[
, c(
"power", "pow_lower",
"pow_upper") := compute_power_interval(
successes[test_id], failures[test_id],
mass = 0.95
),
test_id
]
power_goals_met <- power_results_tests(test_results, power_goal, draw)
# update records
record[[r]]$draw <- draw
record[[r]]$counter <- record[[r]]$counter + 1
record[[r]]$total <- total
record[[r]]$data[[draw]] <- as.data.frame(test_results)
record[[r]][names(power_goals_met)] <- power_goals_met
return(record)
}
power_results_tests <- function(results, power_goal, draw) {
max_set_pwr <- results[
check == TRUE, .(row = .I[which.max(pow_lower)]), .(set)
]
last_best <- results[max_set_pwr$row, .(test, power, pow_lower, pow_upper)]
lower_reached <- last_best[, all(pow_lower >= power_goal)]
center_reached <- last_best[, all(power >= power_goal)]
upper_reached <- last_best[, min(pow_upper) > power_goal]
p_goal_met <- last_best[, pbeta(
power_goal, pow_upper * draw + 1,
(1 - pow_upper) * draw + 1)]
p_goal_met <- max(1 - p_goal_met)
nlist(last_best, lower_reached, center_reached, upper_reached, p_goal_met)
}
compute_power_interval <- function(successes, failures, mass = 0.95) {
if (mass <= 0 | mass >= 1) {
stop("HDI interval must have a mass between 0 and 1")
}
opt_result <- optimize(beta_optimization_fn, c(0, 1 - mass),
a = successes,
b = failures, mass = mass, tol = 1e-09)
left_tail_p <- opt_result$minimum
data.table(
power = successes / (successes + failures),
pow_lb = qbeta(left_tail_p, successes, failures),
pow_ub = qbeta(mass + left_tail_p, successes, failures))
}
beta_optimization_fn <- function(x, a, b, mass = 0.95) {
qbeta(x + mass, a, b) - qbeta(x, a, b)
}
print_power_status <- function(record, filename = NULL) {
if (!is.null(filename)) {
plog_file <- file(filename, open = "wt")
sink(plog_file)
sink(plog_file, type = "message")
}
n_recs <- length(record)
sample_size <- paste(unlist(record[[n_recs]]$sample_size), collapse = ",")
pass <- record[[n_recs]]$pass
draw <- record[[n_recs]]$draw
counter <- record[[n_recs]]$counter
last_best <- record[[n_recs]]$last_best
power_goal <- record[[n_recs]]$power_goal
p_goal_met <- record[[n_recs]]$p_goal_met
message(sprintf(
paste0(
"Status:\n Attempt %d, Draw %d, Iter %d, Sample Size %s\n",
" prob. of achieving power of %.2f is currently %.2f"),
pass, draw, counter, sample_size, power_goal, p_goal_met))
last_best[, message(sprintf(
" Power: %.2f [%.2f, %.2f]\t%s",
power, pow_lower, pow_upper, test)), test]
if (!is.null(filename)) {
sink(type = "message")
sink()
}
return(invisible())
}
# power tests ---------------------------------------------------------------------------------
test_interval_type <- function(op, op_list = FALSE) {
if (op_list) {
return(c(
"!", "!=", "neq", "l>", "lgt", "l>=", "lgeq",
"l<", "llt", "l<=", "lleq", "r>", "rgt", "r>=",
"rgeq", "r<", "rlt", "r<=", "rleq", "==", "eq"))
}
# not equal to
if (op %in% c("!", "!=", "neq")) {
return("neq")
}
# left greater than
if (op %in% c("l>", "lgt")) {
return("lgt")
}
# left greater than or equal to
if (op %in% c("l>=", "lgeq")) {
return("lgeq")
}
# left less than
if (op %in% c("l<", "llt")) {
return("llt")
}
# left less than or equal to
if (op %in% c("l<=", "lleq")) {
return("lleq")
}
# right greater than
if (op %in% c("r>", "rgt")) {
return("rgt")
}
# right greater than or equal to
if (op %in% c("r>=", "rgeq")) {
return("rgeq")
}
# right less than
if (op %in% c("r<", "rlt")) {
return("rlt")
}
# right less than or equal to
if (op %in% c("r<=", "rleq")) {
return("rleq")
}
# is equal to
if (op %in% c("==", "eq")) {
return("eq")
}
stop("unknown interval test operation: ", op)
}
test_interval <- function(x, type, null = 0) {
t <- test_interval_type(type)
# if (is.vector(x)) {
# x <- matrix(x, ncol = 2)
# }
# make sure lower and upper are sorted
if (!all(apply(
x,
1,
function(i) {
i[2] >= i[1]
}))) {
stop("x must be an interval with the first value <= the second value")
}
lhs <- x[, 1]
rhs <- x[, 2]
switch(t,
"neq" = {
rhs < null | lhs > null
},
"lgt" = {
lhs > null
},
"lgeq" = {
lhs >= null
},
"llt" = {
lhs < null
},
"lleq" = {
lhs <= null
},
"rgt" = {
rhs > null
},
"rgeq" = {
rhs >= null
},
"rlt" = {
rhs < null
},
"rleq" = {
rhs <= null
},
"eq" = {
# ROPE
if (length(null) != 2) {
stop("null_value must be a lower and upper region for this option")
}
lhs >= null[1] & rhs <= null[2]
},
NULL)
}
# stats ---------------------------------------------------------------------------------------
bayes_p <- function(x, null = 0) {
n <- length(x)
p <- sum(x > null) / n
return(c(leq_null = 1 - p, gt_null = p))
}
#' Effect size using Hedges' *g*
#' @export
cohens_d <- function(mu_x, sd_x, n_x, mu_y = NULL, sd_y = NULL,
n_y = NULL, one_sample_test_val = 0.0) {
if (missing(mu_x) || missing(sd_x) || missing(n_x)) {
stop("Missing mean, sd, or sample size for `x`")
}
if (is.null(n_x)) {
stop("Must provide sample size value for `x`")
}
use_y <- !is.null(mu_y)
if (use_y && is.null(sd_y)) {
stop("Must provide standard deviation value for `y`")
}
if (is.null(n_y)) {
if (use_y && is.null(n_x)) {
stop("Sample sizes for `x` and `y` cannot both be `NULL`")
}
n_y <- n_x
}
if (use_y) {
(mu_x - one_sample_test_val) / sd_x
} else {
sd_vec <- cbind(sd_x, sd_y)
n_vec <- cbind(n_x, n_y)
sd_p <- pooled_sd(sd_vec, n_vec)
(mu_x - mu_y) / sd_p
}
}
# example functions ---------------------------------------------------------------------------
one_sample_mvt_draw <- function(stanfit, i) {
post <- rstan::extract(stanfit, pars = c("mu", "Sigma", "nu"))
pars <- list(mu = post$mu[i, ], Sigma = post$Sigma[i, , ], nu = post$nu[i])
return(pars)
}
#' One sample multivariate-t
#'
#' Simulation function for one sample multivariate t
#'
#'
#' @param pars list of parameters needed for simulation and to create objects to pass to stan
#'
#' @return sim_fit_template
one_sample_mvt_make <- function(pars, sample_size_list = NULL) {
if (is.null(sample_size_list)) {
# use sample size from data if only determining current power
N <- pars$N
} else {
# since this is a one sample, it'll have a single slot `n`
N <- sample_size_list$n
}
# the rest of the code uses parameter values to simulate a new stan data list
Sigma <- pars$Sigma
K <- ncol(Sigma)
nu <- pars$nu
mu <- pars$mu
# i.i.d student-t values
mu_t_zero <- array(rt(N * K, df = nu), c(N, K))
# reshape t-values according to Sigma and rescale by mu
Y <- t(apply(
mu_t_zero,
1,
function(x) {
mu + x %*% chol(Sigma)
}))
# use the template to check for some slots not used, such as initialization values
stan_data <- refit_stan_template()
stan_data$data <- nlist(Y, K, N)
inits <- function(chain_id) {
list(
mu_adjustment = c(-0.25, 0.25, 0.25), nu_minus_one = 36,
sigma = c(0.7, 0.88, 0.85), chol_corr = diag(3))
}
stan_data$update_stan_args <- list(init = inits)
return(stan_data)
}
one_sample_mvt_tests <- function(refitted) {
# extract paramters
post <- rstan::extract(refitted, pars = c("mu", "sigma"))
mu <- post$mu
sigma <- post$sigma
# mean differences
emb_fam <- mu[, 2] - mu[, 1]
par_fam <- mu[, 3] - mu[, 1]
par_emb <- mu[, 3] - mu[, 2]
# effect sizes
#fx_size_EF <- cohens_d(mu[, 2], sigma[, 2], mu[, 1], sigma[, 1]) fix
#fx_size_PF <- cohens_d(mu[, 3], sigma[, 3], mu[, 1], sigma[, 1])
# test values
hdi_EF <- rbaes::hdi(emb_fam, 0.9)
p_EF <- bayes_p(emb_fam)[1]
hdi_D_EF <- rbaes::hdi(fx_size_EF, 0.9)
hdi_PF <- rbaes::hdi(par_fam, 0.9)
p_PF <- bayes_p(par_fam)[1]
hdi_D_PF <- rbaes::hdi(fx_size_PF, 0.9)
hdi_PE <- rbaes::hdi(par_emb, 0.9)
p_PE <- bayes_p(par_emb)
p_PE_test <- !as.logical(p_PE[1] < 0.05 | p_PE[2] < 0.05)
return(list(
add_test_item(
test_name = "HDI E-F > 0",
passed = test_interval(hdi_EF), check = TRUE,
test_set = "EvF", store_value = hdi_EF),
add_test_item("p E-F < .05", p_EF < 0.05,
check = TRUE, "EvF", p_EF),
add_test_item(
"HDI D > 0.1 | E,F",
test_interval(hdi_D_EF, null = 0.1, type = "l>"),
FALSE, "EvF", hdi_D_EF),
add_test_item("HDI P-F > 0", test_interval(hdi_PF),
check = TRUE, "PvF", hdi_PF),
add_test_item("p P-F < .05", p_PF < 0.05,
check = TRUE, "PvF", p_PF),
add_test_item(
"HDI D > 0.1 | P,F",
test_interval(hdi_D_PF, null = 0.1, type = "l>"),
FALSE, "PvF", hdi_D_PF),
add_test_item(
"HDI P-E == 0", !test_interval(hdi_PE),
FALSE, "PvE", hdi_PE),
add_test_item("p P-E > .05", p_PE_test, FALSE, "PvE", p_PE)))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.