library(keras3) use_backend("jax")
This example looks at the
Kaggle Credit Card Fraud Detection
dataset to demonstrate how
to train a classification model on data with highly imbalanced classes.
You can download the data by clicking "Download" at
the link, or if you're setup with a kaggle API key at
"~/.kaggle/kagle.json"
, you can run the following:
reticulate::py_install("kaggle", pip = TRUE) reticulate::py_available(TRUE) # ensure 'kaggle' is on the PATH system("kaggle datasets download -d mlg-ulb/creditcardfraud") zip::unzip("creditcardfraud.zip", files = "creditcard.csv")
library(readr) df <- read_csv("creditcard.csv", col_types = cols( Class = col_integer(), .default = col_double() )) tibble::glimpse(df)
val_idx <- nrow(df) %>% sample.int(., round( . * 0.2)) val_df <- df[val_idx, ] train_df <- df[-val_idx, ] cat("Number of training samples:", nrow(train_df), "\n") cat("Number of validation samples:", nrow(val_df), "\n")
counts <- table(train_df$Class) counts cat(sprintf("Number of positive samples in training data: %i (%.2f%% of total)", counts["1"], 100 * counts["1"] / sum(counts))) weight_for_0 = 1 / counts["0"] weight_for_1 = 1 / counts["1"]
feature_names <- colnames(train_df) %>% setdiff("Class") train_features <- as.matrix(train_df[feature_names]) train_targets <- as.matrix(train_df$Class) val_features <- as.matrix(val_df[feature_names]) val_targets <- as.matrix(val_df$Class) train_features %<>% scale() val_features %<>% scale(center = attr(train_features, "scaled:center"), scale = attr(train_features, "scaled:scale"))
model <- keras_model_sequential(input_shape = ncol(train_features)) |> layer_dense(256, activation = "relu") |> layer_dense(256, activation = "relu") |> layer_dropout(0.3) |> layer_dense(256, activation = "relu") |> layer_dropout(0.3) |> layer_dense(1, activation = "sigmoid") model
class_weight
argumentmetrics <- list( metric_false_negatives(name = "fn"), metric_false_positives(name = "fp"), metric_true_negatives(name = "tn"), metric_true_positives(name = "tp"), metric_precision(name = "precision"), metric_recall(name = "recall") ) model |> compile( optimizer = optimizer_adam(1e-2), loss = "binary_crossentropy", metrics = metrics ) callbacks <- list( callback_model_checkpoint("fraud_model_at_epoch_{epoch}.keras") ) class_weight <- list("0" = weight_for_0, "1" = weight_for_1) model |> fit( train_features, train_targets, validation_data = list(val_features, val_targets), class_weight = class_weight, batch_size = 2048, epochs = 30, callbacks = callbacks, verbose = 2 )
val_pred <- model %>% predict(val_features) %>% { as.integer(. > 0.5) } pred_correct <- val_df$Class == val_pred cat(sprintf("Validation accuracy: %.2f", mean(pred_correct))) fraudulent <- val_df$Class == 1 n_fraudulent_detected <- sum(fraudulent & pred_correct) n_fraudulent_missed <- sum(fraudulent & !pred_correct) n_legitimate_flagged <- sum(!fraudulent & !pred_correct)
At the end of training, out of
r prettyNum(nrow(val_df), big.mark = ",")
validation transactions, we
are:
r prettyNum(n_fraudulent_detected, big.mark = ",")
of them as
fraudulentr prettyNum(n_fraudulent_missed, big.mark = ",")
fraudulent transactionsr prettyNum(n_legitimate_flagged, big.mark = ",")
legitimate
transactionsIn the real world, one would put an even higher weight on class 1, so as to reflect that False Negatives are more costly than False Positives.
Next time your credit card gets declined in an online purchase -- this is why.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.