Nothing
## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.width = 7,
fig.height = 6,
fig.align = "center"
)
## ---- message=FALSE, warning=FALSE--------------------------------------------
library(parsnip)
library(probably)
library(dplyr)
library(rsample)
library(modeldata)
data("lending_club")
# I think it makes more sense to have "good" as the first level
# By default it comes as the second level
lending_club <- lending_club %>%
mutate(Class = relevel(Class, "good"))
# There are a number of columns in this data set, but we will only use a few
# for this example
lending_club <- select(lending_club, Class, annual_inc, verification_status, sub_grade)
lending_club
## -----------------------------------------------------------------------------
# 75% train, 25% test
set.seed(123)
split <- initial_split(lending_club, prop = 0.75)
lending_train <- training(split)
lending_test <- testing(split)
## -----------------------------------------------------------------------------
count(lending_train, Class)
## -----------------------------------------------------------------------------
logi_reg <- logistic_reg()
logi_reg_glm <- logi_reg %>% set_engine("glm")
# A small model specification that defines the type of model you are
# using and the engine
logi_reg_glm
# Fit the model
logi_reg_fit <- fit(
logi_reg_glm,
formula = Class ~ annual_inc + verification_status + sub_grade,
data = lending_train
)
logi_reg_fit
## -----------------------------------------------------------------------------
predictions <- logi_reg_fit %>%
predict(new_data = lending_test, type = "prob")
head(predictions, n = 2)
lending_test_pred <- bind_cols(predictions, lending_test)
lending_test_pred
## -----------------------------------------------------------------------------
hard_pred_0.5 <- lending_test_pred %>%
mutate(
.pred = make_two_class_pred(
estimate = .pred_good,
levels = levels(Class),
threshold = .5
)
) %>%
select(Class, contains(".pred"))
hard_pred_0.5 %>%
count(.truth = Class, .pred)
## -----------------------------------------------------------------------------
hard_pred_0.75 <- lending_test_pred %>%
mutate(
.pred = make_two_class_pred(
estimate = .pred_good,
levels = levels(Class),
threshold = .75
)
) %>%
select(Class, contains(".pred"))
hard_pred_0.75 %>%
count(.truth = Class, .pred)
## ---- echo=FALSE--------------------------------------------------------------
correct_bad <- nrow(filter(hard_pred_0.75, Class == "bad", .pred == "bad"))
## -----------------------------------------------------------------------------
library(yardstick)
sens(hard_pred_0.5, Class, .pred)
spec(hard_pred_0.5, Class, .pred)
sens(hard_pred_0.75, Class, .pred)
spec(hard_pred_0.75, Class, .pred)
## -----------------------------------------------------------------------------
j_index(hard_pred_0.5, Class, .pred)
j_index(hard_pred_0.75, Class, .pred)
## -----------------------------------------------------------------------------
threshold_data <- lending_test_pred %>%
threshold_perf(Class, .pred_good, thresholds = seq(0.5, 1, by = 0.0025))
threshold_data %>%
filter(.threshold %in% c(0.5, 0.6, 0.7))
## -----------------------------------------------------------------------------
library(ggplot2)
threshold_data <- threshold_data %>%
filter(.metric != "distance") %>%
mutate(group = case_when(
.metric == "sens" | .metric == "spec" ~ "1",
TRUE ~ "2"
))
max_j_index_threshold <- threshold_data %>%
filter(.metric == "j_index") %>%
filter(.estimate == max(.estimate)) %>%
pull(.threshold)
ggplot(threshold_data, aes(x = .threshold, y = .estimate, color = .metric, alpha = group)) +
geom_line() +
theme_minimal() +
scale_color_viridis_d(end = 0.9) +
scale_alpha_manual(values = c(.4, 1), guide = "none") +
geom_vline(xintercept = max_j_index_threshold, alpha = .6, color = "grey30") +
labs(
x = "'Good' Threshold\n(above this value is considered 'good')",
y = "Metric Estimate",
title = "Balancing performance by varying the threshold",
subtitle = "Sensitivity or specificity alone might not be enough!\nVertical line = Max J-Index"
)
## -----------------------------------------------------------------------------
threshold_data %>%
filter(.threshold == max_j_index_threshold)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.