#' Train a CNN model
#'
#'Trains an CNN model based on a list of matrices with occurrence counts for a
#'set of species, generated by \code{\link{iucnn_cnn_features}}, and the
#'corresponding IUCN classes formatted as a iucnn_labels object with
#'\code{\link{iucnn_prepare_labels}}. Note that taxa for which information is
#'only present in one of the two input objects will be removed from further
#'processing.
#'
#'
#'@param x a list of matrices containing the occurrence counts across a spatial
#'grid for a set of species.
#'@param lab an object of the class iucnn_labels, as generated by
#' \code{\link{iucnn_prepare_labels}} containing the labels for all species.
#'@param path_to_output character string. The path to the location
#'where the IUCNN model shall be saved
#'@param production_model an object of type iucnn_model (default=NULL).
#'If an iucnn_model is provided, \code{iucnn_cnn_train} will read the settings of
#'this model and reproduce it, but use all available data for training, by
#'automatically setting the validation set to 0 and cv_fold to 1. This is
#'recommended before using the model for predicting the IUCN status of
#'not evaluated species, as it generally improves the prediction
#'accuracy of the model. Choosing this option will ignore all other provided
#'settings below.
#'@param cv_fold integer (default=1). When setting cv_fold > 1,
#' \code{iucnn_cnn_train} will perform k-fold cross-validation. In this case,
#' the provided setting for test_fraction will be ignored, as the test size of
#' each CV-fold is determined by the specified number provided here.
#'@param test_fraction numeric. The fraction of the input data used as test set.
#'@param seed integer. Set a starting seed for reproducibility.
#'@param max_epochs integer. The maximum number of epochs.
#'@param patience integer. Number of epochs with no improvement after which
#' training will be stopped.
#'@param randomize_instances logical (default=TRUE). When set to TRUE (default)
#'the instances will be shuffled before training (recommended).
#'@param balance_classes logical (default=FALSE). If set to TRUE,
#'\code{iucnn_cnn_train} will perform supersampling of the training instances to
#'account for uneven class distribution in the training data.
#'@param dropout_rate numeric. This will randomly turn off the specified
#'fraction of nodes of the neural network during each epoch of training
#'making the NN more stable and less reliant on individual nodes/weights, which
#'can prevent over-fitting (only available for modes nn-class and nn-reg).
#'See mc_dropout setting explained below if dropout shall also be applied to the
#'predictions. For models trained with a dropout fraction > 0, the predictions
#'(including the validation accuracy)
#'will reflect the stochasticity introduced by the dropout method (MC dropout
#'predictions). This is e.g. required when wanting to predict with a specified
#'accuracy threshold (see target_acc option in
#'\code{\link{iucnn_predict_status}}).
#'@param mc_dropout_reps integer. The number of MC iterations to run when
#'predicting validation accuracy and calculating the accuracy-threshold
#'table required for making predictions with an accuracy threshold.
#'The default of 100 is usually sufficient, larger values will lead to longer
#'computation times, particularly during model testing with cross-validation.
#'@param optimize_for string. Default is "loss", which will train the model
#'until optimal validation set loss is reached. Set to "accuracy" if you want
#'to optimize for maximum validation accuracy instead.
#'@param pooling_strategy string. Pooling strategy after first convolutional
#'layer. Choose between "average" (default) and "max".
#'@param save_model logical. If TRUE the model is saved to disk.
#'@param overwrite logical. If TRUE existing models are
#'overwritten. Default is set to FALSE.
#'@param verbose Default 0, set to 1 for \code{iucnn_cnn_train} to print
#'additional info to the screen while training.
#'
#'@note See \code{vignette("Approximate_IUCN_Red_List_assessments_with_IUCNN")}
#'for a tutorial on how to run IUCNN.
#'
#'@return outputs an \code{iucnn_model} object which can be used in
#'\code{\link{iucnn_predict_status}} for predicting the conservation status
#'of not evaluated species.
#'
#'@keywords Training
#' @examples
#'\dontrun{
#'data("training_occ") #geographic occurrences of species with IUCN assessment
#'data("training_labels")# the corresponding IUCN assessments
#'
#'cnn_training_features <- iucnn_cnn_features(training_occ)
#'cnn_labels <- iucnn_prepare_labels(x = training_labels,
#' y = cnn_training_features)
#'
#'trained_model <- iucnn_cnn_train(cnn_training_features,
#' cnn_labels,
#' overwrite = TRUE,
#' dropout = 0.1)
#'summary(trained_model)
#'}
#'
#' @export
#' @importFrom reticulate py_get_attr source_python
#' @importFrom stats complete.cases
#' @importFrom checkmate assert_data_frame assert_character assert_logical assert_numeric
iucnn_cnn_train <- function(x,
lab,
path_to_output = tempdir(),
production_model = NULL,
cv_fold = 1,
test_fraction = 0.2,
seed = 1234,
max_epochs = 100,
patience = 20,
randomize_instances = TRUE,
balance_classes = TRUE,
dropout_rate = 0.0,
mc_dropout_reps = 100,
optimize_for = 'loss',
pooling_strategy = 'average',
save_model = TRUE,
overwrite = FALSE,
verbose = 0){
# Check input
## assertion
assert_class(lab, classes = "iucnn_labels")
assert_character(path_to_output)
assert_numeric(test_fraction, lower = 0, upper = 1)
assert_numeric(seed)
assert_numeric(max_epochs)
assert_numeric(patience)
assert_logical(randomize_instances)
assert_numeric(dropout_rate, lower = 0, upper = 1)
assert_logical(overwrite)
provided_model <- production_model
if (inherits(provided_model, "iucnn_model")) {
mode <- provided_model$model
if (mode != 'cnn') {
stop('Please provide CNN model as production model for iucnn_cnn_train.
Use iucnn_train_model for other non CNN models.')
}
test_fraction <- 0.
cv_fold <- 1
no_validation <- TRUE
seed <- provided_model$seed
max_epochs <- round(mean(provided_model$final_training_epoch))
patience <- 0
randomize_instances <- provided_model$randomize_instances
balance_classes <- provided_model$balance_classes
dropout_rate <- provided_model$dropout_rate
mc_dropout_reps <- provided_model$mc_dropout_reps
accthres_tbl_stored <- provided_model$accthres_tbl
optimize_for <- provided_model$optimize_for
pooling_strategy <- provided_model$pooling_strategy
}else{
accthres_tbl_stored <- NaN
no_validation <- FALSE
}
if (dropout_rate > 0) {
mc_dropout = TRUE
}else{
mc_dropout = FALSE
mc_dropout_reps = 1
}
act_f = "relu"
act_f_out = "softmax"
# check if the model directory already exists
if (dir.exists(file.path(path_to_output)) & !overwrite) {
stop(sprintf("Directory %s exists. Provide alternative 'path_to_output' or set `overwrite` to TRUE.",
path_to_output))
}
# check that the same species are in input-data and labels
if (!all(names(x) == lab$labels$species)) {
stop(sprintf("Mismatch in species list provided with input data and that
of the input label object. Make sure the taxon names in
both input objects are identical and in the same order."))
}
# source python function
reticulate::source_python(system.file("python",
"IUCNN_train_cnn.py",
package = "IUCNN"))
res <- train_cnn_model(
input_raw = x,
labels = as.matrix(lab$labels$labels),
max_epochs = as.integer(max_epochs),
patience = patience,
test_fraction = test_fraction,
path_to_output = path_to_output,
act_f = act_f,
act_f_out = act_f_out,
seed = as.integer(seed),
dropout = mc_dropout,
dropout_rate = dropout_rate,
mc_dropout_reps = mc_dropout_reps,
randomize_instances = as.integer(randomize_instances),
optimize_for = optimize_for,
pooling_strategy = pooling_strategy,
verbose = verbose,
cv_k = cv_fold,
balance_classes = balance_classes,
no_validation = no_validation,
save_model = save_model
)
test_labels <- as.vector(res$test_labels)
test_predictions <- as.vector(res$test_predictions)
test_predictions_raw <- res$test_predictions_raw
training_accuracy <- res$training_accuracy
validation_accuracy <- res$validation_accuracy
test_accuracy <- res$test_accuracy
training_loss <- res$training_loss
validation_loss <- res$validation_loss
test_loss <- res$test_loss
training_loss_history <- res$training_loss_history
validation_loss_history <- res$validation_loss_history
training_accuracy_history <- res$training_accuracy_history
validation_accuracy_history <- res$validation_accuracy_history
training_mae_history <- res$training_mae_history
validation_mae_history <- res$validation_mae_history
rescale_labels_boolean <- res$rescale_labels_boolean
label_rescaling_factor <- res$label_rescaling_factor
min_max_label <- as.vector(res$min_max_label)
label_stretch_factor <- res$label_stretch_factor
act_f_out <- res$activation_function
trained_model_path <- res$trained_model_path
confusion_matrix <- res$confusion_matrix
accthres_tbl <- res$accthres_tbl
stopping_point <- res$stopping_point
input_data <- res$input_data
sampled_cat_freqs <- res$predicted_class_count
true_cat_freqs <- res$true_class_count
named_res <- NULL
named_res$input_data <- c(input_data, lookup = data.frame(lab$lookup))
named_res$rescale_labels_boolean <- rescale_labels_boolean
named_res$label_rescaling_factor <- label_rescaling_factor
named_res$min_max_label_rescaled <- min_max_label
named_res$label_stretch_factor <- label_stretch_factor
named_res$trained_model_path <- trained_model_path
if (is.nan(accthres_tbl[1])) {accthres_tbl <- accthres_tbl_stored}
named_res$accthres_tbl <- accthres_tbl
named_res$final_training_epoch <- stopping_point
named_res$sampled_cat_freqs <- sampled_cat_freqs
named_res$true_cat_freqs <- true_cat_freqs
named_res$model <- 'cnn'
named_res$seed <- seed
named_res$dropout_rate <- dropout_rate
named_res$max_epochs <- max_epochs
named_res$n_layers <- 1
named_res$use_bias <- FALSE
named_res$balance_classes <- balance_classes
named_res$rescale_features <- FALSE
named_res$act_f <- act_f
named_res$act_f_out <- act_f_out
named_res$test_fraction <- test_fraction
named_res$cv_fold <- cv_fold
named_res$patience <- patience
named_res$randomize_instances <- randomize_instances
named_res$label_noise_factor <- NaN
named_res$mc_dropout <- mc_dropout
named_res$mc_dropout_reps <- mc_dropout_reps
named_res$optimize_for <- optimize_for
named_res$pooling_strategy <- pooling_strategy
named_res$training_loss_history <- training_loss_history
named_res$validation_loss_history <- validation_loss_history
named_res$training_accuracy_history <- training_accuracy_history
named_res$validation_accuracy_history <- validation_accuracy_history
named_res$training_mae_history <- training_mae_history
named_res$validation_mae_history <- validation_mae_history
named_res$training_loss <- training_loss
named_res$validation_loss <- validation_loss
named_res$test_loss <- test_loss
#softmax probs, posterior probs, or regressed values
named_res$test_predictions_raw <- test_predictions_raw
named_res$test_predictions <- test_predictions
named_res$test_labels <- test_labels
named_res$confusion_matrix <- confusion_matrix
named_res$training_accuracy <- training_accuracy
named_res$validation_accuracy <- validation_accuracy
named_res$test_accuracy <- test_accuracy
class(named_res) <- "iucnn_model"
return(named_res)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.