R/model-util.R

Defines functions get_response_units get_response_variable get_treatment_units get_treatment_variable prepare_prior eval_init eval_init_one_param rstan_default_init prepare_init

Documented in eval_init eval_init_one_param get_response_units get_response_variable get_treatment_units get_treatment_variable prepare_init prepare_prior rstan_default_init

#' Helper Function to Prepare an Init for a [brms] Model
#'
#' @param init `function` returning an `numeric` `array` of length `1` or a
#'   `numeric` value.
#' @returns `function` returning a `numeric` `array` of length `1`.
#'
#' @export
prepare_init <- function(init) {
  if (inherits(init, "function")) {
    # init is a function, check that it returns a numeric array of dimension 1
    x <- init()
    assertthat::assert_that(inherits(x, "array"))
    assertthat::assert_that(dim(x) == 1)
    assertthat::assert_that(is.numeric(x))
    init_fn <- init
  } else if (is.numeric(init)) {
    init_fn <- function() {
      as.array(init)
    }
  } else {
    stop(paste0(
      "Initialization should either be a function that returns an array of
      length one or numeric"))
  }
  init_fn
}

#' Helper Function to Create the Default `rstan` scalar init
#'
#' @description By default, \pkg{rstan} will initialize parameters uniformly at
#' random in the range (-2, 2), on the unconstrained scale. Description of how
#' stan transforms parameters to satisfy constraints is described in the
#' [stan documentation](https://mc-stan.org/docs/reference-manual/variable-transforms.html)
#'
#' This helper is especially useful for running models using the \pkg{cmdstanr}
#' backend, which requires all parameters (including distributional) parameters
#' to be initialized.
#'
#' @param lb `numeric` lower bound for parameter
#' @param ub `numeric` upper bound for parameter
#' @param dim `numeric` dimension of parameter.
#' @returns function that return the default \pkg{brms} initial value for a parameter.
#'   If `dim=0`, then it will be a numeric scalar, if `dim=1` or greater, than
#'   return an array with the given dimension.
#'
#' @seealso [rstan::stan]
#'
#' @examples
#' \dontrun{
#'
#' # Explicitly set the default initialization for the distributional parameter
#' # 'sigma' when family=gaussian().
#' init <- BayesPharma::sigmoid_antagonist_init(
#'   sigma = BayesPharma::rstan_default_scalar_init(lb = 0))
#'
#' }
#'
#' @export
rstan_default_init <- function(lb = NULL, ub = NULL, dim = 0) {
  if (dim == 0) {
    if (!is.null(lb) && !is.null(ub)) {
      # https://mc-stan.org/docs/reference-manual/logit-transform-jacobian.html
      return(\() brms::inv_logit_scaled(
        x = stats::runif(1, min = -2, max = 2),
        lb = lb,
        ub = ub))
    } else if (!is.null(lb)) {
      # https://mc-stan.org/docs/reference-manual/lower-bound-transform.html
      return(
        \() exp(stats::runif(1, min = -2, max = 2)) + lb)
    } else if (!is.null(ub)) {
      # https://mc-stan.org/docs/reference-manual/upper-bounded-scalar.html
      return(
        \() ub - exp(stats::runif(1, min = -2, max = 2)))
    } else {
      return(
        \() stats::runif(1, min = -2, max = 2))
    }
  } else {
    if (!is.null(lb) && !is.null(ub)) {
      # https://mc-stan.org/docs/reference-manual/logit-transform-jacobian.html
      return(\() array(
        brms::inv_logit_scaled(
          x = stats::runif(1, min = -2, max = 2),
          lb = lb,
          ub = ub),
        dim = dim))
    } else if (!is.null(lb)) {
      # https://mc-stan.org/docs/reference-manual/lower-bound-transform.html
      return(
        \() array(exp(stats::runif(1, min = -2, max = 2)) + lb, dim = dim))
    } else if (!is.null(ub)) {
      # https://mc-stan.org/docs/reference-manual/upper-bounded-scalar.html
      return(
        \() array(ub - exp(stats::runif(1, min = -2, max = 2)), dim = dim))
    } else {
      return(
        \() array(stats::runif(1, min = -2, max = 2), dim = dim))
    }
  }
}

#' Helper function for `eval_init_one_param`
#'
#' @param param_init `array`, `numeric`, or `function` to initialize parameter
#' @param param_name character parameter name defined in the stan model
#' @param sdata `list` generated by Stan for the model data info
#' @returns `numeric` or `array`. If the parameter is of the form `b_<param>`
#'   then \pkg{brms} expects an array of dimension `sdata$K_<param>`, otherwise
#'   return a `numeric` value.
eval_init_one_param <- function(param_init, param_name, sdata) {
  if (isa(param_init, "array")) {
    param_init
  } else {
    if (param_name |> stringr::str_detect("^b")) {
      param_name_K <- param_name |> stringr::str_replace("^b_", "K_")
      if (!(param_name_K %in% names(sdata))) {
        cat(
          "Unable to initialize parameter '", param_name, "'.",
          "It starts with 'b_', so assumed to be a global parameter ",
          "value but in the stan data, doesn't have a record for ",
          "the dimension '", param_name_K, "'\n", sep = "")
      } else {
        K <- sdata[[param_name_K]]
      }
      if (is.numeric(param_init)) {
        rep(x = param_init, times = K)
      } else if (isa(param_init, "function")) {
        purrr::map(seq_len(K), ~param_init()) |> unlist() |> array()
      } else {
        stop(
          "For initializing parameter '", param_name, "', ",
          "unreconized class: '", class(param_init), "', expected ",
          "one of [array, numeric, function]")
      }

    } else {
      # e.g. the parameter is 'sigma' and not a 'b_' parameter
      if (is.numeric(param_init)) {
        param_init
      } else if (isa(param_init, "function")) {
        param_init()
      } else {
        stop(
          "For initializing parameter '", param_name, "', ",
          "unreconized class: '", class(param_init), "', expected ",
          "one of [numeric, function]")
      }
    }
  }
}

#' Evaluate an init
#'
#' @description How \pkg{brms} models can be initialized depends on the backend.
#' The method that all backends supports is as a list (one for each chain) of
#' lists (one for each variable) with numeric values. Since this requires
#' knowing how many chains are being run, which may not be available when the
#' model is being defined, and to support random initialization, the \pkg{rstan}
#' backend also supports initialization as a function returning a list of
#' functions (one for each parameter) returning a numeric array of length 1.
#' Also, to support the common use-case of initializing everything to zero or
#' randomly in the range (-2, 2) on the unconstrained scale, \pkg{rstan} also
#' supports initializing with `0` and `"random"`.
#'
#' To make \pkg{BayesPharma} more backend agnostic, this helper function takes
#' the an init and the number of chains and reduces it to the list of list
#' format.
#'
#' @param init One of
#'   * `NULL`, `numeric`, `character` in which case use the default \pkg{rstan}
#'     init.
#'   * named `list` with one element for each parameter. The values can be
#'     either `array`, `numeric`, or a function returning a `numeric` value
#' @param sdata result of running [brms::make_standata()], in particular it
#'   it should be list having elements `K_<parameter_name>` for each parameter
#'   in the model. Where the value of these elements is the dimension of the
#'   parameter.
#' @param algorithm `character` string naming the estimation approach to use.
#'   see `brms::brm` for details. This is needed here because some algorithms
#'   can be run in parallel from different initialization points, which affects
#'   the dimension of the initial values if the `chains` parameter is NULL.
#' @param chains `numeric` number of chains for which to initialize
#' @returns `list` of `list` form of model initialization
#'
#' @export
eval_init <- function(
    init,
    sdata = NULL,
    algorithm = "sampling",
    chains = 4) {
  
  if (is.null(algorithm)) {
    if (!is.null(options()$brms.algorithm)) {
      algorithm <- options("brms.algorithm")
    } else {
      algorithm <- "sampling"
    }
  }
  
  if (is.null(chains)) {
    if (algorithm %in% c("sampling", "pathfinder")) {
      chains <- 4
    } else if (algorithm %in% c("meanfield", "fullrank", "fixed_param")) {
      chains <- 1
    } else {
      chains <- 1
      warning(paste0(
        "Unrecognized algorithm '", algorithm, "' in trying to ",
        "number of chains to initialize"))
    }
  }
  if (is.null(init) || is.numeric(init) || is.character(init)) {
    # not sure if cmdstanr can support this type of initialization or not...
    init
  } else if ("bpinit" %in% class(init)) {
    purrr::map(
      .x = 1:chains,
      .f = \(chain_id) {
        purrr::imap(
          .x = init,
          .f = eval_init_one_param,
          sdata)
      })
  } else if (is.list(init)) {
    if (length(init) != chains) {
      stop(paste0(
        "The input init is a list with '", length(init), "' elements. It ",
        "should have instead '", chains, "' to match the number of chains.",
        sep = ""))
    }
    init
  } else {
    stop(
      "Unrecognized class for init ['", paste0(class(init), sep = "','"),
      "'], expected one of [NULL, numeric, character, bpinit, list]")
  }
}


#' Helper Function to Prepare a Prior for a \pkg{brms} Model
#'
#' @description This extends [brms::prior()] by
#'
#'   1) allowing just taking a `numeric` value rather than `constant(<value>)`
#'      to specify a constant prior
#'   2) if [brms::brmsprior] is given, it checks that it has the specified
#'      arguments
#'
#' This is used in building \pkg{BayesPharma} models to allow user specified priors
#' but make sure they are for the right parameters to make sure the model is
#' well specified.
#'
#' @param prior [brms::brmsprior] or `numeric`.
#' @param ... additional arguments to [brms::prior_string()]. If `prior` is a
#'   [brms::brmsprior] then this will check that the slots have the given
#'   values. If prior is `numeric`, then these arguments are passed to
#'   [brms::prior_string()]
#'
#'
#' @examples
#' \dontrun{
#'   # user should specify a prior for hill, but they misspell it:
#'   user_hill_prior <- brms::prior(
#'     prior = normal(1, 1),
#'     nlpar = "hilll",
#'     ub = 0)
#'
#'   # in a script where we want to validate the user_hill_prior
#'   hill_prior <- BayesPharma:::prepare_prior(
#'     prior = user_hill_prior,
#'     nlpar = "hill")
#'
#'   # gives an assert error that nlpar is not set correctly
#'
#' }
#'
#' @returns [brms::brmsprior]
prepare_prior <- function(prior, ...) {
  if (inherits(prior, "brmsprior")) {
    args <- list(...)
    for (arg in names(args)) {
      assertthat::assert_that(
        prior[[arg]] == args[[arg]],
        msg = paste0(
          "In the given prior, the field '", arg, "' is expected to be '",
          args[[arg]], "', but instead it is '", prior[[arg]], "'"))
    }
  } else if (is.numeric(prior)) {
    if (prior == -Inf) {
      prior <- brms::prior_string("constant(negative_infinity())", ...)
    } else if (prior == Inf) {
      prior <- brms::prior_string("constant(infinity())", ...)
    } else {
      prior <- brms::prior_string(
        prior = paste0("constant(", prior, ")"),
        ...)
    }
  } else {
    stop("prior must be a brms::prior(...) or numeric value")
  }
  prior
}

#' Get the Treatment Variable from a BayesPharma Model
#'
#' @param model `bpfit` object resulting from fitting a model with one of the
#'   model functions from the [BayesPharma] package.
#'
#' @returns `character` with the treatment variable
#'
#' @export
get_treatment_variable <- function(model) {
  if (!inherits(model, "bpfit")) {
    warning(paste0(
      "plot_posterior_draws expects model to be of class 'bpfit',",
      " instead it is of class ", class(model)))
  }

  treatment_variable <- model$bayes_pharma_info$formula_info$treatment_variable

  if (is.null(treatment_variable)) {
    stop(paste0(
      "Expected treatment_variable to be defined in the ",
      "model$bayes_pharma_info"))
  }

  if (!(treatment_variable %in% names(model$data))) {
    stop(paste0(
      "Expected the treatment variable '", treatment_variable, "' to be a ",
      "column in the model$data, but instead it has columns ",
      "[", paste0(names(model$data), collapse = ", "), "]"))
  }

  treatment_variable
}

#' Get the Treatment Variable Units from a BayesPharma Model
#'
#' @param model `bpfit` object resulting from fitting a model with one of the
#'   model functions from the [BayesPharma] package.
#'
#' @returns `character` with the treatment variable units
#'
#' @export
get_treatment_units <- function(model) {
  if (!inherits(model, "bpfit")) {
    warning(paste0(
      "plot_posterior_draws expects model to be of class 'bpfit',",
      " instead it is of class ", class(model)))
  }

  treatment_units <- model$bayes_pharma_info$formula_info$treatment_units

  if (is.null(treatment_units)) {
    stop(paste0(
      "Expected treatment_units to be defined in the ",
      "model$bayes_pharma_info"))
  }

  treatment_units
}

#' Get the Response Variable from a BayesPharma Model
#'
#' @param model `bpfit` object resulting from fitting a model with one of the
#'   model functions from the [BayesPharma] package.
#'
#' @returns `character` with the response variable
#'
#' @export
get_response_variable <- function(model) {
  if (!inherits(model, "bpfit")) {
    warning(paste0(
      "plot_posterior_draws expects model to be of class 'bpfit',",
      " instead it is of class ", class(model)))
  }

  response_variable <- model$bayes_pharma_info$formula_info$response_variable

  if (is.null(response_variable)) {
    stop(paste0(
      "Expected response_variable to be defined in the ",
      "model$bayes_pharma_info"))
  }

  if (!(response_variable %in% names(model$data))) {
    stop(paste0(
      "Expected the response variable '", response_variable, "' to be a ",
      "column in the model$data, but instead it has columns ",
      "[", paste0(names(model$data), collapse = ", "), "]"))
  }

  response_variable
}

#' Get the Response Variable Units from a BayesPharma Model
#'
#' @param model `bpfit` object resulting from fitting a model with one of the
#'   model functions from the [BayesPharma] package.
#'
#' @returns `character` with the response variable units
#'
#' @export
get_response_units <- function(model) {
  if (!inherits(model, "bpfit")) {
    warning(paste0(
      "plot_posterior_draws expects model to be of class 'bpfit',",
      " instead it is of class ", class(model)))
  }

  model$bayes_pharma_info$formula_info$response_units
}
maomlab/BayesPharma documentation built on Aug. 24, 2024, 8:45 a.m.