R/stan-prior.R

Defines functions stan_lpdf_name stan_extract_bounds stan_is_constant_prior stan_rngprior stan_unchecked_prior stan_special_prior_local stan_special_prior_global stan_constant_prior stan_target_prior stan_base_prior stan_prior

# unless otherwise specified, functions return a single character 
# string defining the likelihood of the model in Stan language

# Define priors for parameters in Stan language
# @param prior an object of class 'brmsprior'
# @param class the parameter class
# @param coef the coefficients of this class
# @param group the name of a grouping factor
# @param type Stan type used in the definition of the parameter
#   if type is empty the parameter is not initialized inside 'stan_prior'
# @param dim stan array dimension to be specified after the parameter name
#   cannot be merged with 'suffix' as the latter should apply to 
#   individual coefficients while 'dim' should not
#   TODO: decide whether to support arrays for parameters at all
#   an alternative would be to specify elements directly as parameters
# @param coef_type Stan type used in the definition of individual parameter
#   coefficients; only relevant when mixing estimated and fixed coefficients
# @param prefix a prefix to put at the parameter class
# @param suffix a suffix to put at the parameter class
# @param broadcast Stan type to which the prior should be broadcasted 
#   in order to handle vectorized prior statements
#   supported values are 'vector' or 'matrix'
# @param comment character string containing a comment for the parameter
# @param px list or data.frame after which to subset 'prior'
# @return a named list of character strings in Stan language
stan_prior <- function(prior, class, coef = NULL, group = NULL, 
                       type = "real", dim = "", coef_type = "real",
                       prefix = "", suffix = "", broadcast = "vector", 
                       header_type = "", comment = "", px = list(),
                       normalize = TRUE) {
  prior_only <- isTRUE(attr(prior, "sample_prior") == "only")
  prior <- subset2(
    prior, class = class, coef = c(coef, ""), 
    group = c(group, ""), ls = px
  )
  # special priors cannot be passed literally to Stan
  is_special_prior <- is_special_prior(prior$prior)
  if (any(is_special_prior)) {
    special_prior <- prior$prior[is_special_prior]
    stop2("Prior ", collapse_comma(special_prior), " is used in an invalid ", 
          "context. See ?set_prior for details on how to use special priors.")
  }
  
  px <- as.data.frame(px, stringsAsFactors = FALSE)
  upx <- unique(px)
  if (nrow(upx) > 1L) {
    # TODO: find a better solution to handle this case
    # can only happen for SD parameters of the same ID
    base_prior <- rep(NA, nrow(upx))
    for (i in seq_rows(upx)) {
      sub_upx <- lapply(upx[i, ], function(x) c(x, ""))
      sub_prior <- subset2(prior, ls = sub_upx) 
      base_prior[i] <- stan_base_prior(sub_prior)
    }
    if (length(unique(base_prior)) > 1L) {
      # define prior for single coefficients manually
      # as there is not single base_prior anymore
      prior_of_coefs <- prior[nzchar(prior$coef), vars_prefix()]
      take <- match_rows(prior_of_coefs, upx)
      prior[nzchar(prior$coef), "prior"] <- base_prior[take]
    }
    base_prior <- base_prior[1]
    bound <- ""
  } else {
    base_prior <- stan_base_prior(prior)
    bound <- prior[!nzchar(prior$coef), "bound"]
  }
  
  # generate stan prior statements
  out <- list()
  par <- paste0(prefix, class, suffix)
  has_constant_priors <- FALSE
  has_coef_prior <- any(with(prior, nzchar(coef) & nzchar(prior)))
  if (has_coef_prior || nzchar(dim) && length(coef)) {
    # priors on individual coefficients are also individually set
    # priors are always set on individual coefficients for arrays
    index_two_dims <- is.matrix(coef)
    coef <- as.matrix(coef)
    prior <- subset2(prior, coef = coef)
    estimated_coef_indices <- list()
    used_base_prior <- FALSE
    for (i in seq_rows(coef)) {
      for (j in seq_cols(coef)) {
        index <- i
        if (index_two_dims) {
          c(index) <- j
        }
        prior_ij <- subset2(prior, coef = coef[i, j])
        if (NROW(px) > 1L) {
          # disambiguate priors of coefficients with the same name
          # coming from different model components
          stopifnot(NROW(px) == NROW(coef))
          prior_ij <- subset2(prior_ij, ls = px[i, ])
        }
        # zero rows can happen if only global priors present
        stopifnot(nrow(prior_ij) <= 1L)
        coef_prior <- prior_ij$prior
        if (!isTRUE(nzchar(coef_prior))) {
          used_base_prior <- TRUE
          coef_prior <- base_prior
        }
        if (!stan_is_constant_prior(coef_prior)) {
          # all parameters with non-constant priors are estimated 
          c(estimated_coef_indices) <- list(index)
        }
        if (nzchar(coef_prior)) {
          # implies a proper prior or constant
          if (type == coef_type && !nzchar(dim)) {
            # the single coefficient of that parameter equals the parameter
            stopifnot(all(index == 1L))
            par_ij <- par
          } else {
            par_ij <- paste0(par, collapse("[", index, "]")) 
          }
          if (stan_is_constant_prior(coef_prior)) {
            coef_prior <- stan_constant_prior(
              coef_prior, par_ij, broadcast = broadcast
            )
            str_add(out$tpar_prior) <- paste0(coef_prior, ";\n")
          } else {
            coef_prior <- stan_target_prior(
              coef_prior, par_ij, broadcast = broadcast, 
              bound = bound, resp = px$resp[1], normalize = normalize
            )
            str_add(out$prior) <- paste0(tp(), coef_prior, ";\n") 
          }
        }
      }
    }
    # the base prior may be improper flat in which no Stan code is added
    # but we still have estimated coefficients if the base prior is used
    has_estimated_priors <- isTRUE(nzchar(out$prior)) ||
      used_base_prior && !stan_is_constant_prior(base_prior)
    has_constant_priors <- isTRUE(nzchar(out$tpar_prior))
    if (has_estimated_priors && has_constant_priors) {
      # need to mix definition in the parameters and transformed parameters block
      if (!nzchar(coef_type)) {
        stop2("Can either estimate or fix all values of parameter '", par, "'.")
      }
      for (i in seq_along(estimated_coef_indices)) {
        index <- estimated_coef_indices[[i]]
        iu <- paste0(index, collapse = "_")
        str_add(out$par) <- glue(
          "  {coef_type} par_{par}_{iu};\n"
        )
        ib <- collapse("[", index, "]")
        str_add(out$tpar_prior) <- cglue(
          "  {par}{ib} = par_{par}_{iu};\n"
        ) 
      }
    }
  } else if (nzchar(base_prior)) {
    # only a global prior is present and will be broadcasted
    ncoef <- length(coef)
    has_constant_priors <- stan_is_constant_prior(base_prior)
    if (has_constant_priors) {
      constant_base_prior <- stan_constant_prior(
        base_prior, par = par, ncoef = ncoef, broadcast = broadcast
      )
      str_add(out$tpar_prior) <- paste0(constant_base_prior, ";\n")
    } else {
      target_base_prior <- stan_target_prior(
        base_prior, par = par, ncoef = ncoef, bound = bound,
        broadcast = broadcast, resp = px$resp[1], normalize = normalize
      )
      str_add(out$prior) <- paste0(tp(), target_base_prior, ";\n")
    }
  }
  
  if (nzchar(type)) {
    # only define the parameter here if type is non-empty
    comment <- stan_comment(comment)
    par_definition <- glue("  {type} {par}{dim};{comment}\n")
    if (has_constant_priors) {
      # parameter must be defined in the transformed parameters block 
      str_add(out$tpar_def) <- par_definition
    } else {
      # parameter can be defined in the parameters block
      str_add(out$par) <- par_definition
    }
    if (nzchar(header_type)) {
      str_add(out$pll_args) <- glue(", {header_type} {par}") 
    }
  } else {
    if (has_constant_priors) {
      stop2("Cannot fix parameter '", par, "' in this model.")
    }
  }
  has_improper_prior <- !is.null(out$par) && is.null(out$prior)
  if (prior_only && has_improper_prior) {
    stop2("Sampling from priors is not possible as ", 
          "some parameters have no proper priors. ",
          "Error occurred for parameter '", par, "'.")
  }
  out
}

# get the base prior for all coefficients
# this is the lowest level non-coefficient prior
# @param prior a brmsprior object
# @return a character string defining the base prior
stan_base_prior <- function(prior) {
  stopifnot(length(unique(prior$class)) <= 1)
  take <- with(prior, !nzchar(coef) & nzchar(prior))
  prior <- prior[take, ]
  if (!NROW(prior)) {
    return("")
  }
  vars <- c("group", "nlpar", "dpar", "resp", "class")
  for (v in vars) {
    take <- nzchar(prior[[v]])
    if (any(take)) {
      prior <- prior[take, ]
    }
  }
  stopifnot(NROW(prior) == 1)
  prior$prior
}

# Stan prior in target += notation
# @param prior character string defining the prior
# @param par name of the parameter on which to set the prior
# @param ncoef number of coefficients in the parameter
# @param bound bounds of the parameter in Stan language
# @param broadcast Stan type to which the prior should be broadcasted 
# @param name of the response variable
# @return a character string defining the prior in Stan language
stan_target_prior <- function(prior, par, ncoef = 0, broadcast = "vector",
                              bound = "", resp = "", normalize = TRUE) {
  prior <- gsub("[[:space:]]+\\(", "(", prior)
  prior_name <- get_matches(
    "^[^\\(]+(?=\\()", prior, perl = TRUE, simplify = FALSE
  )
  for (i in seq_along(prior_name)) {
    if (length(prior_name[[i]]) != 1L) {
      stop2("The prior '", prior[i], "' is invalid.")
    }
  }
  prior_name <- unlist(prior_name)
  prior_args <- rep(NA, length(prior))
  for (i in seq_along(prior)) {
    prior_args[i] <- sub(glue("^{prior_name[i]}\\("), "", prior[i])
    prior_args[i] <- sub(")$", "", prior_args[i])
  }
  if (broadcast == "matrix" && ncoef > 0) {
    # apply a scalar prior to all elements of a matrix 
    par <- glue("to_vector({par})")
  }
  
  if (nzchar(prior_args)) {
    str_add(prior_args, start = TRUE) <- " | "
  }
  lpdf <- stan_lpdf_name(normalize)
  out <- glue("{prior_name}_{lpdf}({par}{prior_args})")
  par_class <- unique(get_matches("^[^_]+", par))
  par_bound <- par_bounds(par_class, bound, resp = resp)
  prior_bound <- prior_bounds(prior_name)
  trunc_lb <- is.character(par_bound$lb) || par_bound$lb > prior_bound$lb
  trunc_ub <- is.character(par_bound$ub) || par_bound$ub < prior_bound$ub
  if (normalize) {
    # obtain correct normalization constants for truncated priors
    if (trunc_lb || trunc_ub) {
      wsp <- wsp(nsp = 4)
      # scalar parameters are of length 1 but have no coefficients
      ncoef <- max(1, ncoef)
      if (trunc_lb && !trunc_ub) {
        str_add(out) <- glue(
          "\n{wsp}- {ncoef} * {prior_name}_lccdf({par_bound$lb}{prior_args})"
        )
      } else if (!trunc_lb && trunc_ub) {
        str_add(out) <- glue(
          "\n{wsp}- {ncoef} * {prior_name}_lcdf({par_bound$ub}{prior_args})"
        )
      } else if (trunc_lb && trunc_ub) {
        str_add(out) <- glue(
          "\n{wsp}- {ncoef} * log_diff_exp(", 
          "{prior_name}_lcdf({par_bound$ub}{prior_args}), ",
          "{prior_name}_lcdf({par_bound$lb}{prior_args}))"
        )
      }
    }
  }
  out
}

# fix parameters to constants in Stan language
# @param prior character string defining the prior
# @param par name of the parameter on which to set the prior
# @param ncoef number of coefficients in the parameter
# @param broadcast Stan type to which the prior should be broadcasted
# @return a character string defining the prior in Stan language
stan_constant_prior <- function(prior, par, ncoef = 0, broadcast = "vector") {
  stopifnot(grepl("^constant\\(", prior))
  prior_args <- gsub("(^constant\\()|(\\)$)", "", prior)
  if (broadcast == "vector") {
    if (ncoef > 0) {
      # broadcast the scalar prior on the whole parameter vector
      prior_args <- glue("rep_vector({prior_args}, rows({par}))")
    }
    # no action required for individual coefficients of vectors
  } else if (broadcast == "matrix") {
    if (ncoef > 0) {
      # broadcast the scalar prior on the whole parameter matrix 
      prior_args <- glue("rep_matrix({prior_args}, rows({par}), cols({par}))")
    } else {
      # single coefficient is a row in the parameter matrix
      prior_args <- glue("rep_row_vector({prior_args}, cols({par}))")
    }
  }
  glue("  {par} = {prior_args}")
}

# Stan code for global parameters of special priors
# currently implemented are horseshoe and lasso
stan_special_prior_global <- function(bterms, data, prior, normalize, ...) {
  out <- list()
  tp <- tp()
  lpdf <- stan_lpdf_name(normalize)
  px <- check_prefix(bterms)
  p <- usc(combine_prefix(px))
  special <- get_special_prior(prior, px)
  if (!is.null(special$horseshoe)) {
    str_add(out$data) <- glue(
      "  // data for the horseshoe prior\n",
      "  real<lower=0> hs_df{p};  // local degrees of freedom\n",
      "  real<lower=0> hs_df_global{p};  // global degrees of freedom\n",
      "  real<lower=0> hs_df_slab{p};  // slab degrees of freedom\n",
      "  real<lower=0> hs_scale_global{p};  // global prior scale\n",
      "  real<lower=0> hs_scale_slab{p};  // slab prior scale\n"           
    )
    str_add(out$par) <- glue(
      "  // horseshoe shrinkage parameters\n",
      "  real<lower=0> hs_global{p};  // global shrinkage parameters\n",
      "  real<lower=0> hs_slab{p};  // slab regularization parameter\n"
    )
    hs_scale_global <- glue("hs_scale_global{p}")
    if (isTRUE(special$horseshoe$autoscale)) {
      str_add(hs_scale_global) <- glue(" * sigma{usc(px$resp)}")
    }
    str_add(out$prior) <- glue(
      "{tp}student_t_{lpdf}(hs_global{p} | hs_df_global{p}, 0, {hs_scale_global})",
      str_if(normalize, "\n    - 1 * log(0.5)"), ";\n",
      "{tp}inv_gamma_{lpdf}(hs_slab{p} | 0.5 * hs_df_slab{p}, 0.5 * hs_df_slab{p});\n"
    )
  }
  if (!is.null(special$R2D2)) {
    str_add(out$data) <- glue(
      "  // data for the R2D2 prior\n",
      "  real<lower=0> R2D2_mean_R2{p};  // mean of the R2 prior\n",
      "  real<lower=0> R2D2_prec_R2{p};  // precision of the R2 prior\n"
    )
    str_add(out$par) <- glue(
      "  // R2D2 shrinkage parameters\n",
      "  real<lower=0,upper=1> R2D2_R2{p};  // R2 parameter\n"
    )
    if (isTRUE(special$R2D2$autoscale)) {
      var_mult <- glue("sigma{usc(px$resp)}^2 * ")
    }
    str_add(out$tpar_def) <- glue(
      "  real R2D2_tau2{p};  // global R2D2 scale parameter\n"
    )
    str_add(out$tpar_comp) <- glue(
      "  R2D2_tau2{p} = {var_mult}R2D2_R2{p} / (1 - R2D2_R2{p});\n"
    )
    str_add(out$prior) <- glue(
      "{tp}beta_{lpdf}(R2D2_R2{p} | R2D2_mean_R2{p} * R2D2_prec_R2{p}, ",
      "(1 - R2D2_mean_R2{p}) * R2D2_prec_R2{p});\n"
    )
  }
  if (!is.null(special$lasso)) {
    str_add(out$data) <- glue(
      "  // data for the lasso prior\n",
      "  real<lower=0> lasso_df{p};  // prior degrees of freedom\n",
      "  real<lower=0> lasso_scale{p};  // prior scale\n"
    )
    str_add(out$par) <- glue(
      "  // lasso shrinkage parameter\n",
      "  real<lower=0> lasso_inv_lambda{p};\n"
    )
    str_add(out$prior) <- glue(
      "{tp}chi_square_{lpdf}(lasso_inv_lambda{p} | lasso_df{p});\n"
    )
  }
  out
}

# Stan code for local parameters of special priors
# currently implemented are 'horseshoe'
# @param class name of the parameter class
# @param prior a brmsprior object
# @param ncoef number of coefficients in the parameter
# @param px named list to subset 'prior'
# @param center_X is the design matrix centered?
# @param suffix optional suffix of the 'b' coefficient vector
stan_special_prior_local <- function(prior, class, ncoef, px, 
                                     center_X = FALSE, suffix = "",
                                     normalize = TRUE) {
  class <- as_one_character(class)
  stopifnot(class %in% c("b", "bsp"))
  out <- list()
  lpdf <- stan_lpdf_name(normalize)
  p <- usc(combine_prefix(px))
  sp <- paste0(sub("^b", "", class), p)
  ct <- str_if(center_X, "c")
  tp <- tp()
  special <- get_special_prior(prior, px)
  if (!is.null(special$horseshoe)) {
    str_add(out$par) <- glue(
      "  // local parameters for horseshoe prior\n",
      "  vector[K{ct}{sp}] zb{sp};\n",
      "  vector<lower=0>[K{ct}{sp}] hs_local{sp};\n"
    )
    hs_args <- sargs(
      glue("zb{sp}"), glue("hs_local{sp}"), glue("hs_global{p}"), 
      glue("hs_scale_slab{p}^2 * hs_slab{p}")
    )
    str_add(out$tpar_reg_prior) <- glue(
      "  // compute actual regression coefficients\n",
      "  b{sp}{suffix} = horseshoe({hs_args});\n"
    )
    str_add(out$prior) <- glue(
      "{tp}std_normal_{lpdf}(zb{sp});\n",
      "{tp}student_t_{lpdf}(hs_local{sp} | hs_df{p}, 0, 1)",
      str_if(normalize, "\n    - rows(hs_local{sp}) * log(0.5)"), ";\n"
    )
  }
  if (!is.null(special$R2D2)) {
    if (class != "b") {
      stop2("The R2D2 prior does not yet support special coefficient classes.")
    }
    m1 <- str_if(center_X, " -1")
    str_add(out$data) <- glue(
      "  // concentration vector of the D2 prior\n",
      "  vector<lower=0>[K{sp}{m1}] R2D2_cons_D2{sp};\n"
    )
    str_add(out$par) <- glue(
      "  // local parameters for the R2D2 prior\n",
      "  vector[K{ct}{sp}] zb{sp};\n",
      "  simplex[K{ct}{sp}] R2D2_phi{sp};\n"
    )
    R2D2_args <- sargs(
      glue("zb{sp}"), glue("R2D2_phi{sp}"), glue("R2D2_tau2{p}")
    )
    str_add(out$tpar_reg_prior) <- glue(
      "  // compute actual regression coefficients\n",
      "  b{sp}{suffix} = R2D2({R2D2_args});\n"
    )
    str_add(out$prior) <- glue(
      "{tp}std_normal_{lpdf}(zb{sp});\n",
      "{tp}dirichlet_{lpdf}(R2D2_phi{sp} | R2D2_cons_D2{p});\n"
    )
  }
  out
}

# combine unchecked priors for use in Stan
# @param prior a brmsprior object
# @return a single character string in Stan language
stan_unchecked_prior <- function(prior) {
  stopifnot(is.brmsprior(prior))
  if (all(nzchar(prior$class))) {
    return("")
  }
  prior <- subset2(prior, class = "")
  collapse("  ", prior$prior, ";\n")
}

# Stan code to sample separately from priors
# @param prior character string taken from stan_prior
# @param par_declars the parameters block of the Stan code
#     required to extract boundaries
# @param gen_quantities Stan code from the generated quantities block
# @param prior_special a list of values pertaining to special priors
#   such as horseshoe or lasso
# @param sample_prior take draws from priors?
stan_rngprior <- function(prior, par_declars, gen_quantities, 
                          prior_special, sample_prior = "yes") {
  if (!is_equal(sample_prior, "yes")) {
    return(list())
  }
  prior <- strsplit(gsub(" |\\n", "", prior), ";")[[1]]
  # D will contain all relevant information about the priors
  D <- data.frame(prior = prior[nzchar(prior)])
  pars_regex <- "(?<=(_lpdf\\())[^|]+" 
  D$par <- get_matches(pars_regex, D$prior, perl = TRUE, first = TRUE)
  # 'std_normal' has no '|' and thus the above regex matches too much
  np <- !grepl("\\|", D$prior)
  np_regex <- ".+(?=\\)$)"
  D$par[np] <- get_matches(np_regex, D$par[np], perl = TRUE, first = TRUE)
  # 'to_vector' should be removed from the parameter names
  has_tv <- grepl("^to_vector\\(", D$par)
  tv_regex <- "(^to_vector\\()|(\\)(?=((\\[[[:digit:]]+\\])?)$))"
  D$par[has_tv] <- gsub(tv_regex, "", D$par[has_tv], perl = TRUE)
  # do not sample from some auxiliary parameters
  excl_regex <- c("z", "zs", "zb", "zgp", "Xn", "Y", "hs", "tmp")
  excl_regex <- paste0("(", excl_regex, ")", collapse = "|")
  excl_regex <- paste0("^(", excl_regex, ")(_|$)")
  D <- D[!grepl(excl_regex, D$par), ]
  if (!NROW(D)) return(list())
  
  # rename parameters containing indices
  has_ind <- grepl("\\[[[:digit:]]+\\]", D$par)
  D$par[has_ind] <- ulapply(D$par[has_ind], function(par) {
    ind_regex <- "(?<=\\[)[[:digit:]]+(?=\\])"
    ind <- get_matches(ind_regex, par, perl = TRUE)
    gsub("\\[[[:digit:]]+\\]", paste0("_", ind), par)
  })
  # cannot handle priors on variable transformations
  D <- D[D$par %in% stan_all_vars(D$par), ]
  if (!NROW(D)) return(list())
  
  class_old <- c("^L_", "^Lrescor")
  class_new <- c("cor_", "rescor")
  D$par <- rename(D$par, class_old, class_new, fixed = FALSE)
  dis_regex <- "(?<=target\\+=)[^\\(]+(?=_lpdf\\()"
  D$dist <- get_matches(dis_regex, D$prior, perl = TRUE, first = TRUE)
  D$dist <- sub("corr_cholesky$", "corr", D$dist)
  args_regex <- "(?<=\\|)[^$\\|]+(?=\\)($|-))"
  D$args <- get_matches(args_regex, D$prior, perl = TRUE, first = TRUE)
  # 'std_normal_rng' does not exist in Stan
  has_std_normal <- D$dist == "std_normal"
  D$dist[has_std_normal] <- "normal"
  D$args[has_std_normal] <- "0,1"
  
  # extract information from the initial parameter definition
  par_declars <- unlist(strsplit(par_declars, "\n", fixed = TRUE))
  par_declars <- gsub("^[[:blank:]]*", "", par_declars)
  par_declars <- par_declars[!grepl("^//", par_declars)]
  all_pars_regex <- "(?<= )[^[:blank:]]+(?=;)"
  all_pars <- get_matches(all_pars_regex, par_declars, perl = TRUE)
  all_pars <- rename(all_pars, class_old, class_new, fixed = FALSE)
  all_bounds <- get_matches("<.+>", par_declars, first = TRUE)
  all_types <- get_matches("^[^[:blank:]]+", par_declars)
  all_dims <- get_matches(
    "(?<=\\[)[^\\]]*", par_declars, first = TRUE, perl = TRUE
  )
  
  # define parameter types and boundaries
  D$dim <- D$bounds <- ""
  D$type <- "real"
  for (i in seq_along(all_pars)) {
    k <- which(grepl(paste0("^", all_pars[i]), D$par))
    D$dim[k] <- all_dims[i]
    D$bounds[k] <- all_bounds[i]
    if (grepl("^((simo_)|(theta)|(R2D2_phi))", all_pars[i])) {
      D$type[k] <- all_types[i]
    }
  }

  # exclude priors which depend on other priors
  # TODO: enable sampling from these priors as well
  found_vars <- lapply(D$args, find_vars, dot = FALSE, brackets = FALSE)
  contains_other_pars <- ulapply(found_vars, function(x) any(x %in% all_pars))
  D <- D[!contains_other_pars, ]
  if (!NROW(D)) return(list())
  
  out <- list()
  # sample priors in the generated quantities block
  D$lkj <- grepl("^lkj_corr$", D$dist)
  D$args <- paste0(ifelse(D$lkj, paste0(D$dim, ","), ""), D$args)
  D$lkj_index <- ifelse(D$lkj, "[1, 2]", "")
  D$prior_par <- glue("prior_{D$par}")
  str_add(out$gen_def) <- "  // additionally sample draws from priors\n"
  str_add(out$gen_def) <- cglue(
    "  {D$type} {D$prior_par} = {D$dist}_rng({D$args}){D$lkj_index};\n"
  )
  
  # sample from truncated priors using rejection sampling
  D$lb <- stan_extract_bounds(D$bounds, bound = "lower")
  D$ub <- stan_extract_bounds(D$bounds, bound = "upper")
  Ibounds <- which(nzchar(D$bounds))
  if (length(Ibounds)) {
    str_add(out$gen_comp) <- "  // use rejection sampling for truncated priors\n"
    for (i in Ibounds) {
      wl <- if (nzchar(D$lb[i])) glue("{D$prior_par[i]} < {D$lb[i]}")
      wu <- if (nzchar(D$ub[i])) glue("{D$prior_par[i]} > {D$ub[i]}")
      prior_while <- paste0(c(wl, wu), collapse = " || ")
      str_add(out$gen_comp) <- glue(
        "  while ({prior_while}) {{\n",
        "    {D$prior_par[i]} = {D$dist[i]}_rng({D$args[i]}){D$lkj_index[i]};\n",
        "  }}\n"
      )
    }
  }
  out
}

# check if any constant priors are present
# @param prior a vector of character strings
stan_is_constant_prior <- function(prior) {
  grepl("^constant\\(", prior)
}

# extract Stan boundaries expression from a string
stan_extract_bounds <- function(x, bound = c("lower", "upper")) {
  bound <- match.arg(bound)
  x <- rm_wsp(x)
  regex <- glue("(?<={bound}=)[^,>]*")
  get_matches(regex, x, perl = TRUE, first = TRUE)
}

# choose the right suffix for Stan probability densities
stan_lpdf_name <- function(normalize, int = FALSE) {
  if (normalize) {
    out <- ifelse(int, "lpmf", "lpdf")
  } else {
    out <- ifelse(int, "lupmf", "lupdf")
  }
  out
}

Try the brms package in your browser

Any scripts or data that you put into this service are public.

brms documentation built on Aug. 23, 2021, 5:08 p.m.