R/veb_boost.R

Defines functions veb_boost_stumps veb_boost

Documented in veb_boost veb_boost_stumps

#' Performs VEB-Boosting
#'
#' Solves the VEB-Boost regression problem using the supplied inputs
#'
#' @details
#'
#' Given a pre-specified arithmetic tree structure \deqn{T(\mu_1, \dots, \mu_L)}, \deqn{\mu_l := h_l(\beta_l)},
#' priors \deqn{\beta_l \sim g_l(\cdot)}, and inputs for the response, VEB-Boosting is performed.
#'
#' A cyclic CAVI scheme is used, where we cycle over the leaf nodes and update the approxiomation
#' to the posterior distribution at each node in turn.
#'
#' We start with the arithmetic tree structure \deqn{T(\mu_1, \dots, \mu_L) = \sum_{i=1}^k \prod_{j=1}^{d_k} \mu_{i, j}}
#'
#'
#' @param learners is either a single "learner" object, or a list of k "learner" objects
#' A learner object is comprised of:
#' 1. a fit function $fitFunction: (X, Y, sigma2, currentFit) -> newFit (where a fit is a list that must contain $mu1, $mu2, and $KL_div)
#' 2. a prediction function $predFunction: (X, fit, moment) -> posterior moment (1 or 2)
#' 3. a constant check function $constCheckFunction: (fit) -> (TRUE/FALSE) to check if a fit is essentially constant
#' 4. a current fit $currentFit: must contain $mu1 (first posterior moments), $mu2 (second posterior moments), and $KL_div (KL-divergence from q to prior) (can be NULL, at least to start)
#' 5. a predictor object $X (whatever the $fitFunction and $predFunction take in), used for training (can be NULL, e.g. if using constLearner)
#' 6. a predictor object $X_test (whatever the $fitFunction and $predFunction take in), used for testing (can be NULL)
#' 7. a string $growMode for if the learner should be grown (or not)
#' If \code{"+*"}, we grow mu_0 -> (mu_0 * mu_2) + mu_1
#' If \code{"+"}, we grow mu_0 -> (mu_0 + mu_1)
#' If \code{"*"}, we grow mu_0 -> (mu_0 * mu_1) (NOTE: Not recommended if we start with \code{k = 1})
#' If \code{"NA"}, we do not grow this learner
#' 8. a logical $changeToConstant for if the learner should be changed to constant if constCheckFunction evaluates to TRUE
#'
#' @param Y is a numeric response. For all but the 'aft.loglogistic' family, this should be an n-vector.
#' For the 'aft.loglogistic' family, this should be an n x 2 matrix, with the first column being the left end-point of survival time
#' and the second column being the right end-point of survival time (used for interval-censored observations)..
#' In the case of left-censored data, the left end-point should be 'NA', and in the case of right-censored data, the right end-point should be 'NA'.
#' If the observation is uncensored, both end-points should be equal to the observed survival time.
#'
#' @param k is an integer, or a vector of integers of length \code{length(learners)}, for how many terms are in the sum of nodes (for each learner)
#'
#' @param d is either an integer, or an integer vector of length \code{k}, or a list of integer vectors of length \code{length(learners)}
#' (each element either an integer, or a vector of length \code{k}) for the multiplicative depth of each of the k terms
#' NOTE: This can be dangerous. For example, if the fit starts out too large, then entire branhces will be fit to be exactly
#' zero. When this happens, we end up dividing by 0 in places, and this results in NAs, -Inf, etc. USE AT YOUR OWN RISK
#'
#' @param sigma2 is a scalar/n-vector specifying a fixed residual variance. If not NULL, then the residual variance will be
#' fixed to this value/vector. If NULL, then it will be initialized and updated automatically.
#' This should be left as NULL unless you really know what you're doing. For safety, this can only be not NULL if family is gaussian
#'
#' @param family is what family the response is
#'
#' @param tol is a positive scalar specifying the level of convergence to be used
#'
#' @param weights is a vector of the same length as Y weighting the observations in the log-likelihood
#'
#' @param scaleWeights is a logical for if the weights should be scaled to have mean 1 (recommended). If you choose to not scale the weights, then
#' the relative importance of the KL-divergence term will change (possibly desirable, i.e. increase or decrease shrinkage towards the prior)
#'
#' @param exposure is a scalar or a vector used for the Poisson regression, NB, and AFT cases. For Poisson, we assume that the response
#' \deqn{Y_i \sim Pois(c_i \lambda_i)} for given exposure \deqn{c_i}, and we model \deqn{\lambda_i}
#' For AFT, exposure is 1 for non-censored observations, and 0 for right-censored observations
#'
#' @param verbose is a logical flag specifying whether we should report convergence information as we go
#'
#' @param maxit is the maximum number of iterations for each version of the VEB-Boost tree
#'
#' @param backfit is a logical. If TRUE, then after the algorithm is done, it'll run through once more with the
#' current tree and fit to convergence. Useful when, e.g. maxit = 1.
#'
#'
#' @return A \code{VEB_Boost_Node} object with the fit
#'
#'
#' @examples
#'
#' set.seed(1)
#' n = 1000
#' p = 1000
#' X = matrix(runif(n * p), nrow = n, ncol = p)
#' Y = rnorm(n, 5*sin(3*X[, 1]) + 2*(X[, 2]^2) + 3*X[, 3]*X[, 4])
#'
#' For input X and list \code{fit} returned from fitFn (encoding the variational posterior for b), computes either the 1st or 2nd
#' predFn = function(X, fit, moment = c(1, 2)) {
#'   # posterior moments of the response (depending on if \code{moment} is 1 or 2)
#'   if (moment == 1) {
#'     res = E_fit[Xb]
#'   } else {
#'     res = E_fit[(Xb)^2]
#'   }
#'   return(res)
#' }
#'
#' For a given prior g(b), a function that approximates the posterior of b, q(b), using Variational Inference
#' fitFn = function(X, Y, sigma2, init) {
#'   fit = list(whatever is needed to encode the variational posterior)
#'   KL_div = D_KL(q || g) # KL divergence from variational posterior to prior
#'   mu1 = predFn(X, fit, 1)
#'   mu2 = predFn(X, fit, 2)
#'   # add mu1, mu2, and KL_div to the fit, MUST BE CALLED $mu1, $mu1, and $KL_div
#'   fit$mu1 = mu1
#'   fit$mu2 = mu2
#'   fit$KL_div = KL_div
#'   return(fit)
#' }
#'
#' For a given \code{fit}, returns TRUE if the variational posterior is close enough to a constant, else FALSE
#' constCheckFn = function(fit) {
#'   if (fit is close to constant) {
#'     return(TRUE)
#'   } else {
#'     return(FALSE)
#'   }
#' }
#'
#' learner = list(fitFunction = fitFn, predFunction = predFn, constCheckFunction = constCheckFn, currentFit = NULL, X = X, X_test = NULL, growMode = "+*)
#' veb.fit = veb_boost(learner, family = "gaussian")
#'
#' @export
#'

veb_boost = function(learners, Y, k = 1, d = 1, sigma2 = NULL,
                     family = c("gaussian", "binomial", "multinomial.bouchard", "multinomial.titsias", "negative.binomial", "poisson.log1pexp", "aft.loglogistic", "ordinal.logistic"),
                     weights = 1, scaleWeights = TRUE, exposure = NULL, tol = NROW(Y) / 10000, verbose = TRUE, maxit = Inf, backfit = FALSE) {

  ### Check Inputs ###
  # Family
  family = match.arg(family)
  if (grepl("multinomial", family)) {
    Y = as.character(Y)
  }
  if (!(is.vector(Y) || ((family == 'aft.loglogistic') && (is_valid_matrix(Y)) && (ncol(Y) == 2)))) {
    stop("'Y' must be a vector (or a matrix with 2 columns for the 'aft.loglogistic' case)")
  }
  if (!is.null(sigma2)) {
    if (family != "gaussian") {
      stop("'sigma2' can only be not NULL if 'family' is set to 'gaussian'")
    }
    if (!(length(sigma2) %in% c(1, nrow(as.matrix(Y))))) {
      stop("'sigma2' must have length 1 or length of Y")
    }
    if (any(sigma2 <= 0)) {
      stop("'sigma2' must contain only positive values")
    }
  }
  # Scalars
  if (any(k < 1) || any(k %% 1 != 0)) {
    stop("'k' must be a positive whole number")
  }
  if (!(length(k) %in% c(1, length(learners)))) {
    stop("'k' must be of length 1 or 'length(learners)'")
  }
  if (any(unlist(d) < 1) || any(unlist(d) %% 1 != 0)) {
    stop("'d' must be a positive whole number")
  }
  # Logical Flags
  if (!(verbose %in% c(TRUE, FALSE))) {
    stop("'verbose' must be either TRUE or FALSE")
  }
  if (!(backfit %in% c(TRUE, FALSE))) {
    stop("'backfit' must be either TRUE or FALSE")
  }
  if (!(scaleWeights %in% c(TRUE, FALSE))) {
    stop("'scaleWeights' must be either TRUE or FALSE")
  }
  # learner object
  if (!inherits(learners, "list")) {
    stop("'learners' must be a list of learners")
  }
  for (learner in learners) {
    if (!inherits(learner$fitFunction, "function")) {
      stop("'$fitFunction' must be a function")
    }
    if (!identical(tolower(names(formals(learner$fitFunction))), c("x", "y", "sigma2", "currentfit"))) {
      stop("'$fitFunction' must take in arguments 'X', 'Y', 'sigma2', and 'currentfit' (in that order, case insensitive)")
    }
    if (!inherits(learner$predFunction, "function")) {
      stop("'$predFunction' must be a function")
    }
    if (!identical(tolower(names(formals(learner$predFunction))), c("x_test", "currentfit", "moment"))) {
      stop("'$predFunction' must take in arguments 'X_test', 'currentFit', and 'moment' (in that order, case insensitive)")
    }
    if (!inherits(learner$constCheckFunction, "function")) {
      stop("'$constCheckFunction' must be a function")
    }
    if (tolower(names(formals(learner$constCheckFunction))) != "currentfit") {
      stop("'constCheckFunction' must take in only the argument currentFit (case insensitive)")
    }
    if (!is.null(learner$growMode) && !(learner$growMode %in% c("+*", "+", "*", "NA"))) {
      stop("'$growMode' must be either NULL, or in c('+*', '+', '*', 'NA')")
    }
  }
  # weights
  if (any(weights <= 0) || any(is.na(weights))) {
    stop("All weights must be positive and not missing")
  }
  if (!(length(weights) %in% c(1, nrow(as.matrix(Y))))) {
    stop("Weights must be of the same length as Y")
  }
  if (length(weights) == 1) {
    weights = rep(weights, nrow(as.matrix(Y)))
  }
  if (scaleWeights) { # scale so the mean is 1
    weights = weights / mean(weights)
  }
  # Exposure
  if (!(grepl("poisson", family, ignore.case = TRUE) | family == "negative.binomial")) {
    if (!is.null(exposure)) {
      warning("Argument 'exposure' is only used in the Poisson, or negative binomial cases, ignoring supplied input")
    }
  } else {
    if (is.null(exposure)) {
      warning("Argument 'exposure' must be supplied in the Poisson, or negative binomial cases, setting to 1")
      exposure = rep(1, length(Y))
    }
  }
  if (!is.null(exposure)) {
    if (any(exposure < 0)) {
      stop("When using the poisson on NB families, 'exposure' must be positive")
    }
  }
  # Tolerance
  if (!is.numeric(tol) || (length(tol) > 1) || (tol <= 0)) {
    stop("'tol' must be a positive number")
  }
  # Response
  if ((family == "gaussian") && !is.numeric(Y)) {
    stop("'Y' must be numeric when using gaussian response")
  }
  if ((family == "binomial") && any(!(Y %in% c(0, 1)))) {
    stop("'Y' must be 0/1 when using binomial response")
  }
  if ((grepl("poisson", family, ignore.case = TRUE) | family == "negative.binomial") && (any(Y < 0) || any(Y %% 1 != 0))) {
    stop("'Y' must be positive integers when using Poisson or negative binomial response")
  }
  if ((family == "aft.loglogistic")) {
    if (any(Y <= 0, na.rm = TRUE)) {
      stop("Non-NA 'Y' must be positive when using the AFT model")
    }
    if (any(Y[, 2] - Y[, 1] < 0, na.rm = TRUE)) {
      stop("The first column of 'Y' must be <= the second column of 'Y'")
    }
    if (any(rowSums(is.na(Y)) == 2)) {
      stop("'Y' cannot have rows that are all NA")
    }
  }
  if (family == "ordinal.logistic") {
    if (any(!(Y %in% (1:max(Y))))) {
      stop("'Y' must contain only integer values from 1 to max(Y)")
    }
  }
  # maxit
  if (maxit < 1) {
    stop("'maxit' must be a positive integer")
  }
  maxit = ceiling(maxit)


  # get censoring types for AFT model
  if (family == "aft.loglogistic") {
    attr(Y, 'left_cens') = (is.na(Y[, 1]) & !is.na(Y[, 2]))
    attr(Y, 'right_cens') = (!is.na(Y[, 1]) & is.na(Y[, 2]))
    attr(Y, 'int_cens') = (!is.na(rowSums(Y)) & (Y[, 1] < Y[, 2]))
    attr(Y, 'not_cens') = (!is.na(rowSums(Y)) & (Y[, 1] == Y[, 2]))
    attr(Y, 'log_Y') = log(Y)
  }
  if (family == "ordinal.logistic") {
    attr(Y, 'left_cens') = (Y == 1)
    attr(Y, 'right_cens') = (Y == max(Y))
    attr(Y, 'int_cens') = !(attr(Y, 'left_cens') | attr(Y, 'right_cens'))
  }
  if (grepl("multinomial", family)) {
    classes = sort(unique(Y))
    attr(Y, 'which') = sapply(Y, function(j) which(j == classes))
  }

  ### Run Non-multinomial Cases ###
  # if (family != "multinomial") {
  if (!grepl("multinomial", family)) {
    # initialize tree
    veb_boost_tree = initialize_veb_boost_tree(learners = learners, Y = Y, k = k, d = d, weights = weights, family = family, exposure = exposure)
    if (family == "gaussian") {
      veb_boost_tree$sigma2 = if (is.null(sigma2)) var(Y) else sigma2
    } else if (family == "aft.loglogistic") {
      veb_boost_tree$sigma2 = var(rowMeans(attr(Y, 'log_Y'), na.rm = TRUE))
    } else if (family == "ordinal.logistic") {
      veb_boost_tree$cutpoints = seq(from = -max(Y), to = max(Y), length.out = max(Y) - 1)
      attr(veb_boost_tree$cutpoints, "log_Y") = matrix(NA, nrow = length(Y), ncol = 2)
      attr(veb_boost_tree$cutpoints, "log_Y")[attr(Y, 'left_cens'), 2] = veb_boost_tree$cutpoints[1]
      for (k in 2:(length(veb_boost_tree$cutpoints))) {
        attr(veb_boost_tree$cutpoints, "log_Y")[which(Y == k), 1] = veb_boost_tree$cutpoints[k-1]
        attr(veb_boost_tree$cutpoints, "log_Y")[which(Y == k), 2] = veb_boost_tree$cutpoints[k]
      }
      attr(veb_boost_tree$cutpoints, "log_Y")[attr(Y, 'right_cens'), 1] = tail(veb_boost_tree$cutpoints, 1)
    }
    update_sigma2 = ifelse(is.null(sigma2), family %in% c("gaussian", "aft.loglogistic", "ordinal.logistic"), FALSE) # only update variance if Gaussian , or AFT model

    # converge fit once
    veb_boost_tree$convergeFit(tol, update_sigma2 = update_sigma2, verbose = FALSE, maxit = maxit)

    # if growing tree, continue w/ growing tree & fitting to convergence
    if (verbose) {
      cat(paste("ELBO: ", round(veb_boost_tree$ELBO, 3), sep = ""))
      cat("\n")
    }

    veb_boost_tree$convergeFitAll(tol = tol, update_sigma2 = update_sigma2, verbose = FALSE, maxit = maxit)

    # while ((tail(tail(veb_boost_tree$ELBO_progress, 1)[[1]], 1) - tail(tail(veb_boost_tree$ELBO_progress, 2)[[1]], 1) > tol) &&
    #        (length(Traverse(veb_boost_tree, filterFun = function(node) node$isLeaf & !node$isLocked)) > 0) && (tail(tail(veb_boost_tree$ELBO_progress, 1)[[1]], 1) != Inf)) {
    while ((mean(diff(sapply(tail(veb_boost_tree$ELBO_progress, 3), tail, 1))) > 10*tol) &&
           (length(Traverse(veb_boost_tree, filterFun = function(node) node$isLeaf & !node$isLocked)) > 0) && (tail(tail(veb_boost_tree$ELBO_progress, 1)[[1]], 1) != Inf)) {
      if (verbose) {
        cat(paste("ELBO: ", round(veb_boost_tree$ELBO, 3), sep = ""))
        cat("\n")
      }
      veb_boost_tree$convergeFitAll(tol = tol, update_sigma2 = update_sigma2, verbose = FALSE, maxit = maxit)
    }

    veb_boost_tree$lockLearners() # change to constant at the end if needed
    if (backfit) { # converge again, if backfitting
      veb_boost_tree$convergeFit(tol, update_sigma2 = update_sigma2, verbose = FALSE, maxit = Inf)
    }

    veb_boost_tree$predict(1)

    return(veb_boost_tree)

  } else { # else, multinomial case
    classes = sort(unique(Y))
    learner_multiclass = VEBBoostMultiClassLearner$new()
    learner_multiclass$family = family
    learner_multiclass$Y = Y
    learner_multiclass$classes = classes
    fam = ifelse(family == "multinomial.bouchard", "binomial", "multinomial.titsias")

    learnerList = list()
    for (j in 1:length(classes)) { # add learners to learnerList
      classLearner = initialize_veb_boost_tree(learners = learners, Y = 1 * (Y == classes[j]), k = k, d = d, weights = weights, family = fam, my_class_index = j)
      learnerList = c(learnerList, classLearner)
    }
    names(learnerList) = paste0("classLearner_", classes, sep = "")
    learner_multiclass$addclassLearners(learnerList) # add learner list to multi-class clearner

    # converge fit once
    learner_multiclass$convergeFit(tol = tol)

    # if growing tree, continue w/ growing tree & fitting to convergence
    if (verbose) {
      cat(paste("ELBO: ", round(learner_multiclass$ELBO, 3), sep = ""))
      cat("\n")
    }

    learner_multiclass = learner_multiclass$convergeFitAll(tol = tol)

    while ((abs(tail(tail(learner_multiclass$ELBO_progress, 1)[[1]], 1) - tail(tail(learner_multiclass$ELBO_progress, 2)[[1]], 1)) > tol) &&
           (sum(sapply(learner_multiclass$classLearners, function(x) length(Traverse(x, filterFun = function(node) node$isLeaf & !node$isLocked)))) > 0)) {
      if (verbose) {
        cat(paste("ELBO: ", round(learner_multiclass$ELBO, 3), sep = ""))
        cat("\n")
      }
      learner_multiclass$convergeFitAll(tol = tol)
    }

    learner_multiclass$predict(1)

    return(learner_multiclass)
  }

}


#' Wrapper for using VEB-Boost with the SER prior w/ stumps and/or linear terms
#'
#' @details
#'
#' This function performs VEB-Boosting, where the prior to be used is the SER prior, and our predictors are either
#' i) the linear terms of X;
#' ii) the stumps made from the columns of X; or
#' iii) Both (i) and (ii)
#'
#' @param X An (n x p) numeric matrix to be used as the predictors (currently, this wrapper forces all nodes to use the same X)
#'
#' @param Y is a numeric vector response
#'
#' @param X_test is an optional (m X p) matrix to be used as the testing data. Posterior mean response is saved in the output's field \code{$pred_mu1}
#'
#' @param learners is a list of other learners to be used in \code{\link{veb_boost}}
#'
#' @param include_linear is a logical of length 1 or p specifying which columns of X we should include as linear terms.
#' If the length is 1, this value gets recycled for all columns of X.
#' If NULL is supplied, then all valid linear terms are used.
#'
#' @param include_stumps is a logical of length 1 or p specifying which columns of X we should include as stump terms
#' If the length is 1, this value gets recycled for all columns of X.
#' If NULL is supplied, then all valid stumps terms are used.
#'
#' @param num_cuts is a whole number of length 1 or p specifying how many cuts to make when making the stumps terms.
#' If the length is 1, this value gets recycled for all columns of X.
#' For entries corresponding to the indices where \code{include_stumps} is FALSE, these values are ignored.
#' We use the quantiles from each predictor when making the stumps splits, using \code{num_cuts} of them.
#' If \code{num_cuts = Inf}, then all values of the variables are used as split points.
#'
#' @param k is an integer, or a vector of integers of length \code{length(learners)}, for how many terms are in the sum of nodes (for each learner)
#'
#' @param d is either an integer, or an integer vector of length \code{k}, or a list of integer vectors of length \code{length(learners)}
#' (each element either an integer, or a vector of length \code{k}) for the multiplicative depth of each of the k terms
#' NOTE: This can be dangerous. For example, if the fit starts out too large, then entire branhces will be fit to be exactly
#' zero. When this happens, we end up dividing by 0 in places, and this results in NAs, -Inf, etc. USE AT YOUR OWN RISK
#'
#' @param use_quants is a logical for if the cut-points should be based off of the quantiles (`use_quants = TRUE`), or if the
#' cut points should be evenly spaced in the range of the variable (`use_quants = FALSE`).
#'
#' @param scale_X is a string for if/how the columns of X should be scaled.
#' 'sd' scales by the standard deviations of the variables.
#' 'max' scales by the maximum absolute value (so variables are on the [-1, +1] scale).
#' 'NA' performs no scaling.
#'
#' @param growMode is a string for if the learner should be grown (or not)
#' If \code{"+*"}, we grow mu_0 -> (mu_0 * mu_2) + mu_1
#' If \code{"+"}, we grow mu_0 -> (mu_0 + mu_1)
#' If \code{"*"}, we grow mu_0 -> (mu_0 * mu_1) (NOTE: Not recommended if we start with \code{k = 1})
#' If \code{"NA"}, we do not grow this learner
#'
#' @param changeToConstant is a logical for if constant fits should be changed to be constant
#'
#' @param max_log_prior_var is a scalar for the maximum that the estimated log-prior variance for each weak learner can be.
#' The idea is that setting this to be small limits the "size" of each weak learner, similar-ish to the learning rate in boosting.
#' The maximum allowed value is 35, which essentially allows the unrestricted MLE to be estimated.
#' The minimum allowed value is -5.
#'
#' @param use_optim is a logical. If TRUE, then the prior variance is optimized using the Brent method.
#' If FALSE, then a single EM step is taken to optimize over V
#'
#' @param lin_prior_prob is a number between 0 and 1 that gives the prior probability that the effect variable is a linear term.
#' This means that (1 - lin_prior_prob) is the probability that the effect variable is a stump term.
#' Within linear terms and stump terms, all variables have the same prior variance.
#'
#' @param reverse_learners is a logical for if the order of learners should be reversed. If FALSE, all additional learners in `learners` will
#' will come first. If TRUE, the stumps learners will be first.
#'
#' @param ... Other arguments to be passed to \code{\link{veb_boost}}
#'
#' @return A \code{VEB_Boost_Node} object with the fit
#'
#' @examples
#' set.seed(1)
#' n = 1000
#' p = 1000
#' X = matrix(runif(n * p), nrow = n, ncol = p)
#' Y = rnorm(n, 5*sin(3*X[, 1]) + 2*(X[, 2]^2) + 3*X[, 3]*X[, 4])
#' veb.stumps.fit = veb_boost_stumps(X, Y, include_linear = TRUE, family = "gaussian")
#'
#'
#' @useDynLib VEB.Boost, .registration=TRUE
#'
#' @export
#'

veb_boost_stumps = function(X, Y, X_test = NULL, learners = NULL, include_linear = NULL, include_stumps = NULL,
                            num_cuts = ceiling(min(NROW(Y) / 5, max(100, sqrt(NROW(Y))))), k = 1, d = 1,
                            use_quants = TRUE, scale_X = c("sd", "max", "NA"), growMode = c("+*", "+", "*", "NA"), changeToConstant = FALSE,
                            max_log_prior_var = 0, use_optim = TRUE, lin_prior_prob = 0.5, reverse_learners = FALSE, nthreads = ceiling(parallel::detectCores(logical = TRUE) / 2), ...) {
  # # Set higher priority for this process
  # psnice_init = tools::psnice()
  # on.exit({tools::psnice(tools::psnice(value = psnice_init))})
  # tools::psnice(value = 0 - 10*(Sys.info()["sysname"] == "Windows"))

  ### Check Inputs ###
  # Check X
  if (!is_valid_matrix(X)) {
    stop("'X' must be a numeric matrix")
  }
  if (!is.null(X_test) && !is_valid_matrix(X_test)) {
    stop("'X_test' must be a numeric matrix")
  }
  if (any(is.na(X))) {
    stop("'X' cannot have any NAs. Please handle them accordingly first")
  }
  if (!is.null(X_test) && any(is.na(X_test))) {
    stop("'X_test' cannot have any NAs. Please handle them accordingly first")
  }
  if (NROW(X) != NROW(Y)) {
    stop("'X' and 'Y' must have the same number of observations/rows")
  }
  p = ncol(X)
  # check other learners
  if (!is.null(learners) && !is.list(learners)) {
    stop("'learners' must be either NULL or a list of other learners to use")
  }
  # Check logicals
  if (!(is.null(include_linear) || all(include_linear %in% c(TRUE, FALSE)))) {
    stop("'include_linear' must be either TRUE or FALSE, or be NULL")
  }
  if (!(is.null(include_linear) || (length(include_linear) %in% c(1, p)))) {
    stop("'include_linear' must have length 1 or ncol(X), or be NULL")
  }
  if (!(is.null(include_stumps) || all(include_stumps %in% c(TRUE, FALSE)))) {
    stop("'include_stumps' must be either TRUE or FALSE, or be NULL")
  }
  if (!(is.null(include_stumps) || (length(include_stumps) %in% c(1, p)))) {
    stop("'include_stumps' must have length 1 or ncol(X), or be NULL")
  }
  # Make sure at least one include is true
  if (!is.null(include_linear) && !is.null(include_stumps) && all(include_linear == FALSE) && all(include_stumps == FALSE)) {
    stop("At least one of 'include_linear' or 'include_stumps' must be TRUE or NULL")
  }
  # Make sure num_cuts is a positive whole number
  if (any(num_cuts < 1) || any(num_cuts[is.finite(num_cuts)] %% 1 != 0)) {
    stop("'num_cuts' must be a positive whole number or Inf")
  }
  # if (!(length(num_cuts) %in% c(1, p))) {
  #   stop("'num_cuts' must have length 1 or ncol(X)")
  # }
  if (length(num_cuts) != 1) {
    stop("'num_cuts' must have length 1")
  }
  if (!(use_quants %in% c(TRUE, FALSE))) {
    stop("'use_quants' must be either TRUE or FALSE")
  }
  scale_X = match.arg(scale_X)
  growMode = match.arg(growMode)
  if (!(changeToConstant %in% c(TRUE, FALSE))) {
    stop("'changeToConstant' must be either TRUE or FALSE")
  }
  # Check max_log_prior_var
  if (!is.numeric(max_log_prior_var) || (length(max_log_prior_var) != 1)) {
    stop("'max_log_prior_var' must be a single number")
  }
  if ((max_log_prior_var < -5) | (max_log_prior_var > 35)) {
    warning("Restricting 'max_log_prior_var' to be in the range [-5, 35]")
    max_log_prior_var = min(35, max(-5, max_log_prior_var))
  }
  if (!(use_optim %in% c(TRUE, FALSE))) {
    stop("'use_optim' must be either TRUE or FALSE")
  }
  # Check lin_prior_prob
  if (!is.numeric(lin_prior_prob) || (length(lin_prior_prob) != 1)) {
    stop("'lin_prior_prob' must be a single number")
  }
  if ((lin_prior_prob < 0) | (lin_prior_prob > 1)) {
    warning("Restricting 'lin_prior_prob' to be in the range [0, 1]")
    lin_prior_prob = min(1, max(0, lin_prior_prob))
  }
  if (!(reverse_learners %in% c(TRUE, FALSE))) {
    stop("'reverse_learners' must be either TRUE or FALSE")
  }

  scale_X = ifelse(scale_X == "sd", 1, ifelse(scale_X == "max", 2, 0))

  if (!is.null(include_linear) && length(include_linear) == 1) {
    include_linear = rep(include_linear, p)
  }
  if (!is.null(include_stumps) && length(include_stumps) == 1) {
    include_stumps = rep(include_stumps, p)
  }

  if (is.infinite(num_cuts) | (!is.null(include_stumps) && all(!include_stumps))) {
    num_cuts = 0
  }
  if (inherits(X, "dgCMatrix")) {
    X_stumps = make_stumps_matrix_sp_cpp(X, include_linear, include_stumps, num_cuts, use_quants, scale_X, nthreads)
  } else {
    X_stumps = make_stumps_matrix_cpp(X, include_linear, include_stumps, num_cuts, use_quants, scale_X, nthreads)
  }

  # set up testing data, if any
  X_test_stumps = NULL
  if (!is.null(X_test)) {
    # if (inherits(X_test, "dgCMatrix") != inherits(X, "dgCMatrix")) {
    #   stop("`X` and `X_test` must both either be sparse (i.e. `dgCMatrix`) or dense. Please convert types to match")
    # }
    if (inherits(X_test, "dgCMatrix")) {
      X_test_stumps = make_stumps_test_matrix_sp_cpp(X_test, X_stumps)
    } else {
      X_test_stumps = make_stumps_test_matrix_cpp(X_test, X_stumps)
    }
  }

  fitFnSusieStumps_maxlV = function(X, Y, sigma2, currentFit) {
    return(weighted_SER_cpp(X, Y, sigma2, currentFit, max_log_prior_var, lin_prior_prob, use_optim))
  }

  stumps_learner = list(fitFunction = fitFnSusieStumps_maxlV, predFunction = predFnSusieStumps_cpp, constCheckFunction = constCheckFnSusieStumps,
                        currentFit = NULL, X = X_stumps, X_test = X_test_stumps, growMode = growMode, changeToConstant = changeToConstant)

  learners = c(learners, list(stumps_learner))
  if (reverse_learners) { # have to explicitly use k and d here, otherwise it doesn't get changed in the `...`
    learners = rev(learners)
    k = rev(k)
    d = rev(d)
  }
  # Run
  veb.fit = veb_boost(learners = learners, Y = Y, k = k, d = d, ...)

  return(veb.fit)
}
stephenslab/VEB.Boost documentation built on July 2, 2023, 1 p.m.