R/iucnn_cnn_train.R

Defines functions iucnn_cnn_train

Documented in iucnn_cnn_train

#' Train a CNN model
#'
#'Trains an CNN model based on a list of matrices with occurrence counts for a
#'set of species, generated by \code{\link{iucnn_cnn_features}}, and the
#'corresponding IUCN classes formatted as a iucnn_labels object with
#'\code{\link{iucnn_prepare_labels}}. Note that taxa for which information is
#'only present in one of the two input objects will be removed from further
#'processing.
#'
#'
#'@param x a list of matrices containing the occurrence counts across a spatial
#'grid for a set of species.
#'@param lab an object of the class iucnn_labels, as generated by
#' \code{\link{iucnn_prepare_labels}} containing the labels for all species.
#'@param path_to_output character string. The path to the location
#'where the IUCNN model shall be saved
#'@param production_model an object of type iucnn_model (default=NULL).
#'If an iucnn_model is provided, \code{iucnn_cnn_train} will read the settings of
#'this model and reproduce it, but use all available data for training, by
#'automatically setting the validation set to 0 and cv_fold to 1. This is
#'recommended before using the model for predicting the IUCN status of
#'not evaluated species, as it generally improves the prediction
#'accuracy of the model. Choosing this option will ignore all other provided
#'settings below.
#'@param cv_fold integer (default=1). When setting cv_fold > 1,
#'  \code{iucnn_cnn_train} will perform k-fold cross-validation. In this case,
#'  the provided setting for test_fraction will be ignored, as the test size of
#'  each CV-fold is determined by the specified number provided here.
#'@param test_fraction numeric. The fraction of the input data used as test set.
#'@param seed integer. Set a starting seed for reproducibility.
#'@param max_epochs integer. The maximum number of epochs.
#'@param patience integer. Number of epochs with no improvement after which
#'  training will be stopped.
#'@param randomize_instances logical (default=TRUE). When set to TRUE (default)
#'the instances will be shuffled before training (recommended).
#'@param balance_classes logical (default=FALSE). If set to TRUE,
#'\code{iucnn_cnn_train} will perform supersampling of the training instances to
#'account for uneven class distribution in the training data.
#'@param dropout_rate numeric. This will randomly turn off the specified
#'fraction of nodes of the neural network during each epoch of training
#'making the NN more stable and less reliant on individual nodes/weights, which
#'can prevent over-fitting (only available for modes nn-class and nn-reg).
#'See mc_dropout setting explained below if dropout shall also be applied to the
#'predictions. For models trained with a dropout fraction > 0, the predictions
#'(including the validation accuracy)
#'will reflect the stochasticity introduced by the dropout method (MC dropout
#'predictions). This is e.g. required when wanting to predict with a specified
#'accuracy threshold (see target_acc option in
#'\code{\link{iucnn_predict_status}}).
#'@param mc_dropout_reps integer. The number of MC iterations to run when
#'predicting validation accuracy and calculating the accuracy-threshold
#'table required for making predictions with an accuracy threshold.
#'The default of 100 is usually sufficient, larger values will lead to longer
#'computation times, particularly during model testing with cross-validation.
#'@param optimize_for string. Default is "loss", which will train the model
#'until optimal validation set loss is reached. Set to "accuracy" if you want
#'to optimize for maximum validation accuracy instead.
#'@param pooling_strategy string. Pooling strategy after first convolutional
#'layer. Choose between  "average" (default) and "max".
#'@param save_model logical. If TRUE the model is saved to disk.
#'@param overwrite logical. If TRUE existing models are
#'overwritten. Default is set to FALSE.
#'@param verbose Default 0, set to 1 for \code{iucnn_cnn_train} to print
#'additional info to the screen while training.
#'
#'@note See \code{vignette("Approximate_IUCN_Red_List_assessments_with_IUCNN")}
#'for a tutorial on how to run IUCNN.
#'
#'@return outputs an \code{iucnn_model} object which can be used in
#'\code{\link{iucnn_predict_status}} for predicting the conservation status
#'of not evaluated species.
#'
#'@keywords Training

#' @examples
#'\dontrun{
#'data("training_occ") #geographic occurrences of species with IUCN assessment
#'data("training_labels")# the corresponding IUCN assessments
#'
#'cnn_training_features <- iucnn_cnn_features(training_occ)
#'cnn_labels <- iucnn_prepare_labels(x = training_labels,
#'                      y = cnn_training_features)
#'
#'trained_model <- iucnn_cnn_train(cnn_training_features,
#'                                 cnn_labels,
#'                                 overwrite = TRUE,
#'                                 dropout = 0.1)
#'summary(trained_model)
#'}
#'
#' @export
#' @importFrom reticulate py_get_attr source_python
#' @importFrom stats complete.cases
#' @importFrom checkmate assert_data_frame assert_character assert_logical assert_numeric

iucnn_cnn_train <- function(x,
                            lab,
                            path_to_output = tempdir(),
                            production_model = NULL,
                            cv_fold = 1,
                            test_fraction = 0.2,
                            seed = 1234,
                            max_epochs = 100,
                            patience = 20,
                            randomize_instances = TRUE,
                            balance_classes = TRUE,
                            dropout_rate = 0.0,
                            mc_dropout_reps = 100,
                            optimize_for = 'loss',
                            pooling_strategy = 'average',
                            save_model = TRUE,
                            overwrite = FALSE,
                            verbose = 0){

  # Check input
  ## assertion
  assert_class(lab, classes = "iucnn_labels")
  assert_character(path_to_output)
  assert_numeric(test_fraction, lower = 0, upper = 1)
  assert_numeric(seed)
  assert_numeric(max_epochs)
  assert_numeric(patience)
  assert_logical(randomize_instances)
  assert_numeric(dropout_rate, lower = 0, upper = 1)
  assert_logical(overwrite)


  provided_model <- production_model

  if (inherits(provided_model, "iucnn_model")) {
    mode <- provided_model$model
    if (mode != 'cnn') {
      stop('Please provide CNN model as production model for iucnn_cnn_train.
           Use iucnn_train_model for other non CNN models.')
    }

    test_fraction <- 0.
    cv_fold <- 1
    no_validation <- TRUE
    seed <-  provided_model$seed
    max_epochs <- round(mean(provided_model$final_training_epoch))
    patience <- 0
    randomize_instances <- provided_model$randomize_instances
    balance_classes <-  provided_model$balance_classes
    dropout_rate <-  provided_model$dropout_rate
    mc_dropout_reps <-  provided_model$mc_dropout_reps
    accthres_tbl_stored <- provided_model$accthres_tbl
    optimize_for <-  provided_model$optimize_for
    pooling_strategy <-  provided_model$pooling_strategy
  }else{
    accthres_tbl_stored <- NaN
    no_validation <- FALSE
  }




  if (dropout_rate > 0) {
    mc_dropout = TRUE
  }else{
    mc_dropout = FALSE
    mc_dropout_reps = 1
  }

  act_f = "relu"
  act_f_out = "softmax"

  # check if the model directory already exists
  if (dir.exists(file.path(path_to_output)) & !overwrite) {
    stop(sprintf("Directory %s exists. Provide alternative 'path_to_output' or set `overwrite` to TRUE.",
                 path_to_output))
  }

  # check that the same species are in input-data and labels
  if (!all(names(x) == lab$labels$species)) {
    stop(sprintf("Mismatch in species list provided with input data and that
                 of the input label object. Make sure the taxon names in
                 both input objects are identical and in the same order."))
  }

  # source python function
  reticulate::source_python(system.file("python",
                                        "IUCNN_train_cnn.py",
                                        package = "IUCNN"))



  res <- train_cnn_model(
    input_raw = x,
    labels = as.matrix(lab$labels$labels),
    max_epochs = as.integer(max_epochs),
    patience = patience,
    test_fraction = test_fraction,
    path_to_output = path_to_output,
    act_f = act_f,
    act_f_out = act_f_out,
    seed = as.integer(seed),
    dropout = mc_dropout,
    dropout_rate = dropout_rate,
    mc_dropout_reps = mc_dropout_reps,
    randomize_instances = as.integer(randomize_instances),
    optimize_for = optimize_for,
    pooling_strategy = pooling_strategy,
    verbose = verbose,
    cv_k = cv_fold,
    balance_classes = balance_classes,
    no_validation = no_validation,
    save_model = save_model
  )

  test_labels <- as.vector(res$test_labels)
  test_predictions <- as.vector(res$test_predictions)
  test_predictions_raw <- res$test_predictions_raw

  training_accuracy <- res$training_accuracy
  validation_accuracy <- res$validation_accuracy
  test_accuracy <- res$test_accuracy

  training_loss <- res$training_loss
  validation_loss <- res$validation_loss
  test_loss <- res$test_loss

  training_loss_history <- res$training_loss_history
  validation_loss_history <- res$validation_loss_history

  training_accuracy_history <- res$training_accuracy_history
  validation_accuracy_history <- res$validation_accuracy_history

  training_mae_history <- res$training_mae_history
  validation_mae_history <- res$validation_mae_history

  rescale_labels_boolean <- res$rescale_labels_boolean
  label_rescaling_factor <- res$label_rescaling_factor
  min_max_label <- as.vector(res$min_max_label)
  label_stretch_factor <- res$label_stretch_factor

  act_f_out <- res$activation_function
  trained_model_path <- res$trained_model_path

  confusion_matrix <- res$confusion_matrix
  accthres_tbl <- res$accthres_tbl
  stopping_point <- res$stopping_point

  input_data <- res$input_data
  sampled_cat_freqs <- res$predicted_class_count
  true_cat_freqs <- res$true_class_count


  named_res <- NULL

  named_res$input_data <- c(input_data, lookup = data.frame(lab$lookup))

  named_res$rescale_labels_boolean <- rescale_labels_boolean
  named_res$label_rescaling_factor <- label_rescaling_factor
  named_res$min_max_label_rescaled <- min_max_label
  named_res$label_stretch_factor <- label_stretch_factor

  named_res$trained_model_path <- trained_model_path

  if (is.nan(accthres_tbl[1])) {accthres_tbl <- accthres_tbl_stored}
  named_res$accthres_tbl <- accthres_tbl
  named_res$final_training_epoch <- stopping_point
  named_res$sampled_cat_freqs <- sampled_cat_freqs
  named_res$true_cat_freqs <- true_cat_freqs

  named_res$model <- 'cnn'
  named_res$seed <- seed
  named_res$dropout_rate <- dropout_rate
  named_res$max_epochs <- max_epochs
  named_res$n_layers <- 1
  named_res$use_bias <- FALSE
  named_res$balance_classes <- balance_classes
  named_res$rescale_features <- FALSE
  named_res$act_f <- act_f
  named_res$act_f_out <- act_f_out
  named_res$test_fraction <- test_fraction
  named_res$cv_fold <- cv_fold
  named_res$patience <- patience
  named_res$randomize_instances <- randomize_instances
  named_res$label_noise_factor <- NaN
  named_res$mc_dropout <- mc_dropout
  named_res$mc_dropout_reps <- mc_dropout_reps
  named_res$optimize_for <- optimize_for
  named_res$pooling_strategy <- pooling_strategy

  named_res$training_loss_history <- training_loss_history
  named_res$validation_loss_history <- validation_loss_history

  named_res$training_accuracy_history <- training_accuracy_history
  named_res$validation_accuracy_history <- validation_accuracy_history

  named_res$training_mae_history <- training_mae_history
  named_res$validation_mae_history <- validation_mae_history

  named_res$training_loss <- training_loss
  named_res$validation_loss <- validation_loss
  named_res$test_loss <- test_loss
  #softmax probs, posterior probs, or regressed values
  named_res$test_predictions_raw <- test_predictions_raw
  named_res$test_predictions <- test_predictions
  named_res$test_labels <- test_labels

  named_res$confusion_matrix <- confusion_matrix

  named_res$training_accuracy <- training_accuracy
  named_res$validation_accuracy <- validation_accuracy
  named_res$test_accuracy <- test_accuracy

  class(named_res) <- "iucnn_model"

  return(named_res)
}
azizka/IUCNN documentation built on March 29, 2024, 9:38 a.m.