View source: R/bart_package_cross_validation.R
k_fold_cv | R Documentation |
Builds a BART model using a specified set of arguments to build_bart_machine
and estimates the out-of-sample performance by using k-fold cross validation.
k_fold_cv(X, y, k_folds = 5, folds_vec = NULL, verbose = FALSE, ...)
X |
Data frame of predictors. Factors are automatically converted to dummies interally. |
y |
Vector of response variable. If |
k_folds |
Number of folds to cross-validate over. This argument is ignored if |
folds_vec |
An integer vector of indices specifying which fold each observation belongs to. |
verbose |
Prints information about progress of the algorithm to the screen. |
... |
Additional arguments to be passed to |
For each fold, a new BART model is trained (using the same set of arguments) and its performance is evaluated on the holdout piece of that fold.
For regression models, a list with the following components is returned:
y_hat |
Predictions for the observations computed on the fold for which the observation was omitted from the training set. |
L1_err |
Aggregate L1 error across the folds. |
L2_err |
Aggregate L1 error across the folds. |
rmse |
Aggregate RMSE across the folds. |
folds |
Vector of indices specifying which fold each observation belonged to. |
For classification models, a list with the following components is returned:
y_hat |
Class predictions for the observations computed on the fold for which the observation was omitted from the training set. |
p_hat |
Probability estimates for the observations computed on the fold for which the observation was omitted from the training set. |
confusion_matrix |
Aggregate confusion matrix across the folds. |
misclassification_error |
Total misclassification error across the folds. |
folds |
Vector of indices specifying which fold each observation belonged to. |
This function is parallelized by the number of cores set in set_bart_machine_num_cores
.
Adam Kapelner and Justin Bleich
bartMachine
## Not run:
#generate Friedman data
set.seed(11)
n = 200
p = 5
X = data.frame(matrix(runif(n * p), ncol = p))
y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)
#evaluate default BART on 5 folds
k_fold_val = k_fold_cv(X, y)
print(k_fold_val$rmse)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.