sample_ctree: Sample ctree variables from a given conditional inference...

View source: R/sampling.R

sample_ctreeR Documentation

Sample ctree variables from a given conditional inference tree

Description

Sample ctree variables from a given conditional inference tree

Usage

sample_ctree(tree, n_samples, x_test, x_train, p, sample)

Arguments

tree

List. Contains tree which is an object of type ctree built from the party package. Also contains given_ind, the features to condition upon.

n_samples

Numeric. Indicates how many samples to use for MCMC.

x_test

Matrix, data.frame or data.table with the features of the observation whose predictions ought to be explained (test data). Dimension 1xp or px1.

x_train

Matrix, data.frame or data.table with training data.

p

Positive integer. The number of features.

sample

Boolean. True indicates that the method samples from the terminal node of the tree whereas False indicates that the method takes all the observations if it is less than n_samples.

Value

data.table with n_samples (conditional) Gaussian samples

Author(s)

Annabelle Redelmeier

Examples

if (requireNamespace("MASS", quietly = TRUE) & requireNamespace("party", quietly = TRUE)) {
  m <- 10
  n <- 40
  n_samples <- 50
  mu <- rep(1, m)
  cov_mat <- cov(matrix(rnorm(n * m), n, m))
  x_train <- data.table::data.table(MASS::mvrnorm(n, mu, cov_mat))
  x_test <- MASS::mvrnorm(1, mu, cov_mat)
  x_test_dt <- data.table::setDT(as.list(x_test))
  given_ind <- c(4, 7)
  dependent_ind <- (1:dim(x_train)[2])[-given_ind]
  x <- x_train[, given_ind, with = FALSE]
  y <- x_train[, dependent_ind, with = FALSE]
  df <- data.table::data.table(cbind(y, x))
  colnames(df) <- c(paste0("Y", 1:ncol(y)), paste0("V", given_ind))
  ynam <- paste0("Y", 1:ncol(y))
  fmla <- as.formula(paste(paste(ynam, collapse = "+"), "~ ."))
  datact <- party::ctree(fmla, data = df, controls = party::ctree_control(
    minbucket = 7,
    mincriterion = 0.95
  ))
  tree <- list(tree = datact, given_ind = given_ind, dependent_ind = dependent_ind)
  shapr:::sample_ctree(
    tree = tree, n_samples = n_samples, x_test = x_test_dt, x_train = x_train,
    p = length(x_test), sample = TRUE
  )
}

shapr documentation built on May 4, 2023, 5:10 p.m.