R/iucnn_train_model.R

Defines functions iucnn_train_model

Documented in iucnn_train_model

#' Train an IUCNN Model
#'
#'Trains an IUCNN model based on a data.frame of features for a set of species,
#'generated by \code{\link{iucnn_prepare_features}},
#'and the corresponding IUCN classes formatted as a iucnn_labels object
#'with \code{\link{iucnn_prepare_labels}}. Note
#'that NAs are not allowed in the features, and taxa with NAs will
#'automatically be removed! Taxa, for which information is only present in one
#' of the two input objects will be removed as well.
#'
#'
#'@param x a data.set, containing a column "species"
#'with the species names, and
#'subsequent columns with different features.
#'@param lab an object of the class iucnn_labels, as generated by
#' \code{\link{iucnn_prepare_labels}} containing the labels for all species.
#'@param path_to_output character string. The path to the location
#'where the IUCNN model shall be saved
#'@param production_model an object of type iucnn_model (default=NULL).
#'If an iucnn_model is provided, \code{iucnn_train_model} will read the settings of
#'this model and reproduce it, but use all available data for training, by
#'automatically setting the validation set to 0 and cv_fold to 1. This is
#'recommended before using the model for predicting the IUCN status of
#'not evaluated species, as it generally improves the prediction
#'accuracy of the model. Choosing this option will ignore all other provided
#'settings (below).
#'@param mode character string. Choose between the IUCNN models
#'"nn-class" (default, tensorflow neural network classifier),
#'"nn-reg" (tensorflow neural network regression), or
#'"bnn-class" (Bayesian neural network classifier)
#'@param test_fraction numeric. The fraction of the input data used as
#'test set.
#'@param cv_fold integer (default=1). When setting cv_fold > 1,
#'\code{iucnn_train_model} will perform k-fold cross-validation. In this case, the
#'provided setting for test_fraction will be ignored, as the test
#'size of each CV-fold is determined by the specified number provided here.
#'@param seed integer. Set a starting seed for reproducibility.
#'@param max_epochs integer. The maximum number of epochs.
#'@param patience integer. Number of epochs with no improvement
#' after which training will be stopped.
#'@param n_layers character string. Define number node per layer by providing a
#'character string where the number of nodes for each layer are separated by
#'underscores. E.g. '50_30_10' (default) will train a model with 3 hidden layers with
#'50, 30, and 10 nodes respectively. Note that the number of nodes in the output
#' layer is automatically determined based on
#'the number of unique labels in the training set.
#'@param use_bias logical (default=TRUE). Specifies if a bias node is used in
#'the first hidden layer.
#'@param balance_classes logical (default=FALSE). If set to TRUE,
#'\code{iucnn_train_model} will perform supersampling of the training instances to
#'account for uneven class distribution in the training data. In case of
#'training an bnn-class model, choosing this option will add the estimation
#'of class weights instead, to account for class imbalances.
#'@param act_f character string. Specifies the activation
#'function should be used in the hidden layers.
#'Available options are: "relu", "tanh", "sigmoid", or "swish" (latter only for
#'bnn-class). If set to 'auto' (default), \code{iucnn_train_model} will pick a reasonable
#'default ('relu' for nn-class or nn-reg, and 'swish' for bnn-class).
#'@param act_f_out character string. Similar to act_f, this specifies
#'the activation function for the output
#'layer. Available options are "softmax" (nn-class, bnn-class), "tanh" (nn-reg),
#'"sigmoid" (nn-reg), or no activation function "" (nn-reg). When set to "auto"
#'(default), a suitable output activation function will be chosen based on the
#'chosen mode ('softmax' for nn-class or bnn-class, 'tanh' for nn-reg).
#'@param label_stretch_factor numeric (only for mode nn-reg). The provided
#'value will be applied as a factor to stretch or compress the labels before
#'training a regression model. A factor smaller < 1.0 will compress the range
#'of labels, while a factor > 1 will stretch the range.
#'@param randomize_instances logical. When set to TRUE (default) the
#'instances will be shuffled before training (recommended).
#'@param dropout_rate numeric. This will randomly turn off the specified
#'fraction of nodes of the neural network during each epoch of training
#'making the NN more stable and less reliant on individual nodes/weights, which
#'can prevent over-fitting (only available for modes nn-class and nn-reg).
#'See mc_dropout setting explained below if dropout shall also be applied to the
#'predictions.
#'@param mc_dropout logical. If set to TRUE, the predictions (including the
#'validation accuracy) based on a model trained with a dropout fraction > 0
#'will reflect the stochasticity introduced by the dropout method (MC dropout
#'predictions). This is e.g. required when wanting to predict with a specified
#'accuracy threshold (see target_acc option in \code{\link{iucnn_predict_status}}).
#'This option is activated by default when chosing a dropout_rate > 0, unless
#'it is manually set to FALSE here.
#'@param mc_dropout_reps integer. The number of MC iterations to run when
#'predicting validation accuracy and calculating the accuracy-threshold
#'table required for making predictions with an accuracy threshold.
#'The default of 100 is usually sufficient, larger values will lead to longer
#'computation times, particularly during model testing with cross-validation.
#'@param label_noise_factor numeric (only for mode nn-reg). Add specified amount
#'of random noise to the input labels to give the categorical labels a more
#'continuous spread before training the regression model. E.g. a value of 0.2
#'will redraw a label of a species categorized as Vulnerable (class=2) randomly
#'between 1.8 and 2.2, based on a uniform probability distribution.
#'@param rescale_features logical. Set to TRUE if all feature values shall
#'be rescaled to values between 0 and 1 prior to training (default=FALSE).
#'@param save_model logical. If TRUE the model is saved to disk.
#'@param overwrite logical. If TRUE existing models are
#'overwritten. Default is set to FALSE.
#'@param verbose Default 0, set to 1 for \code{iucnn_train_model} to print
#'additional info to the screen while training.
#'
#'@note See \code{vignette("Approximate_IUCN_Red_List_assessments_with_IUCNN")}
#'for a tutorial on how to run IUCNN.
#'
#'@return outputs an \code{iucnn_model} object which can be used in
#'\code{\link{iucnn_predict_status}} for predicting the conservation status
#'of not evaluated species.
#'
#'@keywords Training

#' @examples
#'\dontrun{
#'data("training_occ") #geographic occurrences of species with IUCN assessment
#'data("training_labels")# the corresponding IUCN assessments
#'
#'# 1. Feature and label preparation
#'features <- iucnn_prepare_features(training_occ, type = "geographic") # Training features
#'labels_train <- iucnn_prepare_labels(training_labels, features) # Training labels
#'
#'# 2. Model training
#'m1 <- iucnn_train_model(x = features, lab = labels_train, overwrite = TRUE)
#'
#'summary(m1)
#'plot(m1)
#'}
#'
#'
#' @export
#' @importFrom reticulate py_get_attr source_python
#' @importFrom stats complete.cases
#' @importFrom checkmate assert_data_frame assert_character assert_logical assert_numeric

iucnn_train_model <- function(x,
                        lab,
                        path_to_output = tempdir(),
                        production_model = NULL,
                        mode = 'nn-class',
                        test_fraction = 0.2,
                        cv_fold = 1,
                        seed = 1234,
                        max_epochs = 1000,
                        patience = 200,
                        n_layers = '50_30_10',
                        use_bias = TRUE,
                        balance_classes = FALSE,
                        act_f = "auto",
                        act_f_out = "auto",
                        label_stretch_factor = 1.0,
                        randomize_instances = TRUE,
                        dropout_rate = 0.0,
                        mc_dropout = TRUE,
                        mc_dropout_reps = 100,
                        label_noise_factor = 0.0,
                        rescale_features = FALSE,
                        save_model = TRUE,
                        overwrite = FALSE,
                        verbose = 1){

  # Check input
  ## assertion
  assert_data_frame(x)
  assert_class(lab, classes = "iucnn_labels")
  assert_character(path_to_output)
  assert_numeric(test_fraction, lower = 0, upper = 1)
  assert_numeric(seed)
  assert_numeric(max_epochs)
  assert_character(n_layers)
  assert_logical(use_bias)
  assert_character(act_f)
  assert_character(act_f_out)
  assert_numeric(label_stretch_factor, lower = 0, upper = 2)
  assert_numeric(patience)
  assert_logical(randomize_instances)
  assert_numeric(dropout_rate, lower = 0, upper = 1)
  assert_numeric(label_noise_factor, lower = 0, upper = 1)
  assert_logical(rescale_features)
  assert_logical(overwrite)
  match.arg(mode, choices = c("nn-class", "nn-reg", "bnn-class", "cnn-class"))

  if (cv_fold == 1) {
    if (test_fraction == 0) {
      patience <- 0
    }
  }

  if (mode == "bnn-class") {
    cat("Please add the number of MCMC generations. Should be at least 1M:\n")
    max_epochs <- as.integer(readline())
  }

  provided_model <- production_model

  if (inherits(provided_model, "iucnn_model")) {
    mode <- provided_model$model
    test_fraction <- 0.
    cv_fold <- 1
    seed <- provided_model$seed
    max_epochs <- round(mean(provided_model$final_training_epoch))
    patience <- 0
    n_layers <- paste(provided_model$n_layers,collapse = '_')
    use_bias <- provided_model$use_bias
    balance_classes <- provided_model$balance_classes
    act_f <- provided_model$act_f
    act_f_out <- provided_model$act_f_out
    label_stretch_factor <- provided_model$label_stretch_factor
    randomize_instances <- provided_model$randomize_instances
    dropout_rate <- provided_model$dropout_rate
    mc_dropout <- provided_model$mc_dropout
    mc_dropout_reps <- provided_model$mc_dropout_reps
    label_noise_factor <- provided_model$label_noise_factor
    rescale_features <- provided_model$rescale_features
    # save accthres_tbl to output, since this will be needed to predict
    accthres_tbl_stored <- provided_model$accthres_tbl
    no_validation <- TRUE
  }else{
    accthres_tbl_stored <- NaN
    no_validation <- FALSE
  }

  # check if the model directory already exists
  if (dir.exists(file.path(path_to_output)) & !overwrite) {
    stop(sprintf("Directory %s exists. Provide alternative 'path_to_output' or set `overwrite` to TRUE.",
                 path_to_output))
  }

  data_out <- process_iucnn_input(x,
                                 lab = lab,
                                 mode = mode,
                                 outpath = '.',
                                 write_data_files = FALSE,
                                 verbose = verbose)

  dataset <- data_out[[1]]
  labels <- data_out[[2]]
  instance_names <- data_out[[3]]

  n_layers <- as.numeric(strsplit(n_layers,'_')[[1]])

  # set act fun if chosen auto
  if (act_f_out == 'auto') {
    if (mode == 'nn-reg') {
      act_f_out  <- 'tanh'
    }else{
      act_f_out  <- 'softmax'
    }
  }
  if (act_f == 'auto') {
    if (mode == 'bnn-class') {
      act_f  <- 'swish'
    }else{
      act_f  <- 'relu'
    }
  }

  if (mode == 'bnn-class') {
    warning('
The following settings are currently not supported for BNN models and are being ignored:
cv_fold, patience, act_f_out.
Instead of applying chosen settings for dropout_rate, mc_dropout, and mc_dropout_reps,
the BNN will instead provide posterior estimates of the class labels for each instance.\n')
    if (max_epochs < 1000000) {
      message('For proper convergence for bnn-class models it is recommended to set max_epochs=1000000 or more.')
      overwrite_prompt <-  readline(prompt = 'Do you want to continue with current max_epochs settings (not recommended)? [Y/n]: ')
      if (overwrite_prompt != 'Y') {
        stop('Stopping training request. Increase max_epochs to >= 1000000 before relaunching.')
      }
    }

    # transform the data into BNN compatible format
    bnn_data <- bnn_load_data(dataset,
                             labels,
                             seed = as.integer(seed),
                             testsize = test_fraction,
                             all_class_in_testset = FALSE,
                             randomize_order = randomize_instances,
                             header = TRUE, # input data has a header
                             # input data includes names of instances
                             instance_id = TRUE,
                             from_file = FALSE
    )

    # define number of layers and nodes per layer for BNN
    # define the BNN model
    if (use_bias) {
      bias_node_setting = 3
    }else{
      bias_node_setting = 0
    }
    bnn_model <- create_BNN_model(bnn_data,
                                 n_layers,
                                 seed = as.integer(seed),
                                 use_class_weight = balance_classes,
                                 use_bias_node = bias_node_setting,
                                 actfun = act_f
    )

    # set up the MCMC environment
    update_frequencies <- rep(0.05, length(n_layers) + 1)
    update_window_sizes <- rep(0.075, length(n_layers) + 1)
    adapt_f <- 0.3
    adapt_fM <- 0.6
    sampling_f <- 10
    n_post_samples <- 1000
    print_f <- 100
    mcmc_object <- MCMC_setup(bnn_model,
                             update_frequencies,
                             update_window_sizes,
                             adapt_f,
                             adapt_fM,
                             n_iteration = as.integer(max_epochs),
                             n_post_samples = n_post_samples,
                             print_f = print_f,
                             sampling_f = sampling_f
    )

    # run the MCMC and write output to file
    logger <- run_MCMC(bnn_model,
                      mcmc_object,
                      filename_stem = paste0(path_to_output,'/','BNN')
    )

    # calculate test accuracy
    post_pr_test <- calculate_accuracy(bnn_data,
                                       logger,
                                       bnn_model,
                                       post_summary_mode = 0
    )

    input_data <- bnn_data

    logfile_path <- as.character(py_get_attr(logger, '_logfile'))
    log_file_content <- read.table(logfile_path, sep = '\t', header = TRUE)
    pklfile_path <- as.character(py_get_attr(logger, '_pklfile'))

    test_labels <- bnn_data$test_labels
    test_predictions <- apply(post_pr_test$post_prob_predictions,
                                    1,
                                    which.max) - 1
    test_predictions_raw <- post_pr_test$post_prob_predictions

    confusion_matrix <- post_pr_test$confusion_matrix
    confusion_matrix <- confusion_matrix[1:dim(confusion_matrix)[1] - 1,
                                         1:dim(confusion_matrix)[2] - 1] #remove the sum row and column

    training_accuracy <- log_file_content$accuracy[length(log_file_content$accuracy)]
    test_accuracy <- post_pr_test$mean_accuracy
    validation_accuracy <- NaN

    training_loss <-
      (-log_file_content$likelihood[length(log_file_content$likelihood)]) /
      length(bnn_data$labels)
    test_loss <- NaN
    validation_loss <- NaN

    training_loss_history <- list(
      (-log_file_content$likelihood) / length(bnn_data$labels))
    names(training_loss_history) = 'train_rep_0'
    validation_loss_history <- NaN

    training_accuracy_history <- list(log_file_content$accuracy)
    names(training_accuracy_history) = 'train_rep_0'
    validation_accuracy_history <- list(log_file_content$test_accuracy)
    names(validation_accuracy_history) = 'train_rep_0'

    training_mae_history <- NaN
    validation_mae_history <- NaN

    rescale_labels_boolean <- FALSE
    label_rescaling_factor <- as.integer(max(labels$labels))
    min_max_label <- as.vector(c(min(labels$labels), max(labels$labels)))

    trained_model_path <- pklfile_path
    patience <- 0
    test_fraction <- test_fraction

    # source python function
    reticulate::source_python(system.file("python",
                                          "IUCNN_helper_functions.py",
                                          package = "IUCNN"))
    acctbl_catsample <- get_acctbl_and_catsample_bnn(pklfile_path)

    stopping_point <- max_epochs

    accthres_tbl <- acctbl_catsample[[1]]
    sampled_cat_freqs <- acctbl_catsample[[2]]
    true_cat_freqs <- acctbl_catsample[[3]]

    mc_dropout <- FALSE
    mc_dropout_reps <- 0
    dropout_rate <- 0
    cv_fold <- 1


  }else{

    # source python function
    reticulate::source_python(system.file("python",
                                          "IUCNN_train.py",
                                          package = "IUCNN"))

    # run model via python script
    res <- iucnn_train(dataset = as.matrix(dataset),
                      labels = as.matrix(labels),
                      mode = mode,
                      path_to_output = path_to_output,
                      test_fraction = test_fraction,
                      cv_k = as.integer(cv_fold),
                      seed = as.integer(seed),
                      instance_names = as.matrix(instance_names),
                      feature_names = names(dataset),
                      verbose = 0,
                      max_epochs = as.integer(max_epochs),
                      patience = patience,
                      n_layers = as.list(n_layers),
                      use_bias = use_bias,
                      balance_classes = balance_classes,
                      act_f = act_f,
                      act_f_out = act_f_out,
                      stretch_factor_rescaled_labels = label_stretch_factor,
                      randomize_instances = as.integer(randomize_instances),
                      rescale_features = rescale_features,
                      dropout_rate = dropout_rate,
                      dropout_reps = mc_dropout_reps,
                      mc_dropout = mc_dropout,
                      label_noise_factor = label_noise_factor,
                      no_validation = no_validation,
                      save_model = save_model
    )

    test_labels <- as.vector(res$test_labels)
    test_predictions <- as.vector(res$test_predictions)
    test_instance_names <- as.vector(res$test_instance_names)
    test_predictions_raw <- res$test_predictions_raw

    training_accuracy <- res$training_accuracy
    validation_accuracy <- res$validation_accuracy
    test_accuracy <- res$test_accuracy

    training_loss <- res$training_loss
    validation_loss <- res$validation_loss
    test_loss <- res$test_loss

    training_loss_history <- res$training_loss_history
    validation_loss_history <- res$validation_loss_history

    training_accuracy_history <- res$training_accuracy_history
    validation_accuracy_history <- res$validation_accuracy_history

    training_mae_history <- res$training_mae_history
    validation_mae_history <- res$validation_mae_history

    rescale_labels_boolean <- res$rescale_labels_boolean
    label_rescaling_factor <- res$label_rescaling_factor
    min_max_label <- as.vector(res$min_max_label)
    label_stretch_factor <- res$label_stretch_factor

    act_f_out <- res$activation_function
    trained_model_path <- res$trained_model_path

    confusion_matrix <- res$confusion_matrix
    accthres_tbl <- res$accthres_tbl
    stopping_point <- res$stopping_point

    input_data <- res$input_data
    sampled_cat_freqs <- res$predicted_class_count
    true_cat_freqs <- res$true_class_count


    }

  named_res <- NULL

  named_res$input_data <- c(input_data, lookup = data.frame(lab$lookup))

  named_res$rescale_labels_boolean <- rescale_labels_boolean
  named_res$label_rescaling_factor <- label_rescaling_factor
  named_res$min_max_label_rescaled <- min_max_label
  named_res$label_stretch_factor <- label_stretch_factor

  named_res$trained_model_path <- trained_model_path

  if (is.nan(accthres_tbl[1])) {accthres_tbl <- accthres_tbl_stored}

  named_res$accthres_tbl <- accthres_tbl
  named_res$final_training_epoch <- stopping_point
  named_res$sampled_cat_freqs <- sampled_cat_freqs
  named_res$true_cat_freqs <- true_cat_freqs

  named_res$model <- mode
  named_res$seed <- seed
  named_res$dropout_rate <- dropout_rate
  named_res$max_epochs <- max_epochs
  named_res$n_layers <- n_layers
  named_res$use_bias <- use_bias
  named_res$balance_classes <- balance_classes
  named_res$rescale_features <- rescale_features
  named_res$act_f <- act_f
  named_res$act_f_out <- act_f_out
  named_res$test_fraction <- test_fraction
  named_res$cv_fold <- cv_fold
  named_res$patience <- patience
  named_res$randomize_instances <- randomize_instances
  named_res$label_noise_factor <- label_noise_factor
  named_res$mc_dropout <- mc_dropout
  named_res$mc_dropout_reps <- mc_dropout_reps

  named_res$training_loss_history <- training_loss_history
  named_res$validation_loss_history <- validation_loss_history

  named_res$training_accuracy_history <- training_accuracy_history
  named_res$validation_accuracy_history <- validation_accuracy_history

  named_res$training_mae_history <- training_mae_history
  named_res$validation_mae_history <- validation_mae_history

  named_res$training_loss <- training_loss
  named_res$validation_loss <- validation_loss
  named_res$test_loss <- test_loss
  #softmax probs, posterior probs, or regressed values
  named_res$test_predictions_raw <- test_predictions_raw
  named_res$test_predictions <- test_predictions
  named_res$test_labels <- test_labels
  named_res$test_instance_names <- test_instance_names

  named_res$confusion_matrix <- confusion_matrix

  named_res$training_accuracy <- training_accuracy
  named_res$validation_accuracy <- validation_accuracy
  named_res$test_accuracy <- test_accuracy

  class(named_res) <- "iucnn_model"

  if (mode == "bnn-class") {
    warning("Remember to check MCMC convergence in the log file")
  }

  return(named_res)
}
azizka/IUCNN documentation built on March 29, 2024, 9:38 a.m.