# hytreew.R
# ::rtemis::
# 2018 E.D. Gennatas www.lambdamd.org
#' `rtemis internal`: Low-level Hybrid Tree procedure
#'
#' Train a Hybrid Tree for Regression
#'
#' Note that lambda is treated differently by `glmnet::glmnet` and `MASS::lm.ridge`
#'
#' @inheritParams s_LIHAD
#' @param x data.frame
#' @param lin.type Character: "glmnet", "cv.glmnet" use the equivalent `glmnet` functions.
#' "lm.ridge" uses the MASS function of that name, "glm" uses `lm.fit`,
#' "forward.stagewise" and "stepwise" use `lars::lars` with `type` defined accordingly
#'
#' @author E.D. Gennatas
#' @keywords internal
#' @noRd
hytreew <- function(x, y,
max.depth = 5,
weights = rep(1, NROW(y)),
init = mean(y),
gamma = .1,
shrinkage = 1,
# lincoef --
lin.type = "glmnet",
cv.glmnet.nfolds = 5,
which.cv.glmnet.lambda = "lambda.min",
alpha = 1,
lambda = .05,
lambda.seq = NULL,
# rpart --
minobsinnode = 2,
minobsinnode.lin = 10,
part.minsplit = 2,
part.xval = 0,
part.max.depth = 1,
part.cp = 0,
part.minbucket = 5,
#
save.fitted = FALSE,
verbose = TRUE,
trace = 0,
n.cores = rtCores) {
# [ Check y is not constant ] ===
if (is_constant(y)) {
coefs <- list(rep(0, NCOL(x) + 1))
names(coefs) <- c("(Intercept)", colnames(x))
.mod <- list(
init = init,
shrinkage = shrinkage,
rules = "TRUE",
coefs = coefs
)
class(.mod) <- c("hytreew", "list")
return(.mod)
}
# Globals ----
.env <- environment()
.env$x <- x
.env$y <- y
.env$df <- data.frame(x, y)
.env$dm <- data.matrix(x)
.env$gamma <- gamma
# lin1 ----
if (verbose) msg2("Training Hybrid Tree (max depth = ", max.depth, ")...", sep = "")
if (trace > 0) msg2("Training lin1...", color = red)
coef.c <- lincoef(x, y,
method = lin.type,
alpha = alpha, lambda = lambda, lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda
)
Fval <- init + shrinkage * (data.matrix(cbind(1, x)) %*% coef.c)[, 1] # n
if (trace > 0) msg2("hytreew Fval is", head(Fval), color = red)
# Run hytw ----
root <- list(
x = x,
y = y,
Fval = Fval,
weights = weights,
index = rep(1, length(y)),
depth = 0,
partlin = NULL, # To hold the output of partLmw
left = NULL, # \ To hold the left and right nodes,
right = NULL, # / if partLmw splits
lin = NULL,
part = NULL,
coef.c = coef.c,
terminal = FALSE,
type = NULL,
rule = "TRUE"
)
mod <- hytw(
node = root,
max.depth = max.depth,
minobsinnode = minobsinnode,
minobsinnode.lin = minobsinnode.lin,
shrinkage = shrinkage,
alpha = alpha,
lambda = lambda,
lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda,
coef.c = coef.c,
part.minsplit = part.minsplit,
part.xval = part.xval,
part.max.depth = part.max.depth,
part.cp = part.cp,
part.minbucket = part.minbucket,
.env = .env,
keep.x = FALSE,
simplify = TRUE,
lin.type = lin.type,
verbose = verbose,
trace = trace
)
# Outro ----
.mod <- list(
init = init,
shrinkage = shrinkage,
rules = .env$leaf.rule,
coefs = .env$leaf.coef
)
class(.mod) <- c("hytreew", "list")
if (save.fitted) .mod$fitted <- predict(.mod, x)
.mod
} # rtemis::hytree
# Recursive function
hytw <- function(node = list(
x = NULL,
y = NULL,
Fval = NULL,
weights = NULL,
index = NULL,
depth = NULL,
partlin = NULL, # To hold the output of partLmw
left = NULL, # \ To hold the left and right nodes,
right = NULL, # / if partLmw splits
lin = NULL,
part = NULL,
coef.c = NULL,
terminal = NULL,
type = NULL,
rule = NULL
),
coef.c = 0,
max.depth = 5,
minobsinnode = 2,
minobsinnode.lin = 5,
shrinkage = 1,
# lincoef --
lin.type = "glmnet",
alpha = 1,
lambda = .1,
lambda.seq = NULL,
cv.glmnet.nfolds = 5,
which.cv.glmnet.lambda = "lambda.min",
# rpart --
part.minsplit = 2,
part.xval = 0,
part.max.depth = 1,
part.minbucket = 5,
part.cp = 0,
.env = NULL,
keep.x = FALSE,
simplify = TRUE,
verbose = TRUE,
trace = 0) {
# Exit ----
if (node$terminal) {
return(node)
}
# x <- node$x
# y <- node$y
x <- .env$x
y <- .env$y
depth <- node$depth
Fval <- node$Fval # n
resid <- y - Fval # n
if (trace > 0) msg2("hytw Fval is", head(Fval), color = red)
if (trace > 0) msg2("hytw resid is", head(resid), color = red)
nobsinnode <- length(node$index)
# Add partlin to node ----
if (node$depth < max.depth && nobsinnode >= minobsinnode) {
# if (trace > 0) msg2("y1 (resid) is", resid, color = red)
node$partlin <- partLmw(
x1 = x, y1 = resid, # remove x
weights = node$weights,
.env = .env,
minobsinnode.lin = minobsinnode.lin,
# lincoef --
lin.type = lin.type,
alpha = alpha,
lambda = lambda,
lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda,
# rpart --
part.minsplit = part.minsplit,
part.xval = part.xval,
part.max.depth = part.max.depth,
part.cp = part.cp,
part.minbucket = part.minbucket,
verbose = verbose,
trace = trace
)
# Fval <- Fval + shrinkage * (node$partlin$part.val + node$partlin$lin.val)
# resid <- y - Fval
# if (trace > 1) msg2("Fval is", Fval)
# '- If node split ----
if (!node$partlin$terminal) {
node$type <- "split"
left.index <- node$partlin$left.index
right.index <- node$partlin$right.index
if (trace > 1) msg2("Depth:", depth, "left.index:", node$partlin$left.index)
# x.left <- x[left.index, , drop = FALSE]
# x.right <- x[right.index, , drop = FALSE]
# weights.left <- weights.right <- node$weights
# weights.left[right.index] <- weights.left[right.index] * gamma
# weights.right[left.index] <- weights.right[left.index] * gamma
# y.left <- y[left.index]
# y.right <- y[right.index]
# if (trace > 1) msg2("y.left is", y.left)
# if (trace > 1) msg2("y.right is", y.right)
Fval.left <- Fval + shrinkage * (node$partlin$part.val + node$partlin$lin.val.left)
Fval.right <- Fval + shrinkage * (node$partlin$part.val + node$partlin$lin.val.right)
# resid.left <- y - Fval.left
# resid.right <- y - Fval.right
coef.c.left <- coef.c.right <- coef.c
# Cumulative sum of coef.c
coef.c.left <- coef.c.left + c(
node$partlin$lin.coef.left[1] + node$partlin$part.c.left,
node$partlin$lin.coef.left[-1]
)
coef.c.right <- coef.c.right + c(
node$partlin$lin.coef.right[1] + node$partlin$part.c.right,
node$partlin$lin.coef.right[-1]
)
if (trace > 1) msg2("coef.c.left is", coef.c.left, "coef.c.right is", coef.c.right)
# coef.c.right[[paste0("depth", depth + 1)]] <- list(coef = node$partlin$lin.coef,
# c = node$partlin$part.c.right)
if (!is.null(node$partlin$cutFeat.point)) {
rule.left <- node$partlin$split.rule
rule.right <- gsub("<", ">=", node$partlin$split.rule)
} else {
rule.left <- node$partlin$split.rule
rule.right <- paste0("!", rule.left) # fix: get cutFeat.name levels and find complement
}
# Init Left and Right nodes
node$left <- list(
x = .env$x,
y = .env$y,
weights = node$partlin$weights.left,
Fval = Fval.left,
index = left.index,
depth = depth + 1,
coef.c = coef.c.left,
partlin = NULL, # To hold the output of partLmw
left = NULL, # \ To hold the left and right nodes,
right = NULL, # / if partLmw splits
terminal = FALSE,
type = NULL,
rule = paste0(node$rule, " & ", node$partlin$rule.left)
)
node$right <- list(
x = .env$x,
y = .env$y,
weights = node$partlin$weights.right,
Fval = Fval.right,
index = right.index,
depth = depth + 1,
coef.c = coef.c.right,
partlin = NULL, # To hold the output of partLmw
left = NULL, # \ To hold the right and right nodes,
right = NULL, # / if partLmw splits
terminal = FALSE,
type = NULL,
rule = paste0(node$rule, " & ", node$partlin$rule.right)
)
# if (!keep.x) node$x <- NULL
node$split.rule <- node$partlin$split.rule
if (simplify) {
node$y <- node$Fval <- node$index <- node$depth <- node$lin <- node$part <- node$type <- node$partlin <- NULL
}
# Run Left and Right nodes
# Left ----
if (trace > 0) msg2("Depth = ", depth + 1, "; Working on Left node...", sep = "")
node$left <- hytw(node$left,
coef.c = coef.c.left,
max.depth = max.depth,
minobsinnode = minobsinnode,
minobsinnode.lin = minobsinnode.lin,
shrinkage = shrinkage,
# lincoef --
lin.type = lin.type,
alpha = alpha,
lambda = lambda,
lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda,
# rpart --
part.minsplit = part.minsplit,
part.xval = part.xval,
part.max.depth = part.max.depth,
part.cp = part.cp,
part.minbucket = part.minbucket,
.env = .env,
keep.x = keep.x,
simplify = simplify,
verbose = verbose,
trace = trace
)
# Right ----
if (trace > 0) msg2("Depth = ", depth + 1, "; Working on Right node...", sep = "")
node$right <- hytw(node$right,
coef.c = coef.c.right,
max.depth = max.depth,
minobsinnode = minobsinnode,
minobsinnode.lin = minobsinnode.lin,
shrinkage = shrinkage,
# lincoef --
lin.type = lin.type,
alpha = alpha,
lambda = lambda,
lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda,
# rpart --
part.minsplit = part.minsplit,
part.xval = part.xval,
part.max.depth = part.max.depth,
part.cp = part.cp,
part.minbucket = part.minbucket,
.env = .env,
keep.x = keep.x,
simplify = simplify,
verbose = verbose,
trace = trace
)
if (simplify) node$coef.c <- NULL
} else {
# partLmw did not split
node$terminal <- TRUE
.env$leaf.rule <- c(.env$leaf.rule, node$rule)
.env$leaf.coef <- c(.env$leaf.coef, list(node$coef.c))
node$type <- "nosplit"
if (trace > 0) msg2("STOP: nosplit")
if (simplify) node$x <- node$y <- node$Fval <- node$index <- node$depth <- node$type <- node$partlin <- NULL
} # !node$terminal
} else {
# max.depth or minobsinnode reached
node$terminal <- TRUE
.env$leaf.rule <- c(.env$leaf.rule, node$rule)
.env$leaf.coef <- c(.env$leaf.coef, list(node$coef.c))
if (node$depth == max.depth) {
if (trace > 0) msg2("STOP: max.depth")
node$type <- "max.depth"
} else if (nobsinnode < minobsinnode) {
if (trace > 0) msg2("STOP: minobsinnode")
node$type <- "minobsinnode"
}
if (simplify) node$x <- node$y <- node$Fval <- node$index <- node$depth <- node$type <- node$partlin <- NULL
return(node)
} # max.depth, minobsinnode
node
} # rtemis::hytw
#' \pkg{rtemis} internal: Ridge and Stump
#'
#' Fit a linear model on (x, y) and a tree on the residual yhat - y
#'
#' @keywords internal
#' @noRd
partLmw <- function(x1, y1, # remove x, use .env$x
weights,
.env,
minobsinnode.lin,
# lincoef --
lin.type,
alpha,
lambda,
lambda.seq,
cv.glmnet.nfolds,
which.cv.glmnet.lambda,
# rpart --
part.minsplit,
part.xval,
part.max.depth,
part.cp,
part.minbucket,
verbose,
trace) {
# Part ----
if (trace > 1) msg2("partLmw")
dat <- data.frame(x1, y1)
part <- rpart::rpart(y1 ~ ., dat,
weights = weights,
control = rpart::rpart.control(
minsplit = part.minsplit,
xval = part.xval,
maxdepth = part.max.depth,
minbucket = part.minbucket,
cp = part.cp
)
)
part.val <- predict(part) # n
if (is.null(part$splits)) {
if (trace > 0) msg2("Note: rpart did not split")
terminal <- TRUE
cutFeat.name <- cutFeat.point <- cutFeat.category <- NULL
split.rule <- NULL
part.c.left <- part.c.right <- 0
left.index <- right.index <- split.rule.left <- split.rule.right <- NULL
lin.val.left <- lin.val.right <- 0
lin.coef.left <- lin.coef.right <- rep(0, NCOL(x1) + 1)
} else {
if (part$splits[1, 2] == 1) {
left.yval.row <- 3
right.yval.row <- 2
} else {
left.yval.row <- 2
right.yval.row <- 3
}
part.c.left <- part$frame$yval[left.yval.row]
part.c.right <- part$frame$yval[right.yval.row]
terminal <- FALSE
cutFeat.name <- rownames(part$splits)[1]
cutFeat.point <- cutFeat.category <- NULL
if (!is.null(cutFeat.name)) {
cutFeat.index <- which(names(x1) == cutFeat.name)
if (is.numeric(x1[[cutFeat.name]])) {
cutFeat.point <- part$splits[1, "index"]
if (trace > 0) {
msg2("Split Feature is \"", cutFeat.name,
"\"; Cut Point = ", cutFeat.point,
sep = ""
)
}
split.rule.left <- paste(cutFeat.name, "<", cutFeat.point)
split.rule.right <- paste(cutFeat.name, ">=", cutFeat.point)
} else {
cutFeat.category <- levels(x1[[cutFeat.name]])[which(part$csplit[1, ] == 1)]
if (trace > 0) {
msg2("Split Feature is \"", cutFeat.name,
"\"; Cut Category is \"", cutFeat.category,
"\"",
sep = ""
)
}
split.rule.left <- paste0(cutFeat.name, " %in% ", "c(", paste(cutFeat.category, collapse = ", "))
split.rule.right <- paste0("!", cutFeat.name, " %in% ", "c(", paste(cutFeat.category, collapse = ", "))
}
if (length(cutFeat.point) > 0) {
left.index <- which(x1[, cutFeat.index] < cutFeat.point)
right.index <- seq_len(NROW(x1))[-left.index]
} else {
left.index <- which(is.element(x1[, cutFeat.index], cutFeat.category))
right.index <- seq_len(NROW(x1))[-left.index]
}
# ? should weights be assigned after this so are available even if there was no split
weights.left <- weights.right <- weights
weights.left[right.index] <- weights.left[right.index] * .env$gamma
weights.right[left.index] <- weights.right[left.index] * .env$gamma
} # !is.null(cutFeat.name)
}
# Lin ----
resid <- y1 - part.val # n
# resid.left <- resid[left.index]
# resid.right <- resid[right.index]
if (!is.null(cutFeat.name)) {
if (is_constant(resid) || length(resid) < minobsinnode.lin) {
if (trace > 0) msg2("Not fitting any more lines here")
lin.val.left <- rep(0, length(y1))
lin.coef.left <- rep(0, NCOL(x1) + 1)
} else {
# dat <- data.frame(x1[left.index, , drop = FALSE], resid.left)
# dat.mm <- model.matrix(resid.left ~ ., dat)
lin.coef.left <- lincoef(x1, resid,
weights = weights.left,
method = lin.type,
alpha = alpha, lambda = lambda, lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda
)
lin.val.left <- (cbind(1, .env$dm) %*% lin.coef.left)[, 1]
} # if (is_constant(resid.left))
if (is_constant(resid) || length(resid) < minobsinnode.lin) {
if (trace > 0) msg2("Not fitting any more lines here")
lin.val.right <- rep(0, length(y1))
lin.coef.right <- rep(0, NCOL(x1) + 1)
} else {
lin.coef.right <- lincoef(x1, resid,
weights = weights.right,
method = lin.type,
alpha = alpha, lambda = lambda, lambda.seq = lambda.seq,
cv.glmnet.nfolds = cv.glmnet.nfolds,
which.cv.glmnet.lambda = which.cv.glmnet.lambda
)
lin.val.right <- (cbind(1, .env$dm) %*% lin.coef.right)[, 1]
} # if (is_constant(resid.right))
} # if (!is.null(cutFeat.name))
list(
weights.left = weights.left,
weights.right = weights.right,
lin.coef.left = lin.coef.left,
lin.coef.right = lin.coef.right,
part.c.left = part.c.left,
part.c.right = part.c.right,
lin.val.left = lin.val.left,
lin.val.right = lin.val.right,
part.val = part.val,
cutFeat.name = cutFeat.name,
cutFeat.point = cutFeat.point,
cutFeat.category = cutFeat.category,
left.index = left.index,
right.index = right.index,
split.rule = split.rule.left,
rule.left = split.rule.left,
rule.right = split.rule.right,
terminal = terminal
)
} # rtemis::partLmw
#' Predict method for `hytreew` object
#'
#' @method predict hytreew
#' @param object `hytreew`
#' @param newdata Data frame of predictors
#' @param n.feat (internal use) Integer: Use first `n.feat` columns of newdata to predict.
#' Defaults to all
#' @param fixed.cxr (internal use) Matrix: Cases by rules to use instead of matching cases to rules using
#' `newdata`
#' @param cxr.newdata (internal use) Data frame: Use these values to match cases by rules
#' @param cxr Logical: If TRUE, return list which includes cases-by-rules matrix along with predicted values
#' @param cxrcoef Logical: If TRUE, return cases-by-rules * coefficients matrix along with predicted values
#' @param verbose Logical: If TRUE, print messages to console
#' @param ... Not used
#'
#' @export
#' @author E.D. Gennatas
predict.hytreew <- function(object, newdata,
n.feat = NCOL(newdata),
fixed.cxr = NULL,
cxr.newdata = NULL,
cxr = FALSE,
cxrcoef = FALSE,
verbose = FALSE, ...) {
# newdata colnames ----
if (is.null(colnames(newdata))) colnames(newdata) <- paste0("V", seq_len(NCOL(newdata)))
# Predict ----
newdata <- newdata[, seq(n.feat), drop = FALSE]
rules <- plyr::ldply(object$rules)[, 1]
if (is.null(fixed.cxr)) {
cases <- if (is.null(cxr.newdata)) newdata else cxr.newdata
.cxr <- matchCasesByRules(cases, rules, verbose = verbose)
} else {
.cxr <- fixed.cxr
}
coefs <- plyr::laply(object$coefs, c)
# Match coefficients to each case by matrix multiplication
# Each case only has "1" in .cxr for the corresponding leaf it belongs
.cxrcoef <- .cxr %*% coefs
# Add column of ones for intercept
newdata <- data.matrix(cbind(1, newdata))
yhat <- sapply(seq_len(NROW(newdata)), function(n) {
object$init + object$shrinkage * (newdata[n, ] %*% t(.cxrcoef[n, , drop = FALSE]))
})
if (!cxrcoef && !cxr) {
out <- yhat
} else {
out <- list(yhat = yhat)
if (cxrcoef) out$cxrcoef <- .cxrcoef
if (cxr) out$cxr <- .cxr
}
out
} # rtemis:: predict.hytreew
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.