View source: R/softbart_probit.R
softbart_probit | R Documentation |
Fits a nonparametric probit regression model with the nonparametric function
modeled using a SoftBart model. Specifically, the model takes \Pr(Y = 1
\mid X = x) = \Phi\{a + r(x)\}
where
a
is an offset and r(x)
is a Soft BART ensemble.
softbart_probit(
formula,
data,
test_data,
num_tree = 20,
k = 1,
hypers = NULL,
opts = NULL,
verbose = TRUE
)
formula |
A model formula with a binary factor on the left-hand-side and predictors on the right-hand-side. |
data |
A data frame consisting of the training data. |
test_data |
A data frame consisting of the testing data. |
num_tree |
The number of trees in the ensemble to use. |
k |
Determines the standard deviation of the leaf node parameters, which is given by |
hypers |
A list of hyperparameters constructed from the |
opts |
A list of options for running the chain constructed from the |
verbose |
If |
Returns a list with the following components:
sigma_mu
: samples of the standard deviation of the leaf node parameters
var_counts
: a matrix with a column for each predictor group containing the number of times each predictor is used in the ensemble at each iteration.
mu_train
: samples of the nonparametric function evaluated on the training set; pnorm(mu_train)
gives the success probabilities.
mu_test
: samples of the nonparametric function evaluated on the test set; pnorm(mu_train)
gives the success probabilities .
p_train
: samples of probabilities on training set.
p_test
: samples of probabilities on test set.
mu_train_mean
: posterior mean of mu_train
.
mu_test_mean
: posterior mean of mu_test
.
p_train_mean
: posterior mean of p_train
.
p_test_mean
: posterior mean of p_test
.
offset
: we fit model of the form (offset + BART), with the offset estimated empirically prior to running the chain.
pnorm_offset
: the pnorm
of the offset, which is chosen to match the probability of the second factor level.
formula
: the formula specified by the user.
ecdfs
: empirical distribution functions, used by the predict
function.
opts
: the options used when running the chain.
forest
: a forest object; see the MakeForest
documentation for more details.
## NOTE: SET NUMBER OF BURN IN AND SAMPLE ITERATIONS HIGHER IN PRACTICE
num_burn <- 10 ## Should be ~ 5000
num_save <- 10 ## Should be ~ 5000
set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
gen_data <- function(n_train, n_test, P, sigma) {
X <- matrix(runif(n_train * P), nrow = n_train)
mu <- (f_fried(X) - 14) / 5
X_test <- matrix(runif(n_test * P), nrow = n_test)
mu_test <- (f_fried(X_test) - 14) / 5
Y <- factor(rbinom(n_train, 1, pnorm(mu)), levels = c(0,1))
Y_test <- factor(rbinom(n_test, 1, pnorm(mu_test)), levels = c(0,1))
return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test,
mu_test = mu_test))
}
## Simiulate dataset
sim_data <- gen_data(250, 250, 100, 1)
df <- data.frame(X = sim_data$X, Y = sim_data$Y)
df_test <- data.frame(X = sim_data$X_test, Y = sim_data$Y_test)
## Fit the model
opts <- Opts(num_burn = num_burn, num_save = num_save)
fitted_probit <- softbart_probit(Y ~ ., df, df_test, opts = opts)
## Plot results
plot(fitted_probit$mu_test_mean, sim_data$mu_test)
abline(a = 0, b = 1)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.