R/calculate_power_curves.R

Defines functions get_power_curve_output calculate_power_curves

Documented in calculate_power_curves get_power_curve_output

#' @title Calculate Power Curves
#'
#' @description Calculate and optionally plot power curves for different effect sizes and
#' trial counts. This function takes a
#'
#' @param trials A numeric vector indicating the trial(s) used when computing the power curve. If a single
#' value, this will be fixed and only `effectsize` will be varied.
#' @param effectsize  Default `1`. A numeric vector indicating the effect size(s) used when computing the power curve. If a single
#' value, this will be fixed and only `trials` will be varied. If using a length-2 effect size with `eval_design_mc()` (such as
#' a binomial probability interval), the effect size pairs can be input as entries in a list.
#' @param candidateset Default `NULL`. The candidate set (see `gen_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `gen_args`.
#' @param model Default `NULL`. The model (see `gen_design()` and `eval_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `gen_args`/`eval_args`.
#' @param alpha Default `0.05`. The allowable Type-I error rate (see `eval_design()` documentation for more information). Provided to aid code completion: can also
#' be provided in `eval_args`.
#' @param gen_args Default `list()`. A list of argument/value pairs to specify the design generation parameters for `gen_design()`.
#' @param eval_function Default `"eval_design"`. A string (or function) specifying the skpr power evaluation function.
#' Can also be `"eval_design_mc"`, `"eval_design_survival_mc"`, and `"eval_design_custom_mc"`.
#' @param eval_args Default `list()`. A list of argument/value pairs to specify the design power evaluation parameters for `eval_function`.
#' @param random_seed Default `123`. The random seed used to generate and then evaluate the design. The seed is set right before design generation.
#' @param iterate_seed Default `FALSE`. This will iterate the random seed with each new design. Set this to `TRUE` to add more variability to the design generation process.
#' @param plot_results Default `TRUE`. Whether to print out a plot of the power curves in addition to the data frame of results. Requires `ggplot2`.
#' @param auto_scale Default `TRUE`. Whether to automatically scale the y-axis to 0 and 1.
#' @param ggplot_elements Default `list()`. Extra `ggplot2` elements to customize the plot, passed in as elements in a list.
#' @param x_breaks Default `NULL`, automaticly generated by ggplot2.
#' @param y_breaks Default `seq(0,1,by=0.1)`. Y-axis breaks.
#'
#'@return A data.frame of power values with design generation information.
#'@export
#'@examples
#'if(skpr:::run_documentation()) {
#'cand_set = expand.grid(brew_temp = c(80, 85, 90),
#'                       altitude = c(0, 2000, 4000),
#'                       bean_sun = c("low", "partial", "high"))
#'#Plot power for a linear model with all interactions
#'calculate_power_curves(trials=seq(10,60,by=5),
#'                       candidateset = cand_set,
#'                       model = ~.*.,
#'                       alpha = 0.05,
#'                       effectsize = 1,
#'                       eval_function = "eval_design") |>
#'  head(30)
#'
#'}
#'if(skpr:::run_documentation()) {
#'#Add multiple effect sizes
#'calculate_power_curves(trials=seq(10,60,by=1),
#'                       candidateset = cand_set,
#'                       model = ~.*.,
#'                       alpha = 0.05,
#'                       effectsize = c(1,2),
#'                       eval_function = "eval_design") |>
#'  head(30)
#'}
#'if(skpr:::run_documentation()) {
#'#Generate power curve for a binomial model
#'calculate_power_curves(trials=seq(50,150,by=10),
#'                       candidateset = cand_set,
#'                       model = ~.,
#'                       effectsize = c(0.6,0.9),
#'                       eval_function = "eval_design_mc",
#'                       eval_args = list(nsim = 100, glmfamily = "binomial")) |>
#'  head(30)
#'}
#'if(skpr:::run_documentation()) {
#'#Generate power curve for a binomial model and multiple effect sizes
#'calculate_power_curves(trials=seq(50,150,by=10),
#'                       candidateset = cand_set,
#'                       model = ~.,
#'                       effectsize = list(c(0.5,0.9),c(0.6,0.9)),
#'                       eval_function = "eval_design_mc",
#'                       eval_args = list(nsim = 100, glmfamily = "binomial")) |>
#'  head(30)
#'}
calculate_power_curves = function(
  trials,
  effectsize = 1,
  candidateset = NULL,
  model = NULL,
  alpha = 0.05,
  gen_args = list(),
  eval_function = "eval_design",
  eval_args = list(),
  random_seed = 123,
  iterate_seed = FALSE,
  plot_results = TRUE,
  auto_scale = TRUE,
  x_breaks = NULL,
  y_breaks = seq(0, 1, by = 0.1),
  ggplot_elements = list()
) {
  if (!inherits(ggplot_elements, "list")) {
    stop("`ggplot_elements` must be a list of ggplot2 objects")
  }
  eval_function = as.character(substitute(eval_function))

  if (
    !eval_function %in%
      c(
        "eval_design",
        "eval_design_mc",
        "eval_design_survival_mc",
        "eval_design_custom_mc"
      )
  ) {
    stop(
      "skpr: eval_function `",
      eval_function,
      "` not recognized as one of `eval_design`, `eval_design_mc`, `eval_design_survival_mc`, or `eval_design_custom_mc`"
    )
  }
  if (!inherits(eval_args, "list") || !inherits(gen_args, "list")) {
    stop("skpr: Both `eval_args` and `gen_args` must be of class `list`")
  }
  if (!is.null(eval_args[["glmfamily"]])) {
    if (eval_args[["glmfamily"]] == "gaussian") {
      length_2_effect = FALSE
    } else {
      length_2_effect = TRUE
    }
  } else {
    length_2_effect = FALSE
  }
  calc_effect = TRUE
  if (!is.null(eval_args[["calceffect"]])) {
    calc_effect = eval_args[["calceffect"]]
  }

  if (
    length_2_effect &&
      inherits(effectsize, "list") &&
      !all(unlist(lapply(effectsize, length)) == 2)
  ) {
    stop(
      "skpr: If passing in a list of effectsizes for evaluation functions that require two values, all effect sizes must be length-2 vectors in the list"
    )
  }
  if (length_2_effect && length(effectsize) == 2 && is.numeric(effectsize)) {
    effectsize = list(effectsize)
  }
  if (length(effectsize) == 1) {
    only_trials = TRUE
  } else {
    only_trials = FALSE
  }
  if (length(trials) == 1) {
    only_effect = TRUE
  } else {
    only_effect = FALSE
  }
  if (length(trials) > 1 && length(effectsize) > 1) {
    grid_plot = TRUE
  } else {
    grid_plot = FALSE
  }
  n_designs = length(trials) * length(effectsize)
  pb = progress::progress_bar$new(
    format = "Calculating Power Curve :current/:total [:bar] ETA::eta",
    total = n_designs,
    width = 120
  )
  if (any(trials < 2)) {
    trials = trials[trials > 1]
    warning(
      "skpr: removing trials less than 2 runs--those values can never result in a valid design."
    )
  }
  if (!is.null(candidateset)) {
    if (!is.null(gen_args[["candidateset"]])) {
      gen_args[["candidateset"]] = NULL
    }
    gen_args = append(gen_args, list(candidateset = candidateset))
  }
  if (!is.null(model)) {
    if (!is.null(gen_args[["model"]])) {
      gen_args[["model"]] = NULL
    }
    gen_args = append(gen_args, list(model = model))
  }
  if (!is.null(model)) {
    if (!is.null(eval_args[["model"]])) {
      eval_args[["model"]] = NULL
    }
    eval_args = append(eval_args, list(model = model))
  }
  if (!is.null(alpha)) {
    if (!is.null(eval_args[["alpha"]])) {
      eval_args[["alpha"]] = NULL
    }
    eval_args = append(eval_args, list(alpha = alpha))
  }
  if (!is.null(effectsize)) {
    if (!is.null(eval_args[["effectsize"]])) {
      eval_args[["effectsize"]] = NULL
    }
    eval_args = append(eval_args, list(effectsize = effectsize))
  }
  gen_args$progress = FALSE
  if (
    eval_function %in%
      c("eval_design_mc", "eval_design_survival_mc", "eval_design_custom_mc")
  ) {
    eval_args$progress = FALSE
  }
  #Capture the output from the simulation to provide a table of errors
  capture_output = function(fun) {
    function(...) {
      warn = err = ""
      res = withCallingHandlers(
        tryCatch(fun(...), error = function(e) {
          err <<- conditionMessage(e)
          NA
        }),
        warning = function(w) {
          warn <<- append(warn, conditionMessage(w))
          invokeRestart("muffleWarning")
        }
      )
      attr(res, "warn") = warn
      attr(res, "err") = err
      return(res)
    }
  }
  result_dataframe = list()
  counter = 1
  gen_errors = list()
  gen_warnings = list()
  eval_errors = list()
  eval_warnings = list()
  for (ef in effectsize) {
    for (trial in trials) {
      pb$tick()
      if (iterate_seed) {
        recorded_seed = random_seed + counter
      } else {
        recorded_seed = random_seed
      }
      set.seed(recorded_seed)
      gen_args["trials"] = trial

      gen_fun = function() {
        do.call("gen_design", gen_args)
      }
      temp_design = capture_output(gen_fun)()

      gen_errors[[counter]] = data.frame(
        trials = trial,
        seed = recorded_seed,
        err = attr(temp_design, "err")
      )
      gen_warnings[[counter]] = data.frame(
        trials = trial,
        seed = recorded_seed,
        warn = attr(temp_design, "warn")
      )
      if (all(!is.na(temp_design))) {
        attr(temp_design, "err") = NULL
        attr(temp_design, "warn") = NULL

        if (!is.null(eval_args[["design"]])) {
          eval_args[["design"]] = NULL
        }

        eval_args[["design"]] = temp_design
        eval_args[["effectsize"]] = ef

        eval_fun = function() {
          do.call(eval_function, eval_args)
        }
        power_output = capture_output(eval_fun)()
        if (!length_2_effect) {
          eval_errors[[counter]] = data.frame(
            effect = ef,
            trials = trial,
            seed = recorded_seed,
            err = attr(power_output, "err")
          )
          eval_warnings[[counter]] = data.frame(
            effect = ef,
            trials = trial,
            seed = recorded_seed,
            warn = attr(power_output, "warn")
          )
        } else {
          eval_errors[[counter]] = data.frame(
            effect_low = ef[1],
            effect_high = ef[2],
            trials = trial,
            seed = recorded_seed,
            err = attr(power_output, "err")
          )
          eval_warnings[[counter]] = data.frame(
            effect_low = ef[1],
            effect_high = ef[2],
            trials = trial,
            seed = recorded_seed,
            warn = attr(power_output, "warn")
          )
        }
        attr(power_output, "err") = NULL
        attr(power_output, "warn") = NULL
        if (all(is.na(power_output))) {
          power_output = data.frame(
            parameter = NA_character_,
            type = NA_character_,
            power = NA_real_
          )
        }
      } else {
        power_output = data.frame(
          parameter = NA_character_,
          type = NA_character_,
          power = NA_real_
        )
      }
      result_dataframe[[counter]] = power_output
      result_dataframe[[counter]]$trials = trial
      if (!"anticoef" %in% names(eval_args)) {
        if (length(ef) == 1) {
          result_dataframe[[counter]]$effectsize = ef
        } else {
          result_dataframe[[counter]]$effectsize_low = ef[1]
          result_dataframe[[counter]]$effectsize_high = ef[2]
        }
      }
      result_dataframe[[counter]]$random_seed = recorded_seed
      counter = counter + 1
    }
  }
  all_results = do.call("rbind", result_dataframe)
  all_results = all_results[!is.na(all_results$power), ]
  attr(all_results, "gen_args") = gen_args
  attr(all_results, "eval_args") = eval_args
  attr(all_results, "gen_errors") = do.call("rbind", gen_errors)
  attr(all_results, "gen_warnings") = do.call("rbind", gen_warnings)
  attr(all_results, "eval_errors") = do.call("rbind", eval_errors)
  attr(all_results, "eval_warnings") = do.call("rbind", eval_warnings)
  class(all_results) = "data.frame"
  if (
    !(length(find.package("ggplot2", quiet = TRUE)) > 0) &&
      !(length(find.package("gridExtra", quiet = TRUE)) > 0) &&
      plot_results
  ) {
    warning("{ggplot2} and {gridExtra} package required for plotting results")
    plot_results = FALSE
  }
  if (plot_results) {
    if (auto_scale) {
      if (!is.null(x_breaks)) {
        x_breaks_gg = x_breaks
      } else {
        x_breaks_gg = ggplot2::waiver()
      }
      ggscale_element = list(
        ggplot2::scale_y_continuous(
          "Power",
          limits = c(0, 1),
          breaks = y_breaks
        ),
        ggplot2::scale_x_continuous(
          "Trials",
          breaks = x_breaks_gg,
          expand = c(0, 0)
        )
      )
    } else {
      ggscale_element = list()
    }
    effect_results = all_results[
      all_results$type %in% c("effect.power", "effect.power.mc"),
    ]
    parameter_results = all_results[
      all_results$type %in% c("parameter.power", "parameter.power.mc"),
    ]
    effect_title = list(ggplot2::labs(title = "Effect Power"))
    parameter_title = list(ggplot2::labs(title = "Parameter Power"))
    color_scale_name = list(ggplot2::scale_color_discrete("Model Term"))

    if (grid_plot) {
      if (!length_2_effect) {
        effect_plot = ggplot2::ggplot(effect_results) +
          ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(
            x = trials,
            y = power,
            color = parameter
          )) +
          ggplot2::facet_grid(
            effectsize ~ .,
            labeller = ggplot2::labeller(effectsize = ggplot2::label_both)
          ) +
          ggscale_element +
          color_scale_name +
          effect_title

        parameter_plot = ggplot2::ggplot(parameter_results) +
          ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(
            x = trials,
            y = power,
            color = parameter
          )) +
          ggplot2::facet_grid(
            effectsize ~ .,
            labeller = ggplot2::labeller(effectsize = ggplot2::label_both)
          ) +
          ggscale_element +
          color_scale_name +
          parameter_title
      } else {
        if (length(unique(all_results$effectsize_low)) == 1) {
          effect_plot = ggplot2::ggplot(effect_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_high ~ .,
              labeller = ggplot2::labeller(
                effectsize_high = ggplot2::label_both,
                type = ggplot2::label_value
              )
            ) +
            ggscale_element +
            color_scale_name +
            effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_high ~ .,
              labeller = ggplot2::labeller(
                effectsize_high = ggplot2::label_both,
                type = ggplot2::label_value
              )
            ) +
            ggscale_element +
            color_scale_name +
            parameter_title
        } else if (length(unique(all_results$effectsize_low)) == 1) {
          effect_plot = ggplot2::ggplot(effect_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_low ~ .,
              labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both)
            ) +
            ggscale_element +
            color_scale_name +
            effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_low ~ .,
              labeller = ggplot2::labeller(effectsize_low = ggplot2::label_both)
            ) +
            ggscale_element +
            color_scale_name +
            parameter_title
        } else {
          effect_plot = ggplot2::ggplot(effect_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_low + effectsize_high ~ .,
              labeller = ggplot2::labeller(
                effectsize_low = ggplot2::label_both,
                effectsize_high = ggplot2::label_both
              )
            ) +
            ggscale_element +
            color_scale_name +
            effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = trials,
              y = power,
              color = parameter
            )) +
            ggplot2::facet_grid(
              effectsize_low + effectsize_high ~ .,
              labeller = ggplot2::labeller(
                effectsize_low = ggplot2::label_both,
                effectsize_high = ggplot2::label_both
              )
            ) +
            ggscale_element +
            color_scale_name +
            parameter_title
        }
      }
    } else {
      if (only_effect) {
        if (!length_2_effect) {
          effect_plot = ggplot2::ggplot(effect_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = effectsize,
              y = power,
              color = parameter
            )) +
            ggscale_element +
            color_scale_name +
            effect_title

          parameter_plot = ggplot2::ggplot(parameter_results) +
            ggplot_elements +
            ggplot2::geom_line(ggplot2::aes(
              x = effectsize,
              y = power,
              color = parameter
            )) +
            ggscale_element +
            color_scale_name +
            parameter_title
        } else {
          if (length(unique(all_results$effectsize_low)) == 1) {
            effect_plot = ggplot2::ggplot(effect_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_high,
                y = power,
                color = parameter
              )) +
              ggscale_element +
              color_scale_name +
              effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_high,
                y = power,
                color = parameter
              )) +
              ggscale_element +
              color_scale_name +
              parameter_title
          } else if (length(unique(all_results$effectsize_high)) == 1) {
            effect_plot = ggplot2::ggplot(effect_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_low,
                y = power,
                color = parameter
              )) +
              ggscale_element +
              color_scale_name +
              effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_low,
                y = power,
                color = parameter
              )) +
              ggscale_element +
              color_scale_name +
              parameter_title
          } else {
            effect_plot = ggplot2::ggplot(effect_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_high,
                y = power,
                color = parameter
              )) +
              ggplot2::facet_grid(
                effectsize_low ~ .,
                labeller = ggplot2::labeller(
                  effectsize_low = ggplot2::label_value
                )
              ) +
              ggscale_element +
              color_scale_name +
              effect_title

            parameter_plot = ggplot2::ggplot(parameter_results) +
              ggplot_elements +
              ggplot2::geom_line(ggplot2::aes(
                x = effectsize_high,
                y = power,
                color = parameter
              )) +
              ggplot2::facet_grid(
                effectsize_low ~ .,
                labeller = ggplot2::labeller(
                  effectsize_low = ggplot2::label_value
                )
              ) +
              ggscale_element +
              color_scale_name +
              parameter_title
          }
        }
      } else {
        effect_plot = ggplot2::ggplot(effect_results) +
          ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(
            x = trials,
            y = power,
            color = parameter
          )) +
          ggscale_element +
          color_scale_name +
          effect_title

        parameter_plot = ggplot2::ggplot(parameter_results) +
          ggplot_elements +
          ggplot2::geom_line(ggplot2::aes(
            x = trials,
            y = power,
            color = parameter
          )) +
          ggscale_element +
          color_scale_name +
          parameter_title
      }
    }
    if (calc_effect) {
      gridExtra::grid.arrange(effect_plot, parameter_plot, nrow = 1)
    } else {
      gridExtra::grid.arrange(parameter_plot, nrow = 1)
    }
  }

  class(all_results) = c("skpr_power_curve_output", "data.frame")
  attr(all_results, "output") = list(
    "gen_errors" = attr(all_results, "gen_errors"),
    "gen_warnings" = attr(all_results, "gen_warnings"),
    "eval_errors" = attr(all_results, "eval_errors"),
    "eval_warnings" = attr(all_results, "eval_warnings")
  )

  return(all_results)
}

globalVariables(c("parameter", "effectsize_high", "effectsize_low"))

#'@title Get Power Curve Warnings and Errors
#'
#'@description Gets the warnings and errors from `calculate_power_curves()` output.
#'
#'@param power_curve The output from `calculate_power_curves()`
#'@return A list of data.frames containing warning/error information
#'
#'@export
#'@examples
#'#Generate sample
#'if(skpr:::run_documentation()) {
#'calculate_power_curves(trials=seq(50,150,by=20),
#'                       candidateset = expand.grid(x=c(-1,1),y=c(-1,1)),
#'                       model = ~.,
#'                       effectsize = list(c(0.5,0.9),c(0.6,0.9)),
#'                       eval_function = eval_design_mc,
#'                       eval_args = list(nsim = 100, glmfamily = "binomial"))
#'}
get_power_curve_output = function(power_curve) {
  stopifnot(inherits(power_curve, "skpr_power_curve_output"))
  curve_warn_error = attr(power_curve, "output")
  curve_warn_error$gen_errors = curve_warn_error$gen_errors[
    curve_warn_error$gen_errors$err != "",
  ]
  curve_warn_error$gen_warnings = curve_warn_error$gen_warnings[
    curve_warn_error$gen_warnings$warn != "",
  ]
  curve_warn_error$eval_errors = curve_warn_error$eval_errors[
    curve_warn_error$eval_errors$err != "",
  ]
  curve_warn_error$eval_warnings = curve_warn_error$eval_warnings[
    curve_warn_error$eval_warnings$warn != "",
  ]
  return(curve_warn_error)
}
tylermorganwall/skpr documentation built on April 13, 2025, 5:35 p.m.