Nothing
#' Generate a representative sample of the posterior distribution
#' @description
#' Generate a representative sample of the posterior distribution. The input graph object should be of class `causact_graph` and created using `dag_create()`. The specification of a completely consistent joint distribution is left to the user.
#'
#' @param graph a graph object of class `causact_graph` representing a complete and conistent specification of a joint distribution.
#' @param mcmc a logical value indicating whether to sample from the posterior distribution. When `mcmc=FALSE`, the numpyro code is printed to the console, but not executed. The user can cut and paste the code to another script for running line-by-line. This option is most useful for debugging purposes. When `mcmc=TRUE`, the code is executed and outputs a dataframe of posterior draws.
#' @param num_warmup an integer value for the number of initial steps that will be discarded while the markov chain finds its way into the typical set.
#' @param num_samples an integer value for the number of samples.
#' @param seed an integer-valued random seed that serves as a starting point for a random number generator. By setting the seed to a specific value, you can ensure the reproducibility and consistency of your results.
#' @return If `mcmc=TRUE`, returns a dataframe of posterior distribution samples corresponding to the input `causact_graph`. Each column is a parameter and each row a draw from the posterior sample output. If `mcmc=FALSE`, running `dag_numpyro` returns a character string of code that would help the user generate the posterior distribution; useful for debugging.
#'
#' @examples
#' graph = dag_create() %>%
#' dag_node("Get Card","y",
#' rhs = bernoulli(theta),
#' data = carModelDF$getCard) %>%
#' dag_node(descr = "Card Probability by Car",label = "theta",
#' rhs = beta(2,2),
#' child = "y") %>%
#' dag_node("Car Model","x",
#' data = carModelDF$carModel,
#' child = "y") %>%
#' dag_plate("Car Model","x",
#' data = carModelDF$carModel,
#' nodeLabels = "theta")
#'
#' graph %>% dag_render()
#' numpyroCode = graph %>% dag_numpyro(mcmc=FALSE)
#' \dontrun{
#' ## default functionality returns a data frame
#' # below requires numpyro installation
#' drawsDF = graph %>% dag_numpyro()
#' drawsDF %>% dagp_plot()
#' }
#' @importFrom dplyr bind_rows tibble left_join rowwise select add_row as_tibble group_indices row_number mutate filter join_by
#' @importFrom DiagrammeR create_graph add_global_graph_attrs
#' @importFrom rlang enquo expr_text .data expr is_na eval_tidy parse_expr warn
#' @importFrom igraph graph_from_data_frame topo_sort
#' @importFrom tidyr gather
#' @importFrom stats na.omit
#' @import reticulate
#' @export
dag_numpyro <- function(graph,
mcmc = TRUE,
num_warmup = 1000,
num_samples = 4000,
seed = 111) {
## make sure reticulate autoconfigure is disabled when running this function - I do not think this is needed
# ac_flag <- Sys.getenv("RETICULATE_AUTOCONFIGURE")
# on.exit(
# Sys.setenv(
# RETICULATE_AUTOCONFIGURE = ac_flag
# )
# )
# Sys.setenv(RETICULATE_AUTOCONFIGURE = FALSE)
## initialize to pass devtools check
newPyName <- dataPy <- id <- auto_data <- dimID <- dec <- plateStmnt <- numTabsForNode <- plateLabelling <- varLabelling <- selLabelling <- forLoop <- newVar <- dimNum <- plateState <- plateLabelState <- varNameStmnt <- id <- selStmnt <- NULL ## place holder to pass devtools::check
## get graph object name for label statement
graphName = rlang::as_name(rlang::ensym(graph))
if (graphName == ".") {graphName = get_name(graph)}
. <- NULL ## place holder to pass devtools::check
## First validate that the first argument is indeed a causact_graph
class_g <- class(graph)
## Any causact_graph will have class length of 1
if(length(class_g) > 1){
## This specific case is hard-coded as it has occurred often in early use by the author
if(class_g[1] == chr("grViz") && class_g[2]=="htmlwidget"){
errorMessage <- paste0("Given rendered Causact Graph. Check the declaration for a dag_render() call.")
}
else {
errorMessage <- paste0("Cannot run dag_numpyro() on given object as it is not a Causact Graph.")
}
stop(errorMessage)
}
## Now check the single class
if(class_g != "causact_graph"){
errorMessage <- paste0("Cannot run dag_numpyro() on given object as it is not a Causact Graph.")
stop(errorMessage)
}
## clear cache environment for storing mcmc results
## also verify numpyro is available
if (mcmc) {
rmExpr = rlang::expr(rm(list = ls()))
eval(rmExpr, envir = cacheEnv)
options("reticulate.engine.environment" = cacheEnv)
pyPacks <- reticulate::py_list_packages()
packs_to_check <- c("numpyro", "arviz", "xarray")
existVector = sapply(packs_to_check,
function(element) {
any(element %in%
pyPacks$package)
})
if (!(all(existVector))){
rlang::warn("It is likely you need to restart R for dag_numpyro() to make causact's required connection to Python; numpyro or other dependencies are missing from the currently connected Python. Please restart R, then load the causact package with library(causact).")
}
} ## clear cacheEnv
###get dimension information
graphWithDim = graph %>% dag_dim()
### line to handle nested or intersecting
###update rhs information for labelling computer code
graphWithDim = rhsPriorComposition(graphWithDim)
graphWithDim = rhsOperationComposition(graphWithDim)
###retrieve nodeDF,edgeDF,argDF,plateIndexDF, and plateNodeDF
nodeDF = graphWithDim$nodes_df
edgeDF = graphWithDim$edges_df
argDF = graphWithDim$arg_df
plateDF = graphWithDim$plate_index_df
plateNodeDF = graphWithDim$plate_node_df
dimDF = graphWithDim$dim_df
###arrangeNodes in topological order -> top-down
nodeIDOrder = igraph::graph_from_data_frame(edgeDF %>% dplyr::select(from,to)) %>%
igraph::topo_sort(mode = "out") %>%
names() %>%
as.integer()
## append non-connected nodes into nodeIDOrder
nodeIDOrder = union(nodeIDOrder,nodeDF$id)
## arrange nodeDF by nodeIDOrder
nodeDF = nodeDF[match(nodeIDOrder,nodeDF$id) , ] %>%
dplyr::mutate(nodeOrder = dplyr::row_number())
###Use DAPROPLIMOPO(DAta,PRior,OPeration,LIkelihood,MOdel,POsterior)
###Find all nodes that require data based on user input
###Err on the side of including a node
### Initialize all the code statements so that NULL
### values are skipped without Error
nameChangeStatements = NULL
importStatements = NULL
dataStatements = NULL
plateDataStatements = NULL
dimStatements = NULL
functionArguments = NULL
coordLabelsStatements = NULL
codeStatements = NULL
modelStatement = NULL
posteriorStatement = NULL
###IMPORT: Create code for import statements
importStatements = "import numpy as np
import numpyro as npo
import numpyro.distributions as dist
import pandas as pd
import arviz as az
from jax import random
from numpyro.infer import MCMC, NUTS
from jax.numpy import transpose as t
from jax.numpy import (exp, log, log1p, expm1, abs, mean,
sqrt, sign, round, concatenate, atleast_1d,
cos, sin, tan, cosh, sinh, tanh,
sum, prod, min, max, cumsum, cumprod )
## note that above is from JAX numpy package, not numpy.\n"
###DATA: Create Code for Data Lines (Nodes that are not in plates)
### replace any references to R-Objects with `.` in it to
### a renamedObj. Use new column called dataPy to store name
### also handle spaces and colons
##punctuation other than underscore
pattern <- "[^[:^punct:]_$]"
pattern2 <- "[[:space:]]" #whitespace
nodeDF = nodeDF %>%
mutate(newPyName = grepl(pattern, data, perl = TRUE) |
grepl(pattern2, data, perl = TRUE)) %>%
mutate(dataPy = ifelse(newPyName,
paste0("renameNodeForPy__",row_number()),
data))
## create new objects for access from Python if needed
renameDF = nodeDF %>% filter(newPyName) %>% select(data,dataPy)
if (NROW(renameDF) > 0) {
# Loop through each row and copy objects using assign()
for (i in 1:nrow(renameDF)) {
old_name <- renameDF$data[i]
new_name <- renameDF$dataPy[i]
assign(new_name,
eval(rlang::parse_expr(old_name)),
cacheEnv)
nameChangeStatements = paste0(nameChangeStatements,
new_name, " = ",
old_name,"\n")
}
} ## end create new objects to overcome periods and hyphens
lhsNodesDF = nodeDF %>%
dplyr::filter(obs == TRUE | !is.na(data)) %>%
dplyr::filter(!(label %in% plateDF$indexLabel)) %>%
dplyr::mutate(codeLine = paste0(auto_label,
" = ",
"np.array(",
paste0("r.",
gsub("\\$", ".",dataPy),
")"))) %>%
dplyr::mutate(codeLine = paste0(abbrevLabelPad(codeLine), " #DATA"))
###Aggregate Code Statements for DATA
if(nrow(lhsNodesDF) > 0) {
dataStatements = paste(lhsNodesDF$codeLine,
sep = "\n")
functionArguments = paste(c(functionArguments,
lhsNodesDF$auto_label),
collapse = ",")
}
###DIM: Create code for plate dimensions
plateDimDF = plateDF %>% dplyr::filter(!is.na(dataNode)) %>%
mutate(newPyName = grepl(pattern, dataNode, perl = TRUE) |
grepl(pattern2, dataNode, perl = TRUE)) %>%
mutate(dataPy = ifelse(newPyName,
paste0("renameDimForPy__",row_number()),
dataNode))
if (nrow(plateDimDF) > 0) {
plateDataStatements = paste(paste0(
abbrevLabelPad(paste0(plateDimDF$indexLabel)),# four spaces to have invis _dim
" = ",
"pd.factorize(r.",
gsub("\\$", ".", plateDimDF$dataPy),
",use_na_sentinel=True)[0] #DIM"),
sep = "\n")
###make labels for dim variables = to label_dim
dimStatements = paste(
paste0(abbrevLabelPad(paste0(plateDimDF$indexLabel,"_dim")),
" = ",
"len(np.unique(",
plateDimDF$indexLabel,
")) #DIM"),
sep = "\n"
)
coordLabelsStatements = paste(paste0(
abbrevLabelPad(paste0(plateDimDF$indexLabel,"_crd")),# four spaces to have invis _dim
" = ",
"pd.factorize(r.",
gsub("\\$", ".", plateDimDF$dataPy),
",use_na_sentinel=True)[1] #DIM"),
sep = "\n")
functionArguments = paste(c(functionArguments,
plateDimDF$indexLabel),
collapse = ",")
}
## create new objects for access from Python if needed
renameDIMDF = plateDimDF %>% filter(newPyName) %>% select(dataNode,dataPy)
if (NROW(renameDIMDF) > 0) {
# Loop through each row and copy objects using assign()
for (i in 1:nrow(renameDIMDF)) {
old_name <- renameDIMDF$dataNode[i]
new_name <- renameDIMDF$dataPy[i]
## this works, but might need to clean up env
assign(new_name,eval(rlang::parse_expr(old_name)),cacheEnv)
nameChangeStatements = paste0(nameChangeStatements,
new_name, " = ",
old_name, "\n")
}
} ## end create new objects to overcome periods and hyphens
### DEFINE NUMPYRO FUNCTION
functionName = paste0(graphName,"_model")
numPyFunStartStatement = paste0(
paste0("def ",functionName,"("),
paste(functionArguments,sep = ","),
"):")
### alter nodeDF to have numpyro code
### first, add unobserved distribution node code
modelCodeDF = nodeDF %>%
## get rid of data only nodes - not part of likelihood
filter(!(obs == TRUE & distr == FALSE)) %>%
## narrow down columns to useful ones
select(id, rhs, obs, rhsID, distr, auto_label, auto_data, dimID, auto_rhs, dec, det, nodeOrder) %>%
## add in plate dimension labels
left_join(getPlateStatements(graphWithDim), by = join_by(id == nodeID, auto_label == auto_label)) %>%
rowwise() %>%
## create code lines for unobserved RV's
mutate(codeLine = NA) %>% ##init column
mutate(codeLine =
ifelse(distr==TRUE & obs == FALSE,
paste0(
auto_label,
" = npo.sample('",
auto_label, "', ",
rlang::eval_tidy(
rlang::parse_expr(auto_rhs)),
")\n"),
codeLine)) %>%
## create code lines for observed RV's - likelihoods
mutate(codeLine =
ifelse(distr==TRUE & obs == TRUE,
paste0(
auto_label,
" = npo.sample('",
auto_label, "', ",
rlang::eval_tidy(
rlang::parse_expr(auto_rhs)),
",obs=",auto_label,")\n"),
codeLine)) %>%
## create code lines for deterministic RV's -- operations
mutate(codeLine =
ifelse((distr==FALSE & obs==FALSE),
paste0(
auto_label,
" = npo.deterministic('",
auto_label, "', ",
## replace R power(^) with python power(**) and matirx mult(%*%) with (@)
gsub("%\\*%", "@",
gsub("\\^", "**", auto_rhs)),
")\n"),
codeLine)) %>%
select(dimLabel = indexLabel,codeLine,auto_label,plateStmnt,numTabsForNode,plateLabelling,varLabelling,selLabelling)
## create code to handle concatenation in Python
modelCodeDF = modelCodeDF %>%
mutate(codeLine = replace_c(codeLine)) %>%
mutate(numTabsForNode = ifelse(rlang::is_na(numTabsForNode),1,numTabsForNode))
###Create MODEL function BODY using codeLines from above
# Using a for loop to iterate over rows
prevDimLabel = NA
modelStatement = "\t## Define random variables and their relationships"
for (i in 1:nrow(modelCodeDF)) {
currDimLabel = modelCodeDF$dimLabel[i]
numTabs = modelCodeDF$numTabsForNode[i]
## node not on plate
if (rlang::is_na(currDimLabel)) {
## for not null add to existing statements
if (!(rlang::is_null(modelStatement))) {
modelStatement =
paste(modelStatement,paste0(paste(rep("\t",
numTabs),
collapse = ""),
modelCodeDF$codeLine[i]),
sep = "\n")
} else { ## initialize modelStatement
modelStatement =
paste0(paste(rep("\t",
numTabs),
collapse = ""),
modelCodeDF$codeLine[i])
}
}
## additional line for starting plate
if (!rlang::is_na(currDimLabel) & !identical(currDimLabel,prevDimLabel)) {
newLine = modelCodeDF$plateStmnt[i]
modelStatement = paste0(modelStatement,"\n",
newLine)
}
## node on plate
if (!rlang::is_na(currDimLabel)) {
modelStatement =
paste0(modelStatement,
paste(rep("\t",numTabs),collapse = ""),
modelCodeDF$codeLine[i])
}
prevDimLabel = currDimLabel
}
# get all non-observed / non-formula nodes by default
# that are not discrete distributions
discDists = c("bernoulli","binomial","beta_binomial",
"negative_binomial","hypergeometric",
"poisson","multinomial","categorical")
unobservedNodes = graphWithDim$nodes_df %>%
dplyr::filter(obs == FALSE & distr == TRUE) %>%
dplyr::filter(!(rhs %in% discDists))%>%
dplyr::pull(auto_label)
#group unobserved nodes by their rhs for later plotting by ggplot
#all nodes sharing the same prior will be graphed on the same scale
# this code should be moved out of dag_numpyro at some point
priorGroupDF = graphWithDim$nodes_df %>%
dplyr::filter(obs == FALSE & distr == TRUE)
## replaced dependency on dplyr::group_indices
## since its functionality changes fro 0.85 to v1.0
## code returns unique group id's based on prior dist
grpIndexDF = priorGroupDF %>%
dplyr::select(auto_rhs) %>%
dplyr::distinct() %>%
dplyr::mutate(priorGroup = dplyr::row_number())
## add priorGroup column
priorGroupDF = priorGroupDF %>%
dplyr::left_join(grpIndexDF, by = "auto_rhs")
###Create POSTERIOR draws statement
if (mcmc == TRUE) { ##clear cacheEnv make sure priorGrp is restored
## ensure expected numpyro environment is available
# Check if the environment is set up
if (!getOption("causact_env_setup", default = FALSE)) {
message("In order to use dag_numpyro() for computational Bayesian inference, you must configure a conda Python environment called 'r-causact'.")
message("To do this, run install_causact_deps().")
return(invisible())
}
assign("priorGroupDF", priorGroupDF, envir = cacheEnv)
meaningfulLabels(graphWithDim) ###assign meaningful labels in cacheEnv
} # end if mcmc=TRUE
posteriorStatement = paste0("\n# computationally get posterior\nmcmc = MCMC(NUTS(",functionName,"), num_warmup = ",num_warmup,", num_samples = ",num_samples,")")
rngStatement = paste0("rng_key = random.PRNGKey(seed = ",
seed,")")
if (!rlang::is_null(functionArguments)) {
runStatement = paste0("mcmc.run(rng_key,",functionArguments,")")
} else {
runStatement = paste0("mcmc.run(rng_key)")
}
## format posterior dataframe
drawsStatement = "drawsDS = az.from_numpyro(mcmc"
dimToKeepStatement = "dimensions_to_keep = ['chain','draw'"
unstackToDFStatement = "# unstack plate variables to flatten dataframe as needed\n"
## check if any dimensions (i.e. plates)
if (NROW(plateDimDF > 0)) {
drawsStatement = paste0(drawsStatement,",\n\tcoords = {'")
dLabels = unique(stats::na.omit(dimDF$dimLabel))
numDlabels = length(dLabels)
for (i in 1:numDlabels) {
dimToKeepStatement = paste0(dimToKeepStatement,
",'",
dLabels[i],
"_dim'")
drawsStatement = paste0(drawsStatement,
dLabels[i],
"_dim': ",
dLabels[i],
"_crd")
## finish dimToKeep and draws statements below
if (i == numDlabels) { ## closing bracket
drawsStatement = paste0(drawsStatement,"},\n\tdims = {'")
dimToKeepStatement = paste0(dimToKeepStatement,
"]\n")
} else {
## add comma to keep going
drawsStatement = paste0(drawsStatement,",\n\t\t'")
}
}
## now loop over variables to add dimensionality
plateNodes = dimDF %>% filter(dimType == "plate") %>%
arrange(nodeID,dimLabel) %>%
mutate(dimLabel = paste0("'",dimLabel,"_dim","'")) %>%
group_by(nodeID) %>%
summarize(nodeID = dplyr::first(nodeID),
indexLabel = paste0("[",
paste0(dimLabel, collapse = ","),
"]")) %>% left_join(nodeDF %>% select(id,auto_label), by = join_by(nodeID == id))
numNodes = NROW(plateNodes)
for (i in 1:numNodes) {
drawsStatement = paste0(drawsStatement,
plateNodes$auto_label[i],
"': ",
plateNodes$indexLabel[i])
if (i == numNodes) { ## closing bracket
drawsStatement = paste0(drawsStatement,
"}\n\t")
} else {
## add comma to keep going
drawsStatement = paste0(drawsStatement,",\n\t\t'")
}
}
} else { ## add closing bracket if no plates
dimToKeepStatement = paste0(dimToKeepStatement,
"]\n")
}
## for unstacking code
unstackDF = modelCodeDF %>%
select(auto_label, dimLabel, plateLabelling,
varLabelling, selLabelling, numTabsForNode) %>%
filter(!rlang::is_na(dimLabel)) %>%
group_by(dimLabel) %>%
summarize(forLoop = dplyr::first(plateLabelling),
newVar = paste0(strrep("\t",numTabsForNode-1),
paste0("new_varname = f'",
auto_label,"_",
varLabelling,"'\n",
strrep("\t",numTabsForNode-1),
"drawsDS = drawsDS.assign(**{new_varname:drawsDS['",auto_label,"'].sel(",
selLabelling,
")})"),
collapse = "\n")) %>%
### create consolidated loop code for unstacking
mutate(unStackLine = paste0(forLoop,newVar, sep = "\n"))
## create unstack statement
unstackToDFStatement = paste0(unstackToDFStatement,
unstackDF$unStackLine,
collapse = "\n")
## drop all plate dims
if (NROW(plateDimDF > 0)){
dropState = paste0("drawsDS = drawsDS.drop_dims([",
paste0(paste0("'",unique(na.omit(dimDF$dimLabel)),
"_dim'"),
collapse = ","),
"])")
} else {
dropState = ""
}
## close the statement
drawsStatement = paste0(drawsStatement,
").posterior\n# prepare xarray dataset for export to R dataframe")
drawsDFStatement =
"drawsDS = drawsDS.squeeze(drop = True ).drop_dims([dim for dim in drawsDS.dims if dim not in dimensions_to_keep])\n"
drawsDFStatement = paste0(dimToKeepStatement,
drawsDFStatement,
unstackToDFStatement,
dropState,
"\ndrawsDF = drawsDS.squeeze().to_dataframe()")
##########################################
###Aggregate all code
codeStatements = c(importStatements,
dataStatements,
plateDataStatements,
dimStatements,
coordLabelsStatements,
numPyFunStartStatement,
modelStatement,
posteriorStatement,
rngStatement,
runStatement,
drawsStatement,
drawsDFStatement)
#codeStatements
## wrap python code in R to get posterior draws
codeRun = paste0(nameChangeStatements,
'reticulate::py_run_string("\n',
paste(codeStatements, collapse = '\n'),
'"\n) ## END PYTHON STRING\n',
"drawsDF = reticulate::py$drawsDF")
###print out Code as text for user to use
if(mcmc == FALSE){
codeForUser = paste0("\n## The below code will return a posterior distribution \n## for the given DAG. Use dag_numpyro(mcmc=TRUE) to return a\n## data frame of the posterior distribution: \n",codeRun)
message(codeForUser)
}
##EVALUATE CODE IN cacheEnv ENVIRONMENT
##make expression out of Code Statements
###BELOW LINE COMMENTED DURING DAG_NUMPYRO TESTING
codeExpr = parse(text = codeRun)
##eval expression - use original graph without DIM
if(mcmc == TRUE) {
eval(codeExpr, envir = cacheEnv) ## evaluate in other env
###return data frame of posterior draws
return(dplyr::as_tibble(py$drawsDF, .name_repair = "universal"))
}
return(invisible(codeForUser)) ## just print code
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.