Nothing
#' Variable Importance plot
#'
#' Creates a input variable importance plot
#'@param model list containing a model of class "hs_rulefit".
#'@param top If a integer number is given only shows the top most important variables in the plot
#'@param var_names optional vector with the variable names to be shown in plot.
#'@import ggplot2
#'@examples
#' #Fit HorseRuleFit
#' x = matrix(rnorm(5000), ncol=10)
#' y = apply(x,1,function(x)sum(x[1:5])+rnorm(1))
#' hrres = HorseRuleFit(X = x, y=y,
#' thin=1, niter=100, burnin=10,
#' L=5, S=6, ensemble = "both", mix=0.3, ntree=100,
#' intercept=FALSE, linterms=1:10,
#' alpha=1, beta=2, linp = 1, restricted = 0)
#' Variable_importance(hrres)
#'@export
Variable_importance = function(model, top=NULL, var_names=NULL){
Importance = NULL
Input = NULL
if(class(model)!="HorseRulemodel"){
stop("Model must be of class HorseRulemodel")
}
if(is.null(var_names)){
var_names = colnames(model$X)
}
beta = (model$postdraws)$beta
rules = model$rules
inp = dim(model$X)[2]
lin = (model$modelstuff)$linterms
start = ifelse(model$modelstuff$intercept ==T, 1, 0) + length(lin)
p = dim(beta)[2]
if(is.null(var_names)){
var_names = sprintf("variable %i", 1:p)
}
samp = dim(beta)[1]
out = matrix(0, nrow=samp, ncol=inp)
for(i in ((start+1):p)) {
splitted = unlist(strsplit(rules[i-start], split = " & "))
len = length(splitted)
vars = unique(unlist(strsplit(splitted, "in|<=|>"))[seq(from=1, to=len*2, by=2)])
m = length(vars)
for(j in 1:m){
str = vars[j]
ind = as.numeric(regmatches(str, gregexpr("[0-9]+", str)))
out[, ind] = out[, ind] + abs(beta[,i])/m
}
}
if(length(lin)>0){
for(j in 1:length(lin)){
out[,lin[j]] = out[,lin[j]] + abs(beta[,j])
}
}
if(!is.null(top)){
normout = t(apply(out, 1, function(x)((x-min(x))/(max(x)-min(x)))))
meanimp = apply(normout,2,mean)
ind = order(meanimp, decreasing = T)[1:top]
topout = normout[,ind]
name = c()
val = c()
for(j in 1:top){
name = c(name, rep(var_names[ind][j], times = samp))
val = c(val, topout[,j])
}
} else {
normout = t(apply(out, 1, function(x)((x-min(x))/(max(x)-min(x)))))
meanimp = apply(normout,2,mean)
ind = order(meanimp, decreasing = T)
topout = normout[,ind]
name = c()
val = c()
for(j in 1:inp){
name = c(name, rep(var_names[ind][j], times = samp))
val = c(val, topout[,j])
}
}
frame = data.frame(factor(name, levels=unique(name)), val)
colnames(frame) =c("Input", "Importance")
p = ggplot(frame, aes(Input, Importance))
p + stat_boxplot(geom = "errorbar",colour = I("#3366FF"))+
geom_boxplot(colour = I("#3366FF")) +theme_bw() +theme(text = element_text(size=20, angle=90))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.