R/approach_vaeac.R

Defines functions vaeac_plot_imputed_ggpairs vaeac_plot_eval_crit vaeac_prep_message_batch vaeac_print_train_summary vaeac_save_state vaeac_update_pretrained_model vaeac_update_para_locations vaeac_get_n_decimals vaeac_get_data_objects vaeac_get_evaluation_criteria vaeac_get_x_explain_extended vaeac_get_full_state_list vaeac_get_current_save_state vaeac_get_optimizer vaeac_get_save_file_names vaeac_get_mask_generator_name vaeac_get_model_from_checkp vaeac_get_extra_para_default vaeac_check_parameters vaeac_check_x_colnames vaeac_check_save_parameters vaeac_check_masking_ratio vaeac_check_cuda vaeac_check_save_names vaeac_check_verbose vaeac_check_mask_gen vaeac_check_activation_func vaeac_check_epoch_values vaeac_check_which_vaeac_model vaeac_check_logicals vaeac_check_probabilities vaeac_check_positive_numerics vaeac_check_positive_integers vaeac_check_extra_named_list vaeac_impute_missing_entries vaeac_train_model_continue vaeac_train_model_auxiliary vaeac_train_model prepare_data.vaeac setup_approach.vaeac

Documented in prepare_data.vaeac setup_approach.vaeac vaeac_check_activation_func vaeac_check_cuda vaeac_check_epoch_values vaeac_check_extra_named_list vaeac_check_logicals vaeac_check_mask_gen vaeac_check_masking_ratio vaeac_check_parameters vaeac_check_positive_integers vaeac_check_positive_numerics vaeac_check_probabilities vaeac_check_save_names vaeac_check_save_parameters vaeac_check_verbose vaeac_check_which_vaeac_model vaeac_check_x_colnames vaeac_get_current_save_state vaeac_get_data_objects vaeac_get_evaluation_criteria vaeac_get_extra_para_default vaeac_get_full_state_list vaeac_get_mask_generator_name vaeac_get_model_from_checkp vaeac_get_n_decimals vaeac_get_optimizer vaeac_get_save_file_names vaeac_get_x_explain_extended vaeac_impute_missing_entries vaeac_plot_eval_crit vaeac_plot_imputed_ggpairs vaeac_prep_message_batch vaeac_print_train_summary vaeac_save_state vaeac_train_model vaeac_train_model_auxiliary vaeac_train_model_continue vaeac_update_para_locations vaeac_update_pretrained_model

# SHAPR functions -------------------------------------------------------------------------------------------------
#' @rdname setup_approach
#'
#' @param vaeac.depth Positive integer (default is `3`). The number of hidden layers
#' in the neural networks of the masked encoder, full encoder, and decoder.
#' @param vaeac.width Positive integer (default is `32`). The number of neurons in each
#' hidden layer in the neural networks of the masked encoder, full encoder, and decoder.
#' @param vaeac.latent_dim Positive integer (default is `8`). The number of dimensions in the latent space.
#' @param vaeac.lr Positive numeric (default is `0.001`). The learning rate used in the [torch::optim_adam()] optimizer.
#' @param vaeac.activation_function An [torch::nn_module()] representing an activation function such as, e.g.,
#' [torch::nn_relu()] (default), [torch::nn_leaky_relu()], [torch::nn_selu()], or [torch::nn_sigmoid()].
#' @param vaeac.n_vaeacs_initialize Positive integer (default is `4`). The number of different vaeac models to initiate
#'  in the start. Pick the best performing one after `vaeac.extra_parameters$epochs_initiation_phase`
#'  epochs (default is `2`) and continue training that one.
#' @param vaeac.epochs Positive integer (default is `100`). The number of epochs to train the final vaeac model.
#' This includes `vaeac.extra_parameters$epochs_initiation_phase`, where the default is `2`.
#' @param vaeac.extra_parameters Named list with extra parameters to the `vaeac` approach. See
#'  [shapr::vaeac_get_extra_para_default()] for description of possible additional parameters and their default values.
#'
#' @inheritParams default_doc_explain
#'
#' @export
#' @author Lars Henry Berge Olsen
setup_approach.vaeac <- function(internal, # add default values for vaeac here.
                                 vaeac.depth = 3,
                                 vaeac.width = 32,
                                 vaeac.latent_dim = 8,
                                 vaeac.activation_function = torch::nn_relu,
                                 vaeac.lr = 0.001,
                                 vaeac.n_vaeacs_initialize = 4,
                                 vaeac.epochs = 100,
                                 vaeac.extra_parameters = list(),
                                 ...) {
  # Check that torch is installed
  if (!requireNamespace("torch", quietly = TRUE)) {
    stop("`torch` is not installed. Please run `install.packages('torch')`.")
  }
  if (!torch::torch_is_installed()) stop("`torch` is not properly installed. Please run `torch::install_torch()`.")

  # Extract the objects we will use later
  S <- internal$objects$S
  X <- internal$objects$X
  parameters <- internal$parameters

  # Small printout to user
  if (parameters$verbose == 2) message("Setting up the `vaeac` approach.")

  # Check if we are doing a combination of approaches
  combined_approaches <- length(parameters$approach) > 1

  # Ensure that `parameters$vaeac.extra_parameters` is a named list
  if (is.null(parameters$vaeac.extra_parameters)) parameters$vaeac.extra_parameters <- list()
  if (!is.list(parameters$vaeac.extra_parameters)) stop("`vaeac.extra_parameters` must be a list.")
  if (length(parameters$vaeac.extra_parameters) > 0) vaeac_check_extra_named_list(parameters$vaeac.extra_parameters)

  # Ensure that all vaeac parameters are in their right location
  parameters <- vaeac_update_para_locations(parameters = parameters)

  # Extract the default values defined for the vaeac parameters in this function
  vaeac_main_para_names <- methods::formalArgs(setup_approach.vaeac)
  vaeac_main_para_names <- vaeac_main_para_names[!vaeac_main_para_names %in% c("internal", "...")]
  vaeac_main_para <- mget(vaeac_main_para_names)

  # Add the default extra parameter values for the non-user specified extra parameters
  parameters$vaeac.extra_parameters <- utils::modifyList(vaeac_get_extra_para_default(),
    parameters$vaeac.extra_parameters,
    keep.null = TRUE
  )

  # Add the default main parameter values for the non-user specified main parameters
  parameters <- utils::modifyList(vaeac_main_para, parameters, keep.null = TRUE)

  # Reorder them such that the vaeac parameters are at the end of the parameters list
  parameters <- c(parameters[(length(vaeac_main_para) + 1):length(parameters)], parameters[seq_along(vaeac_main_para)])

  # Check if vaeac is to be applied on a subset of coalitions.
  if (!parameters$exact || parameters$is_groupwise || combined_approaches) {
    # We have either:
    # 1) sampled `n_combinations` different subsets of coalitions (i.e., not exact),
    # 2) using the coalitions which respects the groups in group Shapley values, and/or
    # 3) using a combination of approaches where vaeac is only used on a subset of the coalitions.
    # Here, objects$S contains the coalitions while objects$X contains the information about the approach.

    # Extract the the coalitions / masks which are estimated using vaeac as a matrix
    parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions <-
      S[X[approach == "vaeac"]$id_combination, , drop = FALSE]

    # Extract the weights for the corresponding coalitions / masks.
    parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <-
      X$shapley_weight[X[approach == "vaeac"]$id_combination]

    # Normalize the weights/probabilities such that they sum to one.
    parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <-
      parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob /
        sum(parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob)
  } else {
    # We are going to use the MCAR(`masking_ratio`) masking scheme. Set the variables to `NULL` as we do not need them.
    parameters$vaeac.mask_gen_coalitions <- parameters$vaeac.mask_gen_coalitions_prob <- NULL
  }

  # Check if user provided a pre-trained vaeac model, otherwise, we train one from scratch.
  if (is.null(parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model)) {
    # We train a vaeac model with the parameters in `parameters`, as user did not provide pre-trained vaeac model
    if (parameters$verbose == 2) {
      message(paste0(
        "Training the `vaeac` model with the provided parameters from scratch on ",
        ifelse(parameters$vaeac.extra_parameter$vaeac.cuda, "GPU", "CPU"), "."
      ))
    }

    # Specify that a vaeac model was NOT provided
    parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided <- FALSE

    # Extract all veaac parameters and remove the "vaeac." prefix as the names need to mach the parameters in "do.call"
    vaeac_all_parameters <- c(
      parameters$vaeac.extra_parameters,
      parameters[vaeac_main_para_names[vaeac_main_para_names != "vaeac.extra_parameters"]]
    )
    names(vaeac_all_parameters) <- sub("vaeac\\.", "", names(vaeac_all_parameters))
    vaeac_all_parameters <- c(vaeac_all_parameters, parameters[c("seed", "verbose")]) # Add seed and verbose

    # Fit/train the vaeac model with the provided model parameters
    vaeac_model <- do.call(vaeac_train_model, c(vaeac_all_parameters, list(x_train = internal$data$x_train)))

    # Add this to the explainer object
    parameters$vaeac <- list(
      models = vaeac_model[1:(grep("train_vlb", names(vaeac_model)) - 1)], # Models are all entries before `train_vlb`
      results = vaeac_model[c("train_vlb", "val_iwae", "val_iwae_running")], # The train & val results
      parameters = vaeac_model$parameters # List of all the parameters used to train the vaeac model
    )

    # Add `vaeac` as a class to the object. We use this to validate the input when
    # `vaeac.pretrained_vaeac_model` is given to the `shapr::explain()` function.
    class(parameters$vaeac) <- c(class(parameters$vaeac), "vaeac")
  } else {
    # User provided a pre-trained vaeac model. (Minimal checking for valid vaeac model is conducted.)
    # The pre-trained vaeac model is either:
    # 1. The explanation$internal$parameters$vaeac list of type "vaeac" from an earlier call to explain().
    # 2. A string containing the path to where the "vaeac" model is stored on disk.
    if (parameters$verbose == 2) message("Loading the provided `vaeac` model.")

    # Boolean representing that a pre-trained vaeac model was provided
    parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided <- TRUE

    # Check some aspects of the pre-trained vaeac model and add it to the parameters list if it passes the checks
    parameters <- vaeac_update_pretrained_model(parameters = parameters)

    # Small printout informing about the location of the model
    if (parameters$verbose == 2) {
      message(paste0(
        "The `vaeac` model runs/is trained on ", ifelse(parameters$vaeac$parameters$cuda, "GPU", "CPU"), "."
      ))
    }
  }

  # Get which vaeac model we are to use, load it and then store the checkpoint
  checkpoint <- torch::torch_load(parameters$vaeac$models[[parameters$vaeac.extra_parameters$vaeac.which_vaeac_model]])
  parameters$vaeac.checkpoint <- checkpoint

  # Set up and store the vaeac model such that it is loaded before calling the `prepare_data.vaeac()` function.
  parameters$vaeac.model <-
    vaeac_get_model_from_checkp(checkpoint = checkpoint, cuda = checkpoint$cuda, mode_train = FALSE)

  # Extract and save sampling method. That is, if we are to sample randomly from the inferred generative distributions
  # or if we are to sample the most likely values (mean for cont and class with highest prob for cat features).
  parameters$vaeac.sampler <- if (parameters$vaeac.extra_parameters$vaeac.sample_random) {
    parameters$vaeac.model$sampler_random
  } else {
    parameters$vaeac.model$sampler_most_likely
  }

  # Update/overwrite the parameters list in the internal list.
  internal$parameters <- parameters

  # Small printout to user
  if (parameters$verbose == 2) message("Done with setting up the `vaeac` approach.\n")

  # Return the updated internal list.
  return(internal)
}

#' @inheritParams default_doc
#'
#' @rdname prepare_data
#' @export
#' @author Lars Henry Berge Olsen
prepare_data.vaeac <- function(internal, index_features = NULL, ...) {
  # If not provided, then set `index_features` to all non trivial coalitions
  if (is.null(index_features)) index_features <- seq(2, internal$parameters$n_combinations - 1)

  # Extract objects we are going to need later
  S <- internal$objects$S
  seed <- internal$parameters$seed
  verbose <- internal$parameters$verbose
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain
  n_samples <- internal$parameters$n_samples
  vaeac.model <- internal$parameters$vaeac.model
  vaeac.sampler <- internal$parameters$vaeac.sampler
  vaeac.checkpoint <- internal$parameters$vaeac.checkpoint
  vaeac.batch_size_sampling <- internal$parameters$vaeac.extra_parameters$vaeac.batch_size_sampling

  # Small printout to the user about which batch we are working on
  if (verbose == 2) vaeac_prep_message_batch(internal = internal, index_features = index_features)

  # Apply all coalitions to all explicands to get a data table where `vaeac` will impute the `NaN` values
  x_explain_extended <- vaeac_get_x_explain_extended(x_explain = x_explain, S = S, index_features = index_features)

  # Set the number of observations do generate the MC samples for at the time.
  n_explain_extended <- nrow(x_explain_extended)
  batch_size <- if (is.null(vaeac.batch_size_sampling)) n_explain_extended else vaeac.batch_size_sampling
  if (batch_size > n_explain_extended) batch_size <- n_explain_extended

  # Impute the missing entries using the vaeac approach.
  x_explain_with_MC_samples_dt <- vaeac_impute_missing_entries(
    x_explain_with_NaNs = x_explain_extended,
    n_explain = n_explain,
    n_samples = n_samples,
    vaeac_model = vaeac.model,
    checkpoint = vaeac.checkpoint,
    sampler = vaeac.sampler,
    batch_size = batch_size,
    verbose = verbose,
    seed = seed,
    index_features = index_features
  )

  # Return the generated conditional Monte Carlo samples
  return(x_explain_with_MC_samples_dt)
}


# Train functions ======================================================================================================
#' Train the Vaeac Model
#'
#' @description Function that fits a vaeac model to the given dataset based on the provided parameters,
#' as described in \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}. Note that
#' all default parameters specified below origin from [shapr::setup_approach.vaeac()] and
#' [shapr::vaeac_get_extra_para_default()].
#'
#' @details
#' The vaeac model consists of three neural networks, i.e., a masked encoder, a full encoder, and a decoder.
#' The networks have shared `depth`, `width`, and `activation_function`. The encoders maps the `x_train`
#' to a latent representation of dimension `latent_dim`, while the decoder maps the latent representations
#' back to the feature space. See \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}
#' for more details. The function first initiates `n_vaeacs_initialize` vaeac models with different randomly
#' initiated network parameter values to remedy poorly initiated values. After `epochs_initiation_phase` epochs, the
#' `n_vaeacs_initialize` vaeac models are compared and the function continues to only train the best performing
#' one for a total of `epochs` epochs. The networks are trained using the ADAM optimizer with the learning rate is `lr`.
#'
#' @param depth Positive integer (default is `3`). The number of hidden layers
#' in the neural networks of the masked encoder, full encoder, and decoder.
#' @param width Positive integer (default is `32`). The number of neurons in each
#' hidden layer in the neural networks of the masked encoder, full encoder, and decoder.
#' @param latent_dim Positive integer (default is `8`). The number of dimensions in the latent space.
#' @param lr Positive numeric (default is `0.001`). The learning rate used in the [torch::optim_adam()] optimizer.
#' @param activation_function An [torch::nn_module()] representing an activation function such as, e.g.,
#' [torch::nn_relu()] (default), [torch::nn_leaky_relu()], [torch::nn_selu()], or [torch::nn_sigmoid()].
#' @param n_vaeacs_initialize Positive integer (default is `4`). The number of different vaeac models to initiate
#'  in the start. Pick the best performing one after `epochs_initiation_phase`
#'  epochs (default is `2`) and continue training that one.
#' @param epochs Positive integer (default is `100`). The number of epochs to train the final vaeac model.
#' This includes `epochs_initiation_phase`, where the default is `2`.
#' @param x_train A data.table containing the training data. Categorical data must have class names \eqn{1,2,\dots,K}.
#' @param model_description String (default is `make.names(Sys.time())`). String containing, e.g., the name of the
#' data distribution or additional parameter information. Used in the save name of the fitted model. If not provided,
#' then a name will be generated based on [base::Sys.time()] to ensure a unique name. We use [base::make.names()] to
#' ensure a valid file name for all operating systems.
#' @param folder_to_save_model String (default is [base::tempdir()]). String specifying a path to a folder where
#' the function is to save the fitted vaeac model. Note that  the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`.
#' @param cuda Logical (default is `FALSE`). If `TRUE`, then the `vaeac` model will be trained using cuda/GPU.
#' If [torch::cuda_is_available()] is `FALSE`, the we fall back to use CPU. If `FALSE`, we use the CPU. Using a GPU
#' for smaller tabular dataset often do not improve the efficiency.
#' See \code{vignette("installation", package = "torch")} fo help to enable running on the GPU (only Linux and Windows).
#' @param epochs_initiation_phase Positive integer (default is `2`). The number of epochs to run each of the
#' `n_vaeacs_initialize` `vaeac` models before continuing to train only the best performing model.
#' @param epochs_early_stopping Positive integer (default is `NULL`). The training stops if there has been no
#' improvement in the validation IWAE for `epochs_early_stopping` epochs. If the user wants the training process
#' to be solely based on this training criterion, then `epochs` in [shapr::explain()] should be set to a large
#' number. If `NULL`, then `shapr` will internally set `epochs_early_stopping = vaeac.epochs` such that early
#' stopping does not occur.
#' @param save_every_nth_epoch Positive integer (default is `NULL`). If provided, then the vaeac model after
#' every `save_every_nth_epoch`th epoch will be saved.
#' @param val_ratio Numeric (default is `0.25`). Scalar between `0` and `1` indicating the ratio of
#' instances from the input data which will be used as validation data. That is, `val_ratio = 0.25` means
#' that `75%` of the provided data is used as training data, while the remaining `25%` is used as validation data.
#' @param val_iwae_n_samples Positive integer (default is `25`). The number of generated samples used
#' to compute the IWAE criterion when validating the vaeac model on the validation data.
#' @param batch_size Positive integer (default is `64`). The number of samples to include in each batch
#' during the training of the vaeac model. Used in [torch::dataloader()].
#' @param skip_conn_layer Logical (default is `TRUE`). If `TRUE`, we apply identity skip connections in each
#' layer, see [shapr::skip_connection()]. That is, we add the input \eqn{X} to the outcome of each hidden layer,
#' so the output becomes \eqn{X + activation(WX + b)}.
#' @param skip_conn_masked_enc_dec Logical (default is `TRUE`). If `TRUE`, we apply concatenate skip
#' connections between the layers in the masked encoder and decoder. The first layer of the masked encoder will be
#' linked to the last layer of the decoder. The second layer of the masked encoder will be
#' linked to the second to last layer of the decoder, and so on.
#' @param batch_normalization Logical (default is `FALSE`). If `TRUE`, we apply batch normalization after the
#' activation function. Note that if `skip_conn_layer = TRUE`, then the normalization is applied after the
#' inclusion of the skip connection. That is, we batch normalize the whole quantity \eqn{X + activation(WX + b)}.
#' @param paired_sampling Logical (default is `TRUE`). If `TRUE`, we apply paired sampling to the training
#' batches. That is, the training observations in each batch will be duplicated, where the first instance will be masked
#' by \eqn{S} while the second instance will be masked by \eqn{\bar{S}}. This ensures that the training of the
#' `vaeac` model becomes more stable as the model has access to the full version of each training observation. However,
#' this will increase the training time due to more complex implementation and doubling the size of each batch. See
#' [shapr::paired_sampler()] for more information.
#' @param running_avg_n_values running_avg_n_values Positive integer (default is `5`).
#' The number of previous IWAE values to include
#' when we compute the running means of the IWAE criterion.
#' @param masking_ratio Numeric (default is `0.5`). Probability of masking a feature in the
#' [shapr::mcar_mask_generator()] (MCAR = Missing Completely At Random). The MCAR masking scheme ensures that `vaeac`
#' model can do arbitrary conditioning as all coalitions will be trained. `masking_ratio` will be overruled if
#' `mask_gen_coalitions` is specified.
#' @param mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the
#' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally
#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when
#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e.,
#' when `group` is specified in [shapr::explain()].
#' @param mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height
#' of `mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in
#' `mask_gen_coalitions`.
#' @param sigma_mu Numeric (default is `1e4`). One of two hyperparameter values in the normal-gamma prior
#' used in the masked encoder, see Section 3.3.1 in
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#' @param sigma_sigma Numeric (default is `1e-4`). One of two hyperparameter values in the normal-gamma prior
#' used in the masked encoder, see Section 3.3.1 in
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#' @param save_data Logical (default is `FALSE`). If `TRUE`, then the data is stored together with
#' the model. Useful if one are to continue to train the model later using [shapr::vaeac_train_model_continue()].
#' @param log_exp_cont_feat Logical (default is `FALSE`). If we are to \eqn{\log} transform all
#' continuous features before sending the data to [shapr::vaeac()]. The `vaeac` model creates unbounded Monte Carlo
#' sample values. Thus, if the continuous features are strictly positive (as for, e.g., the Burr distribution and
#' Abalone data set), it can be advantageous to \eqn{\log} transform the data to unbounded form before using `vaeac`.
#' If `TRUE`, then [shapr::vaeac_postprocess_data()] will take the \eqn{\exp} of the results to get back to strictly
#' positive values when using the `vaeac` model to impute missing values/generate the Monte Carlo samples.
#' @param verbose Boolean. An integer specifying the level of verbosity. Use `0` (default) for no verbosity,
#' `1` for low verbose, and `2` for high verbose.
#' @param seed Positive integer (default is `1`). Seed for reproducibility. Specifies the seed before any randomness
#' based code is being run.
#' @param which_vaeac_model String (default is `best`). The name of the `vaeac` model (snapshots from different
#' epochs) to use when generating the Monte Carlo samples. The standard choices are: `"best"` (epoch with lowest IWAE),
#' `"best_running"` (epoch with lowest running IWAE, see `vaeac.running_avg_n_values`), and `last` (the last epoch).
#' Note that additional choices are available if `vaeac.save_every_nth_epoch` is provided. For example, if
#' `vaeac.save_every_nth_epoch = 5`, then `vaeac.which_vaeac_model` can also take the values `"epoch_5"`, `"epoch_10"`,
#' `"epoch_15"`, and so on.
#' @param ... List of extra parameters, currently not used.
#'
#' @return A list containing the training/validation errors and paths to where the vaeac models are saved on the disk.
#' @export
#' @author Lars Henry Berge Olsen
vaeac_train_model <- function(x_train,
                              model_description,
                              folder_to_save_model,
                              cuda,
                              n_vaeacs_initialize,
                              epochs_initiation_phase,
                              epochs,
                              epochs_early_stopping,
                              save_every_nth_epoch,
                              val_ratio,
                              val_iwae_n_samples,
                              depth,
                              width,
                              latent_dim,
                              lr,
                              batch_size,
                              running_avg_n_values,
                              activation_function,
                              skip_conn_layer,
                              skip_conn_masked_enc_dec,
                              batch_normalization,
                              paired_sampling,
                              masking_ratio,
                              mask_gen_coalitions,
                              mask_gen_coalitions_prob,
                              sigma_mu,
                              sigma_sigma,
                              save_data,
                              log_exp_cont_feat,
                              which_vaeac_model,
                              verbose,
                              seed,
                              ...) {
  # Set seed for reproducibility for both R and torch
  set.seed(seed)
  torch::torch_manual_seed(seed)

  # Set epochs_early_stopping to epochs to ensure that early stopping never occurs
  if (is.null(epochs_early_stopping)) epochs_early_stopping <- epochs

  # Check all the vaeac parameters
  do.call(vaeac_check_parameters, mget(methods::formalArgs(vaeac_train_model)))

  # Check if we can use cuda
  if (cuda) cuda <- vaeac_check_cuda(cuda)

  # Determine which mask generator to use
  mask_generator_name <- vaeac_get_mask_generator_name(
    mask_gen_coalitions = mask_gen_coalitions,
    mask_gen_coalitions_prob = mask_gen_coalitions_prob,
    masking_ratio = masking_ratio,
    verbose = verbose
  )

  # Set up the data loaders and get the save file names and load them into the local environment
  list2env(
    vaeac_get_data_objects(
      x_train = x_train,
      log_exp_cont_feat = log_exp_cont_feat,
      val_ratio = val_ratio,
      batch_size = batch_size,
      paired_sampling = paired_sampling,
      model_description = model_description,
      depth = depth,
      width = width,
      latent_dim = latent_dim,
      lr = lr,
      epochs = epochs,
      save_every_nth_epoch = save_every_nth_epoch,
      folder_to_save_model = folder_to_save_model,
      train_indices = NULL,
      val_indices = NULL
    ),
    envir = environment()
  )

  # Get information saved together with the vaeac model to make it possible to load the model from disk later.
  # Note that some of the parameters could be derived from others, but for simplicity we store all needed objects.
  state_list <- vaeac_get_full_state_list(environment())

  # Check if we are to add the training data to the state list
  if (save_data) state_list <- c(state_list, list(x_train = x_train, x_train_torch = x_train_torch))

  ## Initializing vaeac models
  # Initialize several vaeac models and keep the one with the best training variational lower bound
  # after a given number of epochs. Keep the version with highest vlb, denoted by "best_vlb".
  best_vlb <- -Inf

  # Create a `progressr::progressor()` to keep track of the overall training time of the vaeac approach
  if (requireNamespace("progressr", quietly = TRUE)) {
    progressr_bar <- progressr::progressor(steps = epochs_initiation_phase * (n_vaeacs_initialize - 1) + epochs)
  } else {
    progressr_bar <- NULL
  }

  # Iterate over the initializations.
  initialization_idx <- 1
  for (initialization_idx in seq(n_vaeacs_initialize)) {
    # Initialize a new vaeac model
    vaeac_model <- vaeac(
      one_hot_max_sizes = one_hot_max_sizes,
      width = width,
      depth = depth,
      latent_dim = latent_dim,
      activation_function = activation_function,
      skip_conn_layer = skip_conn_layer,
      skip_conn_masked_enc_dec = skip_conn_masked_enc_dec,
      batch_normalization = batch_normalization,
      paired_sampling = paired_sampling,
      mask_generator_name = mask_generator_name,
      masking_ratio = masking_ratio,
      mask_gen_coalitions = mask_gen_coalitions,
      mask_gen_coalitions_prob = mask_gen_coalitions_prob,
      sigma_mu = sigma_mu,
      sigma_sigma = sigma_sigma
    )

    # Send the model to the GPU, if we have access to it and user wants to
    if (cuda) vaeac_model$cuda()

    # Add the number of trainable parameters in the vaeac model to the state list
    if (initialization_idx == 1) {
      state_list$n_trainable_parameters <- vaeac_model$n_train_param
      if (verbose == 2) {
        message(paste0("The vaeac model contains ", vaeac_model$n_train_param[1, 1], " trainable parameters."))
      }
    }

    # Print which initialization vaeac the function is working on
    if (verbose == 2) {
      message(paste0("Initializing vaeac number ", initialization_idx, " of ", n_vaeacs_initialize, "."))
    }

    # Create the ADAM optimizer
    optimizer <- vaeac_get_optimizer(vaeac_model = vaeac_model, lr = lr, optimizer_name = "adam")

    # Train the current initialized vaeac model
    vaeac_model_now_list <- vaeac_train_model_auxiliary(
      vaeac_model = vaeac_model,
      optimizer = optimizer,
      epochs = epochs_initiation_phase,
      epochs_start = 1, # All the vaeacs should start from scratch
      train_dataloader = train_dataloader,
      val_dataloader = val_dataloader,
      val_iwae_n_samples = val_iwae_n_samples,
      running_avg_n_values = running_avg_n_values,
      epochs_early_stopping = FALSE, # Do not want to do early stopping during initialization
      verbose = verbose,
      cuda = cuda,
      progressr_bar = progressr_bar,
      save_every_nth_epoch = save_every_nth_epoch,
      initialization_idx = initialization_idx,
      n_vaeacs_initialize = n_vaeacs_initialize,
      train_vlb = NULL, # We start from scratch
      val_iwae = NULL, # We start from scratch
      val_iwae_running = NULL # We start from scratch
    )

    # If the new initialization have lower training VLB than previous initializations, then we keep it.
    if ((best_vlb <= vaeac_model_now_list$avg_vlb)$item()) {
      vaeac_model_best_list <- vaeac_model_now_list
    }
  } # Done with initial training of all vaeac models

  # Check if we are printing detailed debug information
  # Small printout to the user stating which initiated vaeac model was the best.
  if (verbose == 2) {
    message(paste0(
      "Best vaeac inititalization was number ", vaeac_model_best_list$initialization_idx, " (of ", n_vaeacs_initialize,
      ") with a training VLB = ", round(as.numeric(vaeac_model_best_list$train_vlb[-1]$cpu()), 3),
      " after ", epochs_initiation_phase, " epochs. Continue to train this inititalization."
    ))
  }

  return_list <- vaeac_train_model_auxiliary(
    vaeac_model = vaeac_model_best_list$vaeac_model,
    optimizer = vaeac_model_best_list$optimizer,
    train_dataloader = train_dataloader,
    val_dataloader = val_dataloader,
    val_iwae_n_samples = val_iwae_n_samples,
    running_avg_n_values = running_avg_n_values,
    verbose = verbose,
    cuda = cuda,
    progressr_bar = progressr_bar,
    epochs = epochs,
    epochs_start = epochs_initiation_phase + 1,
    epochs_early_stopping = epochs_early_stopping,
    save_every_nth_epoch = save_every_nth_epoch,
    vaeac_save_file_names = vaeac_save_file_names, # Provide the save names for the models
    state_list = state_list, # Need to provide the state list as it will be saved together with the models
    initialization_idx = NULL, # Do not need to specify it as we are not doing the initialization now
    n_vaeacs_initialize = NULL, # Do not need to specify it as we are not doing the initialization now
    train_vlb = vaeac_model_best_list$train_vlb, # Send in the array from the best initiated vaeac model
    val_iwae = vaeac_model_best_list$val_iwae,
    val_iwae_running = vaeac_model_best_list$val_iwae_running
  )

  # Return the paths where the models are saved and the training/validation errors.
  return(return_list)
}

#' Function used to train a `vaeac` model
#'
#' @description
#' This function can be applied both in the initialization phase when, we train several initiated `vaeac` models, and
#' to keep training the best performing `vaeac` model for the remaining number of epochs. We are in the former setting
#' when `initialization_idx` is provided and the latter when it is `NULL`. When it is `NULL`, we save the `vaeac` models
#' with lowest VLB, IWAE, running IWAE, and the epochs according to `save_every_nth_epoch` to disk.
#'
#' @inheritParams vaeac_train_model
#' @param vaeac_model A [shapr::vaeac()] object. The `vaeac` model this function is to train.
#' @param optimizer A [torch::optimizer()] object. See [shapr::vaeac_get_optimizer()].
#' @param train_dataloader A [torch::dataloader()] containing the training data for the `vaeac` model.
#' @param val_dataloader A [torch::dataloader()] containing the validation data for the `vaeac` model.
#' @param train_vlb A [torch::torch_tensor()] (default is `NULL`)
#' of one dimension containing previous values for the training VLB.
#' @param val_iwae A [torch::torch_tensor()] (default is `NULL`)
#' of one dimension containing previous values for the validation IWAE.
#' @param val_iwae_running A [torch::torch_tensor()] (default is `NULL`)
#' of one dimension containing previous values for the running validation IWAE.
#' @param progressr_bar A [progressr::progressor()] object (default is `NULL`) to keep track of progress.
#' @param epochs_start Positive integer (default is `1`). At which epoch the training is starting at.
#' @param vaeac_save_file_names Array of strings containing the save file names for the `vaeac` model.
#' @param state_list Named list containing the objects returned from [shapr::vaeac_get_full_state_list()].
#' @param initialization_idx Positive integer (default is `NULL`). The index
#' of the current `vaeac` model in the initialization phase.
#'
#' @return Depending on if we are in the initialization phase or not. Then either the trained `vaeac` model, or
#' a list of where the `vaeac` models are stored on disk and the parameters of the model.
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_train_model_auxiliary <- function(vaeac_model,
                                        optimizer,
                                        train_dataloader,
                                        val_dataloader,
                                        val_iwae_n_samples,
                                        running_avg_n_values,
                                        verbose,
                                        cuda,
                                        epochs,
                                        save_every_nth_epoch,
                                        epochs_early_stopping,
                                        epochs_start = 1,
                                        progressr_bar = NULL,
                                        vaeac_save_file_names = NULL,
                                        state_list = NULL,
                                        initialization_idx = NULL,
                                        n_vaeacs_initialize = NULL,
                                        train_vlb = NULL,
                                        val_iwae = NULL,
                                        val_iwae_running = NULL) {
  # Check for valid input
  if (xor(is.null(initialization_idx), is.null(n_vaeacs_initialize))) {
    stop("Either none or both of `initialization_idx` and `n_vaeacs_initialize` must be given.")
  }

  if (is.null(state_list) && is.null(initialization_idx)) {
    stop("`state_list` must be provide when `initialization_idx = NULL` to properly save the `vaeac` model.")
  }

  if (is.null(vaeac_save_file_names) && is.null(initialization_idx)) {
    stop(paste0(
      "`vaeac_save_file_names` must be provide when `initialization_idx = NULL` ",
      "to know where to save the vaeac model."
    ))
  }

  if (!((is.null(train_vlb) && is.null(val_iwae) && is.null(val_iwae_running)) ||
    (!is.null(train_vlb) && !is.null(val_iwae) && !is.null(val_iwae_running)))) {
    stop("Either none or all of `train_vlb`, `val_iwae`, and `val_iwae_running` must be given.")
  }

  # Variable that we change to `TRUE` if early stopping is applied
  if (!is.null(state_list)) state_list$early_stopping_applied <- FALSE

  # Variable to store the epochs of the `vaeac` at the best epoch according to IWAE and IWAE_running
  if (is.null(initialization_idx)) best_epoch <- best_epoch_running <- NULL

  # Get the batch size
  batch_size <- train_dataloader$batch_size

  # Extract the mask generator and the variational lower bound scale factor from the vaeac model object.
  mask_generator <- vaeac_model$mask_generator
  vlb_scale_factor <- vaeac_model$vlb_scale_factor

  # Start the training loop
  for (epoch in seq(from = epochs_start, to = epochs)) {
    avg_vlb <- 0 # Set average variational lower bound to 0 for this epoch
    batch_index <- 1 # Index to keep track of which batch we are working on

    # Iterate over the training data
    coro::loop(for (batch in train_dataloader) {
      # If batch size is less than batch_size, extend it with objects from the beginning of the dataset
      if (batch$shape[1] < batch_size) {
        batch <- vaeac_extend_batch(batch = batch, dataloader = train_dataloader, batch_size = batch_size)
      }

      # Generate mask and do an optimizer step over the mask and the batch
      mask <- mask_generator(batch)

      # Send the batch and mask to GPU if we have access to it and user wants to
      if (cuda) {
        batch <- batch$cuda()
        mask <- mask$cuda()
      }

      # Set all previous gradients to zero.
      optimizer$zero_grad()

      # Compute the variational lower bound for the batch given the mask
      vlb <- vaeac_model$batch_vlb(batch, mask)$mean()

      # Backpropagation: minimize the negative vlb.
      vlb_loss <- (-vlb / vlb_scale_factor)
      vlb_loss$backward()

      # Update the vaeac_model parameters by using the optimizer
      optimizer$step()

      # Update running variational lower bound average using the recursive average formula/update.
      # a + (new - a)/(i+1) = {(i+1)a + new - a}/(i+1) = { a(i) + new}/(i+1) = a *i/(i+1) + new/(i+1)
      avg_vlb <- avg_vlb + (vlb$to(dtype = torch::torch_float())$clone()$detach() - avg_vlb) / batch_index

      # Update the batch index.
      batch_index <- batch_index + 1
    }) # Done with one new epoch

    ## Time to evaluate the vaeac_model on the validation data, potentially save it, and check for early stopping.

    # Store the VLB
    train_vlb <- torch::torch_cat(c(train_vlb, avg_vlb), -1)

    # Compute the validation IWAE
    val_iwae_now <- vaeac_get_val_iwae(
      val_dataloader = val_dataloader,
      mask_generator = mask_generator,
      batch_size = batch_size,
      vaeac_model = vaeac_model,
      val_iwae_n_samples = val_iwae_n_samples
    )
    val_iwae <- torch::torch_cat(c(val_iwae, val_iwae_now), -1)

    # Compute the running validation IWAE
    val_iwae_running_now <-
      val_iwae[
        (-min(length(val_iwae), running_avg_n_values) +
          length(val_iwae) + 1):(-1 + length(val_iwae) + 1),
        drop = FALSE
      ]$mean()$view(1)
    val_iwae_running <- torch::torch_cat(c(val_iwae_running, val_iwae_running_now), -1)

    # Check if we are to save the models
    if (is.null(initialization_idx)) {
      # Save if current vaeac model has the lowest validation IWAE error
      if ((max(val_iwae) <= val_iwae_now)$item() || is.null(best_epoch)) {
        best_epoch <- epoch
        if (verbose == 2) message("Saving `best` vaeac model at epoch ", epoch, ".")
        vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[1])
      }

      # Save if current vaeac model has the lowest running validation IWAE error
      if ((max(val_iwae_running) <= val_iwae_running_now)$item() || is.null(best_epoch_running)) {
        best_epoch_running <- epoch
        if (verbose == 2) message("Saving `best_running` vaeac model at epoch ", epoch, ".")
        vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[2])
      }

      # Save if we are in an n'th epoch and are to save every n'th epoch
      if (is.numeric(save_every_nth_epoch) && epoch %% save_every_nth_epoch == 0) {
        if (verbose == 2) message("Saving `nth_epoch` vaeac model at epoch ", epoch, ".")
        vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3 + epoch %/% save_every_nth_epoch])
      }
    }

    # Handle the message to the progress bar based on if we are doing initialization or final training
    if (!is.null(progressr_bar)) {
      update_message <- if (!is.null(initialization_idx)) {
        paste0(
          "Training vaeac (init. ", initialization_idx, " of ", n_vaeacs_initialize, "): Epoch: ", epoch,
          " | VLB: ", vaeac_get_n_decimals(avg_vlb$item()), " | IWAE: ", vaeac_get_n_decimals(val_iwae_now$item()), " |"
        )
      } else {
        paste0(
          "Training vaeac (final model): Epoch: ", epoch, " | best epoch: ", best_epoch,
          " | VLB: ", vaeac_get_n_decimals(avg_vlb$item()), " | IWAE: ", vaeac_get_n_decimals(val_iwae_now$item()), " |"
        )
      }
      progressr_bar(message = update_message)
    }

    # Check if we are to apply early stopping, i.e., no improvement in the IWAE for `epochs_early_stopping` epochs.
    if (is.numeric(epochs_early_stopping)) {
      if (epoch - best_epoch >= epochs_early_stopping) {
        if (verbose == 2) {
          message(paste0(
            "No IWAE improvment in ", epochs_early_stopping, " epochs. Apply early stopping at epoch ",
            epoch, "."
          ))
        }
        if (!is.null(progressr_bar)) progressr_bar("Training vaeac (early stopping)", amount = epochs - epoch)
        state_list$early_stopping_applied <- TRUE # Add that we did early stopping to the state list
        state_list$epochs <- epoch # Update the number of used epochs.
        break # Stop the training loop
      }
    }
  } # Done with all epochs in training phase

  # Find out what to return
  if (!is.null(initialization_idx)) {
    # Here we return the models and the optimizer which we will train further if this was the best initialization
    return_list <- list(
      vaeac_model = vaeac_model,
      optimizer = optimizer,
      train_vlb = train_vlb,
      val_iwae = val_iwae,
      val_iwae_running = val_iwae_running,
      avg_vlb = avg_vlb,
      initialization_idx = initialization_idx,
      state_list = state_list
    )
  } else {
    # Save the vaeac model at the last epoch
    if (verbose == 2) message("Saving `last` vaeac model at epoch ", epoch, ".")
    last_state <- vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3], return_state = TRUE)

    # Summary printout
    if (verbose == 2) vaeac_print_train_summary(best_epoch, best_epoch_running, last_state)

    # Create a return list
    return_list <- list(
      best = vaeac_save_file_names[1],
      best_running = vaeac_save_file_names[2],
      last = vaeac_save_file_names[3],
      train_vlb = as.array(train_vlb$cpu()),
      val_iwae = as.array(val_iwae$cpu()),
      val_iwae_running = as.array(val_iwae_running$cpu()),
      parameters = last_state
    )

    # Add the potentially additional save names
    if (!is.null(vaeac_save_file_names) && length(vaeac_save_file_names) > 3) {
      return_list <- append(
        return_list,
        setNames(
          as.list(vaeac_save_file_names[-(1:3)]),
          paste0("epoch_", save_every_nth_epoch * seq(length(vaeac_save_file_names) - 3))
        ),
        3
      )
    }

    # Update the class of the returned object
    attr(return_list, "class") <- c("vaeac", class(return_list))
  }
  return(return_list)
}

#' Continue to Train the vaeac Model
#'
#' @description Function that loads a previously trained vaeac model and continue the training, either
#' on new data or on the same dataset as it was trained on before. If we are given a new dataset, then
#' we assume that new dataset has the same distribution and one_hot_max_sizes as the original dataset.
#'
#' @inheritParams vaeac_train_model
#' @param explanation A [shapr::explain()] object and `vaeac` must be the used approach.
#' @param epochs_new Positive integer. The number of extra epochs to conduct.
#' @param lr_new Positive numeric. If we are to overwrite the old learning rate in the adam optimizer.
#'
#' @return A list containing the training/validation errors and paths to where the vaeac models are saved on the disk.
#' @export
#' @author Lars Henry Berge Olsen
vaeac_train_model_continue <- function(explanation,
                                       epochs_new,
                                       lr_new = NULL,
                                       x_train = NULL,
                                       save_data = FALSE,
                                       verbose = 0,
                                       seed = 1) {
  # Check the input
  if (!"shapr" %in% class(explanation)) stop("`explanation` must be a list of class `shapr`.")
  if (!"vaeac" %in% explanation$internal$parameters$approach) stop("`vaeac` is not an approach in `explanation`.")
  if (!is.null(lr_new)) vaeac_check_positive_numerics(list(lr_new = lr_new))
  if (!is.null(x_train) && !data.table::is.data.table(x_train)) stop("`x_train` must be a `data.table` object.")
  vaeac_check_verbose(verbose)
  vaeac_check_positive_integers(list(epochs_new = epochs_new, seed = seed))
  vaeac_check_logicals(list(save_data = save_data))

  # Set seed for reproducibility
  set.seed(seed)

  # Extract the vaeac list and load the model at the last epoch or the best (default 'best' when path is provided)
  vaeac_model <- explanation$internal$parameters$vaeac
  vaeac_model_path <- if (!is.null(vaeac_model$models$last)) vaeac_model$models$last else vaeac_model$models$best
  checkpoint <- torch::torch_load(vaeac_model_path)

  # Get which device we are to continue to train the model
  device <- ifelse(checkpoint$cuda, "cuda", "cpu")

  # If we applied early stopping before and are calling this function, then we turn early stopping off
  if (isTRUE(checkpoint$early_stopping_applied)) checkpoint$epochs_early_stopping <- epochs_new

  # Check for access to a single training data set and use the data from the checkpoint if `x_train` is not provided
  if (is.null(checkpoint$normalized_data) && is.null(x_train)) {
    stop("The `vaeac` model did not include data (set `vaeac.save_data = TRUE in `explain()`) and `x_train = NULL`.")
  }
  if (!is.null(checkpoint$x_train) && !is.null(x_train)) {
    message("The `vaeac` model includes data and `x_train` was provided to this function. We only use `x_train`.")
  }
  if (is.null(x_train)) x_train <- checkpoint$x_train

  # Check that the provided vaeac model is trained on a dataset with the same feature names
  vaeac_check_x_colnames(feature_names_vaeac = checkpoint$feature_list$labels, feature_names_new = names(x_train))

  # Check if we can reuse the original validation and training indices
  if (!is.null(checkpoint$x_train) || nrow(x_train) == checkpoint$n_train) {
    val_indices <- checkpoint$val_indices
    train_indices <- checkpoint$train_indices
  } else {
    val_indices <- train_indices <- NULL
  }

  # Set up the data loaders and get the save file names and load them into the local environment
  list2env(
    vaeac_get_data_objects(
      x_train = x_train,
      log_exp_cont_feat = checkpoint$log_exp_cont_feat,
      val_ratio = checkpoint$val_ratio,
      batch_size = checkpoint$batch_size,
      paired_sampling = checkpoint$paired_sampling,
      model_description = checkpoint$ model_description,
      depth = checkpoint$depth,
      width = checkpoint$width,
      latent_dim = checkpoint$latent_dim,
      lr = checkpoint$lr, # Use the old one as this parameter is used in the filenames
      epochs = checkpoint$epochs + epochs_new,
      save_every_nth_epoch = checkpoint$save_every_nth_epoch,
      folder_to_save_model = checkpoint$folder_to_save_model,
      train_indices = train_indices,
      val_indices = val_indices
    ),
    envir = environment()
  )

  # List to values saved to disk together with the vaeac models below.
  state_list_new <- list(
    norm_mean = as.array(x_train_preprocessed$norm_mean),
    norm_std = as.array(x_train_preprocessed$norm_std),
    n_train = n_train,
    epochs_new = epochs_new,
    train_indices = train_indices,
    val_indices = val_indices,
    lr_new = lr_new
  )

  # If we are also to save the data to state_list.
  if (save_data) {
    state_list_new <- c(state_list_new, list(x_train = x_train, x_train_torch = x_train_torch))

    # Give a message regarding disk usage
    vaeac_check_save_parameters(
      save_data = save_data,
      epochs = epochs_new,
      save_every_nth_epoch = checkpoint$save_every_nth_epoch,
      x_train_size = format(utils::object.size(x_train), units = "auto")
    )
  }

  # Add the new state list as a list to the checkpoint
  n_times_continued_trained <- sum(grepl("state_list_new", names(checkpoint)))
  state_list_new_name <- paste("state_list_new", n_times_continued_trained + 1, sep = "_")
  state_list <- checkpoint
  state_list[[state_list_new_name]] <- state_list_new

  # Set up the vaeac model in training mode and based on the parameters stored in the checkpoint
  vaeac_model <- vaeac_get_model_from_checkp(checkpoint = checkpoint, cuda = checkpoint$cuda, mode_train = TRUE)

  # Send the loaded optimizer parameters to GPU if necessary
  if (checkpoint$cuda) {
    checkpoint$optimizer_state_dict$state <- lapply(
      checkpoint$optimizer_state_dict$state,
      function(x) lapply(x, function(y) if ("torch_tensor" %in% class(y)) y$cuda() else y)
    )
  }

  # Specify the learning rate we will use, create the an adam optimizer, and insert the stored optimizer state.
  lr_now <- if (!is.null(lr_new)) lr_new else checkpoint$lr
  optimizer <- vaeac_get_optimizer(vaeac_model = vaeac_model, lr = lr_now, optimizer_name = "adam")
  optimizer$load_state_dict(checkpoint$optimizer_state_dict)

  # Compute the new number of epochs
  epochs_old <- checkpoint$epochs
  epochs <- epochs_old + epochs_new
  state_list$epochs <- epochs

  # Create a `progressr::progressor()` to keep track of the new training
  if (requireNamespace("progressr", quietly = TRUE)) {
    progressr_bar <- progressr::progressor(steps = epochs_new)
  } else {
    progressr_bar <- NULL
  }

  # Train the vaeac model for `epochs_new` number of epochs
  vaeac_tmp <- vaeac_train_model_auxiliary(
    vaeac_model = vaeac_model,
    optimizer = optimizer,
    train_dataloader = train_dataloader,
    val_dataloader = val_dataloader,
    val_iwae_n_samples = checkpoint$val_iwae_n_samples,
    running_avg_n_values = checkpoint$running_avg_n_values,
    verbose = verbose,
    cuda = checkpoint$cuda,
    progressr_bar = progressr_bar,
    epochs = epochs,
    epochs_start = epochs_old + 1,
    epochs_early_stopping = checkpoint$epochs_early_stopping,
    save_every_nth_epoch = checkpoint$save_every_nth_epoch,
    vaeac_save_file_names = vaeac_save_file_names, # Provide the save names for the models
    state_list = state_list, # Need to provide the state list as it will be saved together with the models
    initialization_idx = NULL, # Do not need to specify it as we are not doing the initialization now
    n_vaeacs_initialize = NULL, # Do not need to specify it as we are not doing the initialization now
    train_vlb = checkpoint$train_vlb$to(device = device), # Send to correct device such that we can append new values
    val_iwae = checkpoint$val_iwae$to(device = device),
    val_iwae_running = checkpoint$val_iwae_running$to(device = device)
  )

  # Create the return list
  return_list <- list(
    models = vaeac_tmp[1:(grep("train_vlb", names(vaeac_tmp)) - 1)], # Models are all entries before `train_vlb`
    results = vaeac_tmp[c("train_vlb", "val_iwae", "val_iwae_running")], # The train & val results
    parameters = vaeac_tmp$parameters # List of all the parameters used to train the vaeac model
  )

  # Add `vaeac` as a class to the object. We use this to validate the input when
  # `vaeac.pretrained_vaeac_model` is given to the `shapr::explain()` function.
  class(return_list) <- c(class(return_list), "vaeac")

  # Return the paths where the models are saved and the training/validation errors.
  return(return_list)
}


# Imputation functions =================================================================================================
#' Impute Missing Values Using Vaeac
#'
#' @details  Function that imputes the missing values in 2D matrix where each row constitute an individual.
#' The values are sampled from the conditional distribution estimated by a vaeac model.
#'
#' @inheritParams vaeac_train_model
#' @param x_explain_with_NaNs A 2D matrix, where the missing entries to impute are represented by `NaN`.
#' @param n_samples Integer. The number of imputed versions we create for each row in `x_explain_with_NaNs`.
#' @param index_features Optional integer vector. Used internally in shapr package to index the coalitions.
#' @param n_explain Positive integer. The number of explicands.
#' @param vaeac_model An initialized `vaeac` model that we are going to use to generate the MC samples.
#' @param checkpoint List containing the parameters of the `vaeac` model.
#' @param sampler A sampler object used to sample the MC samples.
#'
#' @return A data.table where the missing values (`NaN`) in `x_explain_with_NaNs` have been imputed `n_samples` times.
#' The data table will contain extra id columns if `index_features` and `n_explain` are provided.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_impute_missing_entries <- function(x_explain_with_NaNs,
                                         n_samples,
                                         vaeac_model,
                                         checkpoint,
                                         sampler,
                                         batch_size,
                                         verbose = 0,
                                         seed = NULL,
                                         n_explain = NULL,
                                         index_features = NULL) {
  # We only need `n_explain` when `index_features` is provided
  if (xor(is.null(index_features), is.null(n_explain))) {
    stop("Either none or both of `index_features` and `n_explain` must be given.")
  }

  # Set seed for reproducibility if provided by the user. Both in R and torch.
  if (!is.null(seed)) {
    set.seed(seed)
    torch::torch_manual_seed(seed)
  }

  if (verbose == 2) message("Preprocessing the explicands.")

  # Preprocess `x_explain_with_NaNs`. Turn factor names into numerics 1,2,...,K, (vaeac only accepts numerics) and keep
  # track of the maping of names. Optionally log-transform the continuous features. Then, finally, normalize the data
  # using the training means and standard deviations. I.e., we assume that the new data follow the same distribution as
  # the training data. If this is NOT the case, then vaeac will generate unreasonable imputations.
  x_explain_with_NaNs_processed <- vaeac_preprocess_data(
    data = x_explain_with_NaNs,
    log_exp_cont_feat = checkpoint$log_exp_cont_feat,
    normalize = TRUE,
    norm_mean = checkpoint$norm_mean, # Normalize using training data means
    norm_std = checkpoint$norm_std # Normalize using training data standard deviations
  )$data_normalized_torch

  # Create the data set object
  dataset <- vaeac_dataset(X = x_explain_with_NaNs_processed, one_hot_max_sizes = checkpoint$one_hot_max_sizes)

  # Create a data loader that load/iterate over the data set in chronological order.
  dataloader <- torch::dataloader(dataset = dataset, batch_size = batch_size, shuffle = FALSE)

  if (verbose == 2) message("Generating the MC samples.")

  # Create an auxiliary list of lists to store the imputed values combined with the original values. The structure is
  # [[i'th MC sample]][[b'th batch]], where the entries are tensors of dimension batch_size x n_features.
  results <- lapply(seq(n_samples), function(k) list())

  # Generate the conditional Monte Carlo samples for the observation `x_explain_with_NaNs`, one batch at the time.
  coro::loop(for (batch in dataloader) {
    # Make a deep copy of the batch and detach it from graph.
    batch_extended <- batch$clone()$detach()

    # If batch size is less than batch_size, extend it with objects from the beginning of the dataset.
    if (batch_extended$shape[1] < batch_size) {
      batch_extended <- vaeac_extend_batch(batch = batch_extended, dataloader = dataloader, batch_size = batch_size)
    }

    # Send the original and extended batch to GPU if applicable.
    if (checkpoint$cuda) {
      batch <- batch$cuda()
      batch_extended <- batch_extended$cuda()
    }

    # Compute the imputation mask, i.e., which entries we are to impute.
    mask_extended <- torch::torch_isnan(batch_extended)$to(dtype = torch::torch_float())

    # Do not need to keep track of the gradients, as we are not fitting the model.
    torch::with_no_grad({
      # Compute the distribution parameters for the generative models inferred by the masked encoder and decoder.
      # This is a tensor of shape [batch_size, n_samples, n_generative_parameters]. Note that, for only continuous
      # features we have that n_generative_parameters = 2*n_features, but for categorical data the number depends
      # on the number of categories.
      samples_params <- vaeac_model$generate_samples_params(batch = batch_extended, mask = mask_extended, K = n_samples)

      # Remove the parameters belonging to added instances in batch_extended.
      samples_params <- samples_params[1:batch$shape[1], , ]
    })

    # Make a deep copy of the batch with missing values set to zero.
    mask <- torch::torch_isnan(batch)
    batch_zeroed_nans <- batch$clone()$detach()
    batch_zeroed_nans[mask] <- 0

    # Iterate over the number of imputations and generate the imputed samples
    for (i in seq(n_samples)) {
      # Extract the i'th inferred generative parameters for the whole batch.
      # sample_params is a tensor of shape [batch_size, n_generative_parameters].
      sample_params <- samples_params[, i, ]

      # Generate the imputations using the generative distributions inferred by the decoder.
      sample <- sampler(sample_params)

      # Set the imputations for features in S (known features) to zero, as we do not need to generate them.
      sample[torch::torch_logical_not(mask)] <- 0

      # Combine the imputations with the original data to fill in the missing values. Shape is [batch_size, n_features].
      sample <- sample + batch_zeroed_nans

      # Make a deep copy and add it to correct location in the results list.
      results[[i]] <- append(results[[i]], sample$clone()$detach()$cpu())
    } # End of iterating over the n_samples
  }) # End of iterating over the batches. Done imputing.

  if (verbose == 2) message("Postprocessing the Monte Carlo samples.")

  # Order the MC samples into a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features]. The lapply function
  # creates a list of tensors of shape [nrow(x_explain_with_NaNs), 1, n_features] by concatenating the batches for the
  # i'th MC sample to a tensor of shape [nrow(x_explain_with_NaNs), n_features] and then add unsqueeze to add a new
  # singleton dimension as the second dimension to get the shape [nrow(x_explain_with_NaNs), 1, n_features]. Then
  # outside of the lapply function, we concatenate the n_samples torch elements to form a final torch result of shape
  # [nrow(x_explain_with_NaNs), n_samples, n_features].
  result <- torch::torch_cat(lapply(seq(n_samples), function(i) torch::torch_cat(results[[i]])$unsqueeze(2)), dim = 2)

  # Get back to the original distribution by undoing the normalization by multiplying with the std and adding the mean
  result <- result * checkpoint$norm_std + checkpoint$norm_mean

  # Convert from a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features]
  # to a matrix of shape [(nrow(x_explain_with_NaNs) * n_samples), n_features].
  result <- data.table::as.data.table(as.matrix(result$view(c(
    result$shape[1] * result$shape[2],
    result$shape[3]
  ))$detach()$cpu()))

  # Post-process the data such that categorical features have original level names and convert to a data table.
  result <- vaeac_postprocess_data(data = result, vaeac_model_state_list = checkpoint)

  # If user provide `index_features`, then we add columns needed for shapr computations
  if (!is.null(index_features)) {
    # Add id, id_combination and weights (uniform for the `vaeac` approach) to the result.
    result[, c("id", "id_combination", "w") := list(
      rep(x = seq(n_explain), each = length(index_features) * n_samples),
      rep(x = index_features, each = n_samples, times = n_explain),
      1 / n_samples
    )]

    # Set the key in the data table
    data.table::setkeyv(result, c("id", "id_combination"))
  }

  return(result)
}

# Check functions ======================================================================================================
#' Check vaeac.extra_parameters list
#'
#' @param vaeac.extra_parameters List containing the extra parameters to the `vaeac` approach
#'
#' @author Lars Henry Berge Olsen
#' @keywords internal
vaeac_check_extra_named_list <- function(vaeac.extra_parameters) {
  names <- names(vaeac.extra_parameters)
  if (is.null(names)) stop("The parameter `vaeac.extra_parameters` is not a named list.")
  if (any(names == "")) stop("Not all parameters in the list `vaeac.extra_parameters` are named.")
}

#' Function that checks positive integers
#'
#' @param named_list_positive_integers List containing named entries. I.e., `list(a = 1, b = 2)`.
#'
#' @return The function does not return anything.
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_positive_integers <- function(named_list_positive_integers) {
  param_names <- names(named_list_positive_integers)
  for (idx in seq_len(length(named_list_positive_integers))) {
    param_name <- param_names[idx]
    value <- named_list_positive_integers[[param_name]]
    if (!is.numeric(value) || length(value) != 1 || value <= 0 || !is.finite(value) || value %% 1 != 0) {
      stop(paste0("'vaeac.", param_name, "' must be a positive integer."))
    }
  }
}

#' Function that checks positive numerics
#'
#' @param named_list_positive_numerics List containing named entries. I.e., `list(a = 0.2, b = 10^3)`.
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_positive_numerics <- function(named_list_positive_numerics) {
  param_names <- names(named_list_positive_numerics)
  for (idx in seq_len(length(named_list_positive_numerics))) {
    param_name <- param_names[idx]
    value <- named_list_positive_numerics[[param_name]]
    if (!is.numeric(value) || length(value) != 1 || !is.finite(value) || value <= 0) {
      stop(paste0("'vaeac.", param_name, "' must be a positive numeric."))
    }
  }
}

#' Function that checks probabilities
#'
#' @param named_list_probabilities List containing named entries. I.e., `list(a = 0.2, b = 0.9)`.
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_probabilities <- function(named_list_probabilities) {
  # Trick needed for entries that can be vectors (i.e., `vaeac.masking_ratio`)
  named_list_probabilities_tmp <- as.list(unlist(named_list_probabilities))
  param_names <- names(named_list_probabilities_tmp)
  for (idx in seq_len(length(named_list_probabilities_tmp))) {
    param_name <- param_names[idx]
    value <- named_list_probabilities_tmp[[param_name]]
    if (!is.numeric(value) || length(value) != 1 || !is.finite(value) || value < 0 || value > 1) {
      stop(paste0("'vaeac.", param_name, "' must be a valid probability (a number between 0 and 1)."))
    }
  }
}

#' Function that checks logicals
#'
#' @param named_list_logicals List containing named entries. I.e., `list(a = TRUE, b = FALSE)`.
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_logicals <- function(named_list_logicals) {
  param_names <- names(named_list_logicals)
  for (idx in seq_len(length(named_list_logicals))) {
    param_name <- param_names[idx]
    value <- named_list_logicals[[param_name]]
    if (!is.logical(value) || length(value) != 1) {
      stop(paste0("'vaeac.", param_name, "' must be a boolean (i.e., `TRUE` or `FALSE`)."))
    }
  }
}

#' Function that checks for valid `vaeac` model name
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_which_vaeac_model <- function(which_vaeac_model, epochs, save_every_nth_epoch = NULL) {
  valid_names <- c("best", "best_running", "last")
  if (!is.null(save_every_nth_epoch)) {
    valid_names <- c(
      valid_names,
      paste0("epoch_", seq(
        from = save_every_nth_epoch,
        by = save_every_nth_epoch,
        length.out = floor(epochs / save_every_nth_epoch)
      ))
    )
  }

  if (!is.null(which_vaeac_model) && !is.character(which_vaeac_model)) {
    stop("`vaeac.which_vaeac_model` must be a string.")
  }

  if (!which_vaeac_model %in% valid_names) {
    stop(paste0(
      "The provided `vaeac.which_vaeac_model` ('", which_vaeac_model, "') does not match any of the valid values: '",
      paste(valid_names, collapse = "', '"), "'."
    ))
  }
}

#' Function that checks provided epoch arguments
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_epoch_values <- function(epochs, epochs_initiation_phase, epochs_early_stopping, save_every_nth_epoch) {
  if (epochs_initiation_phase >= epochs) {
    stop(paste0(
      "'vaeac.epochs_initiation_phase' (", epochs_initiation_phase, ") must be strictly less than ",
      "'vaeac.epochs' (", epochs, ")."
    ))
  }

  if (epochs_early_stopping > epochs) {
    message(paste0(
      "No early stopping as `vaeac.epochs_early_stopping` (", epochs_early_stopping, ") is larger than ",
      "`vaeac.epochs` (", epochs, ")."
    ))
  }

  # Ensure a valid value for save_every_nth_epoch.
  if (!is.null(save_every_nth_epoch) && save_every_nth_epoch > epochs) {
    stop(paste0("Number of 'epochs' (", epochs, ") is less than 'save_every_nth_epoch' (", save_every_nth_epoch, ")."))
  }
  # Ensure a valid value for save_every_nth_epoch.
  if (!is.null(save_every_nth_epoch) && save_every_nth_epoch <= epochs_initiation_phase) {
    stop(paste0(
      "Number of 'epochs_initiation_phase' (", epochs_initiation_phase, ") is less than ",
      "'save_every_nth_epoch' (", save_every_nth_epoch, ")."
    ))
  }
}

#' Function that checks the provided activation function
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_activation_func <- function(activation_function) {
  # In future, check that it is one of the activation functions and not just a nn_module
  # Check that activation function is an nn_module
  if (!any("nn_module" %in% class(activation_function))) stop("`vaeac.activation_function` is not an `nn_module`.")
}

#' Function that checks the specified masking scheme
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_mask_gen <- function(mask_gen_coalitions, mask_gen_coalitions_prob, x_train) {
  masks <- mask_gen_coalitions
  probs <- mask_gen_coalitions_prob

  if (!is.null(masks) || !is.null(probs)) {
    if (xor(is.null(masks), is.null(probs))) {
      stop(
        "Either both `vaeac.mask_gen_coalitions` and `vaeac.mask_gen_coalitions_prob` need to `NULL` ",
        "or both have to be specified."
      )
    }

    if (!is.matrix(masks)) stop("`vaeac.mask_gen_coalitions` must be a matrix.")
    if (!is.numeric(probs)) stop("`vaeac.mask_gen_coalitions_prob` must be an array.")

    if (nrow(masks) != length(probs)) {
      stop(
        "The number of rows in `vaeac.mask_gen_coalitions` must be equal to the length of ",
        "`vaeac.mask_gen_coalitions_prob`."
      )
    }

    if (ncol(masks) != ncol(x_train)) {
      stop(
        "The number of columns in `vaeac.mask_gen_coalitions` must be equal to the number of ",
        "columns in the `x_train`. That is, the number of features."
      )
    }
  }
}

#' Function that checks the verbose parameter
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_verbose <- function(verbose) {
  if (!is.numeric(verbose) || !(verbose %in% c(0, 1, 2))) {
    stop("`vaeac.verbose` must be either `0` (no verbosity), `1` (low verbosity), or `2` (high verbosity).")
  }
}

#' Function that checks that the save folder exists and for a valid file name
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_save_names <- function(folder_to_save_model, model_description) {
  if (!is.character(folder_to_save_model)) stop("`vaeac.folder_to_save_model` must be a string.")
  if (!is.character(model_description)) stop("`vaeac.model_description` must be a string.")
  if (!dir.exists(folder_to_save_model)) {
    stop(paste0("the folder `vaeac.folder_to_save_model` ('", folder_to_save_model, "') does not exist."))
  }
  if (!grepl("^[A-Za-z0-9._-]+$", model_description)) {
    stop(paste0(
      "`vaeac.model_description` can only contain uppercase and lowercase letters, ",
      "digits, dots, underscores, and hyphens."
    ))
  }
}

#' Function that checks for access to CUDA
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_cuda <- function(cuda) {
  # Check if cuda/GPU is available on the current system
  cuda_available <- torch::cuda_is_available()

  # Give message to user if asked to run on cuda, but cuda is not available.
  if (isFALSE(cuda_available) && isTRUE(cuda)) {
    cuda <- FALSE
    message("Cuda/GPU is not available (`shapr` uses CPU instead).", immediate. = TRUE)
  }

  return(cuda)
}

#' Function that checks that the masking ratio argument is valid
#'
#' @inheritParams vaeac_train_model
#' @param n_features The number of features, i.e., the number of columns in the training data.
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_masking_ratio <- function(masking_ratio, n_features) {
  if (length(masking_ratio) > 1 && length(masking_ratio) != ncol(x_train)) {
    stop(paste0(
      "'Masking_ratio' contains masking ratios for ',", length(masking_ratio), "' features, ",
      "but there are '", ncol(x_train), "' features in 'x_train'."
    ))
  }
}

#' Function that gives a warning about disk usage
#'
#' @param x_train_size The object size of the `x_train` object.
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_save_parameters <- function(save_data, epochs, save_every_nth_epoch, x_train_size) {
  if (save_data && !is.null(save_every_nth_epoch) && epochs / save_every_nth_epoch > 5) {
    message(paste0(
      "Having `save_data = TRUE` and `save_every_nth_epoch = ", save_every_nth_epoch, "` might requirer ",
      "a lot of disk storage if `x_train` (", x_train_size, ") is large."
    ))
  }
}

#' Function that checks the feature names of data and `vaeac` model
#'
#' @param feature_names_vaeac Array of strings containing the feature names of the `vaeac` model.
#' @param feature_names_new Array of strings containing the feature names to compare with.
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_x_colnames <- function(feature_names_vaeac, feature_names_new) {
  n_features_vaeac <- length(feature_names_vaeac)
  n_features_new <- length(feature_names_new)

  # Check that the feature names of x_train matches the names of the training data used to train the vaeac model
  if (!isTRUE(all.equal(feature_names_vaeac, feature_names_new))) {
    stop(paste0(
      "The current feature names (`", paste(feature_names_new, collapse = "`, `"), "`) do not match the ",
      "feature names in the provided `vaeac` model (`", paste(feature_names_vaeac, collapse = "`, `"), ")."
    ))
  }

  # Check for equal number of features (this should never occur as test above indirectly checks this too)
  if (n_features_new != n_features_vaeac) {
    stop(paste0(
      "The provided `vaeac` model is trained on a ", n_features_vaeac, "-dimensional dataset, but the current ",
      "dataset is ", n_features_new, "-dimensional."
    ))
  }
}

#' Function that calls all vaeac parameters check functions
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_check_parameters <- function(x_train,
                                   model_description,
                                   folder_to_save_model,
                                   cuda,
                                   n_vaeacs_initialize,
                                   epochs_initiation_phase,
                                   epochs,
                                   epochs_early_stopping,
                                   save_every_nth_epoch,
                                   val_ratio,
                                   val_iwae_n_samples,
                                   depth,
                                   width,
                                   latent_dim,
                                   lr,
                                   batch_size,
                                   running_avg_n_values,
                                   activation_function,
                                   skip_conn_layer,
                                   skip_conn_masked_enc_dec,
                                   batch_normalization,
                                   paired_sampling,
                                   masking_ratio,
                                   mask_gen_coalitions,
                                   mask_gen_coalitions_prob,
                                   sigma_mu,
                                   sigma_sigma,
                                   save_data,
                                   log_exp_cont_feat,
                                   which_vaeac_model,
                                   verbose,
                                   seed,
                                   ...) {
  # Check verbose parameter
  vaeac_check_verbose(verbose = verbose)

  # Check that the activation function is valid torch::nn_module object
  vaeac_check_activation_func(activation_function = activation_function)

  # Check that the save folder exists and for a valid file name
  vaeac_check_save_names(folder_to_save_model = folder_to_save_model, model_description = model_description)

  # Check the probability parameters
  vaeac_check_probabilities(list(val_ratio = val_ratio, masking_ratio = masking_ratio))

  # Check the masking ratio
  vaeac_check_masking_ratio(masking_ratio = masking_ratio, n_features = ncol(x_train))

  # Check the positive numeric parameters
  vaeac_check_positive_numerics(list(lr = lr, sigma_mu = sigma_mu, sigma_sigma = sigma_sigma))

  # Check the mask_gen_coalitions and mask_gen_coalitions_prob parameters
  vaeac_check_mask_gen(
    mask_gen_coalitions = mask_gen_coalitions,
    mask_gen_coalitions_prob = mask_gen_coalitions_prob,
    x_train = x_train
  )

  # Check the logical parameters
  vaeac_check_logicals(list(
    cuda = cuda,
    skip_conn_layer = skip_conn_layer,
    skip_conn_masked_enc_dec = skip_conn_masked_enc_dec,
    batch_normalization = batch_normalization,
    paired_sampling = paired_sampling,
    save_data = save_data,
    log_exp_cont_feat = log_exp_cont_feat
  ))

  # Check the positive integer parameters
  unchecked_positive_integers <- list(
    n_vaeacs_initialize = n_vaeacs_initialize,
    epochs = epochs,
    epochs_early_stopping = epochs_early_stopping,
    epochs_initiation_phase = epochs_initiation_phase,
    val_iwae_n_samples = val_iwae_n_samples,
    depth = depth,
    width = width,
    latent_dim = latent_dim,
    batch_size = batch_size,
    running_avg_n_values = running_avg_n_values,
    seed = seed
  )
  if (!is.null(save_every_nth_epoch)) unchecked_positive_integers$save_every_nth_epoch <- save_every_nth_epoch
  vaeac_check_positive_integers(unchecked_positive_integers)

  # Check the epoch values
  vaeac_check_epoch_values(
    epochs = epochs,
    epochs_initiation_phase = epochs_initiation_phase,
    epochs_early_stopping = epochs_early_stopping,
    save_every_nth_epoch = save_every_nth_epoch
  )

  # Check the save parameters
  vaeac_check_save_parameters(
    save_data = save_data,
    epochs = epochs,
    save_every_nth_epoch = save_every_nth_epoch,
    x_train_size = format(utils::object.size(x_train), units = "auto")
  )

  # Check that user want to use the vaeac model at a valid checkpoint
  vaeac_check_which_vaeac_model(
    which_vaeac_model = which_vaeac_model,
    epochs = epochs,
    save_every_nth_epoch = save_every_nth_epoch
  )
}

# Get functions ========================================================================================================
#' Function to specify the extra parameters in the `vaeac` model
#'
#' @description In this function, we specify the default values for the extra parameters used in [shapr::explain()]
#' for `approach = "vaeac"`.
#'
#' @details
#' The `vaeac` model consists of three neural network (a full encoder, a masked encoder, and a decoder) based
#' on the provided `vaeac.depth` and `vaeac.width`. The encoders map the full and masked input
#' representations to latent representations, respectively, where the dimension is given by `vaeac.latent_dim`.
#' The latent representations are sent to the decoder to go back to the real feature space and
#' provide a samplable probabilistic representation, from which the Monte Carlo samples are generated.
#' We use the `vaeac` method at the epoch with the lowest validation error (IWAE) by default, but
#' other possibilities are available but setting the `vaeac.which_vaeac_model` parameter. See
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)} for more details.
#'
#' @param vaeac.model_description String (default is `make.names(Sys.time())`). String containing, e.g., the name of the
#' data distribution or additional parameter information. Used in the save name of the fitted model. If not provided,
#' then a name will be generated based on [base::Sys.time()] to ensure a unique name. We use [base::make.names()] to
#' ensure a valid file name for all operating systems.
#' @param vaeac.folder_to_save_model String (default is [base::tempdir()]). String specifying a path to a folder where
#' the function is to save the fitted vaeac model. Note that the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`. Furthermore, the model cannot be moved from its
#' original folder if we are to use the [shapr::vaeac_train_model_continue()] function to continue training the model.
#' @param vaeac.pretrained_vaeac_model List or String (default is `NULL`). 1) Either a list of class
#' `vaeac`, i.e., the list stored in `explanation$internal$parameters$vaeac` where `explanation` is the returned list
#' from an earlier call to the [shapr::explain()] function. 2) A string containing the path to where the `vaeac`
#' model is stored on disk, for example, `explanation$internal$parameters$vaeac$models$best`.
#' @param vaeac.cuda Logical (default is `FALSE`). If `TRUE`, then the `vaeac` model will be trained using cuda/GPU.
#' If [torch::cuda_is_available()] is `FALSE`, the we fall back to use CPU. If `FALSE`, we use the CPU. Using a GPU
#' for smaller tabular dataset often do not improve the efficiency.
#' See \code{vignette("installation", package = "torch")} fo help to enable running on the GPU (only Linux and Windows).
#' @param vaeac.epochs_initiation_phase Positive integer (default is `2`). The number of epochs to run each of the
#' `vaeac.n_vaeacs_initialize` `vaeac` models before continuing to train only the best performing model.
#' @param vaeac.epochs_early_stopping Positive integer (default is `NULL`). The training stops if there has been no
#' improvement in the validation IWAE for `vaeac.epochs_early_stopping` epochs. If the user wants the training process
#' to be solely based on this training criterion, then `vaeac.epochs` in [shapr::explain()] should be set to a large
#' number. If `NULL`, then `shapr` will internally set `vaeac.epochs_early_stopping = vaeac.epochs` such that early
#' stopping does not occur.
#' @param vaeac.save_every_nth_epoch Positive integer (default is `NULL`). If provided, then the vaeac model after
#' every `vaeac.save_every_nth_epoch`th epoch will be saved.
#' @param vaeac.val_ratio Numeric (default is `0.25`). Scalar between `0` and `1` indicating the ratio of
#' instances from the input data which will be used as validation data. That is, `vaeac.val_ratio = 0.25` means
#' that `75%` of the provided data is used as training data, while the remaining `25%` is used as validation data.
#' @param vaeac.val_iwae_n_samples Positive integer (default is `25`). The number of generated samples used
#' to compute the IWAE criterion when validating the vaeac model on the validation data.
#' @param vaeac.batch_size Positive integer (default is `64`). The number of samples to include in each batch
#' during the training of the vaeac model. Used in [torch::dataloader()].
#' @param vaeac.batch_size_sampling Positive integer (default is `NULL`) The number of samples to include in
#' each batch when generating the Monte Carlo samples. If `NULL`, then the function generates the Monte Carlo samples
#' for the provided coalitions/combinations and all explicands sent to [shapr::explain()] at the time.
#' The number of coalitions are determined by `n_batches` in [shapr::explain()]. We recommend to tweak `n_batches`
#' rather  than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory.
#' @param vaeac.running_avg_n_values Positive integer (default is `5`). The number of previous IWAE values to include
#' when we compute the running means of the IWAE criterion.
#' @param vaeac.skip_conn_layer Logical (default is `TRUE`). If `TRUE`, we apply identity skip connections in each
#' layer, see [shapr::skip_connection()]. That is, we add the input \eqn{X} to the outcome of each hidden layer,
#' so the output becomes \eqn{X + activation(WX + b)}.
#' @param vaeac.skip_conn_masked_enc_dec Logical (default is `TRUE`). If `TRUE`, we apply concatenate skip
#' connections between the layers in the masked encoder and decoder. The first layer of the masked encoder will be
#' linked to the last layer of the decoder. The second layer of the masked encoder will be
#' linked to the second to last layer of the decoder, and so on.
#' @param vaeac.batch_normalization Logical (default is `FALSE`). If `TRUE`, we apply batch normalization after the
#' activation function. Note that if `vaeac.skip_conn_layer = TRUE`, then the normalization is applied after the
#' inclusion of the skip connection. That is, we batch normalize the whole quantity \eqn{X + activation(WX + b)}.
#' @param vaeac.paired_sampling Logical (default is `TRUE`). If `TRUE`, we apply paired sampling to the training
#' batches. That is, the training observations in each batch will be duplicated, where the first instance will be masked
#' by \eqn{S} while the second instance will be masked by \eqn{\bar{S}}. This ensures that the training of the
#' `vaeac` model becomes more stable as the model has access to the full version of each training observation. However,
#' this will increase the training time due to more complex implementation and doubling the size of each batch. See
#' [shapr::paired_sampler()] for more information.
#' @param vaeac.masking_ratio Numeric (default is `0.5`). Probability of masking a feature in the
#' [shapr::mcar_mask_generator()] (MCAR = Missing Completely At Random). The MCAR masking scheme ensures that `vaeac`
#' model can do arbitrary conditioning as all coalitions will be trained. `vaeac.masking_ratio` will be overruled if
#' `vaeac.mask_gen_coalitions` is specified.
#' @param vaeac.mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the
#' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally
#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when
#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e.,
#' when `group` is specified in [shapr::explain()].
#' @param vaeac.mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height
#' of `vaeac.mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in
#' `vaeac.mask_gen_coalitions`.
#' @param vaeac.sigma_mu Numeric (default is `1e4`). One of two hyperparameter values in the normal-gamma prior
#' used in the masked encoder, see Section 3.3.1 in
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#' @param vaeac.sigma_sigma Numeric (default is `1e-4`). One of two hyperparameter values in the normal-gamma prior
#' used in the masked encoder, see Section 3.3.1 in
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#' @param vaeac.save_data Logical (default is `FALSE`). If `TRUE`, then the data is stored together with
#' the model. Useful if one are to continue to train the model later using [shapr::vaeac_train_model_continue()].
#' @param vaeac.log_exp_cont_feat Logical (default is `FALSE`). If we are to \eqn{\log} transform all
#' continuous features before sending the data to [shapr::vaeac()]. The `vaeac` model creates unbounded Monte Carlo
#' sample values. Thus, if the continuous features are strictly positive (as for, e.g., the Burr distribution and
#' Abalone data set), it can be advantageous to \eqn{\log} transform the data to unbounded form before using `vaeac`.
#' If `TRUE`, then [shapr::vaeac_postprocess_data()] will take the \eqn{\exp} of the results to get back to strictly
#' positive values when using the `vaeac` model to impute missing values/generate the Monte Carlo samples.
#' @param vaeac.sample_random Logical (default is `TRUE`). If `TRUE`, the function generates random Monte Carlo samples
#' from the inferred generative distributions. If `FALSE`, the function use the most likely values, i.e., the mean and
#' class with highest probability for continuous and categorical, respectively.
#' @param vaeac.which_vaeac_model String (default is `best`). The name of the `vaeac` model (snapshots from different
#' epochs) to use when generating the Monte Carlo samples. The standard choices are: `"best"` (epoch with lowest IWAE),
#' `"best_running"` (epoch with lowest running IWAE, see `vaeac.running_avg_n_values`), and `last` (the last epoch).
#' Note that additional choices are available if `vaeac.save_every_nth_epoch` is provided. For example, if
#' `vaeac.save_every_nth_epoch = 5`, then `vaeac.which_vaeac_model` can also take the values `"epoch_5"`, `"epoch_10"`,
#' `"epoch_15"`, and so on.
#' @param vaeac.save_model Boolean. If `TRUE` (default), the `vaeac` model will be saved either in a
#' [base::tempdir()] folder or in a user specified location in `vaeac.folder_to_save_model`. If `FALSE`, then
#' the paths to model and the model will will be deleted from the returned object from [shapr::explain()].
#'
#' @return Named list of the default values `vaeac` extra parameter arguments specified in this function call.
#' Note that both `vaeac.model_description` and `vaeac.folder_to_save_model` will change with time and R session.
#'
#' @export
#' @author Lars Henry Berge Olsen
vaeac_get_extra_para_default <- function(vaeac.model_description = make.names(Sys.time()),
                                         vaeac.folder_to_save_model = tempdir(),
                                         vaeac.pretrained_vaeac_model = NULL,
                                         vaeac.cuda = FALSE,
                                         vaeac.epochs_initiation_phase = 2,
                                         vaeac.epochs_early_stopping = NULL,
                                         vaeac.save_every_nth_epoch = NULL,
                                         vaeac.val_ratio = 0.25,
                                         vaeac.val_iwae_n_samples = 25,
                                         vaeac.batch_size = 64,
                                         vaeac.batch_size_sampling = NULL,
                                         vaeac.running_avg_n_values = 5,
                                         vaeac.skip_conn_layer = TRUE,
                                         vaeac.skip_conn_masked_enc_dec = TRUE,
                                         vaeac.batch_normalization = FALSE,
                                         vaeac.paired_sampling = TRUE,
                                         vaeac.masking_ratio = 0.5,
                                         vaeac.mask_gen_coalitions = NULL,
                                         vaeac.mask_gen_coalitions_prob = NULL,
                                         vaeac.sigma_mu = 1e4,
                                         vaeac.sigma_sigma = 1e-4,
                                         vaeac.sample_random = TRUE,
                                         vaeac.save_data = FALSE,
                                         vaeac.log_exp_cont_feat = FALSE,
                                         vaeac.which_vaeac_model = "best",
                                         vaeac.save_model = TRUE) {
  # Return a named list with the extra parameters to the vaeac model
  return(mget(methods::formalArgs(vaeac_get_extra_para_default)))
}

#' Function to load a `vaeac` model and set it in the right state and mode
#'
#' @inheritParams vaeac_train_model
#' @param checkpoint List. This must be a loaded `vaeac` save object. That is, `torch::torch_load('vaeac_save_path')`.
#' @param mode_train Logical. If `TRUE`, the returned `vaeac` model is set to be in training mode.
#' If `FALSE`, the returned `vaeac` model is set to be in evaluation mode.
#'
#' @return A `vaeac` model with the correct state (based on `checkpoint`), sent to the desired hardware (based on
#' `cuda`), and in the right mode (based on `mode_train`).
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_model_from_checkp <- function(checkpoint, cuda, mode_train) {
  # Check parameters
  vaeac_check_logicals(list(cuda = cuda, mode_train = mode_train))

  # Set up the model such that it is loaded before calling the `prepare_data.vaeac()` function.
  vaeac_model <- vaeac(
    one_hot_max_sizes = checkpoint$one_hot_max_sizes,
    width = checkpoint$width,
    depth = checkpoint$depth,
    latent_dim = checkpoint$latent_dim,
    activation_function = checkpoint$activation_function,
    skip_conn_layer = checkpoint$skip_conn_layer,
    skip_conn_masked_enc_dec = checkpoint$skip_conn_masked_enc_dec,
    batch_normalization = checkpoint$batch_normalization,
    paired_sampling = checkpoint$paired_sampling,
    mask_generator_name = checkpoint$mask_generator_name,
    masking_ratio = checkpoint$masking_ratio,
    mask_gen_coalitions = checkpoint$mask_gen_coalitions,
    mask_gen_coalitions_prob = checkpoint$mask_gen_coalitions_prob,
    sigma_mu = checkpoint$sigma_mu,
    sigma_sigma = checkpoint$sigma_sigma
  )

  # Set the state of the vaeac model (setting the weights and biases in the networks)
  vaeac_model$load_state_dict(checkpoint$model_state_dict)

  # Apply the mode. Evaluation mode effects certain modules by, e.g., deactivating dropout layers,
  # how batch norm is conducted, and so on...
  if (mode_train) vaeac_model$train() else vaeac_model$eval()

  # Send the model to the GPU, if we are supposed to. Otherwise use CPU
  if (cuda) vaeac_model$cuda() else vaeac_model$cpu()

  # Return the model
  return(vaeac_model)
}

#' Function that determines which mask generator to use
#'
#' @inheritParams vaeac_train_model
#'
#' @return The function does not return anything.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_mask_generator_name <- function(mask_gen_coalitions,
                                          mask_gen_coalitions_prob,
                                          masking_ratio,
                                          verbose) {
  if (!is.null(mask_gen_coalitions) && !is.null(mask_gen_coalitions_prob)) {
    # User have provided mask_gen_coalitions (and mask_gen_coalitions_prob),
    # and we want to use specified_masks_mask_generator
    mask_generator_name <- "specified_masks_mask_generator"

    # Small printout
    if (verbose == 2) {
      message(paste0("Using 'specified_masks_mask_generator' with '", nrow(mask_gen_coalitions), "' coalitions."))
    }
  } else if (length(masking_ratio) == 1) {
    # We are going to use 'mcar_mask_generator' as masking_ratio is a singleton.
    # I.e., all feature values are equally likely to be masked based on masking_ratio.
    mask_generator_name <- "mcar_mask_generator"

    # Small printout
    if (verbose == 2) message(paste0("Using 'mcar_mask_generator' with 'masking_ratio = ", masking_ratio, "'."))
  } else if (length(masking_ratio) > 1) {
    # We are going to use 'specified_prob_mask_generator' as masking_ratio is a vector (of same length as ncol(x_train).
    # I.e., masking_ratio[5] specifies the probability of masking 5 features
    mask_generator_name <- "specified_prob_mask_generator"

    # We have an array of masking ratios. Then we are using the specified_prob_mask_generator.
    if (verbose == 2) {
      message(paste0(
        "Using 'specified_prob_mask_generator' mask generator with 'masking_ratio = [",
        paste(masking_ratio, collapse = ", "), "]'."
      ))
    }
  } else {
    stop("`vaeac` could not determine which masking scheme to use based on the givene parameter arguments.")
  }

  return(mask_generator_name)
}

#' Function that creates the save file names for the `vaeac` model
#'
#' @inheritParams vaeac_train_model
#'
#' @return Array of string containing the save files to use when training the `vaeac` model. The first three names
#' corresponds to the best, best_running, and last epochs, in that order.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_save_file_names <- function(model_description,
                                      n_features,
                                      n_train,
                                      depth,
                                      width,
                                      latent_dim,
                                      lr,
                                      epochs,
                                      save_every_nth_epoch,
                                      folder_to_save_model = NULL) {
  file_names <- c("best", "best_running", "last") # The standard epochs we save the vaeac model

  # Add the optional epochs to save the model
  if (!is.null(save_every_nth_epoch)) {
    file_names <- c(file_names, seq(
      from = save_every_nth_epoch,
      by = save_every_nth_epoch,
      length.out = floor(epochs / save_every_nth_epoch)
    ))
  }

  # Create the file names
  file_names <- paste0(
    make.names(model_description), "_n_features_", n_features, "_n_train_", n_train, "_depth_", depth,
    "_width_", width, "_latent_", latent_dim, "_lr_", lr, "_epoch_", file_names, ".pt"
  )

  # Add the (optional) path to the folder to the name
  if (!is.null(folder_to_save_model)) file_names <- file.path(folder_to_save_model, file_names)

  return(file_names)
}

#' Function to create the optimizer used to train `vaeac`
#'
#' @description
#' Only [torch::optim_adam()] is currently supported. But it is easy to add an additional option later.
#'
#' @inheritParams vaeac_train_model
#' @param vaeac_model A `vaeac` model created using [vaeac()].
#' @param optimizer_name String containing the name of the [torch::optimizer()] to use.
#'
#' @return A [torch::optim_adam()] optimizer connected to the parameters of the `vaeac_model`.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_optimizer <- function(vaeac_model, lr, optimizer_name = "adam") {
  if (optimizer_name == "adam") {
    # Create the adam optimizer with defualt parameters except from the provided learning rate
    optimizer <- torch::optim_adam(
      params = vaeac_model$parameters,
      lr = lr,
      betas = c(0.9, 0.999),
      eps = 1e-08,
      weight_decay = 0,
      amsgrad = FALSE
    )
  } else {
    stop("Only the `adam` optimizer has been implemented for the `vaeac` approach.")
  }

  return(optimizer)
}



#' Function that extracts additional objects from the environment to the state list
#'
#' @description
#' The function extract the objects that we are going to save together with the `vaeac` model to make it possible to
#' train the model further and to evaluate it.
#' The environment should be the local environment inside the [shapr::vaeac_train_model_auxiliary()] function.
#'
#' @inheritParams vaeac_get_full_state_list
#'
#' @return List containing the values of `epoch`, `train_vlb`, `val_iwae`, `val_iwae_running`,
#' and the `state_dict()` of the vaeac model and optimizer.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_current_save_state <- function(environment) {
  object_names <- c("epoch", "train_vlb", "val_iwae", "val_iwae_running")
  objects <- lapply(object_names, function(name) environment[[name]])
  names(objects) <- object_names
  objects$model_state_dict <- environment[["vaeac_model"]]$state_dict()
  objects$optimizer_state_dict <- environment[["optimizer"]]$state_dict()
  return(objects)
}

#' Function that extracts the state list objects from the environment
#'
#' #' @description
#' The function extract the objects that we are going to save together with the `vaeac` model to make it possible to
#' train the model further and to evaluate it.
#' The environment should be the local environment inside the [shapr::vaeac_train_model_auxiliary()] function.
#'
#' @param environment The [base::environment()] where the objects are stored.
#'
#' @return List containing the values of `norm_mean`, `norm_std`, `model_description`, `folder_to_save_model`,
#' `n_train`, `n_features`, `one_hot_max_sizes`, `epochs`, `epochs_specified`, `epochs_early_stopping`,
#' `early_stopping_applied`, `running_avg_n_values`, `paired_sampling`, `mask_generator_name`, `masking_ratio`,
#' `mask_gen_coalitions`, `mask_gen_coalitions_prob`, `val_ratio`, `val_iwae_n_samples`,
#' `n_vaeacs_initialize`, `epochs_initiation_phase`, `width`, `depth`, `latent_dim`, `activation_function`,
#' `lr`, `batch_size`, `skip_conn_layer`, `skip_conn_masked_enc_dec`, `batch_normalization`, `cuda`,
#' `train_indices`, `val_indices`, `save_every_nth_epoch`, `sigma_mu`,
#' `sigma_sigma`, `feature_list`, `col_cat_names`, `col_cont_names`, `col_cat`, `col_cont`, `cat_in_dataset`,
#' `map_new_to_original_names`, `map_original_to_new_names`, `log_exp_cont_feat`, `save_data`, `verbose`,
#' `seed`, and `vaeac_save_file_names`.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_full_state_list <- function(environment) {
  object_names <- c(
    "norm_mean", "norm_std", "model_description", "folder_to_save_model", "n_train", "n_features", "one_hot_max_sizes",
    "epochs", "epochs_specified", "epochs_early_stopping", "early_stopping_applied", "running_avg_n_values",
    "paired_sampling", "mask_generator_name", "masking_ratio", "mask_gen_coalitions",
    "mask_gen_coalitions_prob", "val_ratio", "val_iwae_n_samples", "n_vaeacs_initialize",
    "epochs_initiation_phase", "width", "depth", "latent_dim", "activation_function",
    "lr", "batch_size", "skip_conn_layer", "skip_conn_masked_enc_dec", "batch_normalization", "cuda",
    "train_indices", "val_indices", "save_every_nth_epoch", "sigma_mu", "sigma_sigma", "feature_list", "col_cat_names",
    "col_cont_names", "col_cat", "col_cont", "cat_in_dataset", "map_new_to_original_names", "map_original_to_new_names",
    "log_exp_cont_feat", "save_data", "verbose", "seed", "vaeac_save_file_names"
  )
  objects <- lapply(object_names, function(name) environment[[name]])
  names(objects) <- object_names
  objects <- utils::modifyList(objects, environment[["x_train_preprocessed"]], keep.null = TRUE) # Flatten this list
  return(objects)
}



#' Function to extend the explicands and apply all relevant masks/coalitions
#'
#' @inheritParams explain
#' @inheritParams default_doc_explain
#' @param S The `internal$objects$S` matrix containing the possible coalitions.
#'
#' @return The extended version of `x_explain` where the masks from `S` with indices `index_features` have been applied.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_x_explain_extended <- function(x_explain, S, index_features) {
  n_coaltions <- length(index_features) # Get the number of active coalitions
  n_explain <- nrow(x_explain) # Get the number of explicands
  mask <- S[index_features, , drop = FALSE] # Get the masks/coalitions we are to generate MC samples for
  mask[mask == 0] <- NaN # Set zeros to `NaN` to indicate that they are missing and to be imputed by `vaeac`
  x_explain_extended <-
    x_explain[rep(seq_len(nrow(x_explain)), each = n_coaltions), ] # Extend the explicands `n_coalitions` times
  mask_extended <- mask[rep(seq(n_coaltions), times = n_explain), ] # Extend the masks `n_expliand` times
  x_explain_extended[is.na(mask_extended)] <- NaN # Apply the mask. The NaNs are features outside coalition S.
  return(x_explain_extended)
}

#' Extract the Training VLB and Validation IWAE from a list of explanations objects using the vaeac approach
#'
#' @param explanation_list A list of [explain()] objects applied to the same data, model, and
#' `vaeac` must be the used approach. If the entries in the list is named, then the function use
#' these names. Otherwise, it defaults to the approach names (with integer suffix for duplicates)
#' for the explanation objects in `explanation_list`.
#'
#' @return A data.table containing the training VLB, validation IWAE, and running validation IWAE at each epoch for
#' each vaeac model.
#' @export
#' @author Lars Henry Berge Olsen
vaeac_get_evaluation_criteria <- function(explanation_list) {
  # Check if user only provided a single explanation and did not put it in a list
  if ("shapr" %in% class(explanation_list)) explanation_list <- list(explanation_list)

  # Check that all explanation objects use the `vaeac` approach
  explanation_approaches <- sapply(explanation_list, function(explanation) explanation$internal$parameters$approach)
  if (any(explanation_approaches != "vaeac")) {
    stop(sprintf(
      "Explanation object number `%d` in the `explanation_list` does not use the `vaeac` approach.",
      seq_along(explanation_approaches)[explanation_approaches != "vaeac"][1]
    ))
  }

  # Name the elements in the explanation_list if no names have been provided
  if (is.null(names(explanation_list))) explanation_list <- MSEv_name_explanation_list(explanation_list)

  # Extract the evaluation criteria and put them into a data.table
  vaeac_VLB_IWAE_dt <- data.table::rbindlist(
    lapply(explanation_list, function(explanation) {
      data.table::data.table(do.call(cbind, explanation$internal$parameters$vaeac$results))[, Epoch := .I]
    }),
    use.names = TRUE,
    idcol = "Method",
  )
  names(vaeac_VLB_IWAE_dt)[2:4] <- c("VLB", "IWAE", "IWAE_running")
  vaeac_VLB_IWAE_dt$Method <- factor(vaeac_VLB_IWAE_dt$Method, levels = names(explanation_list))
  data.table::setkeyv(vaeac_VLB_IWAE_dt, c("Method", "Epoch"))
  data.table::setcolorder(vaeac_VLB_IWAE_dt, c("Method", "Epoch"))

  return(vaeac_VLB_IWAE_dt)
}

#' Function to set up data loaders and save file names
#'
#' @inheritParams vaeac_train_model
#' @param train_indices Numeric array (optional) containing the indices of the training observations.
#' There are conducted no checks to validdate the indices.
#' @param val_indices Numeric array (optional) containing the indices of the validation observations.
#' #' There are conducted no checks to validdate the indices.
#'
#' @return List of objects needed to train the `vaeac` model
vaeac_get_data_objects <- function(x_train,
                                   log_exp_cont_feat,
                                   val_ratio,
                                   batch_size,
                                   paired_sampling,
                                   model_description,
                                   depth,
                                   width,
                                   latent_dim,
                                   lr,
                                   epochs,
                                   save_every_nth_epoch,
                                   folder_to_save_model,
                                   train_indices = NULL,
                                   val_indices = NULL) {
  if (xor(is.null(train_indices), is.null(val_indices))) {
    stop("Either none or both of `train_indices` and `val_indices` must be given.")
  }

  # Get the dimensions of the x_train
  n_train <- nrow(x_train)
  n_features <- ncol(x_train)

  # Preprocess x_train. Turn factor names into numerics 1,2,...,K, (vaeac only accepts numerics) and keep track
  # of the maping of names. Optionally log-transform the continuous features. Then, finally, normalize the data.
  x_train_preprocessed <- vaeac_preprocess_data(data = x_train, log_exp_cont_feat = log_exp_cont_feat, normalize = TRUE)

  # Extract the preprocessed and normalized x_train as a torch tensor and the one-hot feature sizes (cont have size 1)
  x_train_torch <- x_train_preprocessed$data_normalized_torch
  one_hot_max_sizes <- x_train_preprocessed$one_hot_max_sizes

  # Splitting the input into a training and validation data sets
  if (is.null(train_indices)) { # The val_indices will also be NULL due to the xor check above
    val_size <- ceiling(n_train * val_ratio) # Number of observations in the validation set
    val_indices <- sample(n_train, val_size, replace = FALSE) # Sample indices for the validation set
    train_indices <- seq(n_train)[-val_indices] # The remaining indices constitutes the training set
  } else {
    val_size <- NULL
  }
  val_dataset <- vaeac_dataset(x_train_torch[val_indices], one_hot_max_sizes) # Create a torch::dataset() for vaeac
  train_dataset <- vaeac_dataset(x_train_torch[train_indices], one_hot_max_sizes) # Create a torch::dataset() for vaeac

  # Ensure a valid batch size
  if (batch_size > length(train_indices)) {
    message(paste0(
      "Decrease `batch_size` (", batch_size, ") to largest allowed value (", length(train_indices), "), ",
      "i.e., the number of training observations."
    ))
    batch_size <- length(train_indices)
  }

  # Create the Data Loader objects which iterate over the data in the Data Set objects
  train_dataloader <- torch::dataloader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = if (paired_sampling) FALSE else TRUE, # Must be `FALSE` when `sampler` is specified
    sampler = if (paired_sampling) paired_sampler(train_dataset, shuffle = TRUE) else NULL
  )

  val_dataloader <- torch::dataloader(
    dataset = val_dataset,
    batch_size = batch_size,
    shuffle = FALSE,
    sampler = if (paired_sampling) paired_sampler(val_dataset, shuffle = FALSE) else NULL
  )

  # Get all the file names for the vaeac objects we are going to save
  vaeac_save_file_names <- vaeac_get_save_file_names(
    model_description = model_description,
    n_features = n_features,
    n_train = n_train,
    depth = depth,
    width = width,
    latent_dim = latent_dim,
    lr = lr,
    epochs = epochs,
    save_every_nth_epoch = save_every_nth_epoch,
    folder_to_save_model = folder_to_save_model
  )

  return(list(
    n_train = n_train,
    n_features = n_features,
    x_train_preprocessed = x_train_preprocessed,
    x_train_torch = x_train_torch,
    one_hot_max_sizes = one_hot_max_sizes,
    val_size = val_size,
    val_indices = val_indices,
    train_indices = train_indices,
    batch_size = batch_size,
    train_dataloader = train_dataloader,
    val_dataloader = val_dataloader,
    vaeac_save_file_names = vaeac_save_file_names
  ))
}

#' Function to get string of values with specific number of decimals
#'
#' @param value The number to get `n_decimals` for.
#' @param n_decimals Positive integer. The number of decimals. Default is three.
#'
#' @return String of `value` with `n_decimals` decimals.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_get_n_decimals <- function(value, n_decimals = 3) trimws(format(round(value, n_decimals), nsmall = n_decimals))

# Update functions =====================================================================================================
#' Move `vaeac` parameters to correct location
#'
#' @description
#' This function ensures that the main and extra parameters for the `vaeac`
#' approach is located at their right locations.
#'
#' @param parameters List. The `internal$parameters` list created inside the [shapr::explain()] function.
#'
#' @return Updated version of `parameters` where all `vaeac` parameters are located at the correct location.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_update_para_locations <- function(parameters) {
  # Get the name of the main parameters for the `vaeac` approach
  vaeac.main_para_default_names <- methods::formalArgs(setup_approach.vaeac)
  vaeac.main_para_default_names <-
    vaeac.main_para_default_names[!vaeac.main_para_default_names %in% c("internal", "vaeac.extra_parameters", "...")]

  # Get the default values for vaeac's main parameters defined above into a named list
  vaeac.main_para_default <- as.list(formals(sys.function(sys.parent())))
  vaeac.main_para_default <- vaeac.main_para_default[vaeac.main_para_default %in% vaeac.main_para_default_names]

  # Get the names of the vaeac's main parameters provided by the user
  vaeac.main_para_user_names <- names(parameters)
  vaeac.main_para_user_names <- vaeac.main_para_user_names[grepl("vaeac.", vaeac.main_para_user_names)]
  vaeac.main_para_user_names <- vaeac.main_para_user_names[!vaeac.main_para_user_names %in% "vaeac.extra_parameters"]

  # Get the default values for vaeac's extra parameters into a named list
  vaeac.extra_para_default <- vaeac_get_extra_para_default()
  vaeac.extra_para_default_names <- names(vaeac.extra_para_default)

  # Get the names of the extra parameters provided by the user
  vaeac.extra_para_user_names <- names(parameters$vaeac.extra_parameters)

  # Get the names of all parameters and the user specified parameters
  vaeav.all_para_default_names <- c(vaeac.main_para_default_names, vaeac.extra_para_default_names)

  # Check if any of the main parameters with the "vaeac." prefix is unknown (i.e., not main or extra parameter)
  not_extra_para_in_main_para <-
    vaeac.main_para_user_names[!vaeac.main_para_user_names %in% vaeav.all_para_default_names]
  if (length(not_extra_para_in_main_para) > 0) {
    # Give a message to the user about the unknown extra parameters
    warning(paste0(
      "The following vaeac main parameters are not recognized (`shapr` removes them): ",
      paste(strsplit(paste(paste0("`", not_extra_para_in_main_para, "`"), collapse = ", "),
        ",(?=[^,]+$)",
        perl = TRUE
      )[[1]], collapse = " and"), ".\n"
    ))

    # Delete the unknown extra parameters
    parameters[not_extra_para_in_main_para] <- NULL
  }

  # Check if any of the extra parameters with the "vaeac." prefix is unknown (i.e., not main or extra parameter)
  not_main_para_in_extra_para <-
    vaeac.extra_para_user_names[!vaeac.extra_para_user_names %in% vaeav.all_para_default_names]
  if (length(not_main_para_in_extra_para) > 0) {
    # Give a message to the user about the unknown extra parameters
    warning(paste0(
      "The following vaeac extra parameters are not recognized (`shapr` removes them): ",
      paste(strsplit(paste(paste0("`", not_main_para_in_extra_para, "`"), collapse = ", "),
        ",(?=[^,]+$)",
        perl = TRUE
      )[[1]], collapse = " and"), ".\n"
    ))

    # Delete the unknown extra parameters
    parameters$vaeac.extra_parameters[not_main_para_in_extra_para] <- NULL
  }

  # Check for parameters that have been provided as both main and extra parameter
  both_main_and_extra_para <- vaeac.extra_para_user_names[vaeac.extra_para_user_names %in% vaeac.main_para_user_names]
  if (length(both_main_and_extra_para > 0)) {
    # Print a message to the user and tell them that we use those in `vaeac.extra_parameters`.
    warning(paste0(
      "The following vaeac parameters were given as both main and extra parameters (`shapr` uses the ",
      "values at the correct location ): ",
      paste(strsplit(paste(paste0("`", both_main_and_extra_para, "`"), collapse = ", "),
        ",(?=[^,]+$)",
        perl = TRUE
      )[[1]], collapse = " and"), ".\n"
    ))
    # Note that we do not move it here as the moving will be fixed in the next two if-clauses
  }

  # Check if any any extra parameters have been given as main parameters
  extra_para_in_main_para <- vaeac.main_para_user_names[vaeac.main_para_user_names %in% vaeac.extra_para_default_names]
  if (length(extra_para_in_main_para) > 0) {
    warning(paste0(
      "The following vaeac parameters were given as main parameters but should have been extra ",
      "parameters (`shapr` fixes this): ",
      paste(strsplit(paste(paste0("`", extra_para_in_main_para, "`"), collapse = ", "),
        ",(?=[^,]+$)",
        perl = TRUE
      )[[1]], collapse = " and"), ".\n"
    ))

    # Move extra parameter from the main parameters to extra_parameters list if they have NOT been specified already
    parameters$vaeac.extra_parameters[extra_para_in_main_para[!extra_para_in_main_para %in%
      vaeac.extra_para_user_names]] <-
      parameters[extra_para_in_main_para[!extra_para_in_main_para %in% vaeac.extra_para_user_names]]

    # Remove the extra parameter from the main parameters
    parameters[extra_para_in_main_para] <- NULL
  }

  # Check if any any main parameters have been given as extra parameters
  main_para_in_extra_para <- vaeac.extra_para_user_names[vaeac.extra_para_user_names %in% vaeac.main_para_default_names]
  if (length(main_para_in_extra_para) > 0) {
    # Give a message to the user about the misplaced main parameters in the extra list
    warning(paste0(
      "The following vaeac parameters were given as extra parameters but should have been main ",
      "parameters (`shapr` fixes this): ",
      paste(strsplit(paste(paste0("`", main_para_in_extra_para, "`"), collapse = ", "),
        ",(?=[^,]+$)",
        perl = TRUE
      )[[1]], collapse = " and"), ".\n"
    ))

    # Move main parameters from the extra_parameters list to main parameters if they have NOT been specified already
    parameters[main_para_in_extra_para[!main_para_in_extra_para %in% vaeac.main_para_user_names]] <-
      parameters$vaeac.extra_parameters[main_para_in_extra_para[!main_para_in_extra_para
      %in% vaeac.main_para_user_names]]

    # Remove the main parameter from the extra list
    parameters$vaeac.extra_parameters[main_para_in_extra_para] <- NULL
  }

  # Return the fixed parameters list
  return(parameters)
}

#' Function that checks and adds a pre-trained `vaeac` model
#'
#' @param parameters List containing the parameters used within [shapr::explain()].
#'
#' @return This function adds a valid pre-trained vaeac model to the `parameter`.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_update_pretrained_model <- function(parameters) {
  # Extract the provided pre-trained vaeac model
  vaeac_object <- parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model

  # Check that it is either a list or string
  if (!(is.list(vaeac_object) || is.character(vaeac_object))) {
    stop("The `vaeac.pretrained_vaeac_model` parameter must be either a list or a string. Read the documentation.")
  }

  # Check if we are given a list
  if (is.list(vaeac_object)) {
    # Check for list of type vaeac
    if (!("vaeac" %in% class(vaeac_object))) stop("The `vaeac.pretrained_vaeac_model` list is not of type `vaeac`.")
    vaeac_check_x_colnames(
      feature_names_vaeac = vaeac_object$parameters$feature_list$labels,
      feature_names_new = parameters$feature_names
    )

    # Add the pre-trained valid vaeac model to the parameters list
    parameters$vaeac <- parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model

    # Remove the pre-trained vaeac model as it has been approved as a vaeac model
    parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model <- NULL
  }

  # Check if we are given a string
  if (is.character(vaeac_object)) {
    # Check that the file exists
    if (!file.exists(vaeac_object)) {
      stop(paste0("The `vaeac.pretrained_vaeac_model` file ('", vaeac_object, "') does not exist."))
    }

    # Read in the vaeac model from the disk
    vaeac_model <- torch::torch_load(vaeac_object)

    # Some very small check that we have read in a vaeac model
    if (is.null(vaeac_model$model_state_dict)) {
      stop("The provided file is not a vaeac model as it is missing, e.g., the `model_state_dict` entry.")
    }
    if (is.null(vaeac_model$optimizer_state_dict)) {
      stop("The provided file is not a vaeac model as it is missing, e.g., the `optimizer_state_dict` entry.")
    }

    # Check that the provided vaeac model is trained on a dataset with the same feature names
    vaeac_check_x_colnames(
      feature_names_vaeac = vaeac_model$feature_list$labels,
      feature_names_new = parameters$feature_names
    )

    # Extract the training/validation results
    evaluation_criterions <- c("train_vlb", "val_iwae", "val_iwae_running")
    vaeac_model_results <- lapply(vaeac_model[evaluation_criterions], as.array)

    # Save path to the vaeac approach to use to generate the MC samples.
    parameters$vaeac <- list(
      models = list(best = vaeac_object),
      results = vaeac_model_results,
      parameters = vaeac_model[!names(vaeac_model) %in% evaluation_criterions]
    )

    # Add `vaeac` as a class to the object. We use this to validate the input when
    # `vaeac.pretrained_vaeac_model` is given to the `shapr::explain()` function.
    class(parameters$vaeac) <- c(class(parameters$vaeac), "vaeac")
  }

  # Return the updated parameters list
  return(parameters)
}

# Save functions =======================================================================================================
#' Function that saves the state list and the current save state of the `vaeac` model
#'
#' @param state_list List containing all the parameters in the state.
#' @param file_name String containing the file path.
#' @param return_state Logical if we are to return the state list or not.
#'
#' @return This function does not return anything
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_save_state <- function(state_list, file_name, return_state = FALSE) {
  state <- modifyList(state_list, vaeac_get_current_save_state(parent.frame()), keep.null = TRUE)
  class(state) <- c(class(state), "vaeac")
  torch::torch_save(state, file_name)
  if (return_state) {
    return(state)
  }
}

# Print functions ======================================================================================================
#' Function to printout a training summary for the `vaeac` model
#'
#' @param best_epoch Positive integer. The epoch with the lowest validation error.
#' @param best_epoch_running Positive integer. The epoch with the lowest running validation error.
#' @param last_state The state list (i.e., the saved `vaeac` object)
#' of `vaeac` model at the epoch with the lowest IWAE.
#'
#' @return This function only prints out a message.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
vaeac_print_train_summary <- function(best_epoch, best_epoch_running, last_state) {
  message(sprintf(
    "\nResults of the `vaeac` training process:
Best epoch:             %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f
Best running avg epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f
Last epoch:             %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f\n",
    best_epoch,
    last_state$train_vlb[best_epoch]$cpu(),
    last_state$val_iwae[best_epoch]$cpu(),
    last_state$val_iwae_running[best_epoch]$cpu(),
    best_epoch_running,
    last_state$train_vlb[best_epoch_running]$cpu(),
    last_state$val_iwae[best_epoch_running]$cpu(),
    last_state$val_iwae_running[best_epoch_running]$cpu(),
    last_state$epoch,
    last_state$train_vlb[-1]$cpu(),
    last_state$val_iwae[-1]$cpu(),
    last_state$val_iwae_running[-1]$cpu()
  ))
}

#' Produce message about which batch prepare_data is working on
#' @inheritParams default_doc
#' @inheritParams default_doc_explain
#' @author Lars Henry Berge Olsen
#' @keywords internal
vaeac_prep_message_batch <- function(internal, index_features) {
  id_batch <- internal$objects$X[id_combination == index_features[1]]$batch
  n_batches <- internal$parameters$n_batches
  message(paste0("Generating Monte Carlo samples using `vaeac` for batch ", id_batch, " of ", n_batches, "."))
}

# Plot functions =======================================================================================================
#' Plot the training VLB and validation IWAE for `vaeac` models
#'
#' @description This function makes ([ggplot2::ggplot()]) figures of the training VLB and the validation IWAE for a list
#' of [shapr::explain()] objects with `approach = "vaeac"`. See [setup_approach()] for more information about the
#' `vaeac` approach. Two figures are returned by the function. In the figure, each object in `explanation_list` gets
#' its own facet, while in the second figure, we plot the criteria in each facet for all objects.
#'
#' @details See \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)} or the
#' \href{https://borea17.github.io/paper_summaries/iwae/}{blog post} for a summary of the VLB and IWAE.
#'
#' @param explanation_list A list of [explain()] objects applied to the same data, model, and
#' `vaeac` must be the used approach. If the entries in the list is named, then the function use
#' these names. Otherwise, it defaults to the approach names (with integer suffix for duplicates)
#' for the explanation objects in `explanation_list`.
#' @param plot_from_nth_epoch Integer. If we are only plot the results form the nth epoch and so forth.
#' The first epochs can be large in absolute value and make the rest of the plot difficult to interpret.
#' @param plot_every_nth_epoch Integer. If we are only to plot every nth epoch. Usefully to illustrate
#' the overall trend, as there can be a lot of fluctuation and oscillation in the values between each epoch.
#' @param facet_wrap_scales String. Should the scales be fixed ("`fixed`", the default),
#' free ("`free`"), or free in one dimension ("`free_x`", "`free_y`").
#' @param facet_wrap_ncol Integer. Number of columns in the facet wrap.
#' @param criteria Character vector. The possible options are "VLB", "IWAE", "IWAE_running". Default is the first two.
#' @param plot_type Character vector. The possible options are "method" and "criterion". Default is to plot both.
#'
#' @return Either a single [ggplot2::ggplot()] object or a list of [ggplot2::ggplot()] objects based on the
#' `plot_type` parameter.
#'
#' @examples
#' \dontrun{
#' library(xgboost)
#' library(data.table)
#' library(shapr)
#'
#' data("airquality")
#' data <- data.table::as.data.table(airquality)
#' data <- data[complete.cases(data), ]
#'
#' x_var <- c("Solar.R", "Wind", "Temp", "Month")
#' y_var <- "Ozone"
#'
#' ind_x_explain <- 1:6
#' x_train <- data[-ind_x_explain, ..x_var]
#' y_train <- data[-ind_x_explain, get(y_var)]
#' x_explain <- data[ind_x_explain, ..x_var]
#'
#' # Fitting a basic xgboost model to the training data
#' model <- xgboost(data = as.matrix(x_train), label = y_train, nround = 100, verbose = FALSE)
#'
#' # Specifying the phi_0, i.e. the expected prediction without any features
#' p0 <- mean(y_train)
#'
#' # Train vaeac with and without paired sampling
#' explanation_paired <- explain(
#'   model = model,
#'   x_explain = x_explain,
#'   x_train = x_train,
#'   approach = approach,
#'   prediction_zero = p0,
#'   n_samples = 1, # As we are only interested in the training of the vaeac
#'   vaeac.epochs = 10, # Should be higher in applications.
#'   vaeac.n_vaeacs_initialize = 1,
#'   vaeac.width = 16,
#'   vaeac.depth = 2,
#'   vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE)
#' )
#'
#' explanation_regular <- explain(
#'   model = model,
#'   x_explain = x_explain,
#'   x_train = x_train,
#'   approach = approach,
#'   prediction_zero = p0,
#'   n_samples = 1, # As we are only interested in the training of the vaeac
#'   vaeac.epochs = 10, # Should be higher in applications.
#'   vaeac.width = 16,
#'   vaeac.depth = 2,
#'   vaeac.n_vaeacs_initialize = 1,
#'   vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE)
#' )
#'
#' # Collect the explanation objects in an named list
#' explanation_list <- list(
#'   "Regular sampling" = explanation_regular,
#'   "Paired sampling" = explanation_paired
#' )
#'
#' # Call the function with the named list, will use the provided names
#' vaeac_plot_eval_crit(explanation_list = explanation_list)
#'
#' # The function also works if we have only one method,
#' # but then one should only look at the method plot.
#' vaeac_plot_eval_crit(
#'   explanation_list = explanation_list[2],
#'   plot_type = "method"
#' )
#'
#' # Can alter the plot
#' vaeac_plot_eval_crit(
#'   explanation_list = explanation_list,
#'   plot_from_nth_epoch = 2,
#'   plot_every_nth_epoch = 2,
#'   facet_wrap_scales = "free"
#' )
#'
#' # If we only want the VLB
#' vaeac_plot_eval_crit(
#'   explanation_list = explanation_list,
#'   criteria = "VLB",
#'   plot_type = "criterion"
#' )
#'
#' # If we want only want the criterion version
#' tmp_fig_criterion <-
#'   vaeac_plot_eval_crit(explanation_list = explanation_list, plot_type = "criterion")
#'
#' # Since tmp_fig_criterion is a ggplot2 object, we can alter it
#' # by, e.g,. adding points or smooths with se bands
#' tmp_fig_criterion + ggplot2::geom_point(shape = "circle", size = 1, ggplot2::aes(col = Method))
#' tmp_fig_criterion$layers[[1]] <- NULL
#' tmp_fig_criterion + ggplot2::geom_smooth(method = "loess", formula = y ~ x, se = TRUE) +
#'   ggplot2::scale_color_brewer(palette = "Set1") +
#'   ggplot2::theme_minimal()
#' }
#'
#' @author Lars Henry Berge Olsen
#' @export
vaeac_plot_eval_crit <- function(explanation_list,
                                 plot_from_nth_epoch = 1,
                                 plot_every_nth_epoch = 1,
                                 criteria = c("VLB", "IWAE"),
                                 plot_type = c("method", "criterion"),
                                 facet_wrap_scales = "fixed",
                                 facet_wrap_ncol = NULL) {
  ## Checks
  # Check that ggplot2 is installed
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("ggplot2 is not installed. Please run install.packages('ggplot2')")
  }

  # Check for valid criteria argument
  unknown_criteria <- criteria[!(criteria %in% c("VLB", "IWAE", "IWAE_running"))]
  if (length(unknown_criteria) > 0) {
    stop(paste0(
      "The `criteria` must be one (or several) of 'VLB', 'IWAE', and 'IWAE_running'. ",
      "Do not recognise: '", paste(unknown_plot_type, collapse = "', '"), "'."
    ))
  }

  # Check for valid plot type argument
  unknown_plot_type <- plot_type[!(plot_type %in% c("method", "criterion"))]
  if (length(unknown_plot_type) > 0) {
    stop(paste0(
      "The `plot_type` must be one (or several) of 'method' and 'criterion'. ",
      "Do not recognise: '", paste(unknown_plot_type, collapse = "', '"), "'."
    ))
  }

  # Ensure that even a single explanation object is in a list
  if ("shapr" %in% class(explanation_list)) explanation_list <- list(explanation_list)

  ## Create data.tables
  # Extract the VLB and IWAE
  vaeac_VLB_IWAE_dt <- vaeac_get_evaluation_criteria(explanation_list)

  # Get the relevant criteria
  keep_these_columns <- c("Method", "Epoch", criteria)
  vaeac_VLB_IWAE_dt <- vaeac_VLB_IWAE_dt[, keep_these_columns, with = FALSE]

  # Check for valid `plot_from_nth_epoch`
  max_epoch <- max(vaeac_VLB_IWAE_dt$Epoch)
  if (plot_from_nth_epoch > max_epoch) {
    stop(sprintf(
      "`plot_from_nth_epoch` (%d) is larger than the number of epochs (%d)",
      plot_from_nth_epoch, max_epoch
    ))
  }

  # Remove entries with too low epoch
  vaeac_VLB_IWAE_dt <- vaeac_VLB_IWAE_dt[Epoch >= plot_from_nth_epoch, ]

  # If we are only to plot every nth epoch
  vaeac_VLB_IWAE_dt <- vaeac_VLB_IWAE_dt[Epoch %% plot_every_nth_epoch == 0]

  # Convert it from wide to long
  vaeac_VLB_IWAE_dt_long <- data.table::melt(
    data = vaeac_VLB_IWAE_dt,
    id.vars = c("Method", "Epoch"),
    variable.name = "Criterion",
    variable.factor = TRUE,
    value.name = "Value"
  )

  ## Plot
  return_object <- list()

  # Make the figure where each explanation object has its own facet
  if ("method" %in% plot_type) {
    return_object$figure_each_method <-
      ggplot2::ggplot(vaeac_VLB_IWAE_dt_long, ggplot2::aes(x = Epoch, y = Value, col = Criterion)) +
      ggplot2::labs(title = "The evaluation criterions for different vaeac models") +
      ggplot2::geom_line(ggplot2::aes(group = Criterion, col = Criterion)) +
      ggplot2::facet_wrap(ggplot2::vars(Method), ncol = facet_wrap_ncol, scales = facet_wrap_scales)
  }

  # Make the figure where each criterion has its own facet
  if ("criterion" %in% plot_type) {
    return_object$figure_each_criterion <-
      ggplot2::ggplot(vaeac_VLB_IWAE_dt_long, ggplot2::aes(x = Epoch, y = Value, col = Method)) +
      ggplot2::labs(title = "The evaluation criterions for different vaeac models") +
      ggplot2::geom_line(ggplot2::aes(group = Method, col = Method)) +
      ggplot2::facet_wrap(ggplot2::vars(Criterion), ncol = facet_wrap_ncol, scales = facet_wrap_scales)
  }

  # If only made one figure, then we directly return that object and not a list
  if (length(return_object) == 1) return_object <- return_object[[1]]

  return(return_object)
}

#' Plot Pairwise Plots for Imputed and True Data
#'
#' @description A function that creates a matrix of plots ([GGally::ggpairs()]) from
#' generated imputations from the unconditioned distribution \eqn{p(\boldsymbol{x})} estimated by
#' a `vaeac` model, and then compares the imputed values with data from the true distribution (if provided).
#' See \href{https://www.blopig.com/blog/2019/06/a-brief-introduction-to-ggpairs/}{ggpairs} for an
#' introduction to [GGally::ggpairs()], and the corresponding
#' \href{https://ggobi.github.io/ggally/articles/ggally_plots.html}{vignette}.
#'
#' @param explanation Shapr list. The output list from the [shapr::explain()] function.
#' @param which_vaeac_model String. Indicating which `vaeac` model to use when generating the samples.
#' Possible options are always `'best'`, `'best_running'`, and `'last'`. All possible options can be obtained
#' by calling `names(explanation$internal$parameters$vaeac$models)`.
#' @param x_true Data.table containing the data from the distribution that the `vaeac` model is fitted to.
#' @param upper_cont String. Type of plot to use in upper triangle for continuous features, see [GGally::ggpairs()].
#' Possible options are: `'cor'` (default), `'points'`, `'smooth'`, `'smooth_loess'`, `'density'`, and `'blank'`.
#' @param upper_cat String. Type of plot to use in upper triangle for categorical features, see [GGally::ggpairs()].
#' Possible options are: `'count'` (default), `'cross'`, `'ratio'`, `'facetbar'`, and `'blank'`.
#' @param upper_mix String. Type of plot to use in upper triangle for mixed features, see [GGally::ggpairs()].
#' Possible options are: `'box'` (default), `'box_no_facet'`, `'dot'`, `'dot_no_facet'`, `'facethist'`,
#'  `'facetdensity'`, `'denstrip'`, and `'blank'`
#' @param lower_cont String. Type of plot to use in lower triangle for continuous features, see [GGally::ggpairs()].
#' Possible options are: `'points'` (default), `'smooth'`, `'smooth_loess'`, `'density'`, `'cor'`, and `'blank'`.
#' @param lower_cat String. Type of plot to use in lower triangle for categorical features, see [GGally::ggpairs()].
#' Possible options are: `'facetbar'` (default), `'ratio'`, `'count'`, `'cross'`, and `'blank'`.
#' @param lower_mix String. Type of plot to use in lower triangle for mixed features, see [GGally::ggpairs()].
#' Possible options are: `'facetdensity'` (default), `'box'`, `'box_no_facet'`, `'dot'`, `'dot_no_facet'`,
#'  `'facethist'`, `'denstrip'`, and `'blank'`.
#' @param diag_cont String. Type of plot to use on the diagonal for continuous features, see [GGally::ggpairs()].
#' Possible options are: `'densityDiag'` (default), `'barDiag'`, and `'blankDiag'`.
#' @param diag_cat String. Type of plot to use on the diagonal for categorical features, see [GGally::ggpairs()].
#' Possible options are: `'barDiag'` (default) and `'blankDiag'`.
#' @param cor_method String. Type of correlation measure, see [GGally::ggpairs()].
#' Possible options are: `'pearson'` (default), `'kendall'`, and `'spearman'`.
#' @param add_title Logical. If `TRUE`, then a title is added to the plot based on the internal description
#' of the `vaeac` model specified in `which_vaeac_model`.
#' @param alpha Numeric between `0` and `1` (default is `0.5`). The degree of color transparency.
#'
#' @return A [GGally::ggpairs()] figure.
#' @export
#' @author Lars Henry Berge Olsen
#'
#' @examples
#' \dontrun{
#' library(xgboost)
#' library(data.table)
#' library(shapr)
#'
#' data("airquality")
#' data <- data.table::as.data.table(airquality)
#' data <- data[complete.cases(data), ]
#'
#' x_var <- c("Solar.R", "Wind", "Temp", "Month")
#' y_var <- "Ozone"
#'
#' ind_x_explain <- 1:6
#' x_train <- data[-ind_x_explain, ..x_var]
#' y_train <- data[-ind_x_explain, get(y_var)]
#' x_explain <- data[ind_x_explain, ..x_var]
#'
#' # Fitting a basic xgboost model to the training data
#' model <- xgboost(
#'   data = as.matrix(x_train),
#'   label = y_train,
#'   nround = 100,
#'   verbose = FALSE
#' )
#'
#' explanation <- explain(
#'   model = model,
#'   x_explain = x_explain,
#'   x_train = x_train,
#'   approach = "vaeac",
#'   prediction_zero = mean(y_train),
#'   n_samples = 1,
#'   vaeac.epochs = 10,
#'   vaeac.n_vaeacs_initialize = 1
#' )
#'
#' # Plot the results
#' figure <- vaeac_plot_imputed_ggpairs(
#'   explanation = explanation,
#'   which_vaeac_model = "best",
#'   x_true = x_train,
#'   add_title = TRUE
#' )
#' figure
#'
#' # Note that this is an ggplot2 object which we can alter, e.g., we can change the colors.
#' figure +
#'   ggplot2::scale_color_manual(values = c("#E69F00", "#999999")) +
#'   ggplot2::scale_fill_manual(values = c("#E69F00", "#999999"))
#' }
vaeac_plot_imputed_ggpairs <- function(
    explanation,
    which_vaeac_model = "best",
    x_true = NULL,
    add_title = TRUE,
    alpha = 0.5,
    upper_cont = c("cor", "points", "smooth", "smooth_loess", "density", "blank"),
    upper_cat = c("count", "cross", "ratio", "facetbar", "blank"),
    upper_mix = c("box", "box_no_facet", "dot", "dot_no_facet", "facethist", "facetdensity", "denstrip", "blank"),
    lower_cont = c("points", "smooth", "smooth_loess", "density", "cor", "blank"),
    lower_cat = c("facetbar", "ratio", "count", "cross", "blank"),
    lower_mix = c("facetdensity", "box", "box_no_facet", "dot", "dot_no_facet", "facethist", "denstrip", "blank"),
    diag_cont = c("densityDiag", "barDiag", "blankDiag"),
    diag_cat = c("barDiag", "blankDiag"),
    cor_method = c("pearson", "kendall", "spearman")) {
  # Check that ggplot2 and GGally are installed
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("ggplot2 is not installed. Please run install.packages('ggplot2')")
  }
  if (!requireNamespace("GGally", quietly = TRUE)) {
    stop("GGally is not installed. Please run install.packages('GGally')")
  }

  # Check all input parameters except `which_vaeac_model`
  if (!"shapr" %in% class(explanation)) stop("`explanation` must be an object of type `shapr`.")
  if (!is.null(x_true) && !is.data.table(x_true)) stop("`x_true` must be an object of type `data.table`.")
  vaeac_check_logicals(list(add_title = add_title))
  vaeac_check_probabilities(list(alpha = alpha))
  upper_cont <- match.arg(upper_cont)
  upper_cat <- match.arg(upper_cat)
  upper_mix <- match.arg(upper_mix)
  lower_cont <- match.arg(lower_cont)
  lower_cat <- match.arg(lower_cat)
  lower_mix <- match.arg(lower_mix)
  diag_cont <- match.arg(diag_cont)
  diag_cat <- match.arg(diag_cat)
  cor_method <- match.arg(cor_method)

  # Check if the vaeac model is expected to give a reasonable figure.
  if (!explanation$internal$parameters$exact || explanation$internal$parameters$is_groupwise) {
    message(
      "The vaeac model has not been trained on the empty colition, hence, the figure can be missleading. ",
      "The figure is only reasonable if 'n_combintations = NULL' and 'group = NULL' in the explanation call."
    )
  }

  # Extract the vaeac list from the explanation list
  vaeac_list <- explanation$internal$parameters$vaeac

  # Check that `which_vaeac_model` is a valid vaeac model name and then load the vaeac checkpoint
  if (!is.character(which_vaeac_model) || !which_vaeac_model %in% names(vaeac_list$models)) {
    stop(paste0(
      "The parameter `which_vaeac_model` ('", which_vaeac_model, "') must be one of the following: '",
      paste(names(vaeac_list$models), collapse = "', '"), "'."
    ))
  }
  vaeac_model_path <- vaeac_list$models[[which_vaeac_model]]
  checkpoint <- torch::torch_load(vaeac_model_path)

  # Get the number of observations in the x_true and features
  n_samples <- if (is.null(x_true)) 500 else nrow(x_true)
  n_features <- checkpoint$n_features

  # Checking for valid dimension
  if (!is.null(x_true) && ncol(x_true) != n_features) {
    stop(paste0(
      "Different number of columns in the vaeac model (", n_features, ") and `x_true` (", ncol(x_true), ")."
    ))
  }

  # Set up the vaeac model
  vaeac_model <- vaeac_get_model_from_checkp(checkpoint = checkpoint, cuda = checkpoint$cuda, mode_train = FALSE)

  # Impute the missing entries using the vaeac approach. Here we generate x from p(x), so no conditioning.
  imputed_values <- vaeac_impute_missing_entries(
    x_explain_with_NaNs = matrix(NaN, n_samples, checkpoint$n_features),
    n_samples = 1,
    vaeac_model = vaeac_model,
    checkpoint = checkpoint,
    sampler = explanation$internal$parameters$vaeac.sampler,
    batch_size = n_samples,
    verbose = explanation$internal$parameters$verbose,
    seed = explanation$internal$parameters$seed
  )

  # Combine the true (if there are any) and imputed data and ensure that the categorical features are marked as factors.
  combined_data <- data.table(rbind(x_true, imputed_values))
  col_cat_names <- checkpoint$col_cat_names
  if (length(col_cat_names) > 0) combined_data[, (col_cat_names) := lapply(.SD, as.factor), .SDcols = col_cat_names]

  # Add type variable representing if they are imputed samples or from `x_true`
  combined_data$type <-
    factor(rep(c("True", "Imputed"), times = c(ifelse(is.null(nrow(x_true)), 0, nrow(x_true)), n_samples)))

  # Create the ggpairs figure and potentially add title based on the description of the used vaeac model
  figure <- GGally::ggpairs(
    combined_data,
    columns = seq(n_features),
    mapping = ggplot2::aes(color = type),
    diag = list(continuous = GGally::wrap(diag_cont, alpha = alpha), discrete = diag_cat),
    upper = list(combo = upper_mix, discrete = upper_cat, continuous = GGally::wrap(upper_cont, method = cor_method)),
    lower = list(combo = lower_mix, discrete = lower_cat, continuous = GGally::wrap(lower_cont, alpha = alpha))
  )
  if (add_title) figure <- figure + ggplot2::ggtitle(tools::file_path_sans_ext(basename(vaeac_model_path)))

  return(figure)
}
NorskRegnesentral/shapr documentation built on April 19, 2024, 1:19 p.m.