R/data.R

Defines functions createRandomEffectsDataset createOutcome createForestDataset

Documented in createForestDataset createOutcome createRandomEffectsDataset

#' @title Forest Dataset C++ Wrapper
#'
#' @description
#' Wrapper around a C++ dataset class used to sample a forest.
#' A dataset consists of three matrices / vectors: covariates,
#' bases, and variance weights. Both the basis vector and variance
#' weights are optional.
#'
#' This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>

ForestDataset <- R6::R6Class(
  classname = "ForestDataset",
  cloneable = FALSE,
  public = list(
    #' @field data_ptr External pointer to a C++ ForestDataset class
    data_ptr = NULL,

    #' @description
    #' Create a new ForestDataset object.
    #' @param covariates Matrix of covariates
    #' @param basis (Optional) Matrix of bases used to define a leaf regression
    #' @param variance_weights (Optional) Vector of observation-specific variance weights
    #' @return A new `ForestDataset` object.
    initialize = function(
      covariates,
      basis = NULL,
      variance_weights = NULL
    ) {
      self$data_ptr <- create_forest_dataset_cpp()
      forest_dataset_add_covariates_cpp(self$data_ptr, covariates)
      if (!is.null(basis)) {
        forest_dataset_add_basis_cpp(self$data_ptr, basis)
      }
      if (!is.null(variance_weights)) {
        forest_dataset_add_weights_cpp(self$data_ptr, variance_weights)
      }
    },

    #' @description
    #' Update basis matrix in a dataset
    #' @param basis Updated matrix of bases used to define a leaf regression
    update_basis = function(basis) {
      stopifnot(self$has_basis())
      forest_dataset_update_basis_cpp(self$data_ptr, basis)
    },

    #' @description
    #' Update variance_weights in a dataset
    #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
    #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.
    update_variance_weights = function(variance_weights, exponentiate = F) {
      stopifnot(self$has_variance_weights())
      forest_dataset_update_var_weights_cpp(
        self$data_ptr,
        variance_weights,
        exponentiate
      )
    },

    #' @description
    #' Return number of observations in a `ForestDataset` object
    #' @return Observation count
    num_observations = function() {
      return(dataset_num_rows_cpp(self$data_ptr))
    },

    #' @description
    #' Return number of covariates in a `ForestDataset` object
    #' @return Covariate count
    num_covariates = function() {
      return(dataset_num_covariates_cpp(self$data_ptr))
    },

    #' @description
    #' Return number of bases in a `ForestDataset` object
    #' @return Basis count
    num_basis = function() {
      return(dataset_num_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Return covariates as an R matrix
    #' @return Covariate data
    get_covariates = function() {
      return(forest_dataset_get_covariates_cpp(self$data_ptr))
    },

    #' @description
    #' Return bases as an R matrix
    #' @return Basis data
    get_basis = function() {
      return(forest_dataset_get_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Return variance weights as an R vector
    #' @return Variance weight data
    get_variance_weights = function() {
      return(forest_dataset_get_variance_weights_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has a basis matrix
    #' @return True if basis matrix is loaded, false otherwise
    has_basis = function() {
      return(dataset_has_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has variance weights
    #' @return True if variance weights are loaded, false otherwise
    has_variance_weights = function() {
      return(dataset_has_variance_weights_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has auxiliary data stored at the dimension indicated
    #' @param dim_idx Dimension of auxiliary data
    #' @return True if auxiliary data has been allocated for `dim_idx` False otherwise
    has_auxiliary_dimension = function(dim_idx) {
      return(forest_dataset_has_auxiliary_dimension_cpp(self$data_ptr, dim_idx))
    },

    #' @description
    #' Initialize a new dimension / lane of auxiliary data and allocate data in its place
    #' @param dim_size Size of the new vector of data to allocate
    #' @return None
    add_auxiliary_dimension = function(dim_size) {
      return(forest_dataset_add_auxiliary_dimension_cpp(
        self$data_ptr,
        dim_size
      ))
    },

    #' @description
    #' Retrieve auxiliary data value
    #' @param dim_idx Dimension from which data value to be retrieved
    #' @param element_idx Element to retrieve from dimension `dim_idx`
    #' @return Floating point value stored in the requested auxiliary data space
    get_auxiliary_data_value = function(dim_idx, element_idx) {
      return(forest_dataset_get_auxiliary_data_value_cpp(
        self$data_ptr,
        dim_idx,
        element_idx
      ))
    },

    #' @description
    #' Set auxiliary data value
    #' @param dim_idx Dimension in which data value to be set
    #' @param element_idx Element to set within dimension `dim_idx`
    #' @param value Data value to set at auxiliary data dimension `dim_idx` and element `element_idx`
    #' @return None
    set_auxiliary_data_value = function(dim_idx, element_idx, value) {
      return(forest_dataset_set_auxiliary_data_value_cpp(
        self$data_ptr,
        dim_idx,
        element_idx,
        value
      ))
    },

    #' @description
    #' Retrieve entire auxiliary data vector
    #' @param dim_idx Dimension to retrieve
    #' @return Vector of all of the auxiliary data stored at dimension `dim_idx`
    get_auxiliary_data_vector = function(dim_idx) {
      return(forest_dataset_get_auxiliary_data_vector_cpp(
        self$data_ptr,
        dim_idx
      ))
    }
  )
)

#' @title Outcome Data C++ Wrapper
#'
#' @description
#' Outcome / partial residual used to sample an additive model.
#' The outcome class is a wrapper around a vector of (mutable)
#' outcomes for ML tasks (supervised learning, causal inference).
#' When an additive tree ensemble is sampled, the outcome used to
#' sample a specific model term is the "partial residual" consisting
#' of the outcome minus the predictions of every other model term
#' (trees, group random effects, etc...).
#'
#' This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>

Outcome <- R6::R6Class(
  classname = "Outcome",
  cloneable = FALSE,
  public = list(
    #' @field data_ptr External pointer to a C++ `Outcome` class
    data_ptr = NULL,

    #' @description
    #' Create a new Outcome object.
    #' @param outcome Vector of outcome values
    #' @return A new `Outcome` object.
    initialize = function(outcome) {
      self$data_ptr <- create_column_vector_cpp(outcome)
    },

    #' @description
    #' Extract raw data in R from the underlying C++ object
    #' @return R vector containing (copy of) the values in `Outcome` object
    get_data = function() {
      return(get_residual_cpp(self$data_ptr))
    },

    #' @description
    #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector`
    #' @param update_vector Vector to be added to outcome
    #' @return None
    add_vector = function(update_vector) {
      if (!is.numeric(update_vector)) {
        stop("update_vector must be a numeric vector or 2d matrix")
      } else {
        dim_vec <- dim(update_vector)
        if (!is.null(dim_vec)) {
          if (length(dim_vec) > 2) {
            stop(
              "if update_vector is provided as a matrix, it must be 2d"
            )
          }
          update_vector <- as.numeric(update_vector)
        }
      }
      add_to_column_vector_cpp(self$data_ptr, update_vector)
    },

    #' @description
    #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector`
    #' @param update_vector Vector to be subtracted from outcome
    #' @return None
    subtract_vector = function(update_vector) {
      if (!is.numeric(update_vector)) {
        stop("update_vector must be a numeric vector or 2d matrix")
      } else {
        dim_vec <- dim(update_vector)
        if (!is.null(dim_vec)) {
          if (length(dim_vec) > 2) {
            stop(
              "if update_vector is provided as a matrix, it must be 2d"
            )
          }
          update_vector <- as.numeric(update_vector)
        }
      }
      subtract_from_column_vector_cpp(self$data_ptr, update_vector)
    },

    #' @description
    #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector`
    #' @param new_vector Vector from which to overwrite the current data
    #' @return None
    update_data = function(new_vector) {
      if (!is.numeric(new_vector)) {
        stop("update_vector must be a numeric vector or 2d matrix")
      } else {
        dim_vec <- dim(new_vector)
        if (!is.null(dim_vec)) {
          if (length(dim_vec) > 2) {
            stop(
              "if update_vector is provided as a matrix, it must be 2d"
            )
          }
          new_vector <- as.numeric(new_vector)
        }
      }
      overwrite_column_vector_cpp(self$data_ptr, new_vector)
    }
  )
)

#' @title Random Effects Dataset C++ Wrapper
#'
#' @description
#' Dataset used to sample a random effects model.
#' A random effects dataset consists of three matrices / vectors: group labels,
#' bases, and variance weights. Variance weights are optional.
#'
#' This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>

RandomEffectsDataset <- R6::R6Class(
  classname = "RandomEffectsDataset",
  cloneable = FALSE,
  public = list(
    #' @field data_ptr External pointer to a C++ RandomEffectsDataset class
    data_ptr = NULL,

    #' @description
    #' Create a new RandomEffectsDataset object.
    #' @param group_labels Vector of group labels
    #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones)
    #' @param variance_weights (Optional) Vector of observation-specific variance weights
    #' @return A new `RandomEffectsDataset` object.
    initialize = function(group_labels, basis, variance_weights = NULL) {
      self$data_ptr <- create_rfx_dataset_cpp()
      rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels)
      rfx_dataset_add_basis_cpp(self$data_ptr, basis)
      if (!is.null(variance_weights)) {
        rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights)
      }
    },

    #' @description
    #' Update basis matrix in a dataset
    #' @param basis Updated matrix of bases used to define random slopes / intercepts
    update_basis = function(basis) {
      stopifnot(self$has_basis())
      rfx_dataset_update_basis_cpp(self$data_ptr, basis)
    },

    #' @description
    #' Update variance_weights in a dataset
    #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights
    #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.
    update_variance_weights = function(variance_weights, exponentiate = F) {
      stopifnot(self$has_variance_weights())
      rfx_dataset_update_var_weights_cpp(
        self$data_ptr,
        variance_weights,
        exponentiate
      )
    },

    #' @description
    #' Return number of observations in a `RandomEffectsDataset` object
    #' @return Observation count
    num_observations = function() {
      return(rfx_dataset_num_rows_cpp(self$data_ptr))
    },

    #' @description
    #' Return dimension of the basis matrix in a `RandomEffectsDataset` object
    #' @return Basis vector count
    num_basis = function() {
      return(rfx_dataset_num_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Return group labels as an R vector
    #' @return Group label data
    get_group_labels = function() {
      return(rfx_dataset_get_group_labels_cpp(self$data_ptr))
    },

    #' @description
    #' Return bases as an R matrix
    #' @return Basis data
    get_basis = function() {
      return(rfx_dataset_get_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Return variance weights as an R vector
    #' @return Variance weight data
    get_variance_weights = function() {
      return(rfx_dataset_get_variance_weights_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has group label indices
    #' @return True if group label vector is loaded, false otherwise
    has_group_labels = function() {
      return(rfx_dataset_has_group_labels_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has a basis matrix
    #' @return True if basis matrix is loaded, false otherwise
    has_basis = function() {
      return(rfx_dataset_has_basis_cpp(self$data_ptr))
    },

    #' @description
    #' Whether or not a dataset has variance weights
    #' @return True if variance weights are loaded, false otherwise
    has_variance_weights = function() {
      return(rfx_dataset_has_variance_weights_cpp(self$data_ptr))
    }
  )
)

#' @title Create ForestDataset Object
#' @description
#' Create a forest dataset object
#'
#' This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>
#'
#' @param covariates Matrix of covariates
#' @param basis (Optional) Matrix of bases used to define a leaf regression
#' @param variance_weights (Optional) Vector of observation-specific variance weights
#'
#' @return `ForestDataset` object
#' @export
#'
#' @examples
#' covariate_matrix <- matrix(runif(10*100), ncol = 10)
#' basis_matrix <- matrix(rnorm(3*100), ncol = 3)
#' weight_vector <- rnorm(100)
#' forest_dataset <- createForestDataset(covariate_matrix)
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix)
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector)
createForestDataset <- function(
  covariates,
  basis = NULL,
  variance_weights = NULL
) {
  return(invisible((ForestDataset$new(covariates, basis, variance_weights))))
}

#' @title Create Outcome Object
#' @description
#' Create an outcome object
#'
#' This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>
#'
#' @param outcome Vector of outcome values
#'
#' @return `Outcome` object
#' @export
#'
#' @examples
#' X <- matrix(runif(10*100), ncol = 10)
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
#' outcome <- createOutcome(y)
createOutcome <- function(outcome) {
  return(invisible((Outcome$new(outcome))))
}

#' @title Create RandomEffectsDataset Object
#' @description
#' Create a random effects dataset object
#'
#' This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures.
#' Minimal input validation and error checks are performed -- users are responsible for providing the correct inputs.
#' For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at <https://stochtree.ai/>
#'
#' @param group_labels Vector of group labels
#' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones)
#' @param variance_weights (Optional) Vector of observation-specific variance weights
#'
#' @return `RandomEffectsDataset` object
#' @export
#'
#' @examples
#' rfx_group_ids <- sample(1:2, size = 100, replace = TRUE)
#' rfx_basis <- matrix(rnorm(3*100), ncol = 3)
#' weight_vector <- rnorm(100)
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis)
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector)
createRandomEffectsDataset <- function(
  group_labels,
  basis,
  variance_weights = NULL
) {
  return(invisible(
    (RandomEffectsDataset$new(group_labels, basis, variance_weights))
  ))
}

Try the stochtree package in your browser

Any scripts or data that you put into this service are public.

stochtree documentation built on April 29, 2026, 1:08 a.m.