R/set_prior_distribution.R

Defines functions set_prior_distribution get_prior_distribution make_prior_distribution

Documented in get_prior_distribution make_prior_distribution set_prior_distribution

#' Make a prior distribution from priors
#'
#' Create a `n_param`x `n_draws` database of possible lambda draws to be attached to the model.
#'
#' @inheritParams CausalQueries_internal_inherit_params
#' @param n_draws A scalar. Number of draws.
#' @return A `data.frame` with dimension `n_param`x `n_draws` of possible lambda draws
#' @export
#' @importFrom dirmult rdirichlet
#' @family prior_distribution
#' @examples
#' make_model('X -> Y') %>% make_prior_distribution(n_draws = 5)
#'
make_prior_distribution <- function(model, n_draws = 4000) {

    param_sets <- unique(model$parameters_df$param_set)

    priors <- model$parameters_df$priors

    prior_distribution <-
      sapply(param_sets,
             function(v)
               rdirichlet(n_draws, priors[model$parameters_df$param_set == v]),
             simplify = FALSE) |>
      as.data.frame()

    colnames(prior_distribution) <- model$parameters_df$param_names

    prior_distribution

}


#' Get a prior distribution from priors
#'
#' Add to the model a `n_draws x n_param` matrix of possible parameters.
#'
#' @inheritParams CausalQueries_internal_inherit_params
#' @param n_draws A scalar. Number of draws.
#' @return A `data.frame` with dimension `n_param`x `n_draws` of possible lambda draws
#' @export
#' @family prior_distribution
#' @examples
#' make_model('X -> Y') %>% set_prior_distribution(n_draws = 5) %>% get_prior_distribution()
#' make_model('X -> Y') %>% get_prior_distribution(3)
#'
get_prior_distribution <- function(model, n_draws = 4000) {

    if (!is.null(model$prior_distribution))
        return(model$prior_distribution)

    message("The model does not have an attached prior distribution; generated on the fly")

    make_prior_distribution(model, n_draws)
}

#' Add prior distribution draws
#'
#' Add `n_param x n_draws` database of possible lambda draws to the model.
#'
#' @inheritParams CausalQueries_internal_inherit_params
#' @param n_draws A scalar. Number of draws.
#' @return An object of class \code{causal_model}. It essentially returns a list containing the elements comprising
#' a model (e.g. 'statement', 'nodal_types' and 'DAG') with the `prior_distribution` attached to it.
#' @export
#' @family prior_distribution
#' @examples
#' make_model('X -> Y') %>% set_prior_distribution(n_draws = 5) %>% get_prior_distribution()
#'
set_prior_distribution <- function(model, n_draws = 4000) {

    model$prior_distribution <- make_prior_distribution(model, n_draws = n_draws)

    model

}

Try the CausalQueries package in your browser

Any scripts or data that you put into this service are public.

CausalQueries documentation built on Oct. 20, 2023, 1:06 a.m.