Nothing
#' Construction of a Decision Tree for Longitudinal Data
#'
#' @description
#' Constructs a single decision tree for longitudinal data.
#' The method evaluates both the main effect of a covariate and its
#' interaction with time, incorporating a weighting mechanism to balance
#' the two effects. Three single-tree construction procedures (ST1, ST2,
#' ST3) are available; see Details. For the underlying methodology, refer
#' to Obata and Sugimoto (2026).
#'
#' @param formula A formula specifying the model.
#' The response variable should be on the left side and covariates on the
#' right side. Use \code{response ~ .} to include all covariates except the
#' time variable and the random effect, or select specific covariates such as
#' \code{response ~ x1 + x2}. Time-invariant (baseline) covariates are
#' assumed.
#' @param time Character string giving the column name of the time variable.
#' All individuals are assumed to be observed at the same time points.
#' @param random Character string giving the column name of the random effect
#' (subject identifier).
#' @param weight Weight for balancing the main effect of a covariate and
#' its interaction with time. A value in
#' \eqn{\{0.0, 0.1, \ldots, 1.0\}}: \code{1.0} evaluates only the
#' mean difference in the response variable between the two groups and
#' \code{0.0} evaluates only the difference in change over time of the
#' response variable between the two groups.
#' Set \code{weight = "w"} (the default) to select the optimal weight
#' from the same grid at each node.
#' @param data A data frame containing the variables in \code{formula} together
#' with the time and random-effect variables.
#' @param alpha Significance level used as the stopping rule for tree
#' growth. A smaller value produces a more conservative (smaller) tree.
#' Specify a numeric value or \code{"no"} (default) if not used.
#' Corresponds to ST2.
#' @param gamma Complexity parameter for pruning. A larger value prunes
#' more aggressively, yielding a smaller and simpler tree; a smaller
#' value retains more branches. Specify a numeric value or \code{"no"}
#' (default) if not used. Corresponds to ST3.
#' @param cv Set \code{"yes"} to construct the decision tree using
#' cross-validation, or \code{"no"} (default) otherwise.
#' Corresponds to ST1.
#' @param maxdepth Maximum depth of the tree (default 5).
#' @param minbucket Minimum number of subjects in a terminal node (default 5).
#' @param minsplit Minimum number of subjects required to attempt a split
#' (default 20).
#' @param xval Number of cross-validation folds (default 10). Used to
#' compute the cross-validated coefficient of determination
#' (\eqn{R^2_{\mathrm{CV}}}); when \code{cv = "yes"}, also used for
#' final tree selection.
#'
#' @details
#' Exactly one of \code{alpha}, \code{gamma}, or \code{cv} must be specified.
#' Specifying more than one will result in an error. These correspond to the
#' three single-tree construction procedures:
#' \describe{
#' \item{ST1 (\code{cv = "yes"})}{Tree growth, pruning, and final tree
#' selection via cross-validation.}
#' \item{ST2 (\code{alpha})}{Tree growth with a significance threshold.
#' No pruning or final tree selection via cross-validation.}
#' \item{ST3 (\code{gamma})}{Tree growth followed by pruning with a
#' pre-specified complexity parameter. No final tree selection via
#' cross-validation.}
#' }
#'
#' Since the time variable is not used as a splitting variable, each terminal
#' node (leaf) contains the full longitudinal responses for every subject
#' assigned to it, allowing direct evaluation of longitudinal trajectories
#' within each leaf.
#'
#' @return An object of class \code{"longitree"}. Use
#' \code{\link{summary.longitree}}, \code{\link{predict.longitree}},
#' or \code{\link{plot.longitree}} to inspect the results.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{treeplot}}, \code{\link{longitrees}}
#'
#' @examples
#' data(ltreedata)
#' # ST1: tree construction via cross-validation
#' result_st1 <- longitree(y ~ ., time = "time", random = "subject",
#' weight = 0.7, data = ltreedata, cv = "yes")
#' summary(result_st1)
#' predict(result_st1)
#' plot(result_st1)
#'
#' # ST2: tree growth with a significance threshold
#' result_st2 <- longitree(y ~ ., time = "time", random = "subject",
#' weight = 0.1, data = ltreedata, alpha = 0.05)
#' summary(result_st2)
#' predict(result_st2)
#' plot(result_st2)
#'
#' # ST3: pruning with a complexity parameter
#' result_st3 <- longitree(y ~ ., time = "time", random = "subject",
#' weight = "w", data = ltreedata, gamma = 3)
#' summary(result_st3)
#' predict(result_st3)
#' plot(result_st3)
#'
#' @export
longitree <- function(formula, time, random, weight = "w", data,
alpha = "no", gamma = "no", cv = "no",
maxdepth = 5, minbucket = 5, minsplit = 20,
xval = 10) {
if (alpha == "no" && gamma == "no" && cv == "no") {
stop("At least one of the values for alpha, gamma, or cv must be specified")
}
specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
if (specified_count > 1) {
stop("Only one of alpha, gamma, or cv can be specified at a time")
}
weight_orig <- weight
alpha_orig <- alpha
gamma_orig <- gamma
cv_orig <- cv
data_name <- deparse(substitute(data))
fortran_seed <- sample.int(.Machine$integer.max, 1)
terms_obj <- terms.formula(formula, data = data)
response_var <- attr(terms_obj, "response")
response_name <- attr(terms_obj, "variables")[[response_var + 1]]
formula_str <- deparse(formula)
if (grepl("\\.", formula_str)) {
predictor_names <- setdiff(names(data), c(response_name, time, random))
} else {
predictor_names <- attr(terms_obj, "term.labels")
}
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
datatype <- c()
for (col in 4:ncol(d)) {
if (is.numeric(d[, col])) {
datatype[col - 3] <- 1
} else if (is.factor(d[, col])) {
datatype[col - 3] <- 2
} else {
stop("Covariate data types must be either numeric or factor")
}
}
for (col in 1:ncol(d)) {
if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
}
d[, 3] <- as.double(as.factor(d[, 3]))
d[, 2] <- as.double(as.factor(d[, 2]))
if ((weight < 0 || weight > 1) && weight != "w") {
stop("The weight value is invalid")
}
if (weight == "w") weight <- -1
if (alpha == "no") alpha <- 1
nodenummat <- matrix(0L, nrow(d), maxdepth)
Rsplitmat <- matrix(0, 2^maxdepth - 1, 10)
allfval <- rep(0, 2^maxdepth - 1)
prunind <- 0L
d <- as.matrix(d)
beta1len <- length(unique(d[, 2])) - 2
beta2len <- (length(unique(d[, 2])) - 2) * (length(unique(d[, 3])) - 1)
beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
beta2 <- vapply(seq_len(beta2len),
function(i) beta((length(unique(d[, 3])) - 1)/2, i/2),
numeric(1))
timecount <- length(unique(d[, 3]))
res1 <- .Fortran("treegrowth",
as.double(weight), as.integer(maxdepth), as.integer(minbucket),
as.integer(minsplit), as.double(alpha), as.integer(nrow(d)),
as.integer(ncol(d)), as.integer(timecount), as.double(d),
as.integer(datatype), as.double(beta1), as.double(beta2),
as.integer(beta1len), as.integer(beta2len), as.integer(nodenummat),
as.double(allfval), as.integer(prunind), as.double(Rsplitmat),
PACKAGE = "longitree")
if (gamma == "no") gamma <- -1
if (cv == "yes") cv <- 1
if (cv == "no") cv <- -1
allgammaval <- rep(0, (2^(maxdepth - 1)) - 1)
prunind <- res1[[17]]
nodenummat <- matrix(res1[[15]], nrow(d), maxdepth)
allfval <- res1[[16]]
res2 <- .Fortran("cvtreepruning",
as.integer(prunind), as.integer(nrow(d)), as.double(alpha),
as.double(gamma), as.integer(cv), as.integer(maxdepth),
as.integer(nodenummat), as.double(allfval), as.double(allgammaval),
PACKAGE = "longitree")
bestgammaval <- 0
r2cvval <- 0
allgammaval <- res2[[9]]
res3 <- .Fortran("crossvalidation",
as.double(gamma), as.double(alpha), as.integer(cv), as.double(d),
as.integer(nrow(d)), as.integer(ncol(d)), as.integer(xval),
as.integer(timecount), as.integer(maxdepth), as.double(beta1),
as.double(beta2), as.integer(beta1len), as.integer(beta2len),
as.integer(minbucket), as.integer(minsplit), as.double(weight),
as.integer(datatype), as.double(allgammaval), as.double(bestgammaval),
as.double(r2cvval), as.integer(fortran_seed),
PACKAGE = "longitree")
r2cvval <- res3[[20]]
bestgammaval <- res3[[19]]
prunenodenummat <- matrix(0L, nrow(d), maxdepth)
res4 <- .Fortran("gammatreepruning",
as.integer(prunind), as.integer(nrow(d)), as.double(bestgammaval),
as.integer(cv), as.integer(maxdepth), as.integer(nodenummat),
as.double(allfval), as.integer(prunenodenummat),
PACKAGE = "longitree")
prunenodenummat <- matrix(res4[[8]], nrow(d), maxdepth)
if (sum(prunenodenummat == 0) == nrow(d) * (maxdepth - 1)) {
warning("Warning: No splitting", call. = FALSE)
return(list(rsplitmat = "No Splitting", rpredict = "No Splitting",
nmat = "No Splitting", r2cv = r2cvval,
cvgamma = bestgammaval))
}
predvec <- rep(0, nrow(d))
datatnnodenum <- rep(0L, nrow(d))
res5 <- .Fortran("treepredict",
as.integer(nrow(d)), as.integer(timecount), as.integer(maxdepth),
as.integer(ncol(d)), as.integer(prunenodenummat), as.double(d),
as.double(predvec), as.integer(datatnnodenum),
PACKAGE = "longitree")
rsplit <- matrix(res1[[18]], 2^maxdepth - 1, 10)
rsplitmat <- rsplit[rsplit[, 1] != 0, , drop = FALSE]
rpredict <- data.frame(predict = res5[[7]], terminalnode = res5[[8]])
cvgamma_out <- if (cv_orig == "yes") bestgammaval else NULL
result <- list(
rsplitmat = rsplitmat, rpredict = rpredict, nmat = nodenummat,
r2cv = r2cvval, cvgamma = cvgamma_out,
.meta = list(formula = formula, time = time, random = random,
data = data, data_name = data_name, weight = weight_orig,
alpha = alpha_orig, gamma = gamma_orig, cv = cv_orig,
maxdepth = maxdepth, datatype = datatype))
class(result) <- "longitree"
result
}
#' @describeIn longitree Print a brief summary of a \code{longitree}
#' object.
#' @param object A \code{longitree} object.
#' @export
summary.longitree <- function(object, ...) {
meta <- object$.meta
weight_val <- if (meta$weight == "w") '"w"' else as.character(meta$weight)
data_str <- if (!is.null(meta$data_name)) meta$data_name else "data"
method_str <- if (meta$alpha != "no") paste0("alpha = ", meta$alpha)
else if (meta$gamma != "no") paste0("gamma = ", meta$gamma)
else 'cv = "yes"'
cat("longitree:\n")
cat(paste0("longitree(", deparse(meta$formula), ", time = \"",
meta$time, "\", random = \"", meta$random, "\",\n",
" weight = ", weight_val, ", data = ", data_str,
", ", method_str, ")\n"))
if (meta$alpha != "no") {
cat("\nMethod: alpha =", meta$alpha, "\n")
} else if (meta$gamma != "no") {
cat("\nMethod: gamma =", meta$gamma, "\n")
} else {
cat("\nMethod: cv\n")
}
weight_str <- if (meta$weight == "w") "w (optimal weight)" else as.character(meta$weight)
cat("Weight:", weight_str, "\n")
cat("\nCross-validated coefficient of determination:", object$r2cv, "\n")
if (!is.null(object$cvgamma)) {
cat("Complexity parameter:", object$cvgamma, "\n")
}
invisible(object)
}
#' @describeIn longitree Print method (calls \code{summary}).
#' @param x A \code{longitree} object.
#' @export
print.longitree <- function(x, ...) {
summary.longitree(x, ...)
}
#' @describeIn longitree Extract predicted values and terminal node
#' assignments from a \code{longitree} object. Returns a data frame
#' with columns \code{predict} (predicted values) and
#' \code{terminalnode} (terminal node assignments).
#' @param object A \code{longitree} object.
#' @export
predict.longitree <- function(object, ...) {
object$rpredict
}
#' @describeIn longitree Plot a \code{longitree} object.
#' A convenience wrapper around \code{\link{treeplot}}.
#' @param x A \code{longitree} object.
#' @param ... Additional arguments passed to \code{\link{treeplot}}.
#' @export
plot.longitree <- function(x, ...) {
treeplot(x, ...)
}
# ---- internal helper: build a ggparty tree plot ----------------------------
.build_tree_plot <- function(formula, time, random, data, rsplitmat, rpredict,
nmat, weight, maxdepth, datatype,
use_threetrees_fortran,
snsize, spsize, plotsize,
linesize1, linesize2, tnsize) {
terms_obj <- terms.formula(formula, data = data)
response_var <- attr(terms_obj, "response")
response_name <- attr(terms_obj, "variables")[[response_var + 1]]
formula_str <- deparse(formula)
if (grepl("\\.", formula_str)) {
predictor_names <- setdiff(names(data), c(response_name, time, random))
} else {
predictor_names <- attr(terms_obj, "term.labels")
}
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
if (is.null(datatype)) {
datatype <- c()
for (col in 4:ncol(d)) {
if (is.numeric(d[, col])) {
datatype[col - 3] <- 1
} else if (is.factor(d[, col])) {
datatype[col - 3] <- 2
} else {
stop("Covariate data types must be either numeric or factor")
}
}
}
for (col in 1:ncol(d)) {
if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
}
d[, 3] <- as.double(as.factor(d[, 3]))
d[, 2] <- as.double(as.factor(d[, 2]))
for (col in 4:ncol(d)) {
if (datatype[col - 3] == 2) d[, col] <- d[, col] - 1
}
d <- as.matrix(d)
if (!is.matrix(rsplitmat)) rsplitmat <- t(as.matrix(rsplitmat))
k <- 0; spch1 <- c(); spch2 <- c(); namevec <- c()
for (i in unique(rsplitmat[, 1])) {
k <- k + 1
hie <- 1
for (j in 1:maxdepth) {
if (2^(j - 1) <= i && i < 2^j) { hie <- j; break }
}
ms <- sum(nmat == i)
nd <- matrix(0, ms, ncol(d))
if (use_threetrees_fortran) {
resp1 <- .Fortran("threetreesnodedata",
as.double(d), as.integer(nrow(d)), as.integer(ncol(d)),
as.integer(i), as.integer(hie), as.integer(nmat),
as.integer(ms), as.double(nd), PACKAGE = "longitree")
nd <- matrix(resp1[[8]], ms, ncol(d))
} else {
resp1 <- .Fortran("nodedata",
as.double(d), as.integer(nrow(d)), as.integer(ncol(d)),
as.integer(i), as.integer(hie), as.integer(maxdepth),
as.integer(nmat), as.integer(ms), as.double(nd),
PACKAGE = "longitree")
nd <- matrix(resp1[[9]], ms, ncol(d))
}
ndsize <- length(unique(nd[, rsplitmat[k, 4]]))
splitnumber <- rep(0, ndsize)
splitvector <- rep(0L, ndsize)
splitvec_args <- list(
as.double(nd), as.integer(ms), as.integer(ncol(d)),
as.integer(nrow(d)), as.integer(maxdepth), as.integer(datatype),
as.integer(nmat), as.integer(rsplitmat[k, 4]),
as.integer(rsplitmat[k, 5]), as.integer(ndsize),
as.double(splitnumber), as.integer(splitvector),
PACKAGE = "longitree")
if (use_threetrees_fortran) {
resp2 <- do.call(.Fortran, c("threetreessplitvec", splitvec_args))
} else {
resp2 <- do.call(.Fortran, c("splitvec", splitvec_args))
}
splitnumber <- resp2[[11]]
splitvector <- resp2[[12]]
if (datatype[rsplitmat[k, 4] - 3] == 1) {
for (l in 1:(ndsize - 1)) {
if (abs(splitvector[l + 1] - splitvector[l]) > 0) {
spch1[k] <- paste0("<=", round(splitnumber[l], 3))
spch2[k] <- paste0(">", round(splitnumber[l], 3))
break
}
}
} else if (datatype[rsplitmat[k, 4] - 3] == 2) {
spch1[k] <- "="
spch2[k] <- "="
for (l in 1:ndsize) {
if (splitvector[l] == 0) {
spch1[k] <- if (spch1[k] == "=") paste0("=", splitnumber[l])
else paste(spch1[k], splitnumber[l], sep = ",")
} else if (splitvector[l] == 1) {
spch2[k] <- if (spch2[k] == "=") paste0("=", splitnumber[l])
else paste(spch2[k], splitnumber[l], sep = ",")
}
}
}
namevec[k] <- colnames(d)[rsplitmat[k, 4]]
}
rsmat <- data.frame(
nodenumber = as.numeric(rsplitmat[, 1]),
F_Gweight = as.numeric(rsplitmat[, 2]),
F_GTweight = as.numeric(rsplitmat[, 3]),
name = as.character(namevec),
leftsplit = as.character(spch1),
rightsplit = as.character(spch2),
F_G = as.numeric(rsplitmat[, 6]),
F_GT = as.numeric(rsplitmat[, 7]),
weightedFval = as.numeric(rsplitmat[, 8]),
weightedpval = as.numeric(rsplitmat[, 9]),
afmpval = as.numeric(rsplitmat[, 10]),
stringsAsFactors = FALSE
)
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
tname <- as.name(time)
if (colnames(d)[2] == "id") {
colnames(d)[2] <- "ID"
ranname <- as.name("ID")
} else {
ranname <- as.name(random)
}
terminal_nodes <- unique(rpredict$terminalnode)
find_branch_nodes <- function(tn) {
bn <- c()
for (node in tn) {
while (node > 1) { node <- floor(node / 2); bn <- c(bn, node) }
}
unique(bn)
}
branch_nodes <- find_branch_nodes(terminal_nodes)
spcou <- sort(vapply(branch_nodes,
function(b) match(b, rsmat$nodenumber), integer(1)))
rsmat <- rsmat[spcou, , drop = FALSE]
rownames(rsmat) <- seq_len(nrow(rsmat))
d$pred <- rpredict$predict
d$tn <- rpredict$terminalnode
rsmat$name <- trimws(rsmat$name)
cleaned_left <- gsub("=|>|<", "", rsmat$leftsplit)
cleaned_right <- gsub("=|>|<", "", rsmat$rightsplit)
nameind <- c()
for (i in 1:nrow(rsmat)) {
nameind[i] <- match(rsmat$name[i], colnames(d))
if (is.factor(d[, nameind[i]])) {
numeric_left <- as.numeric(strsplit(cleaned_left, ",")[[i]])
numeric_right <- as.numeric(strsplit(cleaned_right, ",")[[i]])
indexsp <- rep(0L, length(unique(d[, nameind[i]])))
for (j in seq_along(numeric_left)) indexsp[numeric_left[j] + 1] <- 1L
for (j in seq_along(numeric_right)) indexsp[numeric_right[j] + 1] <- 2L
indexsp[indexsp == 0] <- NA
psub <- paste0(
vapply(seq_along(indexsp), function(k) {
if (is.na(indexsp[k])) "NA" else paste0(indexsp[k], "L")
}, character(1)),
collapse = ","
)
ps <- partysplit(eval(parse(text = paste0(nameind[i], "L"))),
index = eval(parse(text = paste0("c(", psub, ")"))))
assign(paste0("sp_", rsmat$nodenumber[i]), ps)
} else if (is.numeric(d[, nameind[i]]) || is.integer(d[, nameind[i]])) {
numeric_val <- as.numeric(strsplit(cleaned_left, ",")[[i]])
ps <- partysplit(eval(parse(text = paste0(nameind[i], "L"))),
breaks = numeric_val)
assign(paste0("sp_", rsmat$nodenumber[i]), ps)
}
}
tnsort <- sort(unique(d$tn))
insort <- sort(rsmat$nodenumber)
splitValues <- stats::setNames(
lapply(insort, function(n) get(paste0("sp_", n))),
as.character(insort)
)
createNodeString <- function(node, insort, tnsort, splitValues) {
ns <- paste0("partynode(", node, "L, info = list(original_id=", node, "L)")
if (!is.null(splitValues[[as.character(node)]])) {
ns <- paste0(ns, ", split = sp_", node)
}
children <- c(node * 2, node * 2 + 1)
cs <- lapply(children, function(ch) {
if (ch %in% insort || ch %in% tnsort) {
createNodeString(ch, insort, tnsort, splitValues)
} else {
NULL
}
})
cs <- cs[!vapply(cs, is.null, logical(1))]
if (length(cs) > 0) {
ns <- paste0(ns, ", kids = list(", paste0(cs, collapse = ", "), ")")
}
paste0(ns, ")")
}
finalString <- createNodeString(1, insort, tnsort, splitValues)
pn <- eval(parse(text = finalString))
py <- party(pn, d)
original_ids <- lapply(nodeapply(py, ids = nodeids(py)),
function(n) n$info[["original_id"]])
ch <- c("Node", "p", "w:", "black", "bold")
idnum <- c(); pvec <- c(); wvec <- c(); porv <- c()
for (i in 1:nrow(rsmat)) {
result_idx <- which(vapply(original_ids,
function(x) x == rsmat$nodenumber[i], logical(1)))
idnum[i] <- result_idx
pvec[i] <- rsmat$weightedpval[i]
wvec[i] <- rsmat$F_Gweight[i]
if (weight == "w") porv[i] <- rsmat$afmpval[i]
}
pvecsub <- vapply(pvec, function(p) {
if (p < 0.001) "<0.001" else paste0("=", round(p, 3))
}, character(1))
pvec <- pvecsub
if (weight == "w") {
porvsub <- vapply(porv, function(p) {
if (p < 0.001) {
"expression(paste(p[OR],'<',0.001,sep=''))"
} else {
paste0("expression(paste(p[OR],'=',", round(p, 3), ",sep=''))")
}
}, character(1))
porv <- porvsub
}
p_labels <- c()
if (weight == "w") {
for (i in seq_along(idnum)) {
p_labels[i] <- paste0(
"geom_node_label(line_list = list(",
"aes(label = paste(ch[1], id)), ",
"aes(label = splitvar), ",
"aes(label = paste(ch[2], pvec[", i, "])), ",
"aes(label = eval(parse(text = porv[", i, "]))), ",
"aes(label = paste(ch[3], wvec[", i, "]))), ",
"line_gpar = list(",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5])), ",
"ids = idnum[", i, "])")
}
} else {
for (i in seq_along(idnum)) {
p_labels[i] <- paste0(
"geom_node_label(line_list = list(",
"aes(label = paste(ch[1], id)), ",
"aes(label = splitvar), ",
"aes(label = paste(ch[2], pvec[", i, "])), ",
"aes(label = paste(ch[3], wvec[", i, "]))), ",
"line_gpar = list(",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
"list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5])), ",
"ids = idnum[", i, "])")
}
}
beta0_vec <- c(); beta1_vec <- c()
for (i in seq_along(unique(d$tn))) {
subset_data <- subset(d, tn == sort(unique(d$tn))[i])
mp <- paste0("lmer(", response_name, "~", tname, "+(1|", ranname,
"), data=subset_data)")
model <- eval(parse(text = mp))
fix <- round(fixef(model), 2)
beta0_vec[i] <- fix[1]
beta1_vec[i] <- fix[2]
}
tnidnum <- c()
sorted_tn <- sort(unique(d$tn))
for (i in seq_along(sorted_tn)) {
tnidnum[i] <- which(vapply(original_ids,
function(x) x == sorted_tn[i], logical(1)))
}
g <- ggparty(py)
g <- g + geom_edge(colour = "gray", linewidth = 1.5)
g <- g + geom_edge_label(colour = "red", size = spsize)
for (i in seq_along(idnum)) {
g <- g + eval(parse(text = p_labels[i]))
}
gtex <- paste0(
"g <- g + geom_node_plot(gglist = list(",
"geom_line(aes(x=", tname, ", y=", response_name, ", group=", ranname,
"), linewidth=", linesize1, ", linetype='dashed', color='#f26651'), ",
"geom_line(aes(x=", tname, ", y=pred), color='red', linewidth=", linesize2, "), ",
"theme_classic(base_size=", plotsize, "/length(tnidnum)), ",
"theme(axis.title.y = element_text(margin = margin(l = -10)), ",
"axis.title.x = element_text(margin = margin(t = -5)))",
"), scales='fixed', id='terminal', ",
"shared_axis_labels=TRUE, shared_legend=TRUE, legend_separator=TRUE,)")
eval(parse(text = gtex))
for (i in seq_along(tnidnum)) {
rank_num <- rank(tnidnum, ties.method = "first")
new_tnidnum <- length(idnum) + rank_num
beta01_i <- paste0(
"expression(paste(beta[0],'=',", beta0_vec[i],
",' ',beta[1],'=',", beta1_vec[i], ",sep=''))")
ptn <- paste0(
"geom_node_label(line_list = list(",
"aes(label = paste0('Node',", new_tnidnum[i],
",', N = ', nodesize/length(unique(d$", tname, ")))), ",
"aes(label = eval(parse(text = ", deparse(beta01_i), ")))), ",
"line_gpar = list(",
"list(size=", tnsize, "/length(tnidnum), col=ch[4], fontface=ch[5]), ",
"list(size=", tnsize, "/length(tnidnum), col=ch[4], fontface=ch[5])), ",
"fontface='bold', ids=tnidnum[", i, "], size=5, nudge_y=0.01)")
g <- g + eval(parse(text = ptn))
}
g
}
#' Decision Tree Plot Visualisation for Longitudinal Data
#'
#' @description
#' Visualises the structure of a decision tree for longitudinal
#' data. Built on \pkg{ggparty}. Each split node displays the node
#' number, split variable, \eqn{p}-value, and weight \eqn{w}. Each
#' terminal node displays the node number, sample size \eqn{N}, and the
#' intercept (\eqn{\hat\beta_0}) and slope (\eqn{\hat\beta_1}) from a
#' linear mixed-effects model fitted within that node. Individual
#' longitudinal trajectories are shown as dashed lines; the predicted
#' values (average at each time point) are shown as solid lines, with the
#' response variable on the vertical axis and time on the horizontal axis.
#'
#' @param x A \code{longitree} or \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree to plot when \code{x}
#' is a \code{threetrees} object.
#' @param snsize Split-node label size (default 50).
#' @param spsize Split-point label size (default 5).
#' @param plotsize Overall plot size (default 80).
#' @param linesize1 Branch line width (default 0.3).
#' @param linesize2 Main line width (default 1).
#' @param tnsize Terminal-node label size (default 60).
#'
#' @return A \code{ggplot2}/\code{ggparty} object.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitree}}, \code{\link{threetrees}}
#'
#' @export
treeplot <- function(x, tree = NULL,
snsize = 50, spsize = 5, plotsize = 80,
linesize1 = 0.3, linesize2 = 1, tnsize = 60) {
use_threetrees_fortran <- FALSE
datatype <- NULL
if (inherits(x, "longitree")) {
meta <- x$.meta
formula <- meta$formula
time <- meta$time
random <- meta$random
data <- meta$data
weight <- meta$weight
maxdepth <- meta$maxdepth
datatype <- meta$datatype
rsplitmat <- x$rsplitmat
rpredict <- x$rpredict
nmat <- x$nmat
} else if (inherits(x, "threetrees")) {
meta <- x$.meta
formula <- meta$formula
time <- meta$time
random <- meta$random
weight <- meta$weight
maxdepth <- meta$maxdepth
use_threetrees_fortran <- TRUE
if (is.null(tree) || !(tree %in% 1:3)) {
stop("tree must be 1, 2, or 3 when using threetrees object")
}
data <- x[[paste0("bootdata", tree)]]
rsplitmat <- x[[paste0("rsplitmat", tree)]]
rpredict <- x[[paste0("rpredict", tree)]]
nmat <- x[[paste0("nmat", tree)]]
} else {
stop("x must be a longitree or threetrees object")
}
.build_tree_plot(
formula = formula, time = time, random = random, data = data,
rsplitmat = rsplitmat, rpredict = rpredict, nmat = nmat,
weight = weight, maxdepth = maxdepth, datatype = datatype,
use_threetrees_fortran = use_threetrees_fortran,
snsize = snsize, spsize = spsize, plotsize = plotsize,
linesize1 = linesize1, linesize2 = linesize2, tnsize = tnsize)
}
#' Construction of Multiple Decision Trees for Longitudinal Data
#'
#' @description
#' Generates multiple trees from bootstrap samples and evaluates
#' all three-tree combinations based on two criteria: cross-validated
#' prediction error and tree diversification measured by the adjusted Rand
#' index (ARI). Bootstrap sampling is performed at the subject level to
#' preserve longitudinal structure.
#'
#' @inheritParams longitree
#' @param bootsize Number of subjects in each bootstrap sample.
#' @param trees Number of bootstrap trees to grow (default 100).
#' @param mins Number of top-ranking candidate three-tree subsets to retain
#' (default 40).
#'
#' @details
#' See \code{\link{longitree}} for a description of the three single-tree
#' construction procedures (ST1, ST2, ST3) corresponding to \code{cv},
#' \code{alpha}, and \code{gamma}.
#'
#' @return An object of class \code{"longitrees"}. Pass to
#' \code{\link{selectionplot}} to select the optimal three-tree combination.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitree}}, \code{\link{selectionplot}},
#' \code{\link{threetrees}}, \code{\link{treeplot}}
#'
#' @export
longitrees <- function(formula, time, random, weight = "w", data,
alpha = "no", gamma = "no", cv = "no",
maxdepth = 5, minbucket = 5, minsplit = 20,
xval = 10, bootsize, trees = 100, mins = 40) {
if (alpha == "no" && gamma == "no" && cv == "no") {
stop("At least one of the values for alpha, gamma, or cv must be specified")
}
specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
if (specified_count > 1) {
stop("Only one of alpha, gamma, or cv can be specified at a time")
}
weight_orig <- weight
alpha_orig <- alpha
gamma_orig <- gamma
cv_orig <- cv
data_name <- deparse(substitute(data))
fortran_seeds <- sample.int(.Machine$integer.max, trees)
terms_obj <- terms.formula(formula, data = data)
response_var <- attr(terms_obj, "response")
response_name <- attr(terms_obj, "variables")[[response_var + 1]]
formula_str <- deparse(formula)
if (grepl("\\.", formula_str)) {
predictor_names <- setdiff(names(data), c(response_name, time, random))
} else {
predictor_names <- attr(terms_obj, "term.labels")
}
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
datatype <- c()
for (col in 4:ncol(d)) {
if (is.numeric(d[, col])) {
datatype[col - 3] <- 1
} else if (is.factor(d[, col])) {
datatype[col - 3] <- 2
} else {
stop("Covariate data types must be either numeric or factor")
}
}
for (col in 1:ncol(d)) {
if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
}
d[, 3] <- as.double(as.factor(d[, 3]))
d[, 2] <- as.double(as.factor(d[, 2]))
if ((weight < 0 || weight > 1) && weight != "w") {
stop("The weight value is invalid")
}
if (weight == "w") weight <- -1
d <- as.matrix(d)
timecount <- length(unique(d[, 3]))
bootdatasize <- bootsize * timecount
if (alpha == "no") alpha <- 1
beta1len <- bootsize - 2
beta2len <- (bootsize - 2) * (timecount - 1)
beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
beta2 <- vapply(seq_len(beta2len),
function(i) beta((timecount - 1)/2, i/2), numeric(1))
if (gamma == "no") gamma <- -1
if (cv == "yes") cv <- 1
if (cv == "no") cv <- -1
predvecmat <- matrix(0, nrow(d), trees)
datatnnodenummat <- matrix(0L, nrow(d), trees)
oridatamat <- matrix(0L, bootdatasize, trees)
cvsplitindmat <- matrix(0L, bootdatasize, trees)
unique_subjects <- unique(d[, 2])
boot_mat <- replicate(trees,
sample(unique_subjects, size = bootsize, replace = TRUE))
for (i in 1:trees) {
subset_idx <- unlist(lapply(boot_mat[, i], function(b) which(d[, 2] == b)))
oridatamat[, i] <- subset_idx
bootdata <- do.call(rbind,
lapply(boot_mat[, i], function(b) d[d[, 2] == b, ]))
bootdata[, 2] <- rep(1:bootsize, each = timecount)
nodenummatori <- matrix(0L, nrow(d), maxdepth)
nodenummatboot <- matrix(0L, bootdatasize, maxdepth)
allfval <- rep(0, 2^maxdepth - 1)
prunind <- 0L
res2 <- .Fortran("threetreesboottreegrowth",
as.integer(nodenummatboot), as.integer(nodenummatori),
as.double(allfval), as.integer(prunind), as.integer(bootdatasize),
as.integer(ncol(d)), as.double(bootdata), as.integer(datatype),
as.double(weight), as.double(beta1), as.double(beta2),
as.integer(beta1len), as.integer(beta2len), as.integer(timecount),
as.integer(minbucket), as.double(d), as.integer(nrow(d)),
as.double(alpha), as.integer(maxdepth), as.integer(minsplit),
PACKAGE = "longitree")
allgammaval <- rep(0, (2^(maxdepth - 1)) - 1)
prunind <- res2[[4]]
nodenummatori <- matrix(res2[[2]], nrow(d), maxdepth)
nodenummatboot <- matrix(res2[[1]], bootdatasize, maxdepth)
allfval <- res2[[3]]
res3 <- .Fortran("threetreesbootcvtreepruning",
as.integer(prunind), as.integer(bootdatasize), as.double(alpha),
as.double(gamma), as.integer(cv), as.integer(maxdepth),
as.integer(nodenummatboot), as.double(allfval),
as.double(allgammaval),
PACKAGE = "longitree")
bestgammaval <- 0
allgammaval <- res3[[9]]
cvsplitind <- rep(0L, bootdatasize)
res4 <- .Fortran("threetreesbootcrossvaridation",
as.double(gamma), as.double(alpha), as.integer(cv),
as.double(bootdata), as.integer(bootdatasize), as.integer(ncol(d)),
as.integer(xval), as.integer(timecount), as.integer(maxdepth),
as.double(beta1), as.double(beta2), as.integer(beta1len),
as.integer(beta2len), as.integer(minbucket), as.integer(minsplit),
as.double(weight), as.integer(datatype), as.double(allgammaval),
as.double(bestgammaval), as.integer(cvsplitind),
as.integer(fortran_seeds[i]),
PACKAGE = "longitree")
bestgammaval <- res4[[19]]
cvsplitind <- res4[[20]]
cvsplitindmat[, i] <- cvsplitind
prunenodenummatboot <- matrix(0L, bootdatasize, maxdepth)
prunenodenummatori <- matrix(0L, nrow(d), maxdepth)
res5 <- .Fortran("threetreesbootgammatreepruning",
as.integer(prunind), as.integer(bootdatasize), as.integer(nrow(d)),
as.double(bestgammaval), as.integer(cv), as.integer(maxdepth),
as.integer(nodenummatboot), as.integer(nodenummatori),
as.double(allfval), as.integer(prunenodenummatboot),
as.integer(prunenodenummatori),
PACKAGE = "longitree")
prunenodenummatboot <- matrix(res5[[10]], bootdatasize, maxdepth)
prunenodenummatori <- matrix(res5[[11]], nrow(d), maxdepth)
predvec <- rep(0, nrow(d))
datatnnodenum <- rep(0L, nrow(d))
res6 <- .Fortran("threetreesboottreepredict",
as.integer(nrow(d)), as.integer(bootdatasize),
as.integer(timecount), as.integer(maxdepth),
as.integer(prunenodenummatboot), as.integer(prunenodenummatori),
as.double(bootdata), as.double(d), as.integer(ncol(d)),
as.double(predvec), as.integer(datatnnodenum),
PACKAGE = "longitree")
predvecmat[, i] <- res6[[10]]
datatnnodenummat[, i] <- res6[[11]]
}
bettertreelossval <- matrix(0, mins, 5)
bettertreemaxval <- matrix(0, mins, 5)
bettertreenum <- matrix(0L, mins, 6)
res7 <- .Fortran("threetreestreechoice",
as.double(predvecmat), as.integer(datatnnodenummat),
as.integer(nrow(d)), as.integer(trees), as.integer(timecount),
as.double(d), as.integer(ncol(d)), as.integer(mins),
as.integer(bettertreenum), as.double(bettertreelossval),
as.double(bettertreemaxval),
PACKAGE = "longitree")
bettertreelossval <- matrix(res7[[10]], mins, 5)
bettertreemaxval <- matrix(res7[[11]], mins, 5)
bettertreenum <- matrix(res7[[9]], mins, 6)
besttreemat <- matrix(NA, (mins * 4) + 2, 8)
besttreemat[1, 1:3] <- -100
besttreemat[(mins * 2 + 2), 1:3] <- -200
j <- 1
for (i in 1:mins) {
j <- j + 1
besttreemat[j, 1:3] <- bettertreenum[i, 1:3]
j <- j + 1
besttreemat[j, 1:3] <- 0
besttreemat[j, 4:8] <- bettertreelossval[i, 1:5]
}
j <- j + 1
for (i in 1:mins) {
j <- j + 1
besttreemat[j, 1:3] <- bettertreenum[i, 4:6]
j <- j + 1
besttreemat[j, 1:3] <- 0
besttreemat[j, 4:8] <- bettertreemaxval[i, 1:5]
}
loss <- besttreemat[seq(3, mins * 2 + 1, 2), ]
rand <- besttreemat[seq(mins * 2 + 4, mins * 4 + 2, 2), ]
loss_valid <- apply(loss[, 6:8, drop = FALSE], 1, function(r) all(r < 1.5))
rand_valid <- apply(rand[, 6:8, drop = FALSE], 1, function(r) all(r < 1.5))
loss <- loss[loss_valid, , drop = FALSE]
rand <- rand[rand_valid, , drop = FALSE]
lossy <- loss[, 4]
lossx <- apply(loss[, 6:8, drop = FALSE], 1, max)
loss <- data.frame(loss = lossy, rand = lossx,
ari12 = loss[, 6], ari13 = loss[, 7],
ari23 = loss[, 8])
randy <- rand[, 4]
randx <- apply(rand[, 6:8, drop = FALSE], 1, max)
rand <- data.frame(loss = randy, rand = randx,
ari12 = rand[, 6], ari13 = rand[, 7],
ari23 = rand[, 8])
plottree <- rbind(loss, rand)
plot(plottree$rand, plottree$loss, col = "gray", cex = 2, pch = 19,
cex.axis = 1.7, xlab = "maximum of the three ARIs",
ylab = "cross-validated prediction error")
loss_treenum_raw <- besttreemat[seq(2, mins * 2, 2), 1:3, drop = FALSE]
rand_treenum_raw <- besttreemat[seq(mins * 2 + 3, mins * 4 + 1, 2),
1:3, drop = FALSE]
treenumber_loss <- data.frame(loss_treenum_raw[loss_valid, , drop = FALSE])
treenumber_rand <- data.frame(rand_treenum_raw[rand_valid, , drop = FALSE])
treenumber <- rbind(treenumber_loss, treenumber_rand)
names(treenumber) <- c("treenumber1", "treenumber2", "treenumber3")
result <- list(
oridatamat = oridatamat, plottree = plottree,
cvsplitindmat = cvsplitindmat, treenumber = treenumber,
datatype = datatype,
.meta = list(
formula = formula, time = time, random = random, weight = weight_orig,
data = data, data_name = data_name, alpha = alpha_orig,
gamma = gamma_orig, cv = cv_orig,
bootsize = bootsize, maxdepth = maxdepth, minbucket = minbucket,
minsplit = minsplit, xval = xval, mins = mins))
class(result) <- "longitrees"
result
}
#' Select Optimal Three-Tree Combination
#'
#' @description
#' Plots the cross-validated prediction error against the maximum pairwise
#' adjusted Rand index (ARI) for candidate three-tree subsets, and selects
#' a subset based on either prediction performance or tree diversification.
#' The selected combination is indicated by a red point on the plot, which
#' corresponds to the three trees used in the subsequent
#' \code{\link{threetrees}} step.
#'
#' @param longitrees A \code{longitrees} object.
#' @param metric \code{"PE"} to select the subset with the smallest
#' cross-validated prediction error, or \code{"ARI"} to select the subset
#' with the smallest maximum pairwise ARI (greatest tree diversification).
#' @param nth Rank of the tree subset to select (1 = best).
#'
#' @return An object of class \code{"selectionplot"}. Pass to
#' \code{\link{threetrees}} to refit and evaluate the selected trees.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitrees}}, \code{\link{threetrees}}
#'
#' @export
selectionplot <- function(longitrees, metric, nth) {
if (!inherits(longitrees, "longitrees"))
stop("longitrees must be a longitrees object")
plottree <- longitrees$plottree
treenumber <- longitrees$treenumber
mins <- longitrees$.meta$mins
plot(plottree$rand, plottree$loss, col = "gray", cex = 2, pch = 19,
cex.axis = 1.7, xlab = "maximum of the three ARIs",
ylab = "cross-validated prediction error")
half <- nrow(plottree) %/% 2
if (metric == "PE") {
plottreech <- plottree[1:half, ]
treenumberch <- treenumber[1:half, ]
pt <- data.frame(plottreech, treenumberch)
df_sorted <- pt[order(pt$loss, -pt$rand), ]
points(df_sorted[nth, "rand"], df_sorted[nth, "loss"],
col = "red", cex = 2, pch = 19)
threetreenumber <- df_sorted[nth, c("treenumber1", "treenumber2",
"treenumber3")]
selected_pe <- df_sorted[nth, "loss"]
selected_ari <- df_sorted[nth, "rand"]
selected_ari_pair <- setNames(
as.numeric(df_sorted[nth, c("ari12", "ari13", "ari23")]),
c("ari12", "ari13", "ari23"))
} else if (metric == "ARI") {
plottreech <- plottree[(half + 1):nrow(plottree), ]
treenumberch <- treenumber[(half + 1):nrow(treenumber), ]
pt <- data.frame(plottreech, treenumberch)
df_sorted <- pt[order(pt$rand, pt$loss), ]
points(df_sorted[nth, "rand"], df_sorted[nth, "loss"],
col = "red", cex = 2, pch = 19)
threetreenumber <- df_sorted[nth, c("treenumber1", "treenumber2",
"treenumber3")]
selected_pe <- df_sorted[nth, "loss"]
selected_ari <- df_sorted[nth, "rand"]
selected_ari_pair <- setNames(
as.numeric(df_sorted[nth, c("ari12", "ari13", "ari23")]),
c("ari12", "ari13", "ari23"))
} else {
stop("metric must be 'PE' or 'ARI'")
}
result <- list(
threetreenumber = threetreenumber,
selected_pe = unname(selected_pe),
selected_ari = unname(selected_ari),
selected_ari_pair = selected_ari_pair)
class(result) <- "selectionplot"
result
}
#' Fit and Evaluate Three Selected Trees
#'
#' @description
#' Refits the three trees selected by \code{\link{selectionplot}} on their
#' original bootstrap samples.
#'
#' @param x A \code{longitrees} object.
#' @param selection A \code{selectionplot} object.
#'
#' @return An object of class \code{"threetrees"}. Use
#' \code{\link{summary.threetrees}}, \code{\link{predict.threetrees}},
#' or \code{\link{plot.threetrees}} to inspect the results.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitrees}}, \code{\link{selectionplot}},
#' \code{\link{treeplot}}
#'
#' @examples
#' data(ltreedata)
#' set.seed(10)
#' trees_res <- longitrees(y ~ ., time = "time", random = "subject",
#' weight = 0.5, data = ltreedata, alpha = 0.01,
#' bootsize = 50, mins = 40)
#' sel <- selectionplot(trees_res, metric = "PE", nth = 1)
#' tt <- threetrees(trees_res, selection = sel)
#' summary(tt)
#' predict(tt, tree = 1)
#' predict(tt, tree = 2)
#' predict(tt, tree = 3)
#' plot(tt, tree = 1)
#' plot(tt, tree = 2)
#' plot(tt, tree = 3)
#'
#' @export
threetrees <- function(x, selection) {
if (!inherits(x, "longitrees"))
stop("x must be a longitrees object")
if (!inherits(selection, "selectionplot"))
stop("selection must be a selectionplot object")
selected_pe <- NULL
selected_ari <- NULL
selected_ari_pair <- NULL
if (inherits(x, "longitrees")) {
meta <- x$.meta
formula <- meta$formula
time <- meta$time
random <- meta$random
weight <- meta$weight
data <- meta$data
data_name <- if (!is.null(meta$data_name)) meta$data_name else "data"
alpha <- meta$alpha
gamma <- meta$gamma
cv <- meta$cv
maxdepth <- meta$maxdepth
minbucket <- meta$minbucket
minsplit <- meta$minsplit
xval <- meta$xval
bootsize <- meta$bootsize
oridatamat <- x$oridatamat
cvsplitindmat <- x$cvsplitindmat
selected_pe <- selection$selected_pe
selected_ari <- selection$selected_ari
selected_ari_pair <- selection$selected_ari_pair
threetreenumber <- selection$threetreenumber
}
if (alpha == "no" && gamma == "no" && cv == "no") {
stop("At least one of the values for alpha, gamma, or cv must be specified")
}
specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
if (specified_count > 1) {
stop("Only one of alpha, gamma, or cv can be specified at a time")
}
fortran_seed <- sample.int(.Machine$integer.max, 1)
terms_obj <- terms.formula(formula, data = data)
response_var <- attr(terms_obj, "response")
response_name <- attr(terms_obj, "variables")[[response_var + 1]]
formula_str <- deparse(formula)
if (grepl("\\.", formula_str)) {
predictor_names <- setdiff(names(data), c(response_name, time, random))
} else {
predictor_names <- attr(terms_obj, "term.labels")
}
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
datatype <- c()
for (col in 4:ncol(d)) {
if (is.numeric(d[, col])) {
datatype[col - 3] <- 1
} else if (is.factor(d[, col])) {
datatype[col - 3] <- 2
} else {
stop("Covariate data types must be either numeric or factor")
}
}
for (col in 1:ncol(d)) {
if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
}
d[, 3] <- as.double(as.factor(d[, 3]))
d[, 2] <- as.double(as.factor(d[, 2]))
weight_orig <- weight
alpha_orig <- alpha
gamma_orig <- gamma
if ((weight < 0 || weight > 1) && weight != "w") {
stop("The weight value is invalid")
}
if (weight == "w") weight <- -1
timecount <- length(unique(d[, 3]))
bootdatasize <- bootsize * timecount
ori1 <- oridatamat[, threetreenumber$treenumber1]
bootdata1 <- d[ori1, ] ; bootdata1[, 2] <- rep(1:bootsize, each = timecount)
ori2 <- oridatamat[, threetreenumber$treenumber2]
bootdata2 <- d[ori2, ] ; bootdata2[, 2] <- rep(1:bootsize, each = timecount)
ori3 <- oridatamat[, threetreenumber$treenumber3]
bootdata3 <- d[ori3, ] ; bootdata3[, 2] <- rep(1:bootsize, each = timecount)
cvsplitind1 <- cvsplitindmat[, threetreenumber$treenumber1]
cvsplitind2 <- cvsplitindmat[, threetreenumber$treenumber2]
cvsplitind3 <- cvsplitindmat[, threetreenumber$treenumber3]
if (alpha == "no") alpha <- 1
beta1len <- bootsize - 2
beta2len <- (bootsize - 2) * (timecount - 1)
beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
beta2 <- vapply(seq_len(beta2len),
function(i) beta((timecount - 1)/2, i/2), numeric(1))
make_mats <- function() {
list(nmat = matrix(0L, bootdatasize, maxdepth),
rsplit = matrix(0, 2^maxdepth - 1, 10),
allfval = rep(0, 2^maxdepth - 1), prunind = 0L)
}
m1 <- make_mats(); m2 <- make_mats(); m3 <- make_mats()
d1 <- as.matrix(bootdata1)
d2 <- as.matrix(bootdata2)
d3 <- as.matrix(bootdata3)
grow_tree <- function(dd, mm) {
.Fortran("threetreestreegrowth",
as.double(weight), as.integer(maxdepth), as.integer(minbucket),
as.integer(minsplit), as.double(alpha), as.integer(nrow(dd)),
as.integer(ncol(dd)), as.integer(timecount), as.double(dd),
as.integer(datatype), as.double(beta1), as.double(beta2),
as.integer(beta1len), as.integer(beta2len), as.integer(mm$nmat),
as.double(mm$allfval), as.integer(mm$prunind), as.double(mm$rsplit),
PACKAGE = "longitree")
}
res11 <- grow_tree(d1, m1)
res12 <- grow_tree(d2, m2)
res13 <- grow_tree(d3, m3)
if (gamma == "no") gamma <- -1
cv_orig <- cv
if (cv == "yes") cv <- 1
if (cv == "no") cv <- -1
allgammaval1 <- rep(0, (2^(maxdepth - 1)) - 1)
allgammaval2 <- rep(0, (2^(maxdepth - 1)) - 1)
allgammaval3 <- rep(0, (2^(maxdepth - 1)) - 1)
prunind1 <- res11[[17]]; prunind2 <- res12[[17]]; prunind3 <- res13[[17]]
nodenummat1 <- matrix(res11[[15]], bootdatasize, maxdepth)
nodenummat2 <- matrix(res12[[15]], bootdatasize, maxdepth)
nodenummat3 <- matrix(res13[[15]], bootdatasize, maxdepth)
allfval1 <- res11[[16]]; allfval2 <- res12[[16]]; allfval3 <- res13[[16]]
cv_prune <- function(pi, dd, al, gm, cv_flag, nm, af, ag) {
.Fortran("threetreescvtreepruning",
as.integer(pi), as.integer(nrow(dd)), as.double(al),
as.double(gm), as.integer(cv_flag), as.integer(maxdepth),
as.integer(nm), as.double(af), as.double(ag),
PACKAGE = "longitree")
}
res21 <- cv_prune(prunind1, d1, alpha, gamma, cv, nodenummat1, allfval1, allgammaval1)
res22 <- cv_prune(prunind2, d2, alpha, gamma, cv, nodenummat2, allfval2, allgammaval2)
res23 <- cv_prune(prunind3, d3, alpha, gamma, cv, nodenummat3, allfval3, allgammaval3)
bestgammaval1 <- 0; bestgammaval2 <- 0; bestgammaval3 <- 0
r2cvval1 <- 0; r2cvval2 <- 0; r2cvval3 <- 0; r2cvvalthree <- 0
allgammaval1 <- res21[[9]]; allgammaval2 <- res22[[9]]
allgammaval3 <- res23[[9]]
d_mat <- as.matrix(d)
res30 <- .Fortran("threetreesthreetreecv",
as.double(gamma), as.double(alpha), as.integer(cv),
as.double(d1), as.double(d2), as.double(d3),
as.integer(bootdatasize), as.integer(ncol(d_mat)),
as.integer(xval), as.integer(timecount), as.integer(maxdepth),
as.double(beta1), as.double(beta2), as.integer(beta1len),
as.integer(beta2len), as.integer(minbucket), as.integer(minsplit),
as.double(weight), as.integer(datatype),
as.double(allgammaval1), as.double(allgammaval2),
as.double(allgammaval3), as.double(bestgammaval1),
as.double(bestgammaval2), as.double(bestgammaval3),
as.double(r2cvval1), as.double(r2cvval2), as.double(r2cvval3),
as.double(r2cvvalthree),
as.integer(cvsplitind1), as.integer(cvsplitind2),
as.integer(cvsplitind3), as.integer(fortran_seed),
PACKAGE = "longitree")
r2cvval1 <- res30[[26]]; r2cvval2 <- res30[[27]]
r2cvval3 <- res30[[28]]; r2cvvalthree <- res30[[29]]
bestgammaval1 <- res30[[23]]; bestgammaval2 <- res30[[24]]
bestgammaval3 <- res30[[25]]
g_prune <- function(pi, bds, bg, cv_flag, nm, af) {
pnm <- matrix(0L, bds, maxdepth)
.Fortran("threetreesgammatreepruning",
as.integer(pi), as.integer(bds), as.double(bg),
as.integer(cv_flag), as.integer(maxdepth), as.integer(nm),
as.double(af), as.integer(pnm),
PACKAGE = "longitree")
}
res41 <- g_prune(prunind1, bootdatasize, bestgammaval1, cv, nodenummat1, allfval1)
res42 <- g_prune(prunind2, bootdatasize, bestgammaval2, cv, nodenummat2, allfval2)
res43 <- g_prune(prunind3, bootdatasize, bestgammaval3, cv, nodenummat3, allfval3)
prunenodenummat1 <- matrix(res41[[8]], bootdatasize, maxdepth)
prunenodenummat2 <- matrix(res42[[8]], bootdatasize, maxdepth)
prunenodenummat3 <- matrix(res43[[8]], bootdatasize, maxdepth)
if (sum(prunenodenummat1 == 0) == bootdatasize * (maxdepth - 1))
stop("No splitting (The first tree)")
if (sum(prunenodenummat2 == 0) == bootdatasize * (maxdepth - 1))
stop("No splitting (The second tree)")
if (sum(prunenodenummat3 == 0) == bootdatasize * (maxdepth - 1))
stop("No splitting (The third tree)")
tree_predict <- function(bds, dd, pnm) {
pv <- rep(0, bds); tn <- rep(0L, bds)
.Fortran("threetreestreepredict",
as.integer(bds), as.integer(timecount), as.integer(maxdepth),
as.integer(ncol(dd)), as.integer(pnm), as.double(dd),
as.double(pv), as.integer(tn),
PACKAGE = "longitree")
}
res51 <- tree_predict(bootdatasize, d1, prunenodenummat1)
res52 <- tree_predict(bootdatasize, d2, prunenodenummat2)
res53 <- tree_predict(bootdatasize, d3, prunenodenummat3)
extract_split <- function(res) {
rs <- matrix(res[[18]], 2^maxdepth - 1, 10)
rs[rs[, 1] != 0, , drop = FALSE]
}
rsplitmat1 <- extract_split(res11)
rsplitmat2 <- extract_split(res12)
rsplitmat3 <- extract_split(res13)
rpredict1 <- data.frame(predict = res51[[7]], terminalnode = res51[[8]])
rpredict2 <- data.frame(predict = res52[[7]], terminalnode = res52[[8]])
rpredict3 <- data.frame(predict = res53[[7]], terminalnode = res53[[8]])
d <- data[c(paste(response_name), paste(random), paste(time),
paste(predictor_names))]
ori1 <- oridatamat[, threetreenumber$treenumber1]
bootdata1 <- d[ori1, ]; bootdata1[, 2] <- rep(1:bootsize, each = timecount)
ori2 <- oridatamat[, threetreenumber$treenumber2]
bootdata2 <- d[ori2, ]; bootdata2[, 2] <- rep(1:bootsize, each = timecount)
ori3 <- oridatamat[, threetreenumber$treenumber3]
bootdata3 <- d[ori3, ]; bootdata3[, 2] <- rep(1:bootsize, each = timecount)
cvgamma_list <- if (cv_orig == "yes") {
list(cvgamma1 = bestgammaval1, cvgamma2 = bestgammaval2,
cvgamma3 = bestgammaval3)
} else {
list(cvgamma1 = NULL, cvgamma2 = NULL, cvgamma3 = NULL)
}
result <- c(list(
bootdata1 = bootdata1, bootdata2 = bootdata2, bootdata3 = bootdata3,
rsplitmat1 = rsplitmat1, rsplitmat2 = rsplitmat2, rsplitmat3 = rsplitmat3,
rpredict1 = rpredict1, rpredict2 = rpredict2, rpredict3 = rpredict3,
nmat1 = nodenummat1, nmat2 = nodenummat2, nmat3 = nodenummat3,
r2cvtree1 = r2cvval1, r2cvtree2 = r2cvval2, r2cvtree3 = r2cvval3,
r2cv3trees = r2cvvalthree,
selected_pe = selected_pe, selected_ari = selected_ari,
selected_ari_pair = selected_ari_pair),
cvgamma_list,
list(.meta = list(formula = formula, time = time, random = random,
weight = weight_orig, data_name = data_name,
alpha = alpha_orig, gamma = gamma_orig,
cv = cv_orig, maxdepth = maxdepth)))
class(result) <- "threetrees"
result
}
#' @describeIn threetrees Print a brief summary of a \code{threetrees} object.
#' @param object A \code{threetrees} object.
#' @export
summary.threetrees <- function(object, ...) {
meta <- object$.meta
weight_val <- if (is.null(meta$weight) || meta$weight == "w") '"w"'
else as.character(meta$weight)
data_str <- if (!is.null(meta$data_name)) meta$data_name else "data"
method_str <- if (!is.null(meta$alpha) && meta$alpha != "no") paste0("alpha = ", meta$alpha)
else if (!is.null(meta$gamma) && meta$gamma != "no") paste0("gamma = ", meta$gamma)
else 'cv = "yes"'
cat("threetrees:\n")
cat(paste0("threetrees(", deparse(meta$formula), ", time = \"",
meta$time, "\", random = \"", meta$random, "\",\n",
" weight = ", weight_val, ", data = ", data_str,
", ", method_str, ")\n"))
if (!is.null(meta$alpha) && meta$alpha != "no") {
cat("\nMethod: alpha =", meta$alpha, "\n")
} else if (!is.null(meta$gamma) && meta$gamma != "no") {
cat("\nMethod: gamma =", meta$gamma, "\n")
} else {
cat("\nMethod: cv\n")
}
weight_str <- if (is.null(meta$weight) || meta$weight == "w") "w (optimal weight)"
else as.character(meta$weight)
cat("Weight:", weight_str, "\n")
cat("\n")
if (!is.null(object$selected_pe))
cat("Cross-validated prediction error:", object$selected_pe, "\n")
if (!is.null(object$selected_ari))
cat("Maximum of the three ARIs:", object$selected_ari, "\n")
if (!is.null(object$selected_ari_pair)) {
cat("ARI Pairs:\n")
cat(" ARI(Tree1, Tree2):", object$selected_ari_pair["ari12"], "\n")
cat(" ARI(Tree1, Tree3):", object$selected_ari_pair["ari13"], "\n")
cat(" ARI(Tree2, Tree3):", object$selected_ari_pair["ari23"], "\n")
}
cat("\nCross-validated coefficient of determination:\n")
cat(" Tree 1:", object$r2cvtree1, "\n")
cat(" Tree 2:", object$r2cvtree2, "\n")
cat(" Tree 3:", object$r2cvtree3, "\n")
cat(" 3Trees:", object$r2cv3trees, "\n")
if (!is.null(object$cvgamma1)) {
cat("\nComplexity parameter:\n")
cat(" Tree 1:", object$cvgamma1, "\n")
cat(" Tree 2:", object$cvgamma2, "\n")
cat(" Tree 3:", object$cvgamma3, "\n")
}
invisible(object)
}
#' @describeIn threetrees Print method (calls \code{summary}).
#' @param x A \code{threetrees} object.
#' @export
print.threetrees <- function(x, ...) {
summary.threetrees(x, ...)
}
#' @describeIn threetrees Extract predicted values and terminal node
#' assignments from a \code{threetrees} object. Returns a data frame
#' with columns \code{predict} (predicted values) and
#' \code{terminalnode} (terminal node assignments).
#' @param object A \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree's predictions to
#' return.
#' @export
predict.threetrees <- function(object, tree = 1, ...) {
if (!(tree %in% 1:3)) stop("tree must be 1, 2, or 3")
object[[paste0("rpredict", tree)]]
}
#' @describeIn threetrees Plot one of the three trees.
#' A convenience wrapper around \code{\link{treeplot}}.
#' @param x A \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree to plot.
#' @param ... Additional arguments passed to \code{\link{treeplot}}.
#' @export
plot.threetrees <- function(x, tree = 1, ...) {
treeplot(x, tree = tree, ...)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.