R/longitree.R

Defines functions plot.threetrees predict.threetrees print.threetrees summary.threetrees threetrees selectionplot longitrees treeplot .build_tree_plot plot.longitree predict.longitree print.longitree summary.longitree longitree

Documented in longitree longitrees plot.longitree plot.threetrees predict.longitree predict.threetrees print.longitree print.threetrees selectionplot summary.longitree summary.threetrees threetrees treeplot

#' Construction of a Decision Tree for Longitudinal Data
#'
#' @description
#' Constructs a single decision tree for longitudinal data.
#' The method evaluates both the main effect of a covariate and its
#' interaction with time, incorporating a weighting mechanism to balance
#' the two effects.  Three single-tree construction procedures (ST1, ST2,
#' ST3) are available; see Details.  For the underlying methodology, refer
#' to Obata and Sugimoto (2026).
#'
#' @param formula A formula specifying the model.
#'   The response variable should be on the left side and covariates on the
#'   right side.  Use \code{response ~ .} to include all covariates except the
#'   time variable and the random effect, or select specific covariates such as
#'   \code{response ~ x1 + x2}.  Time-invariant (baseline) covariates are
#'   assumed.
#' @param time Character string giving the column name of the time variable.
#'   All individuals are assumed to be observed at the same time points.
#' @param random Character string giving the column name of the random effect
#'   (subject identifier).
#' @param weight Weight for balancing the main effect of a covariate and
#'   its interaction with time.  A value in
#'   \eqn{\{0.0, 0.1, \ldots, 1.0\}}: \code{1.0} evaluates only the
#'   mean difference in the response variable between the two groups and
#'   \code{0.0} evaluates only the difference in change over time of the
#'   response variable between the two groups.
#'   Set \code{weight = "w"} (the default) to select the optimal weight
#'   from the same grid at each node.
#' @param data A data frame containing the variables in \code{formula} together
#'   with the time and random-effect variables.
#' @param alpha Significance level used as the stopping rule for tree
#'   growth.  A smaller value produces a more conservative (smaller) tree.
#'   Specify a numeric value or \code{"no"} (default) if not used.
#'   Corresponds to ST2.
#' @param gamma Complexity parameter for pruning.  A larger value prunes
#'   more aggressively, yielding a smaller and simpler tree; a smaller
#'   value retains more branches.  Specify a numeric value or \code{"no"}
#'   (default) if not used.  Corresponds to ST3.
#' @param cv Set \code{"yes"} to construct the decision tree using
#'   cross-validation, or \code{"no"} (default) otherwise.
#'   Corresponds to ST1.
#' @param maxdepth Maximum depth of the tree (default 5).
#' @param minbucket Minimum number of subjects in a terminal node (default 5).
#' @param minsplit Minimum number of subjects required to attempt a split
#'   (default 20).
#' @param xval Number of cross-validation folds (default 10).  Used to
#'   compute the cross-validated coefficient of determination
#'   (\eqn{R^2_{\mathrm{CV}}}); when \code{cv = "yes"}, also used for
#'   final tree selection.
#'
#' @details
#' Exactly one of \code{alpha}, \code{gamma}, or \code{cv} must be specified.
#' Specifying more than one will result in an error.  These correspond to the
#' three single-tree construction procedures:
#' \describe{
#'   \item{ST1 (\code{cv = "yes"})}{Tree growth, pruning, and final tree
#'     selection via cross-validation.}
#'   \item{ST2 (\code{alpha})}{Tree growth with a significance threshold.
#'     No pruning or final tree selection via cross-validation.}
#'   \item{ST3 (\code{gamma})}{Tree growth followed by pruning with a
#'     pre-specified complexity parameter.  No final tree selection via
#'     cross-validation.}
#' }
#'
#' Since the time variable is not used as a splitting variable, each terminal
#' node (leaf) contains the full longitudinal responses for every subject
#' assigned to it, allowing direct evaluation of longitudinal trajectories
#' within each leaf.
#'
#' @return An object of class \code{"longitree"}.  Use
#'   \code{\link{summary.longitree}}, \code{\link{predict.longitree}},
#'   or \code{\link{plot.longitree}} to inspect the results.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{treeplot}}, \code{\link{longitrees}}
#'
#' @examples
#' data(ltreedata)
#' # ST1: tree construction via cross-validation
#' result_st1 <- longitree(y ~ ., time = "time", random = "subject",
#'                            weight = 0.7, data = ltreedata, cv = "yes")
#' summary(result_st1)
#' predict(result_st1)
#' plot(result_st1)
#' 
#' # ST2: tree growth with a significance threshold
#' result_st2 <- longitree(y ~ ., time = "time", random = "subject",
#'                            weight = 0.1, data = ltreedata, alpha = 0.05)
#' summary(result_st2)
#' predict(result_st2)
#' plot(result_st2)
#' 
#' # ST3: pruning with a complexity parameter
#' result_st3 <- longitree(y ~ ., time = "time", random = "subject",
#'                            weight = "w", data = ltreedata, gamma = 3)
#' summary(result_st3)
#' predict(result_st3)
#' plot(result_st3)
#'
#' @export
longitree <- function(formula, time, random, weight = "w", data,
                         alpha = "no", gamma = "no", cv = "no",
                         maxdepth = 5, minbucket = 5, minsplit = 20,
                         xval = 10) {
  if (alpha == "no" && gamma == "no" && cv == "no") {
    stop("At least one of the values for alpha, gamma, or cv must be specified")
  }
  specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
  if (specified_count > 1) {
    stop("Only one of alpha, gamma, or cv can be specified at a time")
  }
  weight_orig <- weight
  alpha_orig  <- alpha
  gamma_orig  <- gamma
  cv_orig     <- cv
  data_name   <- deparse(substitute(data))
  fortran_seed <- sample.int(.Machine$integer.max, 1)

  terms_obj     <- terms.formula(formula, data = data)
  response_var  <- attr(terms_obj, "response")
  response_name <- attr(terms_obj, "variables")[[response_var + 1]]
  formula_str   <- deparse(formula)
  if (grepl("\\.", formula_str)) {
    predictor_names <- setdiff(names(data), c(response_name, time, random))
  } else {
    predictor_names <- attr(terms_obj, "term.labels")
  }
  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]
  datatype <- c()
  for (col in 4:ncol(d)) {
    if (is.numeric(d[, col])) {
      datatype[col - 3] <- 1
    } else if (is.factor(d[, col])) {
      datatype[col - 3] <- 2
    } else {
      stop("Covariate data types must be either numeric or factor")
    }
  }
  for (col in 1:ncol(d)) {
    if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
  }
  d[, 3] <- as.double(as.factor(d[, 3]))
  d[, 2] <- as.double(as.factor(d[, 2]))

  if ((weight < 0 || weight > 1) && weight != "w") {
    stop("The weight value is invalid")
  }
  if (weight == "w") weight <- -1
  if (alpha == "no") alpha <- 1

  nodenummat <- matrix(0L, nrow(d), maxdepth)
  Rsplitmat  <- matrix(0, 2^maxdepth - 1, 10)
  allfval    <- rep(0, 2^maxdepth - 1)
  prunind    <- 0L
  d <- as.matrix(d)

  beta1len <- length(unique(d[, 2])) - 2
  beta2len <- (length(unique(d[, 2])) - 2) * (length(unique(d[, 3])) - 1)
  beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
  beta2 <- vapply(seq_len(beta2len),
                  function(i) beta((length(unique(d[, 3])) - 1)/2, i/2),
                  numeric(1))
  timecount <- length(unique(d[, 3]))

  res1 <- .Fortran("treegrowth",
    as.double(weight), as.integer(maxdepth), as.integer(minbucket),
    as.integer(minsplit), as.double(alpha), as.integer(nrow(d)),
    as.integer(ncol(d)), as.integer(timecount), as.double(d),
    as.integer(datatype), as.double(beta1), as.double(beta2),
    as.integer(beta1len), as.integer(beta2len), as.integer(nodenummat),
    as.double(allfval), as.integer(prunind), as.double(Rsplitmat),
    PACKAGE = "longitree")

  if (gamma == "no") gamma <- -1
  if (cv == "yes") cv <- 1
  if (cv == "no")  cv <- -1

  allgammaval <- rep(0, (2^(maxdepth - 1)) - 1)
  prunind     <- res1[[17]]
  nodenummat  <- matrix(res1[[15]], nrow(d), maxdepth)
  allfval     <- res1[[16]]

  res2 <- .Fortran("cvtreepruning",
    as.integer(prunind), as.integer(nrow(d)), as.double(alpha),
    as.double(gamma), as.integer(cv), as.integer(maxdepth),
    as.integer(nodenummat), as.double(allfval), as.double(allgammaval),
    PACKAGE = "longitree")

  bestgammaval <- 0
  r2cvval      <- 0
  allgammaval  <- res2[[9]]

  res3 <- .Fortran("crossvalidation",
    as.double(gamma), as.double(alpha), as.integer(cv), as.double(d),
    as.integer(nrow(d)), as.integer(ncol(d)), as.integer(xval),
    as.integer(timecount), as.integer(maxdepth), as.double(beta1),
    as.double(beta2), as.integer(beta1len), as.integer(beta2len),
    as.integer(minbucket), as.integer(minsplit), as.double(weight),
    as.integer(datatype), as.double(allgammaval), as.double(bestgammaval),
    as.double(r2cvval), as.integer(fortran_seed),
    PACKAGE = "longitree")

  r2cvval      <- res3[[20]]
  bestgammaval <- res3[[19]]

  prunenodenummat <- matrix(0L, nrow(d), maxdepth)
  res4 <- .Fortran("gammatreepruning",
    as.integer(prunind), as.integer(nrow(d)), as.double(bestgammaval),
    as.integer(cv), as.integer(maxdepth), as.integer(nodenummat),
    as.double(allfval), as.integer(prunenodenummat),
    PACKAGE = "longitree")

  prunenodenummat <- matrix(res4[[8]], nrow(d), maxdepth)
  if (sum(prunenodenummat == 0) == nrow(d) * (maxdepth - 1)) {
    warning("Warning: No splitting", call. = FALSE)
    return(list(rsplitmat = "No Splitting", rpredict = "No Splitting",
                nmat = "No Splitting", r2cv = r2cvval,
                cvgamma = bestgammaval))
  }

  predvec        <- rep(0, nrow(d))
  datatnnodenum  <- rep(0L, nrow(d))
  res5 <- .Fortran("treepredict",
    as.integer(nrow(d)), as.integer(timecount), as.integer(maxdepth),
    as.integer(ncol(d)), as.integer(prunenodenummat), as.double(d),
    as.double(predvec), as.integer(datatnnodenum),
    PACKAGE = "longitree")

  rsplit    <- matrix(res1[[18]], 2^maxdepth - 1, 10)
  rsplitmat <- rsplit[rsplit[, 1] != 0, , drop = FALSE]
  rpredict  <- data.frame(predict = res5[[7]], terminalnode = res5[[8]])

  cvgamma_out <- if (cv_orig == "yes") bestgammaval else NULL

  result <- list(
    rsplitmat = rsplitmat, rpredict = rpredict, nmat = nodenummat,
    r2cv = r2cvval, cvgamma = cvgamma_out,
    .meta = list(formula = formula, time = time, random = random,
                 data = data, data_name = data_name, weight = weight_orig,
                 alpha = alpha_orig, gamma = gamma_orig, cv = cv_orig,
                 maxdepth = maxdepth, datatype = datatype))
  class(result) <- "longitree"
  result
}


#' @describeIn longitree Print a brief summary of a \code{longitree}
#'   object.
#' @param object A \code{longitree} object.
#' @export
summary.longitree <- function(object, ...) {
  meta       <- object$.meta
  weight_val <- if (meta$weight == "w") '"w"' else as.character(meta$weight)
  data_str   <- if (!is.null(meta$data_name)) meta$data_name else "data"
  method_str <- if (meta$alpha != "no") paste0("alpha = ", meta$alpha)
                else if (meta$gamma != "no") paste0("gamma = ", meta$gamma)
                else 'cv = "yes"'
  cat("longitree:\n")
  cat(paste0("longitree(", deparse(meta$formula), ", time = \"",
             meta$time, "\", random = \"", meta$random, "\",\n",
             "             weight = ", weight_val, ", data = ", data_str,
             ", ", method_str, ")\n"))
  if (meta$alpha != "no") {
    cat("\nMethod: alpha =", meta$alpha, "\n")
  } else if (meta$gamma != "no") {
    cat("\nMethod: gamma =", meta$gamma, "\n")
  } else {
    cat("\nMethod: cv\n")
  }
  weight_str <- if (meta$weight == "w") "w (optimal weight)" else as.character(meta$weight)
  cat("Weight:", weight_str, "\n")
  cat("\nCross-validated coefficient of determination:", object$r2cv, "\n")
  if (!is.null(object$cvgamma)) {
    cat("Complexity parameter:", object$cvgamma, "\n")
  }
  invisible(object)
}


#' @describeIn longitree Print method (calls \code{summary}).
#' @param x A \code{longitree} object.
#' @export
print.longitree <- function(x, ...) {
  summary.longitree(x, ...)
}


#' @describeIn longitree Extract predicted values and terminal node
#'   assignments from a \code{longitree} object.  Returns a data frame
#'   with columns \code{predict} (predicted values) and
#'   \code{terminalnode} (terminal node assignments).
#' @param object A \code{longitree} object.
#' @export
predict.longitree <- function(object, ...) {
  object$rpredict
}


#' @describeIn longitree Plot a \code{longitree} object.
#'   A convenience wrapper around \code{\link{treeplot}}.
#' @param x A \code{longitree} object.
#' @param ... Additional arguments passed to \code{\link{treeplot}}.
#' @export
plot.longitree <- function(x, ...) {
  treeplot(x, ...)
}


# ---- internal helper: build a ggparty tree plot ----------------------------
.build_tree_plot <- function(formula, time, random, data, rsplitmat, rpredict,
                             nmat, weight, maxdepth, datatype,
                             use_threetrees_fortran,
                             snsize, spsize, plotsize,
                             linesize1, linesize2, tnsize) {
  terms_obj     <- terms.formula(formula, data = data)
  response_var  <- attr(terms_obj, "response")
  response_name <- attr(terms_obj, "variables")[[response_var + 1]]
  formula_str   <- deparse(formula)
  if (grepl("\\.", formula_str)) {
    predictor_names <- setdiff(names(data), c(response_name, time, random))
  } else {
    predictor_names <- attr(terms_obj, "term.labels")
  }
  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]

  if (is.null(datatype)) {
    datatype <- c()
    for (col in 4:ncol(d)) {
      if (is.numeric(d[, col])) {
        datatype[col - 3] <- 1
      } else if (is.factor(d[, col])) {
        datatype[col - 3] <- 2
      } else {
        stop("Covariate data types must be either numeric or factor")
      }
    }
  }
  for (col in 1:ncol(d)) {
    if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
  }
  d[, 3] <- as.double(as.factor(d[, 3]))
  d[, 2] <- as.double(as.factor(d[, 2]))
  for (col in 4:ncol(d)) {
    if (datatype[col - 3] == 2) d[, col] <- d[, col] - 1
  }
  d <- as.matrix(d)

  if (!is.matrix(rsplitmat)) rsplitmat <- t(as.matrix(rsplitmat))

  k <- 0; spch1 <- c(); spch2 <- c(); namevec <- c()
  for (i in unique(rsplitmat[, 1])) {
    k <- k + 1
    hie <- 1
    for (j in 1:maxdepth) {
      if (2^(j - 1) <= i && i < 2^j) { hie <- j; break }
    }
    ms <- sum(nmat == i)
    nd <- matrix(0, ms, ncol(d))
    if (use_threetrees_fortran) {
      resp1 <- .Fortran("threetreesnodedata",
        as.double(d), as.integer(nrow(d)), as.integer(ncol(d)),
        as.integer(i), as.integer(hie), as.integer(nmat),
        as.integer(ms), as.double(nd), PACKAGE = "longitree")
      nd <- matrix(resp1[[8]], ms, ncol(d))
    } else {
      resp1 <- .Fortran("nodedata",
        as.double(d), as.integer(nrow(d)), as.integer(ncol(d)),
        as.integer(i), as.integer(hie), as.integer(maxdepth),
        as.integer(nmat), as.integer(ms), as.double(nd),
        PACKAGE = "longitree")
      nd <- matrix(resp1[[9]], ms, ncol(d))
    }
    ndsize       <- length(unique(nd[, rsplitmat[k, 4]]))
    splitnumber  <- rep(0, ndsize)
    splitvector  <- rep(0L, ndsize)
    splitvec_args <- list(
      as.double(nd), as.integer(ms), as.integer(ncol(d)),
      as.integer(nrow(d)), as.integer(maxdepth), as.integer(datatype),
      as.integer(nmat), as.integer(rsplitmat[k, 4]),
      as.integer(rsplitmat[k, 5]), as.integer(ndsize),
      as.double(splitnumber), as.integer(splitvector),
      PACKAGE = "longitree")
    if (use_threetrees_fortran) {
      resp2 <- do.call(.Fortran, c("threetreessplitvec", splitvec_args))
    } else {
      resp2 <- do.call(.Fortran, c("splitvec", splitvec_args))
    }
    splitnumber <- resp2[[11]]
    splitvector <- resp2[[12]]
    if (datatype[rsplitmat[k, 4] - 3] == 1) {
      for (l in 1:(ndsize - 1)) {
        if (abs(splitvector[l + 1] - splitvector[l]) > 0) {
          spch1[k] <- paste0("<=", round(splitnumber[l], 3))
          spch2[k] <- paste0(">",  round(splitnumber[l], 3))
          break
        }
      }
    } else if (datatype[rsplitmat[k, 4] - 3] == 2) {
      spch1[k] <- "="
      spch2[k] <- "="
      for (l in 1:ndsize) {
        if (splitvector[l] == 0) {
          spch1[k] <- if (spch1[k] == "=") paste0("=", splitnumber[l])
                      else paste(spch1[k], splitnumber[l], sep = ",")
        } else if (splitvector[l] == 1) {
          spch2[k] <- if (spch2[k] == "=") paste0("=", splitnumber[l])
                      else paste(spch2[k], splitnumber[l], sep = ",")
        }
      }
    }
    namevec[k] <- colnames(d)[rsplitmat[k, 4]]
  }

  rsmat <- data.frame(
    nodenumber   = as.numeric(rsplitmat[, 1]),
    F_Gweight    = as.numeric(rsplitmat[, 2]),
    F_GTweight   = as.numeric(rsplitmat[, 3]),
    name         = as.character(namevec),
    leftsplit    = as.character(spch1),
    rightsplit   = as.character(spch2),
    F_G          = as.numeric(rsplitmat[, 6]),
    F_GT         = as.numeric(rsplitmat[, 7]),
    weightedFval = as.numeric(rsplitmat[, 8]),
    weightedpval = as.numeric(rsplitmat[, 9]),
    afmpval      = as.numeric(rsplitmat[, 10]),
    stringsAsFactors = FALSE
  )

  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]
  tname   <- as.name(time)
  if (colnames(d)[2] == "id") {
    colnames(d)[2] <- "ID"
    ranname <- as.name("ID")
  } else {
    ranname <- as.name(random)
  }

  terminal_nodes <- unique(rpredict$terminalnode)
  find_branch_nodes <- function(tn) {
    bn <- c()
    for (node in tn) {
      while (node > 1) { node <- floor(node / 2); bn <- c(bn, node) }
    }
    unique(bn)
  }
  branch_nodes <- find_branch_nodes(terminal_nodes)
  spcou <- sort(vapply(branch_nodes,
                       function(b) match(b, rsmat$nodenumber), integer(1)))
  rsmat <- rsmat[spcou, , drop = FALSE]
  rownames(rsmat) <- seq_len(nrow(rsmat))

  d$pred <- rpredict$predict
  d$tn   <- rpredict$terminalnode
  rsmat$name <- trimws(rsmat$name)

  cleaned_left  <- gsub("=|>|<", "", rsmat$leftsplit)
  cleaned_right <- gsub("=|>|<", "", rsmat$rightsplit)
  nameind <- c()
  for (i in 1:nrow(rsmat)) {
    nameind[i] <- match(rsmat$name[i], colnames(d))
    if (is.factor(d[, nameind[i]])) {
      numeric_left  <- as.numeric(strsplit(cleaned_left, ",")[[i]])
      numeric_right <- as.numeric(strsplit(cleaned_right, ",")[[i]])
      indexsp <- rep(0L, length(unique(d[, nameind[i]])))
      for (j in seq_along(numeric_left))  indexsp[numeric_left[j]  + 1] <- 1L
      for (j in seq_along(numeric_right)) indexsp[numeric_right[j] + 1] <- 2L
      indexsp[indexsp == 0] <- NA
      psub <- paste0(
        vapply(seq_along(indexsp), function(k) {
          if (is.na(indexsp[k])) "NA" else paste0(indexsp[k], "L")
        }, character(1)),
        collapse = ","
      )
      ps <- partysplit(eval(parse(text = paste0(nameind[i], "L"))),
                       index = eval(parse(text = paste0("c(", psub, ")"))))
      assign(paste0("sp_", rsmat$nodenumber[i]), ps)
    } else if (is.numeric(d[, nameind[i]]) || is.integer(d[, nameind[i]])) {
      numeric_val <- as.numeric(strsplit(cleaned_left, ",")[[i]])
      ps <- partysplit(eval(parse(text = paste0(nameind[i], "L"))),
                       breaks = numeric_val)
      assign(paste0("sp_", rsmat$nodenumber[i]), ps)
    }
  }

  tnsort <- sort(unique(d$tn))
  insort <- sort(rsmat$nodenumber)
  splitValues <- stats::setNames(
    lapply(insort, function(n) get(paste0("sp_", n))),
    as.character(insort)
  )

  createNodeString <- function(node, insort, tnsort, splitValues) {
    ns <- paste0("partynode(", node, "L, info = list(original_id=", node, "L)")
    if (!is.null(splitValues[[as.character(node)]])) {
      ns <- paste0(ns, ", split = sp_", node)
    }
    children <- c(node * 2, node * 2 + 1)
    cs <- lapply(children, function(ch) {
      if (ch %in% insort || ch %in% tnsort) {
        createNodeString(ch, insort, tnsort, splitValues)
      } else {
        NULL
      }
    })
    cs <- cs[!vapply(cs, is.null, logical(1))]
    if (length(cs) > 0) {
      ns <- paste0(ns, ", kids = list(", paste0(cs, collapse = ", "), ")")
    }
    paste0(ns, ")")
  }
  finalString <- createNodeString(1, insort, tnsort, splitValues)
  pn <- eval(parse(text = finalString))
  py <- party(pn, d)
  original_ids <- lapply(nodeapply(py, ids = nodeids(py)),
                         function(n) n$info[["original_id"]])

  ch <- c("Node", "p", "w:", "black", "bold")
  idnum <- c(); pvec <- c(); wvec <- c(); porv <- c()
  for (i in 1:nrow(rsmat)) {
    result_idx <- which(vapply(original_ids,
                               function(x) x == rsmat$nodenumber[i], logical(1)))
    idnum[i] <- result_idx
    pvec[i]  <- rsmat$weightedpval[i]
    wvec[i]  <- rsmat$F_Gweight[i]
    if (weight == "w") porv[i] <- rsmat$afmpval[i]
  }
  pvecsub <- vapply(pvec, function(p) {
    if (p < 0.001) "<0.001" else paste0("=", round(p, 3))
  }, character(1))
  pvec <- pvecsub
  if (weight == "w") {
    porvsub <- vapply(porv, function(p) {
      if (p < 0.001) {
        "expression(paste(p[OR],'<',0.001,sep=''))"
      } else {
        paste0("expression(paste(p[OR],'=',", round(p, 3), ",sep=''))")
      }
    }, character(1))
    porv <- porvsub
  }

  p_labels <- c()
  if (weight == "w") {
    for (i in seq_along(idnum)) {
      p_labels[i] <- paste0(
        "geom_node_label(line_list = list(",
        "aes(label = paste(ch[1], id)), ",
        "aes(label = splitvar), ",
        "aes(label = paste(ch[2], pvec[", i, "])), ",
        "aes(label = eval(parse(text = porv[", i, "]))), ",
        "aes(label = paste(ch[3], wvec[", i, "]))), ",
        "line_gpar = list(",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5])), ",
        "ids = idnum[", i, "])")
    }
  } else {
    for (i in seq_along(idnum)) {
      p_labels[i] <- paste0(
        "geom_node_label(line_list = list(",
        "aes(label = paste(ch[1], id)), ",
        "aes(label = splitvar), ",
        "aes(label = paste(ch[2], pvec[", i, "])), ",
        "aes(label = paste(ch[3], wvec[", i, "]))), ",
        "line_gpar = list(",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5]), ",
        "list(size = ", snsize, "/length(idnum), col = ch[4], fontface = ch[5])), ",
        "ids = idnum[", i, "])")
    }
  }

  beta0_vec <- c(); beta1_vec <- c()
  for (i in seq_along(unique(d$tn))) {
    subset_data <- subset(d, tn == sort(unique(d$tn))[i])
    mp <- paste0("lmer(", response_name, "~", tname, "+(1|", ranname,
                 "), data=subset_data)")
    model <- eval(parse(text = mp))
    fix <- round(fixef(model), 2)
    beta0_vec[i] <- fix[1]
    beta1_vec[i] <- fix[2]
  }

  tnidnum <- c()
  sorted_tn <- sort(unique(d$tn))
  for (i in seq_along(sorted_tn)) {
    tnidnum[i] <- which(vapply(original_ids,
                                function(x) x == sorted_tn[i], logical(1)))
  }

  g <- ggparty(py)
  g <- g + geom_edge(colour = "gray", linewidth = 1.5)
  g <- g + geom_edge_label(colour = "red", size = spsize)
  for (i in seq_along(idnum)) {
    g <- g + eval(parse(text = p_labels[i]))
  }

  gtex <- paste0(
    "g <- g + geom_node_plot(gglist = list(",
    "geom_line(aes(x=", tname, ", y=", response_name, ", group=", ranname,
    "), linewidth=", linesize1, ", linetype='dashed', color='#f26651'), ",
    "geom_line(aes(x=", tname, ", y=pred), color='red', linewidth=", linesize2, "), ",
    "theme_classic(base_size=", plotsize, "/length(tnidnum)), ",
    "theme(axis.title.y = element_text(margin = margin(l = -10)), ",
    "axis.title.x = element_text(margin = margin(t = -5)))",
    "), scales='fixed', id='terminal', ",
    "shared_axis_labels=TRUE, shared_legend=TRUE, legend_separator=TRUE,)")
  eval(parse(text = gtex))

  for (i in seq_along(tnidnum)) {
    rank_num    <- rank(tnidnum, ties.method = "first")
    new_tnidnum <- length(idnum) + rank_num
    beta01_i <- paste0(
      "expression(paste(beta[0],'=',", beta0_vec[i],
      ",' ',beta[1],'=',", beta1_vec[i], ",sep=''))")
    ptn <- paste0(
      "geom_node_label(line_list = list(",
      "aes(label = paste0('Node',", new_tnidnum[i],
      ",', N = ', nodesize/length(unique(d$", tname, ")))), ",
      "aes(label = eval(parse(text = ", deparse(beta01_i), ")))), ",
      "line_gpar = list(",
      "list(size=", tnsize, "/length(tnidnum), col=ch[4], fontface=ch[5]), ",
      "list(size=", tnsize, "/length(tnidnum), col=ch[4], fontface=ch[5])), ",
      "fontface='bold', ids=tnidnum[", i, "], size=5, nudge_y=0.01)")
    g <- g + eval(parse(text = ptn))
  }
  g
}


#' Decision Tree Plot Visualisation for Longitudinal Data
#'
#' @description
#' Visualises the structure of a decision tree for longitudinal
#' data.  Built on \pkg{ggparty}.  Each split node displays the node
#' number, split variable, \eqn{p}-value, and weight \eqn{w}.  Each
#' terminal node displays the node number, sample size \eqn{N}, and the
#' intercept (\eqn{\hat\beta_0}) and slope (\eqn{\hat\beta_1}) from a
#' linear mixed-effects model fitted within that node.  Individual
#' longitudinal trajectories are shown as dashed lines; the predicted
#' values (average at each time point) are shown as solid lines, with the
#' response variable on the vertical axis and time on the horizontal axis.
#'
#' @param x A \code{longitree} or \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree to plot when \code{x}
#'   is a \code{threetrees} object.
#' @param snsize Split-node label size (default 50).
#' @param spsize Split-point label size (default 5).
#' @param plotsize Overall plot size (default 80).
#' @param linesize1 Branch line width (default 0.3).
#' @param linesize2 Main line width (default 1).
#' @param tnsize Terminal-node label size (default 60).
#'
#' @return A \code{ggplot2}/\code{ggparty} object.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitree}}, \code{\link{threetrees}}
#'
#' @export
treeplot <- function(x, tree = NULL,
                     snsize = 50, spsize = 5, plotsize = 80,
                     linesize1 = 0.3, linesize2 = 1, tnsize = 60) {
  use_threetrees_fortran <- FALSE
  datatype <- NULL

  if (inherits(x, "longitree")) {
    meta      <- x$.meta
    formula   <- meta$formula
    time      <- meta$time
    random    <- meta$random
    data      <- meta$data
    weight    <- meta$weight
    maxdepth  <- meta$maxdepth
    datatype  <- meta$datatype
    rsplitmat <- x$rsplitmat
    rpredict  <- x$rpredict
    nmat      <- x$nmat
  } else if (inherits(x, "threetrees")) {
    meta      <- x$.meta
    formula   <- meta$formula
    time      <- meta$time
    random    <- meta$random
    weight    <- meta$weight
    maxdepth  <- meta$maxdepth
    use_threetrees_fortran <- TRUE
    if (is.null(tree) || !(tree %in% 1:3)) {
      stop("tree must be 1, 2, or 3 when using threetrees object")
    }
    data      <- x[[paste0("bootdata", tree)]]
    rsplitmat <- x[[paste0("rsplitmat", tree)]]
    rpredict  <- x[[paste0("rpredict", tree)]]
    nmat      <- x[[paste0("nmat", tree)]]
  } else {
    stop("x must be a longitree or threetrees object")
  }

  .build_tree_plot(
    formula = formula, time = time, random = random, data = data,
    rsplitmat = rsplitmat, rpredict = rpredict, nmat = nmat,
    weight = weight, maxdepth = maxdepth, datatype = datatype,
    use_threetrees_fortran = use_threetrees_fortran,
    snsize = snsize, spsize = spsize, plotsize = plotsize,
    linesize1 = linesize1, linesize2 = linesize2, tnsize = tnsize)
}


#' Construction of Multiple Decision Trees for Longitudinal Data
#'
#' @description
#' Generates multiple trees from bootstrap samples and evaluates
#' all three-tree combinations based on two criteria: cross-validated
#' prediction error and tree diversification measured by the adjusted Rand
#' index (ARI).  Bootstrap sampling is performed at the subject level to
#' preserve longitudinal structure.
#'
#' @inheritParams longitree
#' @param bootsize Number of subjects in each bootstrap sample.
#' @param trees Number of bootstrap trees to grow (default 100).
#' @param mins Number of top-ranking candidate three-tree subsets to retain
#'   (default 40).
#'
#' @details
#' See \code{\link{longitree}} for a description of the three single-tree
#' construction procedures (ST1, ST2, ST3) corresponding to \code{cv},
#' \code{alpha}, and \code{gamma}.
#'
#' @return An object of class \code{"longitrees"}.  Pass to
#'   \code{\link{selectionplot}} to select the optimal three-tree combination.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitree}}, \code{\link{selectionplot}},
#'   \code{\link{threetrees}}, \code{\link{treeplot}}
#'
#' @export
longitrees <- function(formula, time, random, weight = "w", data,
                          alpha = "no", gamma = "no", cv = "no",
                          maxdepth = 5, minbucket = 5, minsplit = 20,
                          xval = 10, bootsize, trees = 100, mins = 40) {
  if (alpha == "no" && gamma == "no" && cv == "no") {
    stop("At least one of the values for alpha, gamma, or cv must be specified")
  }
  specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
  if (specified_count > 1) {
    stop("Only one of alpha, gamma, or cv can be specified at a time")
  }
  weight_orig <- weight
  alpha_orig  <- alpha
  gamma_orig  <- gamma
  cv_orig     <- cv
  data_name   <- deparse(substitute(data))
  fortran_seeds <- sample.int(.Machine$integer.max, trees)

  terms_obj     <- terms.formula(formula, data = data)
  response_var  <- attr(terms_obj, "response")
  response_name <- attr(terms_obj, "variables")[[response_var + 1]]
  formula_str   <- deparse(formula)
  if (grepl("\\.", formula_str)) {
    predictor_names <- setdiff(names(data), c(response_name, time, random))
  } else {
    predictor_names <- attr(terms_obj, "term.labels")
  }
  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]
  datatype <- c()
  for (col in 4:ncol(d)) {
    if (is.numeric(d[, col])) {
      datatype[col - 3] <- 1
    } else if (is.factor(d[, col])) {
      datatype[col - 3] <- 2
    } else {
      stop("Covariate data types must be either numeric or factor")
    }
  }
  for (col in 1:ncol(d)) {
    if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
  }
  d[, 3] <- as.double(as.factor(d[, 3]))
  d[, 2] <- as.double(as.factor(d[, 2]))

  if ((weight < 0 || weight > 1) && weight != "w") {
    stop("The weight value is invalid")
  }
  if (weight == "w") weight <- -1
  d <- as.matrix(d)

  timecount    <- length(unique(d[, 3]))
  bootdatasize <- bootsize * timecount
  if (alpha == "no") alpha <- 1

  beta1len <- bootsize - 2
  beta2len <- (bootsize - 2) * (timecount - 1)
  beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
  beta2 <- vapply(seq_len(beta2len),
                  function(i) beta((timecount - 1)/2, i/2), numeric(1))

  if (gamma == "no") gamma <- -1
  if (cv == "yes") cv <- 1
  if (cv == "no")  cv <- -1

  predvecmat       <- matrix(0, nrow(d), trees)
  datatnnodenummat <- matrix(0L, nrow(d), trees)
  oridatamat       <- matrix(0L, bootdatasize, trees)
  cvsplitindmat    <- matrix(0L, bootdatasize, trees)

  unique_subjects <- unique(d[, 2])
  boot_mat <- replicate(trees,
                        sample(unique_subjects, size = bootsize, replace = TRUE))

  for (i in 1:trees) {
    subset_idx <- unlist(lapply(boot_mat[, i], function(b) which(d[, 2] == b)))
    oridatamat[, i] <- subset_idx
    bootdata <- do.call(rbind,
                        lapply(boot_mat[, i], function(b) d[d[, 2] == b, ]))
    bootdata[, 2] <- rep(1:bootsize, each = timecount)

    nodenummatori  <- matrix(0L, nrow(d), maxdepth)
    nodenummatboot <- matrix(0L, bootdatasize, maxdepth)
    allfval  <- rep(0, 2^maxdepth - 1)
    prunind  <- 0L

    res2 <- .Fortran("threetreesboottreegrowth",
      as.integer(nodenummatboot), as.integer(nodenummatori),
      as.double(allfval), as.integer(prunind), as.integer(bootdatasize),
      as.integer(ncol(d)), as.double(bootdata), as.integer(datatype),
      as.double(weight), as.double(beta1), as.double(beta2),
      as.integer(beta1len), as.integer(beta2len), as.integer(timecount),
      as.integer(minbucket), as.double(d), as.integer(nrow(d)),
      as.double(alpha), as.integer(maxdepth), as.integer(minsplit),
      PACKAGE = "longitree")

    allgammaval    <- rep(0, (2^(maxdepth - 1)) - 1)
    prunind        <- res2[[4]]
    nodenummatori  <- matrix(res2[[2]], nrow(d), maxdepth)
    nodenummatboot <- matrix(res2[[1]], bootdatasize, maxdepth)
    allfval        <- res2[[3]]

    res3 <- .Fortran("threetreesbootcvtreepruning",
      as.integer(prunind), as.integer(bootdatasize), as.double(alpha),
      as.double(gamma), as.integer(cv), as.integer(maxdepth),
      as.integer(nodenummatboot), as.double(allfval),
      as.double(allgammaval),
      PACKAGE = "longitree")

    bestgammaval <- 0
    allgammaval  <- res3[[9]]
    cvsplitind   <- rep(0L, bootdatasize)

    res4 <- .Fortran("threetreesbootcrossvaridation",
      as.double(gamma), as.double(alpha), as.integer(cv),
      as.double(bootdata), as.integer(bootdatasize), as.integer(ncol(d)),
      as.integer(xval), as.integer(timecount), as.integer(maxdepth),
      as.double(beta1), as.double(beta2), as.integer(beta1len),
      as.integer(beta2len), as.integer(minbucket), as.integer(minsplit),
      as.double(weight), as.integer(datatype), as.double(allgammaval),
      as.double(bestgammaval), as.integer(cvsplitind),
      as.integer(fortran_seeds[i]),
      PACKAGE = "longitree")

    bestgammaval <- res4[[19]]
    cvsplitind   <- res4[[20]]
    cvsplitindmat[, i] <- cvsplitind

    prunenodenummatboot <- matrix(0L, bootdatasize, maxdepth)
    prunenodenummatori  <- matrix(0L, nrow(d), maxdepth)

    res5 <- .Fortran("threetreesbootgammatreepruning",
      as.integer(prunind), as.integer(bootdatasize), as.integer(nrow(d)),
      as.double(bestgammaval), as.integer(cv), as.integer(maxdepth),
      as.integer(nodenummatboot), as.integer(nodenummatori),
      as.double(allfval), as.integer(prunenodenummatboot),
      as.integer(prunenodenummatori),
      PACKAGE = "longitree")

    prunenodenummatboot <- matrix(res5[[10]], bootdatasize, maxdepth)
    prunenodenummatori  <- matrix(res5[[11]], nrow(d), maxdepth)

    predvec       <- rep(0, nrow(d))
    datatnnodenum <- rep(0L, nrow(d))

    res6 <- .Fortran("threetreesboottreepredict",
      as.integer(nrow(d)), as.integer(bootdatasize),
      as.integer(timecount), as.integer(maxdepth),
      as.integer(prunenodenummatboot), as.integer(prunenodenummatori),
      as.double(bootdata), as.double(d), as.integer(ncol(d)),
      as.double(predvec), as.integer(datatnnodenum),
      PACKAGE = "longitree")

    predvecmat[, i]       <- res6[[10]]
    datatnnodenummat[, i] <- res6[[11]]
  }

  bettertreelossval <- matrix(0, mins, 5)
  bettertreemaxval  <- matrix(0, mins, 5)
  bettertreenum     <- matrix(0L, mins, 6)

  res7 <- .Fortran("threetreestreechoice",
    as.double(predvecmat), as.integer(datatnnodenummat),
    as.integer(nrow(d)), as.integer(trees), as.integer(timecount),
    as.double(d), as.integer(ncol(d)), as.integer(mins),
    as.integer(bettertreenum), as.double(bettertreelossval),
    as.double(bettertreemaxval),
    PACKAGE = "longitree")

  bettertreelossval <- matrix(res7[[10]], mins, 5)
  bettertreemaxval  <- matrix(res7[[11]], mins, 5)
  bettertreenum     <- matrix(res7[[9]], mins, 6)

  besttreemat <- matrix(NA, (mins * 4) + 2, 8)
  besttreemat[1, 1:3] <- -100
  besttreemat[(mins * 2 + 2), 1:3] <- -200
  j <- 1
  for (i in 1:mins) {
    j <- j + 1
    besttreemat[j, 1:3] <- bettertreenum[i, 1:3]
    j <- j + 1
    besttreemat[j, 1:3] <- 0
    besttreemat[j, 4:8] <- bettertreelossval[i, 1:5]
  }
  j <- j + 1
  for (i in 1:mins) {
    j <- j + 1
    besttreemat[j, 1:3] <- bettertreenum[i, 4:6]
    j <- j + 1
    besttreemat[j, 1:3] <- 0
    besttreemat[j, 4:8] <- bettertreemaxval[i, 1:5]
  }

  loss <- besttreemat[seq(3, mins * 2 + 1, 2), ]
  rand <- besttreemat[seq(mins * 2 + 4, mins * 4 + 2, 2), ]
  loss_valid <- apply(loss[, 6:8, drop = FALSE], 1, function(r) all(r < 1.5))
  rand_valid <- apply(rand[, 6:8, drop = FALSE], 1, function(r) all(r < 1.5))
  loss <- loss[loss_valid, , drop = FALSE]
  rand <- rand[rand_valid, , drop = FALSE]

  lossy <- loss[, 4]
  lossx <- apply(loss[, 6:8, drop = FALSE], 1, max)
  loss  <- data.frame(loss = lossy, rand = lossx,
                      ari12 = loss[, 6], ari13 = loss[, 7],
                      ari23 = loss[, 8])
  randy <- rand[, 4]
  randx <- apply(rand[, 6:8, drop = FALSE], 1, max)
  rand  <- data.frame(loss = randy, rand = randx,
                      ari12 = rand[, 6], ari13 = rand[, 7],
                      ari23 = rand[, 8])
  plottree <- rbind(loss, rand)
  plot(plottree$rand, plottree$loss, col = "gray", cex = 2, pch = 19,
       cex.axis = 1.7, xlab = "maximum of the three ARIs",
       ylab = "cross-validated prediction error")

  loss_treenum_raw <- besttreemat[seq(2, mins * 2, 2), 1:3, drop = FALSE]
  rand_treenum_raw <- besttreemat[seq(mins * 2 + 3, mins * 4 + 1, 2),
                                  1:3, drop = FALSE]
  treenumber_loss <- data.frame(loss_treenum_raw[loss_valid, , drop = FALSE])
  treenumber_rand <- data.frame(rand_treenum_raw[rand_valid, , drop = FALSE])
  treenumber <- rbind(treenumber_loss, treenumber_rand)
  names(treenumber) <- c("treenumber1", "treenumber2", "treenumber3")

  result <- list(
    oridatamat = oridatamat, plottree = plottree,
    cvsplitindmat = cvsplitindmat, treenumber = treenumber,
    datatype = datatype,
    .meta = list(
      formula = formula, time = time, random = random, weight = weight_orig,
      data = data, data_name = data_name, alpha = alpha_orig,
      gamma = gamma_orig, cv = cv_orig,
      bootsize = bootsize, maxdepth = maxdepth, minbucket = minbucket,
      minsplit = minsplit, xval = xval, mins = mins))
  class(result) <- "longitrees"
  result
}


#' Select Optimal Three-Tree Combination
#'
#' @description
#' Plots the cross-validated prediction error against the maximum pairwise
#' adjusted Rand index (ARI) for candidate three-tree subsets, and selects
#' a subset based on either prediction performance or tree diversification.
#' The selected combination is indicated by a red point on the plot, which
#' corresponds to the three trees used in the subsequent
#' \code{\link{threetrees}} step.
#'
#' @param longitrees A \code{longitrees} object.
#' @param metric \code{"PE"} to select the subset with the smallest
#'   cross-validated prediction error, or \code{"ARI"} to select the subset
#'   with the smallest maximum pairwise ARI (greatest tree diversification).
#' @param nth Rank of the tree subset to select (1 = best).
#'
#' @return An object of class \code{"selectionplot"}.  Pass to
#'   \code{\link{threetrees}} to refit and evaluate the selected trees.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitrees}}, \code{\link{threetrees}}
#'
#' @export
selectionplot <- function(longitrees, metric, nth) {
  if (!inherits(longitrees, "longitrees"))
    stop("longitrees must be a longitrees object")
  plottree   <- longitrees$plottree
  treenumber <- longitrees$treenumber
  mins       <- longitrees$.meta$mins
  plot(plottree$rand, plottree$loss, col = "gray", cex = 2, pch = 19,
       cex.axis = 1.7, xlab = "maximum of the three ARIs",
       ylab = "cross-validated prediction error")
  half <- nrow(plottree) %/% 2
  if (metric == "PE") {
    plottreech   <- plottree[1:half, ]
    treenumberch <- treenumber[1:half, ]
    pt <- data.frame(plottreech, treenumberch)
    df_sorted <- pt[order(pt$loss, -pt$rand), ]
    points(df_sorted[nth, "rand"], df_sorted[nth, "loss"],
           col = "red", cex = 2, pch = 19)
    threetreenumber <- df_sorted[nth, c("treenumber1", "treenumber2",
                                        "treenumber3")]
    selected_pe  <- df_sorted[nth, "loss"]
    selected_ari <- df_sorted[nth, "rand"]
    selected_ari_pair <- setNames(
      as.numeric(df_sorted[nth, c("ari12", "ari13", "ari23")]),
      c("ari12", "ari13", "ari23"))
  } else if (metric == "ARI") {
    plottreech   <- plottree[(half + 1):nrow(plottree), ]
    treenumberch <- treenumber[(half + 1):nrow(treenumber), ]
    pt <- data.frame(plottreech, treenumberch)
    df_sorted <- pt[order(pt$rand, pt$loss), ]
    points(df_sorted[nth, "rand"], df_sorted[nth, "loss"],
           col = "red", cex = 2, pch = 19)
    threetreenumber <- df_sorted[nth, c("treenumber1", "treenumber2",
                                        "treenumber3")]
    selected_pe  <- df_sorted[nth, "loss"]
    selected_ari <- df_sorted[nth, "rand"]
    selected_ari_pair <- setNames(
      as.numeric(df_sorted[nth, c("ari12", "ari13", "ari23")]),
      c("ari12", "ari13", "ari23"))
  } else {
    stop("metric must be 'PE' or 'ARI'")
  }
  result <- list(
    threetreenumber = threetreenumber,
    selected_pe     = unname(selected_pe),
    selected_ari    = unname(selected_ari),
    selected_ari_pair = selected_ari_pair)
  class(result) <- "selectionplot"
  result
}


#' Fit and Evaluate Three Selected Trees
#'
#' @description
#' Refits the three trees selected by \code{\link{selectionplot}} on their
#' original bootstrap samples.
#'
#' @param x A \code{longitrees} object.
#' @param selection A \code{selectionplot} object.
#'
#' @return An object of class \code{"threetrees"}.  Use
#'   \code{\link{summary.threetrees}}, \code{\link{predict.threetrees}},
#'   or \code{\link{plot.threetrees}} to inspect the results.
#'
#' @references
#' Obata, R. and Sugimoto, T. (2026). A decision tree analysis for
#' longitudinal measurement data and its applications.
#' \emph{Advances in Data Analysis and Classification}.
#' \doi{10.1007/s11634-025-00665-2}
#'
#' @seealso \code{\link{longitrees}}, \code{\link{selectionplot}},
#'   \code{\link{treeplot}}
#'
#' @examples
#' data(ltreedata)
#' set.seed(10)
#' trees_res <- longitrees(y ~ ., time = "time", random = "subject",
#'                            weight = 0.5, data = ltreedata, alpha = 0.01,
#'                            bootsize = 50, mins = 40)
#' sel <- selectionplot(trees_res, metric = "PE", nth = 1)
#' tt <- threetrees(trees_res, selection = sel)
#' summary(tt)
#' predict(tt, tree = 1)
#' predict(tt, tree = 2)
#' predict(tt, tree = 3)
#' plot(tt, tree = 1)
#' plot(tt, tree = 2)
#' plot(tt, tree = 3)
#'
#' @export
threetrees <- function(x, selection) {
  if (!inherits(x, "longitrees"))
    stop("x must be a longitrees object")
  if (!inherits(selection, "selectionplot"))
    stop("selection must be a selectionplot object")
  selected_pe  <- NULL
  selected_ari <- NULL
  selected_ari_pair <- NULL

  if (inherits(x, "longitrees")) {
    meta     <- x$.meta
    formula  <- meta$formula
    time     <- meta$time
    random   <- meta$random
    weight   <- meta$weight
    data     <- meta$data
    data_name <- if (!is.null(meta$data_name)) meta$data_name else "data"
    alpha    <- meta$alpha
    gamma    <- meta$gamma
    cv       <- meta$cv
    maxdepth <- meta$maxdepth
    minbucket <- meta$minbucket
    minsplit  <- meta$minsplit
    xval      <- meta$xval
    bootsize  <- meta$bootsize
    oridatamat   <- x$oridatamat
    cvsplitindmat <- x$cvsplitindmat
    selected_pe       <- selection$selected_pe
    selected_ari      <- selection$selected_ari
    selected_ari_pair <- selection$selected_ari_pair
    threetreenumber   <- selection$threetreenumber
  }
  if (alpha == "no" && gamma == "no" && cv == "no") {
    stop("At least one of the values for alpha, gamma, or cv must be specified")
  }
  specified_count <- (alpha != "no") + (gamma != "no") + (cv != "no")
  if (specified_count > 1) {
    stop("Only one of alpha, gamma, or cv can be specified at a time")
  }
  fortran_seed <- sample.int(.Machine$integer.max, 1)

  terms_obj     <- terms.formula(formula, data = data)
  response_var  <- attr(terms_obj, "response")
  response_name <- attr(terms_obj, "variables")[[response_var + 1]]
  formula_str   <- deparse(formula)
  if (grepl("\\.", formula_str)) {
    predictor_names <- setdiff(names(data), c(response_name, time, random))
  } else {
    predictor_names <- attr(terms_obj, "term.labels")
  }
  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]
  datatype <- c()
  for (col in 4:ncol(d)) {
    if (is.numeric(d[, col])) {
      datatype[col - 3] <- 1
    } else if (is.factor(d[, col])) {
      datatype[col - 3] <- 2
    } else {
      stop("Covariate data types must be either numeric or factor")
    }
  }
  for (col in 1:ncol(d)) {
    if (!is.double(d[, col])) d[, col] <- as.double(d[, col])
  }
  d[, 3] <- as.double(as.factor(d[, 3]))
  d[, 2] <- as.double(as.factor(d[, 2]))

  weight_orig <- weight
  alpha_orig  <- alpha
  gamma_orig  <- gamma
  if ((weight < 0 || weight > 1) && weight != "w") {
    stop("The weight value is invalid")
  }
  if (weight == "w") weight <- -1

  timecount    <- length(unique(d[, 3]))
  bootdatasize <- bootsize * timecount

  ori1 <- oridatamat[, threetreenumber$treenumber1]
  bootdata1 <- d[ori1, ] ; bootdata1[, 2] <- rep(1:bootsize, each = timecount)
  ori2 <- oridatamat[, threetreenumber$treenumber2]
  bootdata2 <- d[ori2, ] ; bootdata2[, 2] <- rep(1:bootsize, each = timecount)
  ori3 <- oridatamat[, threetreenumber$treenumber3]
  bootdata3 <- d[ori3, ] ; bootdata3[, 2] <- rep(1:bootsize, each = timecount)

  cvsplitind1 <- cvsplitindmat[, threetreenumber$treenumber1]
  cvsplitind2 <- cvsplitindmat[, threetreenumber$treenumber2]
  cvsplitind3 <- cvsplitindmat[, threetreenumber$treenumber3]

  if (alpha == "no") alpha <- 1
  beta1len <- bootsize - 2
  beta2len <- (bootsize - 2) * (timecount - 1)
  beta1 <- vapply(seq_len(beta1len), function(i) beta(1/2, i/2), numeric(1))
  beta2 <- vapply(seq_len(beta2len),
                  function(i) beta((timecount - 1)/2, i/2), numeric(1))

  make_mats <- function() {
    list(nmat = matrix(0L, bootdatasize, maxdepth),
         rsplit = matrix(0, 2^maxdepth - 1, 10),
         allfval = rep(0, 2^maxdepth - 1), prunind = 0L)
  }
  m1 <- make_mats(); m2 <- make_mats(); m3 <- make_mats()
  d1 <- as.matrix(bootdata1)
  d2 <- as.matrix(bootdata2)
  d3 <- as.matrix(bootdata3)

  grow_tree <- function(dd, mm) {
    .Fortran("threetreestreegrowth",
      as.double(weight), as.integer(maxdepth), as.integer(minbucket),
      as.integer(minsplit), as.double(alpha), as.integer(nrow(dd)),
      as.integer(ncol(dd)), as.integer(timecount), as.double(dd),
      as.integer(datatype), as.double(beta1), as.double(beta2),
      as.integer(beta1len), as.integer(beta2len), as.integer(mm$nmat),
      as.double(mm$allfval), as.integer(mm$prunind), as.double(mm$rsplit),
      PACKAGE = "longitree")
  }
  res11 <- grow_tree(d1, m1)
  res12 <- grow_tree(d2, m2)
  res13 <- grow_tree(d3, m3)

  if (gamma == "no") gamma <- -1
  cv_orig <- cv
  if (cv == "yes") cv <- 1
  if (cv == "no")  cv <- -1

  allgammaval1 <- rep(0, (2^(maxdepth - 1)) - 1)
  allgammaval2 <- rep(0, (2^(maxdepth - 1)) - 1)
  allgammaval3 <- rep(0, (2^(maxdepth - 1)) - 1)
  prunind1 <- res11[[17]]; prunind2 <- res12[[17]]; prunind3 <- res13[[17]]
  nodenummat1 <- matrix(res11[[15]], bootdatasize, maxdepth)
  nodenummat2 <- matrix(res12[[15]], bootdatasize, maxdepth)
  nodenummat3 <- matrix(res13[[15]], bootdatasize, maxdepth)
  allfval1 <- res11[[16]]; allfval2 <- res12[[16]]; allfval3 <- res13[[16]]

  cv_prune <- function(pi, dd, al, gm, cv_flag, nm, af, ag) {
    .Fortran("threetreescvtreepruning",
      as.integer(pi), as.integer(nrow(dd)), as.double(al),
      as.double(gm), as.integer(cv_flag), as.integer(maxdepth),
      as.integer(nm), as.double(af), as.double(ag),
      PACKAGE = "longitree")
  }
  res21 <- cv_prune(prunind1, d1, alpha, gamma, cv, nodenummat1, allfval1, allgammaval1)
  res22 <- cv_prune(prunind2, d2, alpha, gamma, cv, nodenummat2, allfval2, allgammaval2)
  res23 <- cv_prune(prunind3, d3, alpha, gamma, cv, nodenummat3, allfval3, allgammaval3)

  bestgammaval1 <- 0; bestgammaval2 <- 0; bestgammaval3 <- 0
  r2cvval1 <- 0; r2cvval2 <- 0; r2cvval3 <- 0; r2cvvalthree <- 0
  allgammaval1 <- res21[[9]]; allgammaval2 <- res22[[9]]
  allgammaval3 <- res23[[9]]

  d_mat <- as.matrix(d)
  res30 <- .Fortran("threetreesthreetreecv",
    as.double(gamma), as.double(alpha), as.integer(cv),
    as.double(d1), as.double(d2), as.double(d3),
    as.integer(bootdatasize), as.integer(ncol(d_mat)),
    as.integer(xval), as.integer(timecount), as.integer(maxdepth),
    as.double(beta1), as.double(beta2), as.integer(beta1len),
    as.integer(beta2len), as.integer(minbucket), as.integer(minsplit),
    as.double(weight), as.integer(datatype),
    as.double(allgammaval1), as.double(allgammaval2),
    as.double(allgammaval3), as.double(bestgammaval1),
    as.double(bestgammaval2), as.double(bestgammaval3),
    as.double(r2cvval1), as.double(r2cvval2), as.double(r2cvval3),
    as.double(r2cvvalthree),
    as.integer(cvsplitind1), as.integer(cvsplitind2),
    as.integer(cvsplitind3), as.integer(fortran_seed),
    PACKAGE = "longitree")

  r2cvval1     <- res30[[26]]; r2cvval2     <- res30[[27]]
  r2cvval3     <- res30[[28]]; r2cvvalthree <- res30[[29]]
  bestgammaval1 <- res30[[23]]; bestgammaval2 <- res30[[24]]
  bestgammaval3 <- res30[[25]]

  g_prune <- function(pi, bds, bg, cv_flag, nm, af) {
    pnm <- matrix(0L, bds, maxdepth)
    .Fortran("threetreesgammatreepruning",
      as.integer(pi), as.integer(bds), as.double(bg),
      as.integer(cv_flag), as.integer(maxdepth), as.integer(nm),
      as.double(af), as.integer(pnm),
      PACKAGE = "longitree")
  }
  res41 <- g_prune(prunind1, bootdatasize, bestgammaval1, cv, nodenummat1, allfval1)
  res42 <- g_prune(prunind2, bootdatasize, bestgammaval2, cv, nodenummat2, allfval2)
  res43 <- g_prune(prunind3, bootdatasize, bestgammaval3, cv, nodenummat3, allfval3)

  prunenodenummat1 <- matrix(res41[[8]], bootdatasize, maxdepth)
  prunenodenummat2 <- matrix(res42[[8]], bootdatasize, maxdepth)
  prunenodenummat3 <- matrix(res43[[8]], bootdatasize, maxdepth)

  if (sum(prunenodenummat1 == 0) == bootdatasize * (maxdepth - 1))
    stop("No splitting (The first tree)")
  if (sum(prunenodenummat2 == 0) == bootdatasize * (maxdepth - 1))
    stop("No splitting (The second tree)")
  if (sum(prunenodenummat3 == 0) == bootdatasize * (maxdepth - 1))
    stop("No splitting (The third tree)")

  tree_predict <- function(bds, dd, pnm) {
    pv <- rep(0, bds); tn <- rep(0L, bds)
    .Fortran("threetreestreepredict",
      as.integer(bds), as.integer(timecount), as.integer(maxdepth),
      as.integer(ncol(dd)), as.integer(pnm), as.double(dd),
      as.double(pv), as.integer(tn),
      PACKAGE = "longitree")
  }
  res51 <- tree_predict(bootdatasize, d1, prunenodenummat1)
  res52 <- tree_predict(bootdatasize, d2, prunenodenummat2)
  res53 <- tree_predict(bootdatasize, d3, prunenodenummat3)

  extract_split <- function(res) {
    rs <- matrix(res[[18]], 2^maxdepth - 1, 10)
    rs[rs[, 1] != 0, , drop = FALSE]
  }
  rsplitmat1 <- extract_split(res11)
  rsplitmat2 <- extract_split(res12)
  rsplitmat3 <- extract_split(res13)

  rpredict1 <- data.frame(predict = res51[[7]], terminalnode = res51[[8]])
  rpredict2 <- data.frame(predict = res52[[7]], terminalnode = res52[[8]])
  rpredict3 <- data.frame(predict = res53[[7]], terminalnode = res53[[8]])

  d <- data[c(paste(response_name), paste(random), paste(time),
              paste(predictor_names))]
  ori1 <- oridatamat[, threetreenumber$treenumber1]
  bootdata1 <- d[ori1, ]; bootdata1[, 2] <- rep(1:bootsize, each = timecount)
  ori2 <- oridatamat[, threetreenumber$treenumber2]
  bootdata2 <- d[ori2, ]; bootdata2[, 2] <- rep(1:bootsize, each = timecount)
  ori3 <- oridatamat[, threetreenumber$treenumber3]
  bootdata3 <- d[ori3, ]; bootdata3[, 2] <- rep(1:bootsize, each = timecount)

  cvgamma_list <- if (cv_orig == "yes") {
    list(cvgamma1 = bestgammaval1, cvgamma2 = bestgammaval2,
         cvgamma3 = bestgammaval3)
  } else {
    list(cvgamma1 = NULL, cvgamma2 = NULL, cvgamma3 = NULL)
  }

  result <- c(list(
    bootdata1 = bootdata1, bootdata2 = bootdata2, bootdata3 = bootdata3,
    rsplitmat1 = rsplitmat1, rsplitmat2 = rsplitmat2, rsplitmat3 = rsplitmat3,
    rpredict1 = rpredict1, rpredict2 = rpredict2, rpredict3 = rpredict3,
    nmat1 = nodenummat1, nmat2 = nodenummat2, nmat3 = nodenummat3,
    r2cvtree1 = r2cvval1, r2cvtree2 = r2cvval2, r2cvtree3 = r2cvval3,
    r2cv3trees = r2cvvalthree,
    selected_pe = selected_pe, selected_ari = selected_ari,
    selected_ari_pair = selected_ari_pair),
    cvgamma_list,
    list(.meta = list(formula = formula, time = time, random = random,
                      weight = weight_orig, data_name = data_name,
                      alpha = alpha_orig, gamma = gamma_orig,
                      cv = cv_orig, maxdepth = maxdepth)))
  class(result) <- "threetrees"
  result
}


#' @describeIn threetrees Print a brief summary of a \code{threetrees} object.
#' @param object A \code{threetrees} object.
#' @export
summary.threetrees <- function(object, ...) {
  meta       <- object$.meta
  weight_val <- if (is.null(meta$weight) || meta$weight == "w") '"w"'
                else as.character(meta$weight)
  data_str   <- if (!is.null(meta$data_name)) meta$data_name else "data"
  method_str <- if (!is.null(meta$alpha) && meta$alpha != "no") paste0("alpha = ", meta$alpha)
                else if (!is.null(meta$gamma) && meta$gamma != "no") paste0("gamma = ", meta$gamma)
                else 'cv = "yes"'
  cat("threetrees:\n")
  cat(paste0("threetrees(", deparse(meta$formula), ", time = \"",
             meta$time, "\", random = \"", meta$random, "\",\n",
             "           weight = ", weight_val, ", data = ", data_str,
             ", ", method_str, ")\n"))
  if (!is.null(meta$alpha) && meta$alpha != "no") {
    cat("\nMethod: alpha =", meta$alpha, "\n")
  } else if (!is.null(meta$gamma) && meta$gamma != "no") {
    cat("\nMethod: gamma =", meta$gamma, "\n")
  } else {
    cat("\nMethod: cv\n")
  }
  weight_str <- if (is.null(meta$weight) || meta$weight == "w") "w (optimal weight)"
                else as.character(meta$weight)
  cat("Weight:", weight_str, "\n")
  cat("\n")
  if (!is.null(object$selected_pe))
    cat("Cross-validated prediction error:", object$selected_pe, "\n")
  if (!is.null(object$selected_ari))
    cat("Maximum of the three ARIs:", object$selected_ari, "\n")
  if (!is.null(object$selected_ari_pair)) {
    cat("ARI Pairs:\n")
    cat("  ARI(Tree1, Tree2):", object$selected_ari_pair["ari12"], "\n")
    cat("  ARI(Tree1, Tree3):", object$selected_ari_pair["ari13"], "\n")
    cat("  ARI(Tree2, Tree3):", object$selected_ari_pair["ari23"], "\n")
  }
  cat("\nCross-validated coefficient of determination:\n")
  cat("  Tree 1:", object$r2cvtree1, "\n")
  cat("  Tree 2:", object$r2cvtree2, "\n")
  cat("  Tree 3:", object$r2cvtree3, "\n")
  cat("  3Trees:", object$r2cv3trees, "\n")
  if (!is.null(object$cvgamma1)) {
    cat("\nComplexity parameter:\n")
    cat("  Tree 1:", object$cvgamma1, "\n")
    cat("  Tree 2:", object$cvgamma2, "\n")
    cat("  Tree 3:", object$cvgamma3, "\n")
  }
  invisible(object)
}


#' @describeIn threetrees Print method (calls \code{summary}).
#' @param x A \code{threetrees} object.
#' @export
print.threetrees <- function(x, ...) {
  summary.threetrees(x, ...)
}


#' @describeIn threetrees Extract predicted values and terminal node
#'   assignments from a \code{threetrees} object.  Returns a data frame
#'   with columns \code{predict} (predicted values) and
#'   \code{terminalnode} (terminal node assignments).
#' @param object A \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree's predictions to
#'   return.
#' @export
predict.threetrees <- function(object, tree = 1, ...) {
  if (!(tree %in% 1:3)) stop("tree must be 1, 2, or 3")
  object[[paste0("rpredict", tree)]]
}


#' @describeIn threetrees Plot one of the three trees.
#'   A convenience wrapper around \code{\link{treeplot}}.
#' @param x A \code{threetrees} object.
#' @param tree Integer 1, 2, or 3 selecting which tree to plot.
#' @param ... Additional arguments passed to \code{\link{treeplot}}.
#' @export
plot.threetrees <- function(x, tree = 1, ...) {
  treeplot(x, tree = tree, ...)
}

Try the longitree package in your browser

Any scripts or data that you put into this service are public.

longitree documentation built on May 16, 2026, 5:06 p.m.