Description Usage Arguments Details Value Examples
This function attempts to predict from Multi-Grained Scanning using xgboost.
1 2 |
model |
Type: list. A model trained by |
data |
Type: data.table. A data to predict on. If passing training data, it will predict as if it was out of fold and you will overfit (so, use the list |
folds |
Type: list. The folds as list for cross-validation if using the training data. Otherwise, leave |
dimensions |
Type: numeric. The dimensions of the data. Only supported is |
multi_class |
Type: numeric. How many classes you got. Set to 2 for binary classification, or regression cases. Set to |
data_start |
Type: vector of numeric. The initial prediction labels. Set to |
For implementation details of Cascade Forest / Complete-Random Tree Forest / Multi-Grained Scanning / Deep Forest, check this: https://github.com/Microsoft/LightGBM/issues/331#issuecomment-283942390 by Laurae.
A data.table or a list based on data
predicted using model
.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | ## Not run:
# Load libraries
library(data.table)
library(Matrix)
library(xgboost)
# Create data
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
agaricus_data_train <- data.table(as.matrix(agaricus.train$data))
agaricus_data_test <- data.table(as.matrix(agaricus.test$data))
agaricus_label_train <- agaricus.train$label
agaricus_label_test <- agaricus.test$label
folds <- Laurae::kfold(agaricus_label_train, 5)
# Train a model (binary classification) - FAST VERSION
model <- MGScanning(data = agaricus_data_train, # Training data
labels = agaricus_label_train, # Training labels
folds = folds, # Folds for cross-validation
dimensions = 1, # Change this for 2 dimensions if needed
depth = 10, # Change this to change the sliding window size
stride = 20, # Change this to change the sliding window speed
nthread = 1, # Change this to use more threads
lr = 1, # Do not touch this unless you are expert
training_start = NULL, # Do not touch this unless you are expert
validation_start = NULL, # Do not touch this unless you are expert
n_forest = 2, # Number of forest models
n_trees = 30, # Number of trees per forest
random_forest = 1, # We want only 2 random forest
seed = 0,
objective = "binary:logistic",
eval_metric = Laurae::df_logloss,
multi_class = 2, # Modify this for multiclass problems)
verbose = TRUE)
# Train a model (binary classification) - SLOW
model <- MGScanning(data = agaricus_data_train, # Training data
labels = agaricus_label_train, # Training labels
folds = folds, # Folds for cross-validation
dimensions = 1, # Change this for 2 dimensions if needed
depth = 10, # Change this to change the sliding window size
stride = 1, # Change this to change the sliding window speed
nthread = 1, # Change this to use more threads
lr = 1, # Do not touch this unless you are expert
training_start = NULL, # Do not touch this unless you are expert
validation_start = NULL, # Do not touch this unless you are expert
n_forest = 2, # Number of forest models
n_trees = 30, # Number of trees per forest
random_forest = 1, # We want only 2 random forest
seed = 0,
objective = "binary:logistic",
eval_metric = Laurae::df_logloss,
multi_class = 2, # Modify this for multiclass problems)
verbose = TRUE)
# Create predictions
data_predictions <- model$preds
# Make real predictions
new_preds <- MGScanning_pred(model, data = agaricus_data_test)
# We can check whether we have equal predictions, it's all TRUE!
all.equal(model$preds, MGScanning_pred(model,
agaricus_data_train,
folds = folds))
# Example on fake pictures (matrices) and multiclass problem
# Generate fake images
new_data <- list(matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20))
# Generate fake labels
new_labels <- c(2, 1, 0, 2, 1, 0, 2, 1, 0, 0)
# Train a model (multiclass problem)
model <- MGScanning(data = new_data, # Training data
labels = new_labels, # Training labels
folds = list(1:3, 3:6, 7:10), # Folds for cross-validation
dimensions = 2,
depth = 10,
stride = 1,
nthread = 1, # Change this to use more threads
lr = 1, # Do not touch this unless you are expert
training_start = NULL, # Do not touch this unless you are expert
validation_start = NULL, # Do not touch this unless you are expert
n_forest = 2, # Number of forest models
n_trees = 10, # Number of trees per forest
random_forest = 1, # We want only 2 random forest
seed = 0,
objective = "multi:softprob",
eval_metric = Laurae::df_logloss,
multi_class = 3, # Modify this for multiclass problems)
verbose = TRUE)
# Matrix output is 10x600
dim(model$preds)
# We can check whether we have equal predictions, it's all TRUE!
all.equal(model$preds, MGScanning_pred(model,
new_data,
folds = list(1:3, 3:6, 7:10)))
# Real predictions on new data
new_data <- list(matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20),
matrix(rnorm(n = 400), ncol = 20, nrow = 20))
new_preds <- MGScanning_pred(model, data = new_data)
## End(Not run)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.