#' @import data.tree
### Node Object ###
VEBBoostNode <- R6::R6Class(
"VEBBoostNode",
public = list(
operator = NULL, # either "+" or "*" for internal nodes, NULL for terminal nodes
currentFit = NULL, # current fit for fitting function
fitFunction = NULL, # function that takes in predictors X, response Y, variances sigma2, and returns the fit
# the fit must have fields mu1 (first moment), mu2 (second moment), KL_div (KL divergence from q to g),
# and returns a function, predFunction(X_new, method = c(1, 2)) that takes in new data X and returns our prediction for the given moment
predFunction = NULL, # function to predict based on current fit
constCheckFunction = NULL, # function to check if fit is constant
family = "gaussian",
AddChildVEB = function(name, check = c("check", "no-warn", "no-check"), ...) { # add VEB node as child
child = VEBBoostNode$new(as.character(name), check, ...)
return(invisible(self$AddChildNode(child)))
},
updateMoments = function() { # after updating the fit, pass changes to moments up to parent node
if (!self$isLeaf) { # if not at a leaf, update moments
if (self$operator == "+") {
children_mu1 = sapply(self$children, function(x) x$mu1)
private$.mu1 = as.numeric(apply(children_mu1, MARGIN = 1, sum))
private$.mu2 = as.numeric(apply(sapply(self$children, function(x) x$mu2), MARGIN = 1, sum) + 2*apply(children_mu1, MARGIN = 1, prod))
} else {
private$.mu1 = as.numeric(apply(sapply(self$children, function(x) x$mu1), MARGIN = 1, prod))
private$.mu2 = as.numeric(apply(sapply(self$children, function(x) x$mu2), MARGIN = 1, prod))
}
}
return(invisible(self))
},
updateMomentsAll = function() { # after updating the fit, pass changes to moments up to internal nodes
if (!self$isLeaf) { # if not at a leaf, update moments
if (self$operator == "+") {
children_mu1 = sapply(self$children, function(x) x$mu1)
private$.mu1 = as.numeric(apply(children_mu1, MARGIN = 1, sum))
private$.mu2 = as.numeric(apply(sapply(self$children, function(x) x$mu2), MARGIN = 1, sum) + 2*apply(children_mu1, MARGIN = 1, prod))
} else {
private$.mu1 = as.numeric(apply(sapply(self$children, function(x) x$mu1), MARGIN = 1, prod))
private$.mu2 = as.numeric(apply(sapply(self$children, function(x) x$mu2), MARGIN = 1, prod))
}
}
if (self$isRoot) { # if at root, stop
return(invisible(self))
} else { # else, update parents moments
return(invisible(self$parent$updateMomentsAll()))
}
},
updateFit = function() { # function to update currentFit
if (self$isLeaf) {
self$currentFit = self$fitFunction(X = self$X, Y = self$Y, sigma2 = self$sigma2, init = self$currentFit)
}
if (!self$isRoot) {
self$parent$updateMoments()
}
return(invisible(self))
},
update_sigma2 = function() { # function to update sigma2
self$sigma2 = ((sum(self$root$Y^2) - 2*sum(self$root$Y*self$root$mu1) + sum(self$root$mu2))) / length(self$root$Y)
return(invisible(self))
},
convergeFit = function(tol = 1e-3, update_sigma2 = FALSE, update_ELBO_progress = TRUE, verbose = TRUE) {
ELBOs = numeric(1000)
ELBOs[1] = -Inf
ELBOs[2] = self$root$ELBO
i = 2
while (abs(ELBOs[i] - ELBOs[i-1]) > tol) {
self$root$Do(function(x) try({x$updateFit()}, silent = T), traversal = 'post-order')
if (update_sigma2) {
self$root$update_sigma2()
}
i = i+1
if (i > length(ELBOs)) { # double size of ELBOs for efficiency rather than making it bigger each iteration
ELBOs = c(ELBOs, rep(0, i))
}
ELBOs[i] = self$root$ELBO
if (verbose & ((i %% 100) == 0)) {
cat(paste("ELBO: ", ELBOs[i], sep = ""))
cat("\n")
}
}
ELBOs = ELBOs[2:i]
if (update_ELBO_progress) {
self$ELBO_progress = ELBOs
}
return(invisible(self$root))
},
AddSiblingVEB = function(learner, operator = c("+", "*"), combine_name) { # function to add subtree as sibling to given node, combining with operator
# learner is tree to add to self
# operator is how to combine them
# name is what to call the combining node
learner_copy = self$clone()
learner_copy$X = NULL
self$fitFunction = NULL
self$currentFit = NULL
self$predFunction = NULL
self$constCheckFunction = NULL
self$operator = operator
self$name = combine_name
self$AddChildNode(learner_copy)
self$AddChildNode(learner)
learner_copy$parent$updateMoments()
return(invisible(self$root)) # return root, so we can assign entire tree to new value (in case we're splitting at root)
},
lockSelf = function() { # function node calls on itself to check if it's locked, and lock if necessary
if (self$isConstant) {
self$isLocked = TRUE
self$fitFunction = private$.fitFnConstComp
self$predFunction = private$.predFnConstComp
self$constCheckFunction = private$.constCheckFnConstComp
self$updateFit()
try({self$updateMomentsAll()}, silent = T) # needed to avoid "attempt to apply non-function" error
}
return(invisible(self))
},
lockLearners = function() { # lock learners that should be locked
# base_learners = Traverse(self$root, traversal = 'post-order', filterFun = function(x) x$isLeaf & !x$isLocked)
# for (learner in base_learners) { # change any near-constant leaf to a constant and seal it off
# if (learner$isConstant) {
# learner$isLocked = TRUE
# learner$fitFunction = learner$.fitFnConstComp
# learner$predFunction = learner$.predFnConstComp
# learner$constCheckFunction = learner$.constCheckFnConstComp
# learner$updateFit()
# try({learner$updateMomentsAll()}, silent = T) # needed to avoid "attempt to apply non-function" error
# }
# }
self$root$Do(function(node) node$lockSelf(), filterFun = function(node) node$isLeaf & !node$isLocked, traversal ='post-order')
base_learners = Traverse(self$root, filterFun = function(x) x$isLeaf & !x$isLocked)
for (learner in base_learners) {
if (learner$isRoot) { # if root, not locked, do this to avoid errors in next if statement
next
}
if ((learner$siblings[[1]]$isLocked) && learner$parent$siblings[[1]]$isLocked) {
learner$isLocked = TRUE
}
}
return(invisible(self$root))
},
addLearnerAll = function() { # to each leaf, add a "+" and "*"
self$root$lockLearners() # lock learners
base_learners = Traverse(self$root, filterFun = function(x) x$isLeaf & !x$isLocked)
for (learner in base_learners) {
fitFn = learner$fitFunction
predFn = learner$predFunction
constCheckFn = learner$constCheckFunction
learner_name = paste("mu_", learner$root$leafCount, sep = '')
combine_name = paste("combine_", learner$root$leafCount, sep = '')
add_fit = list(mu1 = rep(0, length(learner$Y)), mu2 = rep(0, length(learner$Y)), KL_div = 0)
add_node = VEBBoostNode$new(learner_name, fitFunction = fitFn, predFunction = predFn, constCheckFunction = constCheckFn, currentFit = add_fit)
learner$AddSiblingVEB(add_node, "+", combine_name)
learner_name = paste("mu_", learner$root$leafCount, sep = '')
combine_name = paste("combine_", learner$root$leafCount, sep = '')
mult_fit = list(mu1 = rep(1, length(learner$Y)), mu2 = rep(1, length(learner$Y)), KL_div = 0)
mult_node = VEBBoostNode$new(learner_name, fitFunction = fitFn, predFunction = predFn, constCheckFunction = constCheckFn, currentFit = mult_fit)
learner$children[[1]]$AddSiblingVEB(mult_node, "*", combine_name)
}
return(invisible(self$root))
},
convergeFitAll = function(tol = 1e-3, update_sigma2 = FALSE, update_ELBO_progress = TRUE, verbose = FALSE) {
self$addLearnerAll()
self$convergeFit(tol, update_sigma2, update_ELBO_progress, verbose)
return(invisible(self$root))
},
predict.veb = function(X_new, moment = c(1, 2)) { # function to get prediction on new data
self$root$Do(function(node) {
if (1 %in% moment) {
node$pred_mu1 = node$predFunction(X_new, node$currentFit, 1)
}
if (2 %in% moment) {
node$pred_mu2 = node$predFunction(X_new, node$currentFit, 2)
}
}, filterFun = function(x) x$isLeaf)
return(invisible(self$root))
}
),
private = list(
.X = NULL, # predictors for node (if NA, use value of parent)
.Y = NA, # response, only not NA at root
.sigma2 = NA, # variance, only not NA at root
.mu1 = NA, # current first moment
.mu2 = NA, # current second moment
.ELBO_progress = list(-Inf), # list of ELBOs for each iteration of growing and fitting the tree
.pred_mu1 = NULL, # prediction based on predFunction and given new data (first moment)
.pred_mu2 = NULL, # prediction based on predFunction and given new data (second moment)
.isLocked = FALSE, # locked <=> V < V_tol, or both learners directly connected (sibling and parent's sibling) are constant
.alpha = 0, # used in multi-class learner
.ensemble = NULL, # used in multi-class learner, either reference to self, or multi-class learner object
.fitFnConstComp = function(X, Y, sigma2, init) { # constant fit function
if (length(sigma2) == 1) {
sigma2 = rep(sigma2, length(Y))
}
intercept = weighted.mean(Y, 1/sigma2)
KL_div = 0
mu1 = intercept
mu2 = intercept^2
return(list(mu1 = mu1, mu2 = mu2, intercept = intercept, KL_div = KL_div))
},
.predFnConstComp = function(X_new, currentFit, moment = c(1, 2)) { # constant prediction function
if (moment == 1) {
return(currentFit$intercept)
} else if (moment == 2) {
return(currentFit$intercept^2)
} else {
stop("`moment` must be either 1 or 2")
}
},
.constCheckFnConstComp = function(currentFit) {
return(TRUE)
},
.g = function(x) { # inverse logit function
1 / (1 + exp(-x))
}
),
active = list(
ensemble = function(value) {
if (missing(value)) {
if (self$isRoot) {
if (is.null(private$.ensemble)) {
return(invisible(self))
} else {
return(private$.ensemble)
}
} else {
return(invisible(self$root$ensemble))
}
}
if (!self$isRoot) {
stop("`$ensemble' cannot be modified except at the root", call. = FALSE)
} else {
private$.ensemble = value
}
},
isEnsemble = function(value) { # TRUE if is ensemble (root AND not part of higher ensemble), else FALSE
if (!missing(value)) {
stop("`$isEnsemble' cannot be modified directly", call. = FALSE)
}
return((self$isRoot) && (is.null(private$.ensemble)))
},
mu1 = function(value) {
if (!missing(value)) {
stop("`$mu1` cannot be modified directly", call. = FALSE)
}
if (self$isLeaf) {
mu1 = self$currentFit$mu1
} else {
mu1 = private$.mu1
}
if (length(mu1) == 1) {
mu1 = rep(mu1, length(self$root$raw_Y))
}
return(mu1)
},
mu2 = function(value) {
if (!missing(value)) {
stop("`$mu2` cannot be modified directly", call. = FALSE)
}
if (self$isLeaf) {
mu2 = self$currentFit$mu2
} else {
mu2 = private$.mu2
}
if (length(mu2) == 1) {
mu2 = rep(mu2, length(self$root$raw_Y))
}
return(mu2)
},
KL_div = function(value) { # KL divergence from q to g of learner defined by sub-tree with this node as the root
if (!missing(value)) {
stop("`$KL_div` cannot be modified directly", call. = FALSE)
}
if (self$isLeaf) {
return(self$currentFit$KL_div)
}
return(sum(sapply(self$children, function(x) x$KL_div)))
},
X = function(value) { # predictor for node
if (missing(value)) {
if (self$isRoot) {
if (self$isEnsemble) {
return(private$.X)
} else {
return(self$ensemble$X)
}
}
if (is.null(private$.X)) {
return(self$parent$X)
} else {
return(private$.X)
}
}
private$.X = value
},
Y = function(value) { # response for sub-tree
if (missing(value)) {
if (self$isRoot) {
if (self$root$family == "gaussian") {
return(private$.Y)
} else if (self$root$family == "binomial") {
d = self$d
return((private$.Y - .5 + (self$alpha * d)) / d) # alpha*d needed for multi-class case
} else {
stop("family must be either 'gaussian' or 'binomial")
}
}
if (self$parent$operator == "+") {
return(self$parent$Y - self$siblings[[1]]$mu1)
}
if (self$parent$operator == "*") {
return(self$parent$Y * (self$siblings[[1]]$mu1 / self$siblings[[1]]$mu2))
}
} else {
if (self$isRoot) {
private$.Y = value
} else {
stop("`$Y` cannot be modified directly except at the root node", call. = FALSE)
}
}
},
raw_Y = function(value) { # raw value of private$.Y, only needed for logistic ELBO
if (!missing(value)) {
stop("`$raw_Y` cannot be modified directly", call. = FALSE)
}
if (self$isRoot) {
return(private$.Y)
} else {
return(self$parent$raw_Y)
}
},
sigma2 = function(value) { # variance for sub-tree
if (missing(value)) {
if (self$isRoot) {
if (self$root$family == "gaussian") {
return(private$.sigma2)
} else if (self$root$family == "binomial") {
return(1 / self$d)
} else {
stop("family must be either 'gaussian' or 'binomial")
}
}
if (self$parent$operator == "+") {
s2 = self$parent$sigma2
if (length(s2) == 1) {
s2 = rep(s2, length(self$raw_Y))
}
return(s2)
}
if (self$parent$operator == "*") {
s2 = self$parent$sigma2 / self$siblings[[1]]$mu2
if (length(s2) == 1) {
s2 = rep(s2, length(self$raw_Y))
}
return(s2)
}
} else {
if (self$isRoot) {
private$.sigma2 = value
} else {
stop("`$sigma2` cannot be modified directly except at the root node", call. = FALSE)
}
}
},
ELBO_progress = function(value) { # ELBO progress of tree
if (missing(value)) {
if (self$isRoot) {
return(private$.ELBO_progress)
} else {
return(self$root$ELBO_progress)
}
} else {
if (self$isRoot) { # append ELBOs in list
private$.ELBO_progress[[length(private$.ELBO_progress) + 1]] = value
} else {
stop("`$ELBO_progress` cannot be modified directly except at the root node", call. = FALSE)
}
}
},
ELBO = function(value) { # ELBO for entire tree
if (!missing(value)) {
stop("`$ELBO` cannot be modified directly", call. = FALSE)
}
# make s2 vector of variances
#s2 = ifelse(length(self$sigma2) == 1, rep(self$sigma2, length(self$Y)), self$sigma2) # something weird w/ this if-else, not sure why
if (self$root$family == "gaussian") {
s2 = self$sigma2
if (length(s2) == 1) {
s2 = rep(s2, length(self$raw_Y))
}
return(
(-.5 * sum(log(2*pi*s2))) -
(.5 * (sum(self$Y^2 / s2) - 2*sum(self$Y * self$mu1 / s2) + sum(self$mu2 / s2))) -
self$KL_div
)
} else if (self$root$family == "binomial") {
d = self$root$d
xi = self$root$xi
return(
sum(log(private$.g(xi))) + sum((xi / 2)*(d*xi - 1)) + sum((self$root$raw_Y - .5) * self$root$mu1) -
.5*sum(self$root$mu2 * d) - self$KL_div
)
} else {
stop("family must be either 'gaussian' or 'binomial")
}
},
alpha = function(value) { # used in multi-class learner
if (!missing(value)) {
stop("`$alpha' cannot be modified directly except at the ensemble level", call. = FALSE)
}
if (self$isEnsemble) { # if isEnsemble, e.g. if we're in linear or logistic case, return root$.alpha (SHOULD BE 0 IN THIS CASE)
return(private$.alpha)
}
return(self$ensemble$alpha) # otherwise, return ensemble's alpha
},
xi = function(value) { # optimal variational parameters, set to +sqrt(mu2)
if (!missing(value)) {
stop("`$xi` cannot be modified directly", call. = FALSE)
}
return(sqrt(self$root$mu2 + self$alpha^2 - 2*self$alpha*self$root$mu1))
},
d = function(value) { # d == 1/xi * (g(xi) - .5), n x K matrix
if (!missing(value)) {
stop("'$d' cannot be modified directly", call. = FALSE)
}
xi = self$xi
g_xi = private$.g(xi) # matrix of g(xi_i,k), pre-compute once
d = ((g_xi - .5) / xi) # matrix of (g(xi_i,k) - .5) / xi_i,k, pre-compute once
d[xi == 0] = .25 # case of 0/0 (e.g. x_i is all 0), use L'Hopital
return(d)
},
isConstant = function(value) { # uses node's constant check function and current fit to see if it is essentially a constant function
if (!missing(value)) {
stop("'$isConstant' cannot be modified directly", call. = FALSE)
}
return(self$constCheckFunction(self$currentFit))
},
isLocked = function(value) { # locked <=> V < V_tol, or both learners directly connected (sibling and parent's sibling) are constant
if (missing(value)) {
if (self$isLeaf) {
return(private$.isLocked)
} else {
return(all(sapply(self$children, function(x) x$isLocked)))
}
} else {
if (self$isLeaf) {
private$.isLocked = value
} else {
stop("`$isLocked` cannot be modified directly except at leaf nodes", call. = FALSE)
}
}
},
pred_mu1 = function(value) { # predicted first moment given new data
if (!missing(value)) {
if (self$isLeaf) {
private$.pred_mu1 = value
return(invisible(self))
} else {
stop("`$pred_mu1` cannot be modified directly except at leaf nodes by calling `$predict.veb`", call. = FALSE)
}
}
if (self$isLeaf) {
return(private$.pred_mu1)
} else if (self$operator == "+") {
return(apply(do.call(cbind, lapply(self$children, function(x) x$pred_mu1)), MARGIN = 1, sum))
} else {
return(apply(do.call(cbind, lapply(self$children, function(x) x$pred_mu1)), MARGIN = 1, prod))
}
},
pred_mu2 = function(value) { # predicted second moment given new data
if (!missing(value)) {
if (self$isLeaf) {
private$.pred_mu2 = value
return(invisible(self))
} else {
stop("`$pred_mu2` cannot be modified directly except at leaf nodes by calling `$predict.veb`", call. = FALSE)
}
}
if (self$isLeaf) {
return(private$.pred_mu2)
} else if (self$operator == "+") {
return(apply(do.call(cbind, lapply(self$children, function(x) x$pred_mu2)), MARGIN = 1, sum) + 2*apply(do.call(cbind, lapply(self$children, function(x) x$pred_mu1)), MARGIN = 1, prod))
} else {
return(apply(do.call(cbind, lapply(self$children, function(x) x$pred_mu2)), MARGIN = 1, prod))
}
}
),
inherit = data.tree::Node,
lock_objects = FALSE
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.