# Utility function
bnlearn_nodes <- function(bn){
stopifnot(class(bn)[1] %in% c("bn", "bn.fit"))
if (class(bn)[1] == "bn"){
return(names(bn$nodes))
} else{
return(names(bn))
}
}
# Convert bn.fit (or bn_list) to edgeList
bn.fit2edgeList <- function(bn.fit){
eL <- lapply(bn.fit, function(x) match(x$parents, names(bn.fit)))
return(sparsebnUtils::edgeList(eL))
}
# Convert bn_list to bn.fit
bn_list2bn.fit <- function(bn_list){
if (class(bn_list)[1] == "bn.fit"){
return(bn_list)
}
nodes <- unname(sapply(bn_list, `[[`, "node"))
bn <- sparsebnUtils::to_bn(bn.fit2edgeList(bn_list))
bnlearn::nodes(bn) <- nodes
if (!is.null(bn_list[[1]]$prob)){
## bn.fit.dnet if has prob
bn.fit <- bnlearn::custom.fit(bn,
lapply(bn_list, function(x) x$prob))
} else if (!is.null(bn_list[[1]]$coefficients)){
## bn.fit.gnet if has coefficients
bn.fit <- bnlearn::custom.fit(bn,
lapply(bn_list, function(x)
list(coef = x$coefficients, sd = x$sd)))
}
return(bn.fit)
}
# Reorder bn.fit
reorder_bn.fit <- function(bn.fit,
ordering = TRUE){
## if bn_list, convert to bn.fit
if (is.list(bn.fit) && !"bn.fit" %in% class(bn.fit))
bn.fit <- bn_list2bn.fit(bn.fit)
nodes <- bnlearn_nodes(bn.fit)
## get ordering
if (is.logical(ordering)){
ordered_nodes <- bnlearn::node.ordering(bn.fit)
} else if (is.numeric(ordering)){
ordered_nodes <- nodes[ordering]
} else if (is.character(ordering)){
ordered_nodes <- ordering
}
debug_cli_sprintf(!is.character(ordered_nodes) ||
length(ordered_nodes) != length(nodes) ||
!setequal(ordered_nodes, nodes),
"abort", "Supplied ordering not compatible with nodes")
## reverse ordering
ordering <- match(ordered_nodes, nodes)
revert <- match(seq_len(length(ordering)), ordering)
## reorder bn.fit
if (any(seq_len(length(nodes)) != ordering)){
bn_list <- bn.fit[ordering]
bn.fit <- bn_list2bn.fit(bn_list)
}
attr(bn.fit, "ordering") <- ordering
attr(bn.fit, "revert") <- revert
return(bn.fit)
}
# Rename bn.fit
rename_bn.fit <- function(bn.fit,
nodes = "V",
categories = TRUE){
## default names
if (is.null(nodes))
nodes <- "V"
if (is.character(nodes) && length(nodes) == 1)
nodes <- sprintf("%s%s", nodes[1], seq_len(length(bn.fit)))
## rename nodes
original <- bnlearn_nodes(bn.fit)
bnlearn::nodes(bn.fit) <- nodes
## if discrete, rename discrete categorical levels
if (categories && "bn.fit.dnet" %in% class(bn.fit)){
## convert to bn_list
bn_list <- bn.fit[seq_len(length(bn.fit))]
for (node in nodes){
## rename categories in prob to 0 to (r - 1)
dim_prob <- dim(bn_list[[node]]$prob)
dim_nms <- lapply(dim_prob, function(x) seq(0, x-1))
names(dim_nms) <- c(node, bn_list[[node]]$parents)
dimnames(bn_list[[node]]$prob) <- dim_nms
}
bn.fit <- bn_list2bn.fit(bn_list = bn_list)
}
attr(bn.fit, "original") <- original
return(bn.fit)
}
# Get weighted adjacency matrix
wamat <- function(bn.fit){
debug_cli_sprintf(!"bn.fit.gnet" %in% class(bn.fit),
"abort", "bn.fit must be of class bn.fit.gnet")
wa <- bnlearn::amat(bn.fit)
for (node in bn.fit){
wa[node$parents, node$node] <- node$coefficients[-1] # exclude intercept
}
return(wa)
}
# Convert bn to a "default" bn.fit.gnet object
bn2gnet <- function(bn,
seed,
coefs = c(0, 0),
vars = c(0, 0),
normalize = TRUE,
intercept = FALSE){
## TODO: check arguments
bnlearn:::check.bn.or.fit(bn)
if (!missing(seed) && is.numeric(seed) && !is.na(seed))
set.seed(seed)
gnet <- bnlearn::empty.graph(nodes = bnlearn_nodes(bn))
bnlearn::amat(gnet) <- bnlearn::amat(bn)
## generate parameters for gnet
dist <- lapply(gnet$nodes, function(node){
## sample coefficients and standard deviations
params <- list(coef = c(sample(c(-1, 1), # negative or positive
length(node$parents) + 1, replace = TRUE) *
runif(length(node$parents) + 1, # magnitudes
coefs[1], coefs[2])),
sd = runif(1, sqrt(vars[1]), sqrt(vars[2])))
if (! intercept)
params$coef[1] <- 0
names(params$coef) <- c("(Intercept)", node$parents)
return(params)
})
## normalize variances
if (normalize &&
all(sapply(dist, `[[`, "sd") > 0)){
beta <- wamat(bnlearn::custom.fit(gnet, dist = dist)) # coefficient matrix
I <- diag(length(dist)) # identity matrix
Omega <- diag(sapply(dist, `[[`, "sd")^2) # error variances
rownames(I) <- colnames(I) <-
rownames(Omega) <- colnames(Omega) <- names(dist)
## visit nodes topologically
for (node in bnlearn::node.ordering(bn)){
if (length(bnlearn::parents(bn, node)) == 0){
## if no parents, variance 1
Omega[node, node] <- 1
dist[[node]]$sd <- 1
} else{
## if there are parents, scale coefficients and error
## variances such that the variable variance is 1
## TODO: other normalizing strategies
## estimate covariance matrix
Sigma <- solve(t(I - beta)) %*% Omega %*% solve(I - beta)
## scale coefficients
beta[, node] <- beta[, node] / sqrt(Sigma[node, node])
dist[[node]]$coef[-1] <-
dist[[node]]$coef[-1] / sqrt(Sigma[node, node])
## scale error variance
Omega[node, node] <- Omega[node, node] / Sigma[node, node]
dist[[node]]$sd <- dist[[node]]$sd / sqrt(Sigma[node, node])
}
}
}
gnet <- bnlearn::custom.fit(gnet, dist = dist)
return(gnet)
}
# Create parallel graph structure
parallel_bn <- function(p = 3){
nodes <- sprintf("V%s", seq_len(p))
bn <- bnlearn::empty.graph(nodes = nodes)
a <- bnlearn::amat(bn)
a[seq_len(p - 1), p] <- 1
bnlearn::amat(bn) <- a
return(bn)
}
# Create chain graph structure
chain_bn <- function(p){
nodes <- sprintf("V%s", seq_len(p))
bn <- bnlearn::empty.graph(nodes = nodes)
bnlearn::arcs(bn) <- t(sapply(seq_len(p - 1),
function(x) nodes[c(x, x + 1)]))
return(bn)
}
# Create parallel graph structure
random_bn <- function(p,
d,
seed,
...){
if (!missing(seed) && is.numeric(seed) && !is.na(seed))
set.seed(seed)
nodes <- sprintf("V%s", seq_len(p))
repeat{
nel <- pcalg::randDAG(n = p, d = d, weighted = FALSE, ...)
a <- as(nel, "matrix")
## TODO: control with argument
# if (all(colSums(a) <= 4)) break
break
}
rownames(a) <- colnames(a) <- nodes
bn <- bnlearn::empty.graph(nodes = nodes)
bnlearn::amat(bn) <- a
return(bn)
}
# Create sink graph structure
sink_bn <- function(p, # number of parents of target node
d){ # number of parents per parent
nodes <- sprintf("V%s", seq_len(1 + p * (1 + d)))
sink <- length(nodes)
ancestors <- seq_len(p * d)
parents <- setdiff(seq_len(length(nodes)), c(sink,
ancestors))
bn <- bnlearn::empty.graph(nodes = nodes)
a <- bnlearn::amat(bn)
a[parents, sink] <- 1L
a[cbind(ancestors, rep(parents, each = d))] <- 1L
bnlearn::amat(bn) <- a
return(bn)
}
# Create random ER graph structure terminating in a parallel structure
randpar_bn <- function(p1, # number of parents of target node
p2, # number of additional nodes
d, # expected degree
seed){
nodes <- sprintf("V%s", seq_len(p1 + p2 + 1))
sink <- length(nodes)
bn <- bnlearn::empty.graph(nodes = nodes)
rand_bn <- bcb:::random_bn(p = p1 + p2, d = d,
seed = seed)
bnlearn::amat(bn)[-sink, -sink] <- bnlearn::amat(rand_bn)
bnlearn::amat(bn)[tail(bnlearn::node.ordering(rand_bn), p1),
sink] <- 1L
return(bn)
}
# Create random graph structure as in de Kroon (2021) terminating in a parallel structure
dkpar_bn <- function(p1, # number of parents of target node
p2, # number of additional nodes
d, # maximum in-degree
seed){
if (!missing(seed) && !is.na(seed) && is.numeric(seed))
set.seed(seed)
nodes <- sprintf("V%s", seq_len(p1 + p2 + 1))
sink <- length(nodes)
bn <- bnlearn::empty.graph(nodes = nodes)
amat <- bnlearn::amat(bn)
for (i in rev(seq_len(length(nodes) - 1))){
f <- sample(length(nodes) - i, size = 1)
to <- unique(i + sample(length(nodes) - i,
size = f, replace = TRUE))
to <- to[colSums(amat)[to] < d] # restrict maximum in-degree
amat[i, to] <- 1L
}
if (sum(amat[,sink]) < p1){
from <- sample(setdiff(seq_len(length(nodes) - 1),
which(amat[,sink] > 0)),
p1 - sum(amat[,sink]))
amat[from, sink] <- 1L
}
bnlearn::amat(bn) <- amat
return(bn)
}
# Load bnrepository structure terminating in a parallel structure
bnrpar_bn <- function(x, # name of network in bnrepository
p, # number of parents of target node
seed){
if (!missing(seed) && !is.na(seed) && is.numeric(seed))
set.seed(seed)
bn.fit <- load_bn.fit(x = x)
nodes <- sprintf("V%s", seq_len(length(bn.fit) + 1))
sink <- length(nodes)
bn <- bnlearn::empty.graph(nodes = nodes)
bnlearn::amat(bn)[-sink, -sink] <- bnlearn::amat(bn.fit)
bnlearn::amat(bn)[sample(nodes[-sink], size = p),
sink] <- 1L
return(bn)
}
# Load bn.fit
# An extension of phsl::bnrepository() that includes
# functionality for reordering and renaming, as well
# as parallel, chain, and random graphs
#' @export
load_bn.fit <- function(x,
reorder = TRUE,
rename = TRUE,
...){
if ("bn.fit" %in% class(x)){
## ignore
} else if (x %in% avail_bnrepository){
bn.fit <- phsl::bnrepository(x = x)
} else if (grepl("parallel", x)){
p <- as.numeric(strsplit(x, "_")[[1]][2])
bn <- parallel_bn(p = p)
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("chain", x)){
p <- as.numeric(strsplit(x, "_")[[1]][2])
bn <- chain_bn(p = p)
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("sink", x)){
p_d <- as.numeric(strsplit(x, "_")[[1]][seq_len(2) + 1])
bn <- sink_bn(p = p_d[1], d = p_d[2])
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("random", x)){
p_d_seed <- as.numeric(strsplit(x, "_")[[1]][seq_len(3) + 1])
bn <- random_bn(p = p_d_seed[1],
d = p_d_seed[2],
seed = p_d_seed[3])
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("randpar", x)){
p1_p2_d_seed <- as.numeric(strsplit(x, "_")[[1]][seq_len(4) + 1])
bn <- randpar_bn(p1 = p1_p2_d_seed[1],
p2 = p1_p2_d_seed[2],
d = p1_p2_d_seed[3],
seed = p1_p2_d_seed[4])
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("dkpar", x)){
p1_p2_d_seed <- as.numeric(strsplit(x, "_")[[1]][seq_len(4) + 1])
bn <- dkpar_bn(p1 = p1_p2_d_seed[1],
p2 = p1_p2_d_seed[2],
d = p1_p2_d_seed[3],
seed = p1_p2_d_seed[4])
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("bnrpar", x)){
x_p_seed <- strsplit(x, "_")[[1]][seq_len(3) + 1]
bn <- bnrpar_bn(x = x_p_seed[1],
p = as.numeric(x_p_seed[2]),
seed = as.numeric(x_p_seed[3]))
bn.fit <- bn2gnet(bn = bn, ...)
} else if (grepl("cytometry", x)){
## load from sparsebn
require(sparsebn, quietly = TRUE)
data(cytometryDiscrete)
dag <- as.matrix(sparsebnUtils::get.adjacency.matrix(cytometryDiscrete$dag))
rownames(dag) <- colnames(dag) <- cytometry_nodes
bn <- bnlearn::empty.graph(nodes = cytometry_nodes)
bnlearn::amat(bn) <- dag
bn.fit <- bn2gnet(bn = bn, ...)
} else{
browser()
# TODO: default structures in globals.R
}
if (reorder)
bn.fit <- reorder_bn.fit(bn.fit = bn.fit, ordering = TRUE)
if (rename)
bn.fit <- rename_bn.fit(bn.fit = bn.fit, nodes = "V", categories = TRUE)
return(bn.fit)
}
# Get causal effect lower bound from dnet
dnet2ce_lb <- function(dnet){
bn_list <- add_j_m_pt(bn.fit = dnet)
marginals <- sapply(bn_list, `[[`, "mpt")
min(do.call(c, lapply(bn_list, function(node){
unlist(lapply(node$parents, function(x){
## compute local jpt
jpt <- get_jpt(bn.fit = dnet,
nodes = unique(c(node$node, x, bn_list[[x]]$parents)))
pt <- query_jpt(jpt = jpt, target = node$node,
given = x, adjust = bn_list[[x]]$parents)
max(abs(pt - reshape_dim(pt = node$mpt,
new_dimnames = dimnames(pt))))
}))
})))
}
# Extract information from bn.fit for data_row
bn.fit2data_row <- function(bn.fit,
data_row,
effects_array = bn.fit2effects(bn.fit = bn.fit)){
in_deg <- sapply(bn.fit, function(x) length(x$parents))
out_deg <- sapply(bn.fit, function(x) length(x$children))
data_row$avg_deg <- mean(in_deg + out_deg)
data_row$max_in_deg <- max(in_deg)
data_row$max_out_deg <- max(out_deg)
data_row$n_node <- bnlearn::nnodes(bn.fit)
data_row$n_edge <- nrow(bnlearn::directed.arcs(bn.fit))
data_row$n_within <- data_row$n_edge
data_row$n_between <- 0
data_row$n_compelled <- nrow(bnlearn::compelled.arcs(bn.fit))
data_row$n_reversible <- data_row$n_edge - data_row$n_compelled
data_row$n_params <- bnlearn::nparams(bn.fit)
if (is.numeric(data_row$target) &&
data_row$target %% 1 == 0 &&
data_row$target >= 1 &&
data_row$target <= length(bn.fit)){
## provided index
data_row$target <- bnlearn_nodes(bn.fit)[data_row$target]
} else if (! data_row$target %in% bnlearn_nodes(bn.fit)){
## default to leaf which has the most parents
nodes <- rev(bnlearn::node.ordering(bn.fit))
nodes <- nodes[sapply(nodes, function(x) length(bn.fit[[x]]$children)) == 0]
data_row$target <- nodes[[which.max(sapply(nodes,
function(x) length(bn.fit[[x]]$parents)))]]
}
if ("bn.fit.gnet" %in% class(bn.fit)){
## effects on target
effects <- effects_array[setdiff(names(bn.fit), data_row$target),
data_row$target, 1]
effects <- sort(abs(effects[effects != 0]), decreasing = TRUE)
betas <- abs(wamat(bn.fit)[bnlearn::amat(bn.fit) > 0])
vars <- sapply(bn.fit, `[[`, "sd")^2
data_row$ce_lb <- min(betas)
data_row$ri_lb <- effects[1]
data_row$coef_lb <- min(betas)
data_row$coef_ub <- max(betas)
data_row$var_lb <- min(vars)
data_row$var_ub <- max(vars)
effects <- effects / effects[1]
} else if ("bn.fit.dnet" %in% class(bn.fit)){
## effects on target
success <- 1 # TODO: generalize
effects <- effects_array[setdiff(names(bn.fit), data_row$target),
data_row$target, success]
effects <- sort(abs(effects[effects != 0]), decreasing = TRUE)
n_lev <- sapply(bn.fit,
function(node) dim(node$prob)[1])
data_row$ce_lb <- dnet2ce_lb(dnet = bn.fit)
data_row$ri_lb <- max(effects - get_jpt(bn.fit = bn.fit,
nodes = data_row$target)[success])
data_row$var_lb <- min(n_lev)
data_row$var_ub <- max(n_lev)
}
data_row$reg_lb <- effects[1] - effects[2]
return(data_row)
}
# Extract pairwise causal effects
bn.fit2effects <- function(bn.fit){
bnlearn:::check.bn.or.fit(bn.fit)
debug_cli_sprintf(! class(bn.fit)[2] %in% c("bn.fit.gnet", "bn.fit.dnet"),
"abort", "Currently only bn.fit.gnet and bn.fit.dnet supported for determining true causal effects")
nodes <- bnlearn_nodes(bn.fit)
effects <- array(0, dim = c(length(nodes), length(nodes), 2))
if ("bn.fit.gnet" %in% class(bn.fit)){
Beta <- wamat(bn.fit = bn.fit)
effects_list <- vector(mode = "list",
length = length(bn.fit))
for (i in seq_len(length(bn.fit))){
effects_list[[i]] <- expm::`%^%`(Beta, i)
if (!any(effects_list[[i]] != 0)) break
}
effects[,,1] <- Reduce(`+`, effects_list[seq_len(i)])
intercepts <- unname(sapply(bn.fit, function(x) x$coefficients[1]))
if (any(intercepts != 0)){
## TODO: scalable E(Y | do(X = 0))
bn_list0 <- bn.fit[seq_len(length(bn.fit))]
## zero sds
for (node in nodes){
bn_list0[[node]]$sd <- 0
}
bn.fit0 <- bn_list2bn.fit(bn_list0)
## initialize interventions
intervene <- lapply(nodes, function(node){
int <- list(node = 0, n = 1)
names(int) <- c(node, "n")
return(int)
})
for (int in intervene){
data <- ribn(x = bn.fit0, fix = TRUE,
intervene = list(int), debug = 0)
node <- intersect(names(int), nodes)
others <- -match(node, nodes)
effects[node, others, 2] <- unlist(data[others])
}
}
dimnames(effects) <- list(nodes, nodes, c(1, # effect: beta; E(Y | do(X = 1)) - E(Y | do(X = 0))
0)) # intercept: beta0; E(Y | do(X = 0))
} else if ("bn.fit.dnet" %in% class(bn.fit)){
node_values <- bn.fit2values(bn.fit = bn.fit)
success <- 1 # TODO: generalize
for (node in names(bn.fit)){
for (value in node_values[[node]]){
debug_cli(TRUE, cli::cli_alert,
c("computing effects of intervention {node} = {value}"),
.envir = environment())
if (length(bn.fit[[node]]$parents)){
bn.fit0 <-
remove_arcs(bn.fit = bn.fit,
arcs = as.matrix(data.frame(from = bn.fit[[node]]$parents,
to = node)))
} else{
bn.fit0 <- bn.fit
}
bn_list0 <- bn.fit0[seq_len(length(bn.fit0))]
bn_list0[[node]]$prob[] <- 0
bn_list0[[node]]$prob[value] <- 1
bn_list0 <- add_j_m_pt(bn.fit = bn_list0)
mpt <- sapply(bn_list0, `[[`, "mpt", simplify = FALSE)
node_i <- match(node, names(bn.fit))
effects[node_i, -node_i, value] <- sapply(mpt[-node_i], `[[`, success)
}
}
dimnames(effects) <- list(nodes, nodes, c(1, # E(Y | do(X = 1))
2)) # E(Y | do(X = 2))
}
return(effects)
}
# bn.fit to intervention values
bn.fit2values <- function(bn.fit){
bnlearn:::check.bn.or.fit(bn.fit)
debug_cli_sprintf(! class(bn.fit)[2] %in% c("bn.fit.gnet", "bn.fit.dnet"),
"abort", "Currently only bn.fit.gnet and bn.fit.dnet supported for determining true causal effects")
nodes <- bnlearn_nodes(bn.fit)
if ("bn.fit.gnet" %in% class(bn.fit)){
node_values <- sapply(nodes, function(node){
c(-1, 1)
}, simplify = FALSE)
} else if ("bn.fit.dnet" %in% class(bn.fit)){
node_values <- sapply(nodes, function(node){
c(1, 2)
}, simplify = FALSE)
}
return(node_values)
}
# Obtain ancestor relation matrix from dag
dag2arp <- function(dag = bnlearn::amat(bn.fit),
nodes = colnames(dag),
bn.fit){
if (!missing(bn.fit)){
dag <- bnlearn::amat(bn.fit)
nodes <- names(bn.fit)
}
mode(dag) <- "integer"
if (is.null(nodes)){
nodes <- sprintf("V%g", seq_len(ncol(dag)))
}
ancestors <- diag(length(nodes))
rownames(ancestors) <- colnames(ancestors) <- nodes
for (i in seq_len(length(nodes))){
for (j in seq_len(length(nodes))[-i]){
if (phsl:::has_path(i = i, j = j, amat = dag,
nodes = nodes)){
ancestors[i, j] <- 1
}
}
}
return(ancestors)
}
# Zero intercepts of bn.fit
zero_bn.fit <- function(bn.fit, keep_attrib = TRUE){
if (class(bn.fit)[2] != "bn.fit.gnet")
return(bn.fit)
if (keep_attrib){
attrib <- attributes(bn.fit)
}
## already zeroed
if (!is.null(attr(bn.fit, "obs_means")))
return(bn.fit)
bn_list <- bn.fit[seq_len(length(bn.fit))]
for (node in names(bn.fit)){
bn_list[[node]]$sd <- 0
}
obs_means <- unlist(ribn(x = bn_list2bn.fit(bn_list), n = 1))
bn_list <- bn.fit[seq_len(length(bn.fit))]
for (node in names(bn.fit)){
bn_list[[node]]$coefficients[1] <- 0
}
bn.fit <- bn_list2bn.fit(bn_list)
attr(bn.fit, "obs_means") <- obs_means
if (keep_attrib){
attributes(bn.fit) <- attrib
}
return(bn.fit)
}
# Compute log joint probability table from bn.fit
bn.fit2jpt <- function(bn.fit){
## initialize
# nodes <- bnlearn:::topological.ordering(x = bn.fit)
nodes <- names(bn.fit)
## initialize joint probability table (jpt)
new_dimnames <- sapply(nodes, function(node){
dimnames(bn.fit[[node]]$prob)[[1]]
}, simplify = FALSE)
log_jpt <- reshape_dim(0, new_dimnames = new_dimnames, repl = TRUE)
for (node in nodes){
prob <- reshape_dim(bn.fit[[node]]$prob,
new_dimnames = dimnames(log_jpt))
log_jpt <- log_jpt + log(prob)
}
return(exp(log_jpt))
}
# Relatively trivial function for summing over jpt
sum_pt <- function(pt, nodes){
if (length(nodes) == 0){
return(sum(pt))
}
pt_ <- apply(pt, MARGIN = match(nodes,
names(dimnames(pt))), sum)
dim(pt_) <- dim(pt)[match(nodes,
names(dimnames(pt)))]
dimnames(pt_) <- dimnames(pt)[nodes]
return(pt_)
}
# Reshape dimension of pt to match that of new_dimnames
reshape_dim <- function(pt, new_dimnames, repl = TRUE){
## permute dimensions of pt according to new_dimnames
if (!is.null(dimnames(pt))){
perm <- order(match(names(dimnames(pt)),
names(new_dimnames)))
pt <- aperm(pt, perm)
}
## manipulate dimension
dim_pt <- rep(1, length(new_dimnames))
names(dim_pt) <- names(new_dimnames)
dim_pt[names(dimnames(pt))] <- dim(pt)
dim(pt) <- dim_pt
new_dim <- sapply(new_dimnames, length)
if (repl){
debug_cli(prod(new_dim) > 2e8, cli::cli_abort,
"probability table will have {prod(new_dim)} > 2e8 elements")
## replicate to match
dims <- which(dim(pt) == 1)
idx <- lapply(new_dim[dims], function(x) rep(1, x))
pt <- abind::asub(pt, idx, dims, drop = FALSE)
dimnames(pt) <- new_dimnames
} else{
dimnames(pt) <- sapply(names(new_dimnames), function(node){
if (dim(pt)[node] == new_dim[node])
new_dimnames[[node]]
else
paste(new_dimnames[[node]], collapse = "")
}, simplify = FALSE)
}
class(pt) <- "table"
return(pt)
}
# Validate conditional probability table
validate_cpt <- function(cpt,
given = character(0)){
## normalize, if necessary
sums <- sum_pt(cpt,
given)
if (length(sums) == 0 ||
any(is.na(sums)) ||
any(sums == 0)){
## TODO: conditional probabilities when probability of condition is 0
debug_cli(FALSE, cli::cli_alert_warning,
"cpt has sums = {sums}")
# browser()
# sums <- ifelse(is.na(sums) | sums == 0, 1, sums)
attr(cpt, "invalid") <- TRUE
} else{ # if (all(sums > 0)){
cpt <- exp(log(cpt) - reshape_dim(log(sums), dimnames(cpt)))
}
class(cpt) <- "table"
return(cpt)
}
# Query log joint probability table from bn.fit
query_jpt <- function(jpt,
target,
given = character(0),
adjust = character(0)){
## TODO: test adjusting parents with length(given) > 1
nodes <- names(dimnames(jpt))
if (is.numeric(target))
target <- nodes[target]
if (is.numeric(given))
given <- nodes[given]
if (is.numeric(adjust))
adjust <- nodes[adjust]
adjust <- setdiff(adjust, given)
given <- c(given, adjust)
debug_cli(length(intersect(target, given)) ||
length(intersect(target, adjust)) ||
!all(c(target, given, adjust) %in% nodes),
cli::cli_abort, "invalid target, given, or adjust")
## sort according to jpt
target <- nodes[sort(match(target, nodes))]
given <- nodes[sort(match(given, nodes))]
adjust <- nodes[sort(match(adjust, nodes))]
tgp <- c(target, union(given, adjust))
## sum across irrelevant dimensions
log_pt <- log(sum_pt(jpt, nodes = tgp))
## divide by conditioning dimensions
if (length(given)){
## joint probability table of conditioning variables
gpt <- sum_pt(exp(log_pt), given)
## manipulate dimension to element-wise divide
## by joint distribution of conditioning variables
gpt <- reshape_dim(gpt, dimnames(log_pt))
if (length(adjust)){
## joint probability table of adjustment variables
ppt <- sum_pt(exp(log_pt), adjust)
}
## element-wise divide
log_gpt <- ifelse((log_gpt <- log(gpt)) == -Inf, 1, log_gpt)
log_pt <- log_pt - log_gpt
if (length(adjust)){
## manipulate dimension to element-wise multiply
## by joint distribution of adjustment variables
ppt <- reshape_dim(ppt, dimnames(log_pt))
## element-wise multiply
log_pt <- log_pt + log(ppt)
## sum over adjust
log_pt <- log(apply(exp(log_pt),
MARGIN = match(c(target, setdiff(given, adjust)),
names(dimnames(log_pt))), sum))
}
}
log_pt <- log(validate_cpt(cpt = exp(log_pt),
given = setdiff(given, adjust)))
return(exp(log_pt))
}
# Return nodes along directed path from one node to another
nodes_along_path <- function(from,
to,
bn.fit){
ig <- igraph::graph.adjacency(bnlearn::amat(bn.fit))
paths <- igraph::all_simple_paths(graph = ig,
from = from, to = to)
nodes <- setdiff(Reduce(union, lapply(paths,
function(x) x$name)),
c(from, to))
if (is.null(nodes))
nodes <- character(0)
return(nodes)
}
# Recursively get joint probability table
# WARNING: not guaranteed to be correct in every case; use with care
get_jpt <- function(bn.fit,
nodes = names(bn.fit)){
## initialize joint probability table (jpt)
parents <- Reduce(union, sapply(bn.fit[nodes], `[[`, "parents"))
between <- Reduce(union, lapply(parents, function(from){
Reduce(union, lapply(setdiff(parents, from), function(to){
nodes_along_path(from, to, bn.fit)
}))
}))
parents <- union(parents,
between)
new_dimnames <- sapply(union(nodes, parents), function(node){
dimnames(bn.fit[[node]]$prob)[[1]]
}, simplify = FALSE)
log_jpt <- reshape_dim(0, new_dimnames = new_dimnames, repl = TRUE)
for (node in nodes){
prob <- reshape_dim(bn.fit[[node]]$prob,
new_dimnames = dimnames(log_jpt))
log_jpt <- log_jpt + log(prob)
}
if (length(unresolved <- setdiff(parents, nodes))){
## group parents that have a directed path between them
g <- sapply(match(unresolved, names(bn.fit)), function(i){
sapply(match(unresolved, names(bn.fit)), function(j){
phsl:::has_path(i = i, j = j,
amat = bnlearn::amat(bn.fit),
nodes = names(bn.fit)) * 1
})
})
g <- as.matrix(g)
rownames(g) <- colnames(g) <- unresolved
groups <- lapply(igraph::decompose.graph(igraph::graph.adjacency(g)),
function(x) igraph::V(x)$name)
log_probs <- lapply(groups, function(group){
prob <- get_jpt(bn.fit, nodes = group)
log(reshape_dim(prob,
new_dimnames = dimnames(log_jpt)))
})
log_jpt <- log_jpt + Reduce(`+`, log_probs)
}
jpt <- sum_pt(exp(log_jpt),
nodes = nodes)
if (sum(jpt) != 1){
debug_cli(abs(sum(jpt) - 1) > .Machine$double.eps, cli::cli_warn,
"numerical error of {abs(sum(jpt) - 1)} > { .Machine$double.eps}")
jpt <- jpt / sum(jpt)
}
return(jpt)
}
# Get joint and marginal probability tables
add_j_m_pt <- function(bn.fit,
nodes = names(bn.fit),
ignore_if_present = TRUE,
attempt_global = FALSE){
## if mpt and jpt already present for each node, ignore
if (ignore_if_present &&
!any(sapply(bn.fit[nodes], function(node) is.null(node$mpt) || is.null(node$jpt)))){
return(bn.fit)
}
bn.fit <- bn_list2bn.fit(bn.fit) # ensure of class bn.fit
if (setequal(names(bn.fit), nodes)){
## attempt using bn.fit2jpt()
bn_list <- tryCatch({
debug_cli(!attempt_global, cli::cli_abort,
"skipping {.fn bn.fit2jpt}")
jpt <- bn.fit2jpt(bn.fit = bn.fit)
lapply(bn.fit, function(node){
node <- node[names(node)]
node$jpt <- query_jpt(jpt = jpt, target = c(node$node, node$parents))
node$mpt <- query_jpt(jpt = node$jpt, target = node$node, given = character(0))
return(node)
})
},
error = function(err){
debug_cli(attempt_global, cli::cli_alert_danger,
c("error using {.fn bn.fit2jpt} for {length(bn.fit)} nodes: ",
"{gsub('\\n', ' ', as.character(err))}"),
.envir = environment())
return(NULL)
})
} else{
nodes <- unique(nodes)
bn_list <- NULL
}
## reattempt using get_jpt() and then empirical jpt
if (length(bn_list) == 0){
debug_cli(attempt_global && setequal(names(bn.fit), nodes), cli::cli_alert,
"reattempting by applying {.fn get_jpt} to each node")
data <- NULL
envir <- environment()
bn_list <- bn.fit[seq_len(length(bn.fit))]
bn_list[nodes] <- lapply(bn.fit[nodes], function(node){
node <- node[names(node)]
node$jpt <- tryCatch({
get_jpt(bn.fit = bn.fit,
nodes = c(node$node, node$parents))
},
error = function(err){
debug_cli(TRUE, cli::cli_alert_danger,
c("error in {.fn get_jpt} for node {node$node}: ",
"{gsub('\\n', ' ', as.character(err))}"),
.envir = environment())
debug_cli(TRUE, cli::cli_alert,
"resorting to empirical jpt for node {node$node}",
.envir = environment())
if (is.null(data)){
debug_cli(TRUE, cli::cli_alert,
"generating {1e7} samples of observational data",
.envir = environment())
assign(x = "data", value = ribn(x = bn.fit,
n = 1e7, seed = 1), envir = envir)
}
do.call(table, sapply(Reduce(union, node[c("node", "parents")]),
function(x) data[[x]], simplify = FALSE)) / nrow(data)
})
node$mpt <- query_jpt(jpt = node$jpt, target = node$node)
return(node)
})
}
return(bn_list)
}
# Remove arcs from bn.fit object
remove_arcs <- function(bn.fit,
arcs = matrix(character(0), ncol = 2),
debug = 0){
as_bn.fit <- class(bn.fit)[1] == "bn.fit"
amat <- bnlearn::amat(bn_list2bn.fit(bn.fit))
if (length(arcs) == 0 || # no arcs specified
all(amat[arcs] == 0)){ # specified arcs are not present
return(bn.fit)
}
if (!is.null(bn.fit[[1]]$coefficients)){ # gnet
bn_list <- bn.fit[seq_len(length(bn.fit))]
} else if (!is.null(bn.fit[[1]]$prob)){ # dnet
bn_list <- add_j_m_pt(bn.fit = bn.fit,
nodes = unique(arcs[, 2]),
ignore_if_present = TRUE)
}
for (node in names(bn_list)){
## parents of node to be removed
remove_nodes <- arcs[arcs[, 2] == node, 1]
if (length(remove_nodes)){
debug_cli(debug, cli::cli_alert,
"removing {paste(remove_nodes, collapse = ',')} -> {node}")
## remove as parents and children
bn_list[[node]]$parents <- setdiff(bn_list[[node]]$parents,
remove_nodes)
bn_list[remove_nodes] <- lapply(bn_list[remove_nodes], function(parent){
parent$children <- setdiff(parent$children,
node)
return(parent)
})
if (!is.null(bn.fit[[1]]$coefficients)){
## remove coefficients
bn_list[[node]]$coefficients <- bn_list[[node]]$coefficients[c("(Intercept)",
bn_list[[node]]$parents)]
} else if (!is.null(bn.fit[[1]]$prob)){
## marginalize out nodes to be removed
bn_list[[node]]$prob <- query_jpt(jpt = bn_list[[node]]$jpt, target = node,
given = bn_list[[node]]$parents) # reduced parents
if (!as_bn.fit){
## also reduce joint probability table
bn_list[[node]]$jpt <- query_jpt(jpt = bn_list[[node]]$jpt,
target = c(node, bn_list[[node]]$parents))
}
}
}
}
if (as_bn.fit){
return(bn_list2bn.fit(bn_list = bn_list))
} else{
return(bn_list)
}
}
# Restrict the number of categorical levels in a discrete Bayesian network
process_dnet <- function(bn.fit,
min_levels = 2,
max_levels = Inf,
merge_order = c("increasing", "decreasing", "random"),
max_in_deg = max(sapply(bn.fit, function(x) length(x$parents))),
max_out_deg = max(sapply(bn.fit, function(x) length(x$children))),
remove_order = c("decreasing", "random"),
min_cp = 0, # minimum conditional probability
ce_lb = 1e-2,
rename = TRUE,
debug = 3){
if (class(bn.fit)[2] == "bn.fit.gnet"){
return(bn.fit)
}
bn_list <- add_j_m_pt(bn.fit = bn.fit)
## restrict levels by merging levels
merge_order <- match.arg(merge_order)
bool_update <- FALSE
debug_cli(debug >= 2, cli::cli_alert_info,
c("restricting {sum(sapply(bn.fit, function(x) dim(x$prob)[1] < min_levels || dim(x$prob)[1] > max_levels))} ",
"nodes to between {min_levels} and {max_levels} levels in {merge_order} order"))
## check each node
for (i in bnlearn:::node.ordering(bn.fit)){
node <- bn_list[[i]]
if (length(node$mpt) >= min_levels &&
length(node$mpt) <= max_levels &&
all(node$mpt > 0)) next
bool_update <- TRUE
## determine order to assign
ord <- switch(merge_order,
random = sample(length(node$mpt)),
order(node$mpt, decreasing = (merge_order ==
"decreasing")))
## initialize groups with levels with the highest/lowest/random
## non-zero marginal probability
groups <- lapply(tail(ord[node$mpt[ord] > 0], max_levels), function(x) x)
if (length(groups) < min_levels)
groups <- list(Reduce(union, groups)) # single group
## merge smaller
for (j in head(ord, length(ord) - length(groups))){
## add to group based on merge order
marginals <- sapply(groups, function(x) sum(node$mpt[x]))
k <- switch(merge_order,
random = sample(x = length(groups), size = 1),
increasing = which.min(marginals),
decreasing = which.max(marginals))
groups[[k]] <- c(groups[[k]], j)
}
debug_cli(debug >= 3, cli::cli_alert,
c("for node {j}, merging {length(ord)} levels ({sum(node$mpt > 0)} non-zero) ",
"into {length(groups)} levels"))
## create and replace with new cpt
node_prob <- do.call(abind::abind, c(
## along first dimension corresponding to node
list(along = 1),
## for each group
lapply(groups, function(lvls){
## reshape and add each level in the group
reshape_dim(Reduce(`+`, lapply(lvls, function(idx){
abind::asub(node$prob, idx = idx, dims = 1, drop = FALSE)
})), new_dimnames = dimnames(node$prob), repl = FALSE)
})
))
dimnames(node_prob)[[1]] <- do.call(c, lapply(groups, function(lvls){
paste(dimnames(node$prob)[[1]][lvls], collapse = "|")
}))
names(dimnames(node_prob)) <- c(node$node,
node$parents)
bn_list[[i]]$prob <- validate_cpt(cpt = node_prob,
given = node$parents)
## adjust each child
for (j in node$children){
debug_cli(debug >= 3, cli::cli_alert,
"adjusting child {j} of node {i}")
child <- bn_list[[j]]
along <- match(node$node, names(dimnames(child$prob)))
## create and replace with new cpt
child_prob <- do.call(abind::abind, c(
## along dimension corresponding to node
list(along = along),
## for each group
lapply(groups, function(lvls){
## reshape and add each level in the group
reshape_dim(Reduce(`+`, lapply(lvls, function(idx){
abind::asub(child$prob, idx = idx, dims = along, drop = FALSE) *
node$mpt[idx] # multiply by corresponding marginal
})), new_dimnames = dimnames(child$prob), repl = FALSE) /
sum(node$mpt[lvls]) # normalize by sum of marginals
})
))
dimnames(child_prob)[[along]] <- do.call(c, lapply(groups, function(lvls){
paste(dimnames(child$prob)[[along]][lvls], collapse = "|")
}))
names(dimnames(child_prob)) <- c(child$node,
child$parents)
bn_list[[j]]$prob <- validate_cpt(cpt = child_prob,
given = child$parents)
## cut off parents if fewer than min_levels
if (length(groups) < min_levels){
debug_cli(debug >= 3, cli::cli_alert,
"removing {i} as a parent of {j}")
bn_list[[j]]$parents <- setdiff(bn_list[[j]]$parents,
node$node)
new_dimnames <- dimnames(bn_list[[j]]$prob)[-along]
bn_list[[j]]$prob <- abind::asub(bn_list[[j]]$prob,
idx = 1, dims = along, drop = TRUE)
if (is.null(dim(bn_list[[j]]$prob))){
dim(bn_list[[j]]$prob) <- length(bn_list[[j]]$prob)
dimnames(bn_list[[j]]$prob) <- new_dimnames
}
}
} # end for j in children
## remove node if fewer than min_levels
if (length(groups) < min_levels){
debug_cli(debug >= 3, cli::cli_alert,
"removing node {i}, which has fewer than {min_levels} levels")
## remove as a child of parents
for (j in bn_list[[i]]$parents){
bn_list[[j]]$children <- setdiff(bn_list[[j]]$children,
node$node)
}
## remove node
bn_list <- bn_list[-match(node$node, names(bn_list))]
}
} # end for i in nodes
## update bn.fit
if (bool_update){
bn.fit <- bn_list2bn.fit(bn_list)
bn_list <- add_j_m_pt(bn.fit = bn.fit)
}
## restrict in-degree and out-degree
in_deg <- sapply(bn.fit, function(x) length(x$parents))
out_deg <- sapply(bn.fit, function(x) length(x$children))
if (any(in_deg > max_in_deg) || any(out_deg > max_out_deg)){
## order by which to remove arcs
remove_order <- match.arg(remove_order)
## add jpt and mpt; need jpt for remove_arcs()
bn_list <- add_j_m_pt(bn.fit = bn.fit)
## first, restrict in-degree
arcs <- do.call(rbind, lapply(bn.fit, function(node){
if (length(node$parents) > max_in_deg){
remove_parents <- switch(remove_order,
decreasing = node$parents[order(out_deg[node$parents],
decreasing = TRUE)[seq_len(max_in_deg)]],
random = sample(node$parents,
size = max_in_deg))
return(unname(cbind(remove_parents,
node$node)))
}
return(NULL)
}))
## remove arcs and update degrees
bn_list <- remove_arcs(bn.fit = bn_list,
arcs = arcs)
in_deg <- sapply(bn_list, function(x) length(x$parents))
out_deg <- sapply(bn_list, function(x) length(x$children))
## second, restrict out-degree
arcs <- do.call(rbind, lapply(bn.fit, function(node){
if (length(node$children) > max_out_deg){
remove_children <- switch(remove_order,
decreasing = node$children[order(in_deg[node$children],
decreasing = TRUE)[seq_len(max_out_deg)]],
random = sample(node$children,
size = max_out_deg))
return(unname(cbind(node$node,
remove_children)))
}
return(NULL)
}))
## remove arcs and convert to bn.fit
bn_list <- remove_arcs(bn.fit = bn_list,
arcs = arcs)
bn.fit <- bn_list2bn.fit(bn_list = bn_list)
## TODO: remove; temporary for checking
in_deg <- sapply(bn_list, function(x) length(x$parents))
out_deg <- sapply(bn_list, function(x) length(x$children))
if (any(in_deg > max_in_deg) || any(out_deg > max_out_deg)){
browser()
}
}
## minimum conditional probability
small <- names(bn.fit)[which(sapply(bn.fit, function(node){
sum(node$prob < min_cp) > 0
}))]
if (length(small)){
ordering <- bnlearn:::topological.ordering(bn.fit)
ordered <- identical(ordering, bnlearn_nodes(bn.fit))
for (node in small){
bn.fit <- solve_cpt_dnet(bn.fit = bn.fit, target = node,
parents = bn.fit[[node]]$parents,
ce_lb = ce_lb, ordered = ordered,
n_attempts = 100, time_limit = 60,
debug = debug)
}
}
## rename categorical levels of bn.fit
if (rename){
bn.fit <- rename_bn.fit(bn.fit = bn.fit, nodes = "V", categories = TRUE)
}
return(bn.fit)
}
# Generate and solve discrete conditional probability table for new parents using
# Tsamardinos, 2006b: Generating realistic large Bayesian networks by tiling
solve_cpt_dnet <- function(bn.fit,
target,
parents,
ce_lb = 1e-2,
ordered = FALSE,
n_attempts = 100,
time_limit = 60,
debug = 3){
## remove parents of target
if (length(bn.fit[[target]]$parents)){
bn.fit <- remove_arcs(bn.fit = bn.fit, arcs = cbind(bn.fit[[target]]$parents, target))
}
## if no parents, return as is
if (length(parents) == 0)
return(bn.fit)
## reorder bn.fit and new parents
if (!ordered){
bn.fit <- reorder_bn.fit(bn.fit = bn.fit, ordering = TRUE)
revert <- attr(bn.fit, "revert")
}
parents <- intersect(names(bn.fit), parents)
new_dimnames <- sapply(bn.fit[c(target, parents)],
function(x) dimnames(x$prob)[[1]], simplify = FALSE)
parent_dimnames <- list(apply(do.call(expand.grid, new_dimnames[-1]),
1, paste, collapse = "_"))
names(parent_dimnames) <- paste(names(new_dimnames[-1]),
collapse = "_")
temp_dimnames <- c(new_dimnames[1],
parent_dimnames)
dim_tp <- sapply(temp_dimnames, length)
debug_cli(prod(dim_tp) > 2^12, cli::cli_abort,
"new cpt contains {prod(dim_tp)} > {2^12} elements")
debug_cli(debug >= 2, cli::cli_alert_info,
c("attempting to add parent(s) {paste(parents, collapse = ',')} ",
"({dim_tp[2]} configurations) to target {target} ({dim_tp[1]} levels)"),
.envir = environment())
## P(Pa_T = p) where Pa_T is the parents of target variable T
a_p <- c(get_jpt(bn.fit = bn.fit, nodes = parents))
dim(a_p) <- length(a_p)
dimnames(a_p) <- temp_dimnames[-1]
a_p <- reshape_dim(a_p,
new_dimnames = temp_dimnames)
## P(T = t) where T is the target variable
b_t <- bn.fit[[target]]$prob
## objective and constraint functions
obj <- function(par){
return(obj_cpp(par, dim_tp, x_tp_, b_t, a_p))
}
con <- function(par){
return(lapply(con_cpp(par, dim_tp, x_tp_, b_t), c))
}
par2cpt <- function(par){
par <- c(par)
r_t <- par[seq_len(dim_tp[1])]
c_p <- par[-seq_len(dim_tp[1])]
x_tp <- x_tp_ * (as.matrix(r_t) %*% t(c_p))
dimnames(x_tp) <- temp_dimnames
x_tp <- validate_cpt(cpt = x_tp,
given = names(temp_dimnames)[-1])
dim(x_tp) <- sapply(new_dimnames, length)
dimnames(x_tp) <- new_dimnames
return(x_tp)
}
par2dnet <- function(par){
bn_list <- bn.fit[seq_len(length(bn.fit))]
bn_list[[target]]$parents <- parents
bn_list[[target]]$prob <- par2cpt(par)
for (parent in bn_list[[target]]$parents){
bn_list[[parent]]$children <- union(bn_list[[parent]]$children,
target)
}
bn.fit <- bn_list2bn.fit(bn_list)
return(bn.fit)
}
## initialize parameters
par0 <- rep(1, sum(dim_tp))
best <- list(par = par0, fn = Inf, ce = 0) # obj(par0))
tolX <- 1e-8
## begin attempts
success <- FALSE
start_time <- Sys.time()
for (i in seq_len(n_attempts)){
null <- tryCatch({
setTimeLimit(time_limit, transient = TRUE)
debug_cli(debug >= 3, cli::cli_alert,
"attempt {i} of adding parent(s) {paste(parents, collapse = ',')} to target {target}",
.envir = environment())
x_tp_ <- runif(n = prod(dim_tp))
dim(x_tp_) <- dim_tp
soln <- NlcOptim::solnl(par0, obj, con,
tolX = tolX, maxIter = 100)
if (ce_lb <= 0 &&
soln$fn < best$fn){
best <- soln
} else if (ce_lb > 0){
if ((soln$ce <- dnet2ce_lb(par2dnet(soln$par))) >= ce_lb ||
is.null(best$grad) ||
soln$ce > best$ce){
debug_cli(debug >= 2 && !is.null(best$grad), cli::cli_alert,
"accepted attempt with ce = {soln$ce}")
best <- soln
} else{
debug_cli(debug >= 2, cli::cli_alert,
"failed attempt with ce = {soln$ce}")
}
}
if (best$fn < 1e-6 &&
(ce_lb <= 0 || best$ce >= ce_lb)){
debug_cli(debug >= 2, cli::cli_alert_success,
c("achieved desired tolerance on attempt {i} with value {best$fn} ",
"after {prettyunits::pretty_sec(as.numeric(Sys.time() - start_time, unit = 'secs'))}"),
.envir = environment())
success <- TRUE
setTimeLimit(Inf, transient = TRUE)
break
}
}, error = function(err){
debug_cli(TRUE, cli::cli_alert_danger,
"error in attempt {i}: {as.character(err)}",
.envir = environment())
})
}
debug_cli(!success, cli::cli_abort,
"failed to add parent(s) {paste(parents, collapse = ',')}
({dim_tp[2]} configurations) to target {target} ({dim_tp[1]} levels)
after {n_attempts} attempts with {best$fn} achieved",
.envir = environment())
## prepare bn.fit
bn.fit <- par2dnet(best$par)
## reverse ordering, if not supplied as ordered
if (!ordered){
bn.fit <- reorder_bn.fit(bn.fit = bn.fit, ordering = revert)
}
return(bn.fit)
}
# Convert bn to a "default" bn.fit.dnet object
bn2dnet <- function(bn,
seed,
min_levels = 2,
max_levels = 2,
marginal_lb = 1e-2,
ce_lb = 1e-2,
n_attempts = 100,
time_limit = 60,
debug = 3){
## TODO: check arguments
bnlearn:::check.bn.or.fit(bn)
if (!missing(seed) && is.numeric(seed) && !is.na(seed))
set.seed(seed)
## attempt
success <- FALSE
start_time <- Sys.time()
for (i in seq_len(n_attempts)){
null <- tryCatch({
setTimeLimit(time_limit, transient = TRUE)
debug_cli(debug >= 3, cli::cli_alert,
c("attempt {i} of generating dnet with {length(bnlearn_nodes(bn))} ",
"nodes and {sum(bnlearn::amat(bn))} edges"),
.envir = environment())
## initialize with empty graph with random marginal probabilities
dnet_i <- bnlearn::empty.graph(nodes = bnlearn_nodes(bn))
dist <- sapply(bnlearn_nodes(bn), function(node){
prob <- -1
while (any(prob < marginal_lb)){
n_levels <- floor(runif(1, min = min_levels,
max = max_levels + 1 - .Machine$double.eps))
prob <- runif(n_levels)
prob <- prob / sum(prob)
}
dim(prob) <- length(prob)
dimnames(prob) <- list(seq_len(length(prob)) - 1)
names(dimnames(prob)) <- node
return(prob)
}, simplify = FALSE)
dnet_i <- bnlearn::custom.fit(bnlearn::empty.graph(nodes = bnlearn_nodes(bn)), dist)
## add parents and generate cpts
bn_list <- lapply(dnet_i[seq_len(length(dnet_i))], function(node){
node$children <- bnlearn::children(x = bn, node = node$node)
if (length(bnlearn::parents(x = bn, node = node$node))){
node$parents <- bnlearn::parents(x = bn, node = node$node)
new_dimnames <- c(dimnames(node$prob),
do.call(c, lapply(node$parents, function(x)
dimnames(dnet_i[[x]]$prob)[1])))
node$prob <- reshape_dim(pt = node$prob,
new_dimnames = new_dimnames)
node$prob[] <- runif(length(node$prob))
node$prob <- validate_cpt(cpt = node$prob,
given = node$parents)
}
return(node)
})
bn_list <- add_j_m_pt(bn_list)
dnet_i <- bn_list2bn.fit(bn_list = bn_list)
attr(dnet_i, "marginal_lb") <- min(sapply(bn_list, `[[`, "mpt"))
attr(dnet_i, "ce_lb") <- dnet2ce_lb(dnet = dnet_i)
if (i == 1 ||
(attr(dnet_i, "marginal_lb") >= marginal_lb &&
attr(dnet_i, "ce_lb") > attr(dnet, "ce_lb"))){
dnet <- dnet_i
}
if (attr(dnet, "marginal_lb") >= marginal_lb &&
attr(dnet, "ce_lb") >= ce_lb){
## successfully generated dnet satisfying constraints
debug_cli(debug >= 2, cli::cli_alert_success,
c("successfully generated dnet on attempt {i} ",
"after {prettyunits::pretty_sec(as.numeric(Sys.time() - start_time, unit = 'secs'))}"),
.envir = environment())
success <- TRUE
setTimeLimit(Inf, transient = TRUE)
break
}
},
error = function(err){
debug_cli(TRUE, cli::cli_alert_danger,
"error in attempt {i}: {as.character(err)}",
.envir = environment())
})
}
debug_cli(!success && debug >= 2, cli::cli_alert_danger,
c("failed to generate dnet with {length(bn)} nodes and ",
"{sum(bnlearn::amat(bn))} edges after {n_attempts} attempts, ",
"resulting in marginal_lb = {attr(dnet, 'marginal_lb')} ",
"and ce_lb = {attr(dnet, 'ce_lb')}"),
.envir = environment())
attr(dnet, "marginal_lb") <- attr(dnet,
"ce_lb") <- NULL
return(dnet)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.