#' Interactive Plots for Interpretable Machine Learning
#' @description
#' The function \code{imlplots()} creates an interactive shiny based dashboard
#' for visualizing the effects of statistical models.
#' The utilization of mlr (Machine Learning in R) is necessary.
#' For more infos go to \url{https://github.com/mlr-org}
#'
#' There are three types of plots: Partial Dependence Plots (PDP), Individual Conditional
#' Expectation (ICE) plots and Accumulated Local Effects (ALE) plots.
#' @param data Input data frame.
#' Has to contain exactly the same variables as the training data.
#' @param task The mlr task the models were being trained on,
#' e.g. \code{iris.task = makeClassifTask(data = iris, target = "Species")}.
#' Classification and regression tasks are supported.
#' @param models A list of mlr trained models, e.g. \code{list(rf.mod, glm.mod)}. \cr
#' You can provide differently tuned models of the same learner by assigning
#' a unique ID to the learner, e.g.
#' \code{makeLearner("regr.randomForest", id = "ownId")}
#'
#' @param model.check A string. A model check is performed upon initialization,
#' whether the provided models can be used to properly predict. \cr
#' 'all.features' iteratively checks all model/feature combinations. 'sample.feature' randomly
#' selects a single feature from the feature space and checks all models with
#' it.
#'
#' @examples
#' tsk = makeRegrTask(data = boston, target = "medv")
#' mod.rf = train("regr.randomForest", task = tsk)
#' mod.glm = train("regr.glm", task = tsk)
#' imlplots(boston, tsk, list(mod.rf, mod.glm))
#'
#' @note
#' The plots display combinations of different inputs and outputs/ predictions.
#' Therefore they are highly sensitive to the trained and provided models.
#'
#' The variable of interest provides variations of different inputs, while all other
#' variables are held constant. You can look at how the predictions change,
#' if you had provided different test data, by either filtering/ subsetting
#' the data or manually setting a variable to a fixed value for all observations.
#'
#' The function performs a basic check upon initialization,
#' whether the provided models can be used to properly predict.
#' If the check fails, it is recommended to manually test the model with the
#' \code{marginalPrediction()} function of the mmpf package.
#'
#' @author Julia Fried, Tobias Riebe, Christian Scholbeck; in cooperation with
#' the working group for computational statistics at
#' Ludwigs-Maximilians-University Munich.
#'
#' @references
#'
#' Apley (2016). "Visualizing the Effects of Predictor Variables in Black Box Supervised
#' Learning Models"
#'
#' Bischl et. al (2016). "mlr: Machine Learning in R." Journal of Machine Learning
#' Research, 17(170), pp.
#'
#' Friedman, J.H. (2001). "Greedy Function Approximation: A Gradient Boosting
#' Machine." Annals of Statistics 29: 1189 - 1232.
#'
#' Goldstein et al. (2013). "Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of
#' Individual Conditional Expectation"
#'
#' Jones (2017). "mmpf: Monte-Carlo Methods for Prediction Functions "The R Journal Vol. XX/YY, AAAA 20ZZ
#' @export
imlplots = function(data, task, models, model.check = "all.features") {
if (!(is.vector(models))) {models = list(models)}
assertDataFrame(data)
assertClass(task, classes = "Task")
lapply(models, FUN = function(elem) assertClass(elem, class = "WrappedModel"))
learner.models = lapply(models, function(x) x[["learner.model"]])
learner.models.names = lapply(models, function(x) x[["learner"]][["id"]])
target = getTaskDesc(task)$target
task.type = getTaskDesc(task)$type
if (task.type == "regr") {
target.levels = levels(data[[target]])
} else {
target.levels = NULL
}
features = names(data)[!names(data) %in% target]
features.numeric = features[sapply(data[!names(data) %in% target], is.numeric)]
features.factor = features[sapply(data[!names(data) %in% target], is.factor)]
do.call(
modelCheck,
list(data = data, models = models, features = features, model.check))
# basic check whether provided models throw error when using
# marginalPrediction()
app.ui = dashboardPage(
shinyjs::useShinyjs(),
header = dashboardHeader(
title = "LMU Data Science Innovationslabor",
titleWidth = 350,
tags$li(
class = "dropdown",
actionButton(
"reload", "Reload application",
width = "100%",
icon("refresh"),
style = "font-size: 16px; color: #fff; background-color: #337ab7;
border-color: #2e6da4; padding: 13px"))),
sidebar = dashboardSidebar(disable = TRUE),
body = dashboardBody(
h2("Interactive plots for interpretable machine learning"),
fluidRow(
tabBox(
width = 12,
height = NULL,
tabPanel(title = "Data", {
# data table with filters
fluidRow(
column(
width = 12,
box(
width = NULL,
status = "primary",
title = "Data used for plotting ICE curves",
selectInput("data_selection_mode",
"Select observations to sample from",
choices = c(
"Plot all sampled observations",
paste("Plot individual observations",
"(only influences computations of ICE and PDP curves)")
)
),
div(style = "overflow-x: scroll",
DT::dataTableOutput("table")
)
)
)
)
}),
tabPanel(
# plot settings and preview
title = "Settings",
fluidRow(
column(
width = 3,
# left column with plot settings
box(
title = "Plot settings",
width = NULL,
status = "primary",
selectInput("gfx.package", "Select graphics package",
choices = c("ggplot2",
"plotly (resource intensive)")
),
selectInput("models", "Choose predictive model",
choices = learner.models.names,
selected = 1, multiple = FALSE),
selectInput(
"plot_type",
"Choose plot type",
choices = c(
"Individual Conditional Expectation",
"Partial Dependence",
"Accumulated Local Effects"),
selected = 1, multiple = FALSE),
conditionalPanel(
condition =
"input.plot_type ==
'Individual Conditional Expectation'",
selectInput(
"iceplot_mode", "Ice plot mode",
choices = c("Regular",
"Centered"),
selected = "Regular")
),
conditionalPanel(
condition =
"input.plot_type == 'Individual Conditional Expectation'
& input.iceplot_mode == 'Centered'",
uiOutput("iceplot_center")
),
conditionalPanel(
condition =
"input.plot_type == 'Accumulated Local Effects'",
selectInput(
"aleplot_mode", "ALE Plot Mode",
choices = c("Main Effects",
"Second Order Effects"),
selected = "ALE Main Effects")
),
selectInput(
"var", "Variable of interest",
choices = features,
selected = NULL, multiple = FALSE),
conditionalPanel(
condition =
"input.plot_type == 'Accumulated Local Effects' &
input.aleplot_mode == 'Second Order Effects'",
uiOutput("ale_interaction")
),
uiOutput("checkbox"), # checkbox for adjusting the data
uiOutput("knots"),
uiOutput("lines")
),
box(
title = "Adjust feature values",
width = NULL,
uiOutput("sliders"),
uiOutput("selectors")
)
),
column(
width = 9,
# right column with plot preview
fluidRow(
style = "position:fixed; width:70%;",
uiOutput("plot")
)
)
)
),
tabPanel(
title = "Plots",
tabsetPanel(
tabPanel(
title = "Zoomed Plot",
fluidRow(
column(
width = 12,
box(
width = NULL,
height = "600px",
status = "primary",
uiOutput("zoomed_plot")
)
)
)
),
tabPanel(
title = "Scatterplot",
fluidRow(
column(
width = 12,
box(
width = NULL,
status = "primary",
h2("Filtered data"),
uiOutput("scatter_filtered"),
HTML('<hr style="color: purple;">'),
h2("Unfiltered data"),
uiOutput("scatter_unfiltered")
)
)
)
)
)
),
tabPanel(
title = "Learner Summary",
fluidRow(
column(
width = 12,
box(
width = NULL,
status = "primary",
verbatimTextOutput("learner_summary")
)
)
)
)
)
)
)
)
app.server = function(input, output, session) {
# --------------------------------------------------------------------------
# reactive values
df = reactiveValues(
# reactive values for current (adjusted) data frame, available features
# and prediction values
values.adj = data, features = NULL, pred = NULL,
table.rows.selected = NULL)
selected = reactiveValues(
# reactive values only for selected values
knots = 30, lines = 30)
# --------------------------------------------------------------------------
# ui outputs
output$iceplot_center = renderUI({
knot.values = df$pred[, 1, with = FALSE]
# sampled values appearing on the horizontal axis in ICE and PDP
selectInput(
"iceplot_center_x",
"Select horizontal axis value to center ICE curves around
(depends on knots)",
choices = knot.values, selected = selected$iceplot.center.x)
})
output$ale_interaction = renderUI({
variable.options = features[!features %in% c(selected$var)]
selectInput("ale_interaction_var",
"ALE interaction variable",
choices = variable.options)
})
output$checkbox = renderUI({
# checkbox for adjustable features
checkboxGroupInput("checks", "Select adjustable features",
choices = df$features.unused)
})
sliderList = reactive({
# list of strings with input$[feature] for numeric features
lapply(selected$features.numeric, FUN = function(feature) {
text = paste("input$", feature, sep = "")})
})
selectorList = reactive({
# list of strings with input$[feature] for factor features
lapply(selected$features.factor, FUN = function(feature) {
text = paste("input$", feature, sep = "")})
})
featureSliders = reactive({
# make reactive list of slider input expressions (for numeric variables)
# by parsing strings from slider.list
# use this function to capture inputs from feature sliders
lapply(sliderList(), FUN = function(x) eval(parse(text = x)))
})
featureSelectors = reactive({
# make reactive list of selector input expressions (for factor variables)
# by parsing strings from selector.list
# use this function to capture inputs from feature selectors
lapply(selectorList(), FUN = function(x) eval(parse(text = x)))
})
observeEvent({
# rendering feature sliders for numeric features
selected$features.numeric},
{
output$sliders = renderUI({
if (is.null(selected$features.numeric)) {
} else {
sliders = lapply(1:length(selected$features.numeric), function(i) {
input.name = selected$features.numeric[i]
min = min(data[[input.name]])
max = max(data[[input.name]])
decimal.places = max(getDecimalPlaces(min), getDecimalPlaces(max))
step.length = createSliderStepLength(decimal.places)
mean = mean(data[[input.name]])
sliderInput(
inputId = input.name,
label = input.name,
min = as.numeric(min),
max = as.numeric(max),
value = mean,
step = step.length,
sep = "")
})
do.call(tagList, sliders)
# create HTML tags for all sliders
}
})
}
)
observeEvent({
# render feature selectors
selected$features.factor},
{
output$selectors = renderUI({
if (is.null(selected$features.factor)) {
} else {
selectors = lapply(1:length(selected$features.factor), function(i) {
input.name = selected$features.factor[i]
factor.levels = levels(data[[input.name]])
selectInput(input.name, input.name, choices = factor.levels)})
do.call(tagList, selectors)
# create HTML tags for all selectors
}
})
}
)
output$knots = renderUI({
if (is.null(selected$knots)) {
selected$knots = 30
}
# setting to 30 upon initialization; setting init value in selected
# does not work
if (selected$plot.type == "ale") {
sliderInput(
"knots",
"Number of intervals into which the predictor range is divided",
min = 1,
max = 100,
value = selected$knots,
step = 1)
} else {
sliderInput(
"knots",
"Number of knots for each line",
min = 1,
max = nrow(df$values.filtered),
value = selected$knots,
step = 1)
}
})
output$lines = renderUI({
if (is.null(selected$lines)) {
selected$lines = 30
}
# setting to 30 upon initialization; setting init value in selected
# does not work
if (selected$data.selection.mode == "individual" ||
selected$plot.type == "ale") {
selected$lines = 10
# makes lines have full width
return(invisible(NULL))
} else {
sliderInput(
"lines",
"Number of individual observations (lines) to sample from data",
min = 1,
max = nrow(df$values.filtered),
value = selected$lines,
step = 1)
}
})
output$learner_summary = renderPrint({
capture.output(selected$model)
})
output$table = DT::renderDataTable({
shiny::req(!is.null(selected$datatable.select))
DT::datatable(
df$values.adj,
filter = list(position = "top", clear = TRUE, plain = TRUE),
selection = selected$datatable.select
)},
server = TRUE
)
# --------------------------------------------------------------------------
# plot outputs and event reactive functions
output$scatter_unfiltered_basic = renderPlot({
# render ggplot2 object of unfiltered scatter plot
scatterPlotUnfiltered()},
width = 800,
height = 400
)
output$scatter_unfiltered_plotly = renderPlotly({
# render plotly object of unfiltered scatter plot
p = plotly_build(scatterPlotUnfiltered())
p$elementId = NULL
p$x$layout$width = 800
p$x$layout$height = 400
p$width = NULL
p$height = NULL
return(p)
})
# both ggplot2 and plotly versions are ready to be rendered upon
# changing the UI selections; lazy loading keeps one version from not
# being rendered
output$scatter_filtered_basic = renderPlot({
# render ggplot2 object of filtered scatter plot
scatterPlotFiltered()},
width = 800,
height = 400
)
output$scatter_filtered_plotly = renderPlotly({
# render plotly object of filtered scatter plot
p = plotly_build(scatterPlotFiltered())
p$elementId = NULL
p$x$layout$width = 800
p$x$layout$height = 400
p$width = NULL
p$height = NULL
return(p)
})
# both ggplot2 and plotly versions are ready to be rendered upon
# changing the UI selections; lazy loading keeps one version from not
# being rendered
output$scatter_unfiltered = renderUI({
# decide which rendered unfiltered scatter plot to display in UI based on
# selected$gfx.package
if (selected$gfx.package == "plotly") {
plotlyOutput("scatter_unfiltered_plotly")
} else if (selected$gfx.package == "ggplot2") {
plotOutput("scatter_unfiltered_basic")
}
})
output$scatter_filtered = renderUI({
# decide which rendered filtered scatter plot to display in UI based on
# selected$gfx.package
if (selected$gfx.package == "plotly") {
plotlyOutput("scatter_filtered_plotly")
} else if (selected$gfx.package == "ggplot2") {
plotOutput("scatter_filtered_basic")
}
})
scatterPlotUnfiltered = eventReactive({
# plot function for unfiltered scatter plot
selected$table.rows
selected$plot.type
selected$aleplot.mode
selected$gfx.package
selected$var
selected$ale.interaction},
ignoreNULL = FALSE,
{
if (selected$plot.type == "ale" &&
selected$aleplot.mode == "Second Order Effects" &&
selected$gfx.package == "plotly") {
scatterPlot3D(
data = data, target = target,
var = c(selected$var, selected$ale.interaction),
highlighted = selected$table.rows)
} else {
scatterPlot(
data = data, target = target,
var = selected$var,
highlighted = selected$table.rows)
}
}
)
scatterPlotFiltered = eventReactive({
# plot function for filtered scatter plot
df$values.filtered
selected$table.rows
selected$plot.type
selected$aleplot.mode
selected$gfx.package
selected$var
selected$ale.interaction},
ignoreNULL = FALSE,
{
if (selected$plot.type == "ale" &&
selected$aleplot.mode == "Second Order Effects" &&
selected$gfx.package == "plotly") {
scatterPlot3D(
data = df$values.filtered, target = target,
var = c(selected$var, selected$ale.interaction),
highlighted = selected$table.rows)
} else {
scatterPlot(
data = df$values.filtered, target = target,
var = selected$var,
highlighted = selected$table.rows)
}
})
output$iml_plotly_plot = renderPlotly({
# rendering plotly version of imlplot output
p = plotly_build(imlPlot())
p$elementId = NULL
p$x$layout$width = 800
p$x$layout$height = 400
p$width = NULL
p$height = NULL
return(p)
})
output$iml_basic_plot = renderPlot({
# rendering ggplot2 version of imlplot output
imlPlot()},
width = 800,
height = 400
)
output$plot = renderUI({
# decide which plot to display in UI based on selected$gfx.package
if (selected$gfx.package == "plotly") {
plotlyOutput("iml_plotly_plot")
} else if (selected$gfx.package == "ggplot2") {
plotOutput("iml_basic_plot")
}
})
output$zoomed_plotly_plot = renderPlotly({
# render plotly version of zoomed imlplot
p = plotly_build(imlPlot())
p$elementId = NULL
p$x$layout$width = 1200
p$x$layout$height = 600
p$width = NULL
p$height = NULL
return(p)
})
output$zoomed_basic_plot = renderPlot({
# render ggplot2 version of zoomed imlplot
imlPlot()},
width = 1200,
height = 600
)
output$zoomed_plot = renderUI({
# decide which zoomed plot version to display in UI based on
# selected$gfx.package
if (selected$gfx.package == "plotly") {
plotlyOutput("zoomed_plotly_plot")
} else if (selected$gfx.package == "ggplot2") {
plotOutput("zoomed_basic_plot")
}
})
imlPlot = eventReactive({
# plots the predicted values by calling predefined plot functions with
# current reactive values
df$pred
df$values.filtered
selected$table.rows
selected$data.selection.mode
selected$plot.type
selected$iceplot.center.x
selected$gfx.package
selected$var
selected$ale.interaction
selected$knots
selected$lines},
ignoreInit = FALSE,
ignoreNULL = FALSE,
{
shiny::req(!is.null(df$pred))
shiny::req((selected$var %in% names(df$pred)) ||
"error" %in% df$pred)
shiny::req(selected$lines)
if (!"error" %in% df$pred) {
shiny::req(
!(TRUE %in% apply(df$pred, MARGIN = 2,
function(column) {NA %in% column})))
}
withProgress(
message = "Rendering plot..",
detail = "Please wait.",
min = 0, max = 100, value = 100,
{
if (nrow(df$values.filtered) == 0) {
plot = placeholderPlot()
return(plot)
} else if (
((selected$data.selection.mode == "individual") &&
(is.null(selected$table.rows)) &&
(!selected$plot.type == "ale"))) {
plot = placeholderPlot()
return(plot)
} else {
if (task.type == "regr") {
if (selected$plot.type == "ice") {
plot = regrIcePlot(
pred = df$pred,
var = selected$var,
target = target,
knots = selected$knots,
lines = selected$lines,
centered = selected$centered,
center.x = selected$iceplot.center.x)
return(plot)
} else if (selected$plot.type == "pdp") {
plot = regrPartialDependencePlot(
pred = df$pred,
var = selected$var,
target = target,
knots = selected$knots)
return(plot)
} else if (selected$plot.type == "ale") {
plot = regrAlePlot(
data = df$pred,
target = target,
var1 = selected$var,
var2 = selected$ale.interaction,
knots = selected$knots,
gfx.package = selected$gfx.package)
return(plot)
}
} else if (task.type == "classif") {
if (selected$plot.type == "ice") {
plot = classifIcePlot(
pred = df$pred,
var = selected$var,
knots = selected$knots,
lines = selected$lines,
centered = selected$centered,
center.x = selected$iceplot.center.x)
return(plot)
} else if (selected$plot.type == "pdp") {
plot = classifPartialDependencePlot(
pred = df$pred,
var = selected$var,
target = target,
knots = selected$knots)
return(plot)
} else if (selected$plot.type == "ale") {
plot = classifAlePlot(
data = df$pred,
target = target,
target.levels = target.levels,
var1 = selected$var,
var2 = selected$ale.interaction)
return(plot)
}
}
}
}
) # ending: withProgress(...)
})
# --------------------------------------------------------------------------
# Observers
observeEvent({
# observer for calculating predictions
df$values.filtered
selected$table.rows
selected$data.selection.mode
selected$iceplot.center.x
selected$centered
selected$model
selected$plot.type
selected$var
selected$ale.interaction
selected$knots
selected$lines},
{
shiny::req(nrow(df$values.filtered) > 0)
shiny::req(!(TRUE %in% apply(df$values.filtered, MARGIN = 2,
function(column) {NA %in% column})))
if (selected$plot.type == "ale") {
# use ALEPlot::ALEPlot(..) to predict for ale
shiny::withProgress(
message = "Calculating Predictions..",
detail = "Please wait.",
min = 0, max = 100, value = 100,
{
df$pred = makePredictionsAle(
data = df$values.filtered,
target = target,
model = selected$model,
var1 = selected$var,
var2 = selected$ale.interaction,
knots = selected$knots,
task.type = task.type)
}
)
} else {
# (selected$plot.type == "ice" || selected$plot.type == "pdp")
shiny::req(nrow(df$values.filtered) >= selected$knots)
shiny::req(nrow(df$values.filtered) >= selected$lines)
if (selected$data.selection.mode == "individual") {
# mmpf::marginalPrediction(...) marginalizes only over selected
# observations
shiny::req(!is.null(selected$table.rows))
shiny::req(selected$table.rows %in% as.numeric(
row.names(df$values.filtered)))
} else {}
shiny::withProgress(
message = "Calculating predictions..",
detail = "Please wait.",
min = 0, max = 100, value = 100,
{
prediction = makePredictionsIce(
data = df$values.filtered,
var = selected$var,
model = selected$model,
knots = selected$knots,
lines = selected$lines,
task.type = task.type,
selected.rows = selected$table.rows,
data.selection.mode = selected$data.selection.mode)
}
)
if (selected$centered == TRUE) {
# shiny::req(
# selected$iceplot.center.x %in% prediction[, 1, with = FALSE])
shiny::req(!is.null(selected$iceplot.center.x))
shiny::req(!is.na(selected$iceplot.center.x))
shiny::req(selected$var %in% names(prediction))
shiny::withProgress(
message = "Centering predictions..",
detail = "Please wait.",
min = 0, max = 100, value = 100,
{
df$pred = centerPredictions(
predictions = prediction,
center.x = selected$iceplot.center.x,
var = selected$var)
}
)
} else if (selected$centered == FALSE) {
df$pred = prediction
}
}
}
)
observeEvent({
# reload button action on top right corner
input$reload},
{
session$reload()
}
)
#
# observeEvent({
# # line sampling not necessary when individual observations are selected
# # or in ale plot mode
# selected$data.selection.mode
# selected$plot.type},
# {
# req(!is.null(selected$data.selection.mode))
# req(!is.null(selected$plot.type))
# if (selected$data.selection.mode == "individual") {
# shinyjs::disable("lines")
# } else {
# shinyjs::enable("lines")
# }
# }
# )
observeEvent({
input$iceplot_mode
selected$plot.type},
ignoreNULL = FALSE,
{
if (input$iceplot_mode == "Centered" && selected$plot.type == "ice") {
selected$centered = TRUE
} else {
selected$centered = FALSE
}
}
)
observeEvent({
input$gfx.package},
{
if (input$gfx.package == "ggplot2") {
selected$gfx.package = "ggplot2"
} else if (input$gfx.package == "plotly (resource intensive)") {
selected$gfx.package = "plotly"
}
}
)
observeEvent({
selected$plot.type},
{
if (task.type == "classif" && selected$plot.type == "ale") {
updateSelectInput(
session = session,
inputId = "aleplot_mode",
label = "For classification ALE plots only main effects are
supported",
choices = "Main Effects")
}
selected$ale.interaction = NULL
})
observeEvent({
# select model for marginal prediction function based on selected string
# in UI
input$models},
{
for (i in 1:length(learner.models.names)) {
model = learner.models[[i]]
if (input$models == learner.models.names[[i]]) {
selected$model = model
} else {}
}
}
)
observeEvent({
input$ale_interaction_var
selected$plot.type
input$aleplot_mode
selected$gfx.package},
ignoreNULL = FALSE,
{
if (selected$plot.type == "ale" &&
input$aleplot_mode == "Second Order Effects" &&
selected$gfx.package == "ggplot2") {
selected$ale.interaction = input$ale_interaction_var
} else if (selected$plot.type == "ale" &&
input$aleplot_mode == "Second Order Effects" &&
selected$gfx.package == "plotly") {
selected$ale.interaction = input$ale_interaction_var
} else {
selected$ale.interaction = NULL
}
}
)
observeEvent({
# reset the reactive data values when variable of interest changes
input$var},
{
df$values.adj = data
}
)
observeEvent({
# differentiates checked input features into numeric and factor features;
# important, because factor features are adjusted with selectors and need
# to be provided the right values when adjusting
input$checks},
ignoreNULL = FALSE,
{
numerics = c()
factors = c()
if (is.null(input$checks)) {
selected$features.numeric = NULL
selected$features.factor = NULL
} else {
for (elem in input$checks) {
if (elem %in% features.numeric) {
numerics = c(numerics, elem)
} else if (elem %in% features.factor) {
factors = c(factors, elem)
}
}
selected$features.numeric = numerics
selected$features.factor = factors
}
}
)
observeEvent({
# output$knots is a dynamic server side rendered UI and could be caught
# in endless loop when setting init value to old reactive value;
# hence the input is disabled upon init so that UI can finish rendering
input$knots},
{
shinyjs::disable("knots")
selected$knots = input$knots
}
)
observeEvent({
# output$lines is a dynamic server side rendered UI and could be caught
# in endless loop when setting init value to old reactive value;
# hence the input is disabled upon init so that UI can finish rendering
input$lines},
{
shinyjs::disable("lines")
selected$lines = input$lines
}
)
observeEvent({
# set reactive values to UI values
input$var
input$plot_type
input$aleplot_mode
input$data_selection_mode
input$checks
input$iceplot_center_x},
ignoreNULL = FALSE,
{
selected$var = input$var
selected$aleplot.mode = input$aleplot_mode
selected$features = input$checks
selected$iceplot.center.x = input$iceplot_center_x
if (input$data_selection_mode == "Plot all sampled observations") {
selected$data.selection.mode = "sampling"
} else if (input$data_selection_mode ==
paste("Plot individual observations",
"(only influences computations of ICE and PDP curves)")
) {
selected$data.selection.mode = "individual"
}
if (input$plot_type == "Individual Conditional Expectation") {
selected$plot.type = "ice"
} else if (input$plot_type == "Partial Dependence") {
selected$plot.type = "pdp"
} else if (input$plot_type == "Accumulated Local Effects") {
selected$plot.type = "ale"
}
}
)
observeEvent({
selected$data.selection.mode
input$table_rows_all
input$table_rows_selected},
ignoreNULL = FALSE,
{
if (selected$data.selection.mode == "sampling") {
selected$datatable.select = list(mode = "none")
df$table.rows.filtered = input$table_rows_all
# use all filtered observations from data table
selected$table.rows = NULL
} else if (selected$data.selection.mode == "individual") {
# only the selected observations in data are marginalized over,
# so sampling observations/lines is not necessary
selected$datatable.select = list(mode = "multiple")
df$table.rows.filtered = input$table_rows_all
# use all filtered observations from data table
selected$table.rows = input$table_rows_selected
# row indicies of selected observations in data table
}
}
)
observeEvent({
# available features for adjusting the data cannot contain variable
# of interest and (if ale selected) ale interaction variable
selected$var
selected$ale.interaction},
ignoreNULL = FALSE,
{
df$features.unused = names(data)[!names(data) %in% c(
target, selected$var,
selected$ale.interaction)]
}
)
proxy = DT::dataTableProxy("table")
observeEvent({
df$table.rows.filtered},
{
DT::selectRows(proxy, list())
}
)
observeEvent({
df$table.rows.filtered
df$values.adj},
ignoreNULL = FALSE,
ignoreInit = FALSE,
{
df$values.filtered = df$values.adj[df$table.rows.filtered, ]
}
)
observeEvent({
# capture slider and selector values
# set reactive values to NULL if input contains NA or NULL values
# this sometimes happens when the input value is evaluated before the UI
# has finished rendering
featureSliders()
featureSelectors()},
{
if ((TRUE %in% lapply(featureSliders(), is.null)) ||
(TRUE %in% lapply(featureSliders(), is.na))) {
selected$values.numeric = NULL
} else {
selected$values.numeric = unlist(
featureSliders(),
function(x) return(as.numeric(x)))
}
# now do the same with selectors
if ((TRUE %in% lapply(featureSelectors(), is.null)) ||
(TRUE %in% lapply(featureSelectors(), is.na))) {
selected$values.factor = NULL
} else {
selected$values.factor = unlist(
featureSelectors(),
function(x) return(as.numeric(x)))
}
}
)
observeEvent({
# update numeric feature data when user adjusts sliders
selected$values.numeric
selected$features.numeric},
ignoreNULL = FALSE,
{
if (!(is.null(selected$values.numeric))) {
for (i in 1:length(selected$features.numeric)) {
numeric.feature = selected$features.numeric[i]
numeric.value = selected$values.numeric[i]
df$values.adj[numeric.feature] = numeric.value
}
}
unselected.features = df$features.unused[
!df$features.unused %in% c(selected$features)]
df$values.adj[unselected.features] = data[unselected.features]
}
)
observeEvent({
# update factor feature data when user adjusts selectors
selected$values.factor
selected$features.factor},
ignoreNULL = FALSE,
{
if (!(is.null(selected$features.factor))) {
for (j in 1:length(selected$features.factor)) {
factor.feature = selected$features.factor[j]
factor.value = selected$values.factor[j]
factor.levels = levels(data[[factor.feature]])
df$values.adj[factor.feature] = factor(
x = factor.value,
levels = factor.levels)
}
}
unselected.features =
df$features.unused[!df$features.unused %in% c(selected$features)]
df$values.adj[unselected.features] = data[unselected.features]
}
)
} # end server
shinyApp(ui = app.ui, server = app.server)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.