Multinomial Probit Bayesian Additive Regression Trees

Share:

Description

A function to implement multinomial probit regression via Bayesian Addition Regression Trees using partial marginal data augmentation.

Usage

1
2
rmpbart(x.train, y.train, x.test = NULL, Prior = NULL, Mcmc = NULL,
  seedvalue = NULL)

Arguments

x.train

Training data predictors.

y.train

Training data observed classes.

x.test

Test data predictors.

Prior

List of Priors for MPBART: e.g., Prior = list(nu=p+2, V= diag(p - 1), ntrees=200, kfac=2.0, pbd=1.0, pb=0.5 , beta = 2.0, alpha = 0.95, nc = 100, priorindep = 0, minobsnode = 10)

Mcmc

List of MCMC starting values, burn-in ...: e.g., list(sigma0 = diag(p - 1), keep = 1, burn = 100, ndraws = 1000, keep_sigma_draws=FALSE)

seedvalue

random seed value: e.g., seedvalue = 99

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
set.seed(64)
library(mpbart)
p=3
train_wave = mlbench.waveform(50)
test_wave = mlbench.waveform(100)
traindata = data.frame(train_wave$x, y = train_wave$classes) 
testdata = data.frame(test_wave$x, y = test_wave$classes)

x.train = data.frame(train_wave$x)
x.test = data.frame(test_wave$x)

y.train = train_wave$classes

sigma0 = diag(p-1)
burn = 100
ndraws = 200 # a higher number >=1000 is more appropriate.

Mcmc1=list(sigma0=sigma0, burn = burn, ndraws = ndraws)
Prior1 = list(nu=p+2,
              V=(p+2)*diag(p-1),
              ntrees = 5, #typically 200 trees is good 
              kfac = 2.0, 
              pbd = 1.0, 
              pb = 0.5, 
              alpha = 0.99,  
              beta =  2.0, 
              nc = 200, 
              priorindep = FALSE)



out = rmpbart(x.train = x.train, y.train = y.train, x.test = x.test, 
            Prior = Prior1, Mcmc=Mcmc1, seedvalue = 99)

#confusion matrix train
table(y.train, out$predicted_class_train)
table(y.train==out$predicted_class_train)/sum(table(y.train==out$predicted_class_train))
  

#confusion matrix test
table(test_wave$classes, out$predicted_class_test)

test_err <- sum(test_wave$classes != out$predicted_class_test)/
    sum(table(test_wave$classes == out$predicted_class_test))

cat("test error :", test_err )

Want to suggest features or report bugs for rdrr.io? Use the GitHub issue tracker.