R/trunc_gamma_para.R

Defines functions trunc_gamma_para

Documented in trunc_gamma_para

# At the top of an existing R script
if (getRversion() >= "2.15.1") {
  utils::globalVariables(c("L", "U", "i"))
}
#' Estimate Shape and Scale Parameters for Truncated Gamma Distribution
#'
#' This function estimates the shape and scale parameters of a truncated Gamma distribution
#' based on expert-provided summary statistics, including the mean, median, standard deviation,
#' and selected quantiles. Parameter estimation is performed using a grid search combined with
#' weighted least squares optimization. Parallel computing is employed to accelerate the estimation process.
#' @import doParallel
#' @importFrom foreach foreach %dopar%
#' @importFrom truncdist rtrunc
#' @importFrom stats qnorm median sd quantile
#' @importFrom utils tail
#' @importFrom parallel makeCluster clusterSetRNGStream stopCluster
#' @importFrom invgamma rinvgamma
#' @param L Numeric. Lower bound of the truncated Gamma distribution.
#' @param U Numeric. Upper bound of the truncated Gamma distribution.
#' @param expert_data A list of named lists, where each inner list represents one expert's input. Each expert can provide any subset of the following named elements:
#'   \describe{
#'     \item{\code{mean}}{Numeric. The expected mean of the distribution.}
#'     \item{\code{median}}{Numeric. The expected median of the distribution.}
#'     \item{\code{sd}}{Numeric. The expected standard deviation of the distribution.}
#'     \item{\code{q25}}{Numeric. The 2.5th percentile.}
#'     \item{\code{q975}}{Numeric. The 97.5th percentile.}
#'   }
#' @param weights Numeric vector of length 5. Specifies the relative importance of each summary statistic in the optimization procedure. The order corresponds to:
#'   \code{c(mean, median, sd, q25, q975)}. Default is \code{c(10, 10, 2, 1, 1)}.
#' @param num_cores Integer. Number of CPU cores to use for parallel computation. Default is \code{4}.
#' @param seed Optional integer. If provided, sets the seed for reproducibility.
#' @return A list with the following components:
#' \describe{
#'   \item{shape}{Numeric. Estimated shape parameter of the Gamma distribution.}
#'   \item{scale}{Numeric. Estimated scale parameter of the Gamma distribution.}
#' }
#' @examples
#' # Define expert-provided summary data
#' expert_data_correct <- list(
#'   list(mean = 2.2, median = 2.27, sd = NULL, q25 = NULL, q975 = NULL),  # Expert A
#'   list(mean = 2.1, median = 2.3,  sd = NULL, q25 = NULL, q975 = NULL),  # Expert B
#'   list(mean = NULL, median = 2.31, sd = NULL, q25 = NULL, q975 = NULL)  # Expert C
#' )
#' \donttest{
#'   # Estimate parameters using truncated gamma prior
#'   trunc_gamma_para(L = 2,U = 2.5,expert_data = expert_data_correct,num_cores = 4)
#' }
#' @export
trunc_gamma_para<- function(L, U, expert_data, weights = c(10,10,2,1,1), num_cores = 4,seed=NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }
  #check the dataset
  if (!is.list(expert_data) || !all(sapply(expert_data, is.list))) {
    stop("Error: expert_data should be a list of lists, where each list contains only 'mean', 'median', 'sd', 'q25', and 'q975'.")
  }

  valid_keys <- c("mean", "median", "sd", "q25", "q975")

  for (expert in expert_data) {
    if (!all(names(expert) %in% valid_keys)) {
      stop("Error: Each expert's data should only contain 'mean', 'median', 'sd', 'q25', and 'q975'.")
    }

    if (!all(sapply(expert, function(x) is.numeric(x) || is.null(x)))) {
      stop("Error: All values in expert_data should be numeric or NULL.")
    }
  }

  # grid search
  shape_grid <- seq(0.1, 20, length.out = 40)
  rate_grid  <- seq(0.1, 10, length.out = 40)
  param_grid <- expand.grid(shape = shape_grid, rate = rate_grid)  # all the combination

  # parallel computing
  cl <- makeCluster(num_cores)
  clusterSetRNGStream(cl, iseed = seed)
  registerDoParallel(cl)

  # parallel computing
  results <- foreach(i = 1:nrow(param_grid), .combine = rbind, .packages = "truncdist",
                     .export = c("compute_wls_error")) %dopar% {
                       if (!is.null(seed)) {
                         set.seed(seed+i)# fix seed for each i
                       }
                       shape <- param_grid$shape[i]
                       rate <- param_grid$rate[i]

                       # truncated gamma dist
                       sim_data <- rtrunc(1000, "gamma", a = L, b = U, shape = shape, rate = rate)

                       # sum of the error
                       total_wls_error <- sum(sapply(expert_data, function(expert) {
                         compute_wls_error(sim_data, expert, weights)
                       }))

                       c(shape, rate, total_wls_error)
                     }

  stopCluster(cl)  #

  #  **find the optimal para **
  best_index <- which.min(results[,3])
  best_params <- results[best_index, 1:2]

  return(list(shape = best_params[1], scale = 1/best_params[2]))
}

Try the DTEBOP2 package in your browser

Any scripts or data that you put into this service are public.

DTEBOP2 documentation built on June 8, 2025, 1:24 p.m.