Nothing
#' Plot for machine learning models
#'
#' This function plots the mean absolute SHAP values for the ReSurv fits of machine learning models.
#'
#' @param x \code{ReSurvFit} x.
#' @param nsamples \code{integer}, number of observations to sample for neural networks features importance plot.
#' @param ... Other arguments to be passed to plot.
#'
#' @return \code{ggplot2} of the SHAP values for an \code{"XGB"} model or a \code{"NN"} model.
#'
#' @import SHAPforxgboost
#' @import ggplot2
#' @importFrom tibble rownames_to_column
#'
#' @export
#' @method plot ReSurvFit
plot.ReSurvFit <- function(x,
nsamples=NULL,
...
){
hazard_model <- x$hazard_model
output.fit <- x$model.out
if(hazard_model=="XGB"){
#we need the following
shap_values <- shap.values(xgb_model = output.fit$model.out,
X_train = as.matrix(output.fit$data))
df.2.plot <- apply(abs(shap_values$shap_score),2,mean)
plot.color <- "royalblue"
}
if(hazard_model=="NN"){
shap <- reticulate::import("shap")
x_fc= reticulate::np_array(as.matrix( output.fit$data), dtype = "float32")
explainer = shap$KernelExplainer(output.fit$model.out$predict,
x_fc)
if(!is.null(nsamples)){
x_fc <- shap$sample(x_fc,as.integer(nsamples))
x_fc <- reticulate::np_array(as.matrix(x_fc), dtype = "float32")
}
shap_values = explainer$shap_values(x_fc)[[1]]
colnames(shap_values) <- colnames(output.fit$data)
df.2.plot <- apply(abs(shap_values),MARGIN = 2,mean)
plot.color <- "#a71429"
}
df.2.plot %>%
reshape2::melt(df.2.plot, na.rm = FALSE, value.name = "value", id = NULL) %>%
rownames_to_column(var = "feature") %>%
ggplot(aes(x=feature, y=value)) +
geom_bar(stat = "identity", fill=plot.color) +
coord_flip() +
labs(title=" ",
x="",
y="mean(|SHAP|)") +
theme_bw()
}
#
# library(SHAPforxgboost)
# dataX=resurv.fit.xgb$model.out$data
# #
#
# library(xgboost)
# featImp_RBNS <- xgb.importance(model=resurv.fit.xgb$model.out$model.out)
# xgb.plot.importance(featImp_RBNS, main="Feature Importance - RBNS")
#
# shap_values <- shap.values(xgb_model = resurv.fit.xgb$model.out$model.out, X_train = dataX)
# shap_values <- shap.values(xgb_model = resurv.fit.xgb$model.out$model.out, X_train = as.matrix(dataX))
# #
# shap_long <- shap.prep(shap_contrib = shap_values$shap_score, X_train = dataX)
# shap_long <- shap.prep(xgb_model = resurv.fit.xgb$model.out$model.out, X_train = as.matrix(resurv.fit.xgb$model.out$data))
#
# #
# # Return the SHAP values and ranked features by mean|SHAP|
# shap_values <- shap.values(xgb_model = xgb_RBNS_Fit, X_train = as.matrix(df.RBNS_train))
#
# # Prepare the long-format data:
# shap_long <- shap.prep(shap_contrib = shap_values$shap_score, X_train = as.matrix(df.RBNS_train))
#
# # **SHAP summary plot**
# shap.plot.summary(shap_long)
#
#
# shap <- reticulate::import("shap")
#
#
# Kernel
# compute SHAP values
# explainer = shap$DeepExplainer(output.fit$model.out$predict,
# x_fc2)
# shap_values = explainer.shap_values(x_fc)
#
# x_fc2 <- shap$sample(x_fc,as.integer(5))
#
# x_fc2 <- reticulate::np_array(as.matrix(x_fc2), dtype = "float32")
#
#
# shap_values = explainer$shap_values(x_fc2)
#
# shap$summary_plot(shap_values[[1]],x_fc2)
#
# resurv.fit.deepsurv$model.out$model.out
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.