knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
Visualization package for ML models.
This package contains four functions to allow users to conveniently plot various visualizations as well as compare performance of different classifier models. The purpose of this package is to reduce time spent on developing visualizations and comparing models, to speed up the model creation process for data scientists. The four functions will perform the following tasks:
Compare the performance of various models
Plot the confusion matrix based on the input data
Plot train/validation errors vs. parameter values
Plot the ROC curve and calculate the AUC
| Function Name | Input | Output | Description | |-------------|-----|------|-----------| |model_comparison_table| train_data, test_data, ... | Dataframe of model score| Takes in training data, testing data, with the target as the last column and fitted models with meaningful names then outputs a table comparing the scores for different models.| |confusion_matrix | actual_y, predicted_y, labels, title| Confusion Matrix Plot, Dataframe of various scores (Recall, F1 and etc)| Takes in a trained model with X and y values to produce a confusion matrix visual. If predicted_y array is passed in,other evaluation scoring metrics such as Recall, and precision will also be produced.| |plot_train_valid_error| model_name, X_train, y_train, X_test, y_test, param_name, param_vec | Train/validation errors vs. parameter values plot | Takes in a model name, train/validation data sets, a parameter name and a vector of parameter values to try and then plots train/validation errors vs. parameter values.| |plot_roc|y_label, predict_proba |ROC plot| Takes in a vector of prediction labels and a vector of prediction probabilities and plots the ROC curve. The ROC curve also produces AUC score.|
model_comparison_table
``` {r 1, message=FALSE, warning=FALSE} library(RMLViz) library(mlbench) data(Sonar)
toy_classification_data <- dplyr::select(dplyr::as_tibble(Sonar), V1, V2, V3, V4, V5, Class)
train_ind <- caret::createDataPartition(toy_classification_data$Class, p=0.9, list=F)
train_set_cf <- toy_classification_data[train_ind, ] test_set_cf <- toy_classification_data[-train_ind, ]
gbm <- caret::train(Class~., train_set_cf, method="gbm", verbose=F) lm_cf <- caret::train(Class~., train_set_cf, method="LogitBoost", verbose=F)
model_comparison_table(train_set_cf, test_set_cf, gbm_mod=gbm, log_mod = lm_cf)
2. `confusion_matrix` ```r library(RMLViz) data(iris) set.seed(123) split <- caTools::sample.split(iris$Species, SplitRatio = 0.75) training_set <- subset(iris, split == TRUE) valid_set <- subset(iris, split == FALSE) X_train <- training_set[, -5] y_train <- training_set[, 5] X_valid <- valid_set[, -5] y_valid <- valid_set[, 5] predict <- class::knn(X_train, X_train, y_train, k = 5) confusion_matrix(y_train, predict)
plot_train_valid_error
``` {r 3, message=FALSE} library(RMLViz) data(iris)
set.seed(123) split <- caTools::sample.split(iris$Species, SplitRatio = 0.75)
training_set <- subset(iris, split == TRUE) valid_set <- subset(iris, split == FALSE)
X_train <- training_set[, -5] y_train <- training_set[, 5] X_valid <- valid_set[, -5] y_valid <- valid_set[, 5]
plot_train_valid_error('knn', X_train, y_train, X_valid, y_valid, 'k', seq(50))
4. `plot_roc` ``` {r 4, message=FALSE} library(RMLViz) set.seed(420) num.samples <- 100 weight <- sort(rnorm(n=num.samples, mean=172, sd=29)) obese <- ifelse(test=(runif(n=num.samples) < (rank(weight)/num.samples)), yes=1, no=0) glm.fit=glm(obese ~ weight, family=binomial) obese_proba <- glm.fit$fitted.values plot_roc(obese, obese_proba)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.