R/calculate_power_curves.R

Defines functions get_power_curve_output calculate_power_curves

Documented in calculate_power_curves get_power_curve_output

#' @title Calculate Power Curves
#'
#' @description Calculate and optionally plot power curves for different effect sizes and
#' trial counts. This function takes a
#'
#' @param trials A numeric vector indicating the trial(s) used when computing the power curve. If a single
#' value, this will be fixed and only `effectsize` will be varied.
#' @param effectsize  Default `1`. A numeric vector indicating the effect size(s) used when computing the power curve. If a single
#' value, this will be fixed and only `trials` will be varied. If using a length-2 effect size with `eval_design_mc()` (such as
#' a binomial probability interval), the effect size pairs can be input as entries in a list.
#' @param candidateset Default `NULL`. The candidate set (see `gen_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `gen_args`.
#' @param model Default `NULL`. The model (see `gen_design()` and `eval_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `gen_args`/`eval_args`.
#' @param alpha Default `0.05`. The allowable Type-I error rate (see `eval_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `eval_args`.
#' @param gen_args Default `list()`. A list of argument/value pairs to specify the design generation parameters for `gen_design()`.
#' @param eval_function Default `"eval_design"`. A string (or function) specifying the skpr power evaluation function.
#' Can also be `"eval_design_mc"`, `"eval_design_survival_mc"`, and `"eval_design_custom_mc"`.
#' @param eval_args Default `list()`. A list of argument/value pairs to specify the design power evaluation parameters for `eval_function`.
#' @param random_seed Default `123`. The random seed used to generate and then evaluate the design. The seed is set right before design generation.
#' @param iterate_seed Default `FALSE`. This will iterate the random seed with each new design. Set this to `TRUE` to add more variability to the design generation process.
#' @param plot_results Default `TRUE`. Whether to print out a plot of the power curves in addition to the data frame of results. Requires `ggplot2`.
#' @param auto_scale Default `TRUE`. Whether to automatically scale the y-axis to 0 and 1.
#' @param ggplot_elements Default `list()`. Extra `ggplot2` elements to customize the plot, passed in as elements in a list.
#' @param x_breaks Default `NULL`, automaticly generated by ggplot2.
#' @param y_breaks Default `seq(0,1,by=0.1)`. Y-axis breaks.
#'
#'@return A data.frame of power values with design generation information.
#'@export
#'@examples
#'if(skpr:::run_documentation()) {
#'cand_set = expand.grid(brew_temp = c(80, 85, 90),
#'                       altitude = c(0, 2000, 4000),
#'                       bean_sun = c("low", "partial", "high"))
#'#Plot power for a linear model with all interactions
#'calculate_power_curves(trials=seq(10,60,by=1),
#'                       candidateset = cand_set,
#'                       model = ~.*.,
#'                       alpha = 0.05,
#'                       effectsize = 1,
#'                       eval_function = "eval_design")
#'}
#'if(skpr:::run_documentation()) {
#'#Add multiple effect sizes
#'calculate_power_curves(trials=seq(10,60,by=1),
#'                       candidateset = cand_set,
#'                       model = ~.*.,
#'                       alpha = 0.05,
#'                       effectsize = c(1,2),
#'                       eval_function = "eval_design")
#'}
#'if(skpr:::run_documentation()) {
#'#Generate power curve for a binomial model
#'calculate_power_curves(trials=seq(50,150,by=10),
#'                       candidateset = cand_set,
#'                       model = ~.,
#'                       effectsize = c(0.6,0.9),
#'                       eval_function = "eval_design_mc",
#'                       eval_args = list(nsim = 100, glmfamily = "binomial"))
#'}
#'if(skpr:::run_documentation()) {
#'#Generate power curve for a binomial model and multiple effect sizes
#'calculate_power_curves(trials=seq(50,150,by=10),
#'                       candidateset = cand_set,
#'                       model = ~.,
#'                       effectsize = list(c(0.5,0.9),c(0.6,0.9)),
#'                       eval_function = "eval_design_mc",
#'                       eval_args = list(nsim = 100, glmfamily = "binomial"))
#'}
calculate_power_curves = function(trials,
                                  effectsize = 1,
                                  candidateset = NULL,
                                  model = NULL,
                                  alpha = 0.05,
                                  gen_args = list(),
                                  eval_function = "eval_design",
                                  eval_args = list(),
                                  random_seed = 123,
                                  iterate_seed = FALSE,
                                  plot_results = TRUE,
                                  auto_scale = TRUE,
                                  x_breaks = NULL,
                                  y_breaks = seq(0,1,by=0.1),
                                  ggplot_elements = list()) {
  if(!inherits(ggplot_elements, "list")) {
    stop("`ggplot_elements` must be a list of ggplot2 objects")
  }
  eval_function = as.character(substitute(eval_function))

  if(!eval_function %in% c("eval_design", "eval_design_mc", "eval_design_survival_mc", "eval_design_custom_mc")) {
    stop("skpr: eval_function `",eval_function,"` not recognized as one of `eval_design`, `eval_design_mc`, `eval_design_survival_mc`, or `eval_design_custom_mc`")
  }
  if(!inherits(eval_args,"list") || !inherits(gen_args, "list")) {
    stop("skpr: Both `eval_args` and `gen_args` must be of class `list`")
  }
  if(!is.null(eval_args[["glmfamily"]])) {
    if(eval_args[["glmfamily"]] == "gaussian") {
      length_2_effect = FALSE
    } else {
      length_2_effect = TRUE
    }
  } else {
    length_2_effect = FALSE
  }

  if(length_2_effect && inherits(effectsize,"list") && !all(unlist(lapply(effectsize, length)) == 2)) {
    stop("skpr: If passing in a list of effectsizes for evaluation functions that require two values, all effect sizes must be length-2 vectors in the list")
  }
  if(length_2_effect && length(effectsize) == 2 && is.numeric(effectsize) ) {
    effectsize = list(effectsize)
  }
  if(length(effectsize) == 1) {
    only_trials = TRUE
  } else {
    only_trials = FALSE
  }
  if(length(trials) == 1) {
    only_effect = TRUE
  } else {
    only_effect = FALSE
  }
  if(length(trials) > 1 && length(effectsize) > 1) {
    grid_plot = TRUE
  } else {
    grid_plot = FALSE
  }
  n_designs = length(trials) * length(effectsize)
  pb = progress::progress_bar$new(format = "Calculating Power Curve :current/:total [:bar] ETA::eta",
                                  total = n_designs, width=120)
  if(any(trials < 2)) {
    trials = trials[trials > 1]
    warning("skpr: removing trials less than 2 runs--those values can never result in a valid design.")
  }
  if(!is.null(candidateset)) {
    if(!is.null(gen_args[["candidateset"]])) {
      gen_args[["candidateset"]] = NULL
    }
    gen_args = append(gen_args, list(candidateset = candidateset))
  }
  if(!is.null(model)) {
    if(!is.null(gen_args[["model"]])) {
      gen_args[["model"]] = NULL
    }
    gen_args = append(gen_args, list(model = model))
  }
  if(!is.null(model)) {
    if(!is.null(eval_args[["model"]])) {
      eval_args[["model"]] = NULL
    }
    eval_args = append(eval_args, list(model = model))
  }
  if(!is.null(alpha)) {
    if(!is.null(eval_args[["alpha"]])) {
      eval_args[["alpha"]] = NULL
    }
    eval_args = append(eval_args, list(alpha = alpha))
  }
  if(!is.null(effectsize)) {
    if(!is.null(eval_args[["effectsize"]])) {
      eval_args[["effectsize"]] = NULL
    }
    eval_args = append(eval_args, list(effectsize = effectsize))
  }
  gen_args$timer = FALSE
  if(eval_function %in% c("eval_design_mc", "eval_design_survival_mc", "eval_design_custom_mc")) {
    eval_args$progress = FALSE
  }
  #Capture the output from the simulation to provide a table of errors
  capture_output = function(fun) {
    function(...) {
      warn = err = ""
      res = withCallingHandlers(
        tryCatch(fun(...), error=function(e) {
          err <<- conditionMessage(e)
          NA
        }), warning=function(w) {
          warn <<- append(warn, conditionMessage(w))
          invokeRestart("muffleWarning")
        })
      attr(res, "warn") = warn
      attr(res, "err") = err
      return(res)
    }
  }
  result_dataframe = list()
  counter = 1
  gen_errors = list()
  gen_warnings = list()
  eval_errors = list()
  eval_warnings = list()
  for(ef in effectsize) {
    for(trial in trials) {
      pb$tick()
      if(iterate_seed) {
        recorded_seed = random_seed + counter
      } else {
        recorded_seed = random_seed
      }
      set.seed(recorded_seed)
      gen_args["trials"] = trial

      gen_fun = function() {
        do.call("gen_design", gen_args)
      }
      temp_design =  capture_output(gen_fun)()

      gen_errors[[counter]] = data.frame(trials=trial, seed = recorded_seed,err=attr(temp_design,"err"))
      gen_warnings[[counter]] = data.frame(trials=trial, seed = recorded_seed,warn=attr(temp_design,"warn"))
      if(all(!is.na(temp_design))) {
        attr(temp_design,"err") = NULL
        attr(temp_design,"warn") = NULL

        if(!is.null(eval_args[["design"]])) {
          eval_args[["design"]] = NULL
        }

        eval_args[["design"]] = temp_design
        eval_args[["effectsize"]] = ef

        eval_fun = function() {
          do.call(eval_function, eval_args)
        }
        power_output =  capture_output(eval_fun)()
        if(!length_2_effect) {
          eval_errors[[counter]] = data.frame(effect=ef, trials = trial, seed = recorded_seed, err=attr(power_output,"err"))
          eval_warnings[[counter]] = data.frame(effect=ef, trials = trial, seed = recorded_seed, warn=attr(power_output,"warn"))
        } else {
          eval_errors[[counter]] = data.frame(effect_low=ef[1],effect_high=ef[2], trials = trial, seed = recorded_seed, err=attr(power_output,"err"))
          eval_warnings[[counter]] = data.frame(effect_low=ef[1],effect_high=ef[2], trials = trial, seed = recorded_seed, warn=attr(power_output,"warn"))
        }
        attr(power_output,"err") = NULL
        attr(power_output,"warn") = NULL
        if(all(is.na(power_output))) {
          power_output = data.frame(parameter = NA_character_, type = NA_character_, power = NA_real_)
        }
      } else {
        power_output = data.frame(parameter = NA_character_, type = NA_character_, power = NA_real_)
      }
      result_dataframe[[counter]] = power_output
      result_dataframe[[counter]]$trials = trial
      if(!"anticoef" %in% names(eval_args)) {
        if(length(ef) == 1) {
          result_dataframe[[counter]]$effectsize = ef
        } else {
          result_dataframe[[counter]]$effectsize_low = ef[1]
          result_dataframe[[counter]]$effectsize_high = ef[2]
        }
      }
      result_dataframe[[counter]]$random_seed = recorded_seed
      counter = counter + 1
    }
  }
  all_results = do.call("rbind",result_dataframe)
  all_results = all_results[!is.na(all_results$power),]
  attr(all_results, "gen_args") = gen_args
  attr(all_results, "eval_args") = eval_args
  attr(all_results, "gen_errors") = do.call("rbind", gen_errors)
  attr(all_results, "gen_warnings") = do.call("rbind", gen_warnings)
  attr(all_results, "eval_errors") = do.call("rbind", eval_errors)
  attr(all_results, "eval_warnings") = do.call("rbind", eval_warnings)
  class(all_results) = "data.frame"
  if(!(length(find.package("ggplot2", quiet = TRUE)) > 0) &&
     !(length(find.package("gridExtra", quiet = TRUE)) > 0) &&
     plot_results) {
    warning("{ggplot2} and {gridExtra} package required for plotting results")
    plot_results = FALSE
  }
  if(plot_results) {
    if(auto_scale) {
      ggscale_element = list(ggplot2::scale_y_continuous("Power", limits=c(0,1), breaks = y_breaks),
                             ggplot2::scale_x_continuous("Trials", breaks = x_breaks, expand=c(0,0)))
    } else {
      ggscale_element = list()
    }
    effect_results = all_results[all_results$type %in% c("effect.power","effect.power.mc"),]
    parameter_results = all_results[all_results$type %in% c("parameter.power","parameter.power.mc"),]
    effect_title = list(ggplot2::labs(title = "Effect Power"))
    parameter_title = list(ggplot2::labs(title = "Parameter Power"))
    color_scale_name = list(ggplot2::scale_color_discrete("Model Term"))

    if(grid_plot) {
      if(!length_2_effect) {
        effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
                ggplot2::facet_grid(effectsize~.,
                                    labeller = ggplot2::labeller(effectsize = ggplot2::label_both)) +
               ggscale_element + color_scale_name + effect_title

        parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
          ggplot2::facet_grid(effectsize~.,
                              labeller = ggplot2::labeller(effectsize = ggplot2::label_both)) +
          ggscale_element + color_scale_name + parameter_title

      } else {
        if(length(unique(all_results$effectsize_low)) == 1) {
          effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                  ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
                  ggplot2::facet_grid(effectsize_high~.,
                             labeller = ggplot2::labeller(effectsize_high = ggplot2::label_both,
                                                          type = ggplot2::label_value)) +
                 ggscale_element + color_scale_name + effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
            ggplot2::facet_grid(effectsize_high~.,
                                labeller = ggplot2::labeller(effectsize_high = ggplot2::label_both,
                                                             type = ggplot2::label_value)) +
            ggscale_element + color_scale_name + parameter_title

        } else if (length(unique(all_results$effectsize_low)) == 1) {
          effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                  ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
                  ggplot2::facet_grid(effectsize_low~.,
                             labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both)) +
                 ggscale_element + color_scale_name + effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
            ggplot2::facet_grid(effectsize_low~.,
                                labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both)) +
            ggscale_element + color_scale_name + parameter_title

        } else {
          effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                  ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
                  ggplot2::facet_grid(effectsize_low + effectsize_high~.,
                             labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both,
                                                          effectsize_high = ggplot2::label_both)) +
                 ggscale_element + color_scale_name + effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
            ggplot2::facet_grid(effectsize_low + effectsize_high~.,
                                labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both,
                                                             effectsize_high = ggplot2::label_both)) +
            ggscale_element + color_scale_name + parameter_title

        }
      }
    } else {
      if(only_effect) {
        if(!length_2_effect) {
          effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                  ggplot2::geom_line(ggplot2::aes(x=effectsize, y = power, color=parameter)) +
                 ggscale_element + color_scale_name + effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(x=effectsize, y = power, color=parameter)) +
            ggscale_element + color_scale_name + parameter_title
        } else {
          if(length(unique(all_results$effectsize_low)) == 1) {
            effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                    ggplot2::geom_line(ggplot2::aes(x=effectsize_high, y = power, color=parameter)) +
                   ggscale_element + color_scale_name + effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(x=effectsize_high, y = power, color=parameter)) +
              ggscale_element + color_scale_name + parameter_title

          } else if (length(unique(all_results$effectsize_high)) == 1) {
            effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                    ggplot2::geom_line(ggplot2::aes(x=effectsize_low, y = power, color=parameter)) +
                   ggscale_element + color_scale_name + effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(x=effectsize_low, y = power, color=parameter)) +
              ggscale_element + color_scale_name + parameter_title
          } else {
            effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                    ggplot2::geom_line(ggplot2::aes(x=effectsize_high, y = power, color=parameter)) +
                    ggplot2::facet_grid(effectsize_low~.,
                               labeller = ggplot2::labeller(effectsize_low = ggplot2::label_value)) +
                   ggscale_element + color_scale_name + effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(x=effectsize_high, y = power, color=parameter)) +
              ggplot2::facet_grid(effectsize_low~.,
                                  labeller = ggplot2::labeller(effectsize_low = ggplot2::label_value)) +
              ggscale_element + color_scale_name + parameter_title
          }
        }
      } else {
        effect_plot = ggplot2::ggplot(effect_results) + ggplot_elements +
                ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
               ggscale_element + color_scale_name + effect_title

        parameter_plot = ggplot2::ggplot(parameter_results) + ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(x=trials, y = power, color=parameter)) +
          ggscale_element + color_scale_name + parameter_title
      }
    }
    gridExtra::grid.arrange(effect_plot, parameter_plot, nrow=1)
  }

  class(all_results) = c("skpr_power_curve_output", "data.frame")
  attr(all_results,"output") = list("gen_errors"=attr(all_results, "gen_errors"),
                                    "gen_warnings"=attr(all_results, "gen_warnings"),
                                    "eval_errors"=attr(all_results, "eval_errors"),
                                    "eval_warnings"=attr(all_results, "eval_warnings"))

  return(all_results)
}

globalVariables(c("parameter", "effectsize_high", "effectsize_low"))

#'@title Get Power Curve Warnings and Errors
#'
#'@description Gets the warnings and errors from `calculate_power_curves()` output.
#'
#'@param power_curve The output from `calculate_power_curves()`
#'@return A list of data.frames containing warning/error information
#'
#'@export
#'@examples
#'#Generate sample
#'if(skpr:::run_documentation()) {
#'calculate_power_curves(trials=seq(50,150,by=20),
#'                       candidateset = expand.grid(x=c(-1,1),y=c(-1,1)),
#'                       model = ~.,
#'                       effectsize = list(c(0.5,0.9),c(0.6,0.9)),
#'                       eval_function = eval_design_mc,
#'                       eval_args = list(nsim = 100, glmfamily = "binomial"))
#'}
get_power_curve_output = function(power_curve) {
  stopifnot(inherits(power_curve,"skpr_power_curve_output"))
  curve_warn_error = attr(power_curve,"output")
  curve_warn_error$gen_errors    = curve_warn_error$gen_errors[curve_warn_error$gen_errors$err        != "",]
  curve_warn_error$gen_warnings  = curve_warn_error$gen_warnings[curve_warn_error$gen_warnings$warn   != "",]
  curve_warn_error$eval_errors   = curve_warn_error$eval_errors[curve_warn_error$eval_errors$err      != "",]
  curve_warn_error$eval_warnings = curve_warn_error$eval_warnings[curve_warn_error$eval_warnings$warn != "",]
  return(curve_warn_error)
}

Try the skpr package in your browser

Any scripts or data that you put into this service are public.

skpr documentation built on July 9, 2023, 7:23 p.m.