R/internal.R

Defines functions rnd get_mc_dropout_cat_counts get_confusion_matrix get_cat_count get_weighted_errors evaluate_iucnn rank_models process_iucnn_input log_results subsample_n_per_class bnn_predict calculate_accuracy run_MCMC MCMC_setup create_BNN_model bnn_load_data get_footp impute_missing_values cat_bool

#' @importFrom reticulate import
#' @importFrom utils read.table
#' @importFrom magrittr %>%
#' @importFrom dplyr select left_join mutate
#' @importFrom utils write.table
#' @importFrom stats na.omit
#' @importFrom missForest missForest


cat_bool <- function(x) {
  # function that decides if data is categorical
  # 0: continuous 1: factor
  return(prod(round(x) == x))
}

impute_missing_values <- function(df) {
  features <- df

  # apply function to determine whether features are categorical or numeric
  categorical_boolean <-  apply(na.omit(features[, 2:dim(features)[2]]),
                              FUN = cat_bool,
                              2)
  colnames <-  names(categorical_boolean)

  # change the type of each column to either categorical or numeric
  for (i in seq_along(categorical_boolean)) {
    name <-  colnames[i]
    bool_value <-  categorical_boolean[i]
    if (bool_value == 1) {
      features[,name] <-  as.factor(unlist(features[, name]))
    }
  }

  # impute missing values with missForest
  xmis_ob <- as.data.frame(features[, 2:dim(features)[2]])
  missForest_imputation <- missForest(xmis = xmis_ob,
                                      maxiter = 10,
                                      ntree = 100)

  # update the feature df with imputed values
  new_features_df <- features
  new_features_df[, 2:ncol(new_features_df)] <-  missForest_imputation$ximp
  mod_obj <- new_features_df[, 2:ncol(new_features_df)]
  new_features_df[, 2:ncol(new_features_df)] <-  sapply(mod_obj, as.numeric)
  return(new_features_df)
}


# Download Human footprint data
get_footp <- function(x, file_path) {
  test <- file.exists(file.path(file_path,
                                paste("HFP", x, ".tif", sep = "")))
  if (!test) {
    path1 <- paste("https://wcshumanfootprint.org/data/HFP",
                   x, ".zip", sep = "")
    path2 <- file.path(file_path, paste("HFP", x, ".zip", sep = ""))
    download.file(path1, destfile = path2)
    unzip(path2, exdir = file_path)
    file.remove(path2)
  }
}

# BNN helpers
bnn_load_data <- function(features,
                          labels,
                          seed = 1234,
                          testsize = 0.1, # 10% test set
                          all_class_in_testset = TRUE,
                          randomize_order = TRUE,
                          header = TRUE, # input data has a header
                          instance_id = TRUE,
                          from_file = FALSE){
  # source python function
  bn <- reticulate::import("np_bnn")

  dat <- bn$get_data(features,
                    labels,
                    seed = as.integer(seed),
                    testsize = testsize, # 10% test set
                    all_class_in_testset = as.integer(all_class_in_testset),
                    randomize_order = randomize_order,
                    header = as.integer(header), # input data has a header
                    instance_id = as.integer(instance_id), # input data includes column with names of instances
                    from_file = from_file)
  return(dat)
}


create_BNN_model <- function(feature_data,
                             n_nodes_list,
                             seed = 1234,
                             use_class_weight = TRUE,
                             use_bias_node = 3,
                             actfun = 'swish',
                             prior = 1, # 0) uniform, 1) normal, 2) Cauchy, 3) Laplace
                             p_scale = 1, # std for Normal, scale parameter for Cauchy and Laplace, boundaries for Uniform
                             init_std = 0.1) { # st dev of the initial weights

  # source python function
  bn <- reticulate::import("np_bnn")

  alphas <- as.integer(c(0, 0))

  bnn_model <- bn$npBNN(feature_data,
                       n_nodes = as.integer(as.list(n_nodes_list)),
                       use_class_weights = as.integer(use_class_weight),
                       actFun = bn$ActFun(fun = actfun),
                       use_bias_node = as.integer(use_bias_node),
                       prior_f = as.integer(prior),
                       p_scale = as.integer(p_scale),
                       seed = as.integer(seed),
                       init_std = init_std)
  return(bnn_model)
}


MCMC_setup <- function(bnn_model,
                       update_f,
                       update_ws,
                       adapt_f,
                       adapt_fM,
                       MCMC_temperature = 1,
                       likelihood_tempering = 1,
                       n_iteration = 1000000,
                       sampling_f = 10, # how often to write to file (every n iterations)
                       print_f = 100, # how often to print to screen (every n iterations)
                       n_post_samples = 1000, # how many samples to keep in log file
                       sample_from_prior = FALSE) {

  # source python function
  bn <- reticulate::import("np_bnn")

  mcmc <- bn$MCMC(bnn_model,
                 update_f = update_f,
                 update_ws = update_ws,
                 adapt_f = adapt_f,
                 adapt_fM = adapt_fM,
                 temperature = MCMC_temperature,
                 n_iteration = as.integer(n_iteration),
                 sampling_f = as.integer(sampling_f),
                 print_f = as.integer(print_f),
                 n_post_samples = as.integer(n_post_samples),
                 sample_from_prior = as.integer(sample_from_prior),
                 likelihood_tempering = likelihood_tempering)
  return(mcmc)
}


run_MCMC <- function(bnn_model,
                     mcmc_object,
                     filename_stem = "BNN",
                     log_all_weights = FALSE){

  # source python function
  bn <- reticulate::import("np_bnn")

  # initialize output files
  logger <- bn$postLogger(bnn_model,
                         filename = filename_stem,
                         log_all_weights = as.integer(log_all_weights))

  # run MCMC
  bn$run_mcmc(bnn_model, mcmc_object, logger)
  return(logger)
}


calculate_accuracy <- function(bnn_data,
                               logger,
                               bnn_model,
                               data = 'test',
                               post_summary_mode = 0){

  # source python function
  bn <- reticulate::import("np_bnn")

  if (data == 'test') {
    features <- bnn_data$test_data
    labels <- bnn_data$test_labels
    instance_id <- bnn_data$id_test_data
  }else if (data == 'train') {
    features <- bnn_data$data
    labels <- bnn_data$labels
    instance_id <- bnn_data$id_data
  }
  post_pr <- bn$predictBNN(features,
                          pickle_file = py_get_attr(logger,'_pklfile'),
                          test_labels = labels,
                          instance_id = instance_id,
                          fname = bnn_data$file_name,
                          post_summary_mode = as.integer(post_summary_mode))
  return(post_pr)
}


bnn_predict <- function(features,
                        instance_id,
                        model_path,
                        post_cutoff,
                        filename,
                        post_summary_mode = 1){

  # source python function
  bn <- reticulate::import("np_bnn")


  post_pr <- bn$predictBNN(as.matrix(features),
                          pickle_file = model_path,
                          instance_id = instance_id,
                          fname = filename,
                          post_cutoff = post_cutoff,
                          post_summary_mode = post_summary_mode
  )

  return(post_pr)
}

subsample_n_per_class <- function(features,
                                  labels,
                                  n_samples){
  # select same of each class
  a <- sample(which(labels$labels$labels == 0),n_samples)
  a <- append(a, sample(which(labels$labels$labels == 1), n_samples))
  a <- append(a, sample(which(labels$labels$labels == 2), n_samples))
  a <- append(a, sample(which(labels$labels$labels == 3), n_samples))
  a <- append(a, sample(which(labels$labels$labels == 4), n_samples))
  labels$labels <- labels$labels[a,]
  target_sp <- labels$labels$species
  features <- features[match(target_sp, features$species),]
  return(list(features, labels))
}

log_results <- function(res,logfile,
                        iucnn_model_out,
                        init_logfile = FALSE){
  # init a new logfile, make sure, you don't overwrite previous results
  if (init_logfile) {
    header <- c("mode",
                "level",
                "dropout_rate",
                "seed",
                "max_epochs",
                "patience",
                "n_layers",
                "use_bias",
                "balance_classes",
                "rescale_features",
                "randomize_instances",
                "mc_dropout",
                "mc_dropout_reps",
                "act_f",
                "act_f_out",
                "cv_fold",
                "test_fraction",
                "label_stretch_factor",
                "label_noise_factor",
                "final_train_epoch_all",
                "final_train_epoch_mean",
                "train_acc",
                "val_acc",
                "test_acc",
                "training_loss",
                "validation_loss",
                "test_loss",
                "confusion_LC",
                "confusion_NT",
                "confusion_VU",
                "confusion_EN",
                "confusion_CR",
                "confusion_0",
                "confusion_1",
                "delta_LC",
                "delta_NT",
                "delta_VU",
                "delta_EN",
                "delta_CR",
                "delta_0",
                "delta_1",
                "model_outpath")
    if (file.exists(logfile)) {
      overwrite_prompt <-  readline(prompt = "Specified log-file already exists. Do you want to overwrite? [Y/n]: ")
      if (overwrite_prompt == 'Y') {
        cat(header, file = logfile, sep = "\t")
        cat('\n', file = logfile, append = TRUE)
      }else{
        stop('Not overwriting existing log-file. Please specify different logfile path or set init_logfile=FALSE')
      }
    }else{
      cat(header, file = logfile,sep = "\t")
      cat('\n', file = logfile, append = TRUE)
    }
  }
  if (inherits(res, "iucnn_model")) {
    if (length(res$input_data$lookup.lab.num.z) == 2) {
      label_level <- 'broad'
      ratio_prediction_lines <- c(NaN,
                                 NaN,
                                 NaN,
                                 NaN,
                                 NaN,
                                 abs(get_cat_count(res$test_labels,max_cat = 1) -
                                       get_cat_count(res$test_predictions,max_cat = 1)))
      confusion_matrix_lines <- c(NaN,
                                 NaN,
                                 NaN,
                                 NaN,
                                 NaN,
                                 paste(res$confusion_matrix[1,], collapse = '_'),
                                 paste(res$confusion_matrix[2,], collapse = '_'))
    }else{
      label_level <- 'detail'
      ratio_prediction_lines <- c(abs(get_cat_count(res$test_labels, max_cat = 4) -
                                       get_cat_count(res$test_predictions, max_cat = 4)),
                                 NaN,
                                 NaN)
      confusion_matrix_lines <- c(paste(res$confusion_matrix[1,], collapse = '_'),
                                 paste(res$confusion_matrix[2,], collapse = '_'),
                                 paste(res$confusion_matrix[3,], collapse = '_'),
                                 paste(res$confusion_matrix[4,], collapse = '_'),
                                 paste(res$confusion_matrix[5,], collapse = '_'),
                                 NaN,
                                 NaN)
    }
    cat(c(res$model,
          label_level,
          res$dropout_rate,
          res$seed,
          res$max_epochs,
          res$patience,
          paste(res$n_layers, collapse = '_'),
          res$use_bias,
          res$balance_classes,
          res$rescale_features,
          res$randomize_instances,
          res$mc_dropout,
          res$mc_dropout_reps,
          res$act_f,
          res$act_f_out,
          res$cv_fold,
          res$test_fraction,
          res$label_stretch_factor,
          res$label_noise_factor,
          paste(res$final_training_epoch, collapse = '_'),
          round(mean(res$final_training_epoch),0),
          round(res$training_accuracy,6),
          round(res$validation_accuracy,6),
          round(res$test_accuracy,6),
          round(res$training_loss,6),
          round(res$validation_loss,6),
          round(res$test_loss,6),
          confusion_matrix_lines,
          ratio_prediction_lines,
          iucnn_model_out), sep = "\t", file = logfile, append = TRUE)
    cat('\n', file = logfile, append = TRUE)
  message(paste0("Model-testing results written to file: ", logfile))
  }
}


process_iucnn_input <- function(x,
                                lab = NaN,
                                mode = NaN,
                                outpath = '.',
                                write_data_files = FALSE,
                                verbose = 1) {
  if (typeof(lab) == 'double') { # aka if lab=NaN when running from predict_iucnn
    # complete cases only
    tmp.in <- x[complete.cases(x),]
    if (nrow(tmp.in) != nrow(x)) {
      mis <- x[!complete.cases(x),]
      if (verbose == 1) {
        warning("Information for species was incomplete, species removed\n", paste(mis$species, "\n"))
      }
    }
    instance_id <- tmp.in$species
    #prepare input data
    tmp <- tmp.in %>%
      dplyr::select(-.data$species)

    dataset <- tmp
    labels <- NaN
    instance_names <- instance_id

  }else{
    ## specific checks
    if (!"species" %in% names(x)) {
      stop("species column not found in x.
           The features input need a column named 'species'
           with the species names matching those in labels")
    }

    # merge species and labels to match order
    tmp.in <- left_join(x, lab$labels, by = "species")

    # check if species were lost by the merging
    if (nrow(tmp.in) != nrow(x)) {
      mis <- x$species[!x$species %in% tmp$species]
      if (verbose == 1) {
        warning("Labels for species not found, species removed.\n", paste(mis, "\n"))
      }
    }

    if (nrow(tmp.in) != nrow(lab$labels)) {
      mis <- lab$labels$species[!lab$labels$species %in% tmp$species]
      if (verbose == 1) {
        warning("Features for species not found, species removed.\n", paste(mis, "\n"))
      }
    }

    # complete cases only
    tmp <- tmp.in[complete.cases(tmp.in),]

    if (nrow(tmp) != nrow(tmp.in)) {
      mis <- tmp.in[!complete.cases(tmp.in),]
      if (verbose == 1) {
        warning("Information for species was incomplete, species removed\n", paste(mis$species, "\n"))
      }
    }

    # check that not all species were removed
    if (nrow(tmp) == 0) {
      stop("Labels and features do not match or there are no species with complete features.")
    }

    # report the number of species
    t1 <- nrow(tmp)

    if (t1 < 200) {
      if (verbose == 1) {
        warning("The number of training taxa is low, consider including more species")
      }
    }

    if (verbose == 1) {
      message(sprintf("%s species included in model training", t1))
    }

    # check class balance
    t2 <- table(tmp$labels)

    if (max(t2) / min(t2) > 3) {
      if (verbose == 1) {
        warning("Classes unbalanced")
      }
    }
    if (verbose == 1) {
      message(sprintf("Class max/min representation ratio: %s", round(max(t2) / min(t2), 1)))
    }
    # prepare input data for the python function
    dataset <- tmp %>%
      dplyr::select(-.data$species, -.data$labels)

    if (mode == 'bnn-class') {
      dataset <- tmp[, seq_along(names(tmp)) - 1]
    }

    instance_names <- tmp %>%
      dplyr::select(.data$species)

    labels <- tmp %>%
      dplyr::select(.data$labels)

    # prepare labels to start at 0
    if (min(labels$labels) != 0) {
      if (verbose == 1) {
        warning(sprintf("Labels need to start at 0. Labels substracted with %s",
                        min(labels$labels)))
      }

      labels <- labels %>%
        dplyr::mutate(labels = .data$labels - min(.data$labels))
    }

    if (mode == 'bnn-class') {
      # in the current npbnn function we need to add a dummy column of instance names
      labels[['names']] <- replicate(length(labels$labels),'sp.')
      labels <- labels[, c('names','labels')]
    }

  }
  if (write_data_files) {
    write.table(as.matrix(dataset),
                paste(outpath,'iucnn_input_features.txt' , sep = '/'),
                sep = '\t',
                quote = FALSE,
                row.names = FALSE)
    if (typeof(lab) == 'list') {
      write.table(as.matrix(labels),paste(outpath,'iucnn_input_labels.txt',sep = '/'),
                  sep = '\t',
                  quote = FALSE,
                  row.names = FALSE)
    }
    write.table(as.matrix(instance_names), paste(outpath,'iucnn_input_instance_names.txt', sep = '/'),
                sep = '\t',
                quote = FALSE,
                row.names = FALSE)
    write.table(names(dataset),paste(outpath,'iucnn_input_feature_names.txt',sep = '/'),
                sep = '\t',
                quote = FALSE,
                row.names = FALSE)
  }

  return(list(dataset,labels,instance_names))
}


rank_models <- function(model_testing_results, rank_by = "val_acc") {
  if (rank_by == "val_acc") {
    # highest validation accuracy
    sorted_model_testing_results <-
      model_testing_results[order(model_testing_results$val_acc, decreasing = TRUE), ]
  } else if (rank_by == "val_loss") {
    # lowest validation loss
    sorted_model_testing_results <-
      model_testing_results[order(model_testing_results$validation_loss, decreasing = FALSE), ]
  } else if (rank_by == "weighted_error") {
    # smallest weighted misclassification error
    if (typeof(model_testing_results$confusion_LC) == "character") {
      LC_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_LC", 1)
      NT_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_NT", 2)
      VU_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_VU", 3)
      EN_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_EN", 4)
      CR_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_CR", 5)
      error_list <- list(LC_weighted_errors,
                        NT_weighted_errors,
                        VU_weighted_errors,
                        EN_weighted_errors,
                        CR_weighted_errors)
    } else {
      not_threatened_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_0", 1)
      threatened_weighted_errors <- get_weighted_errors(model_testing_results, "confusion_1", 2)
      error_list <- list(not_threatened_weighted_errors, threatened_weighted_errors)
    }
    total_error_all_rows <- rowSums(data.frame(t(matrix(unlist(error_list),
                                                       nrow = length(error_list),
                                                       byrow = TRUE))))
    model_testing_results["weighted_error"] <-  total_error_all_rows
    sorted_model_testing_results <-
      model_testing_results[order(model_testing_results$weighted_error,
                                  decreasing = FALSE), ]
  } else if (rank_by == "total_class_matches") {

    # fewest class misclassifications
    if (typeof(model_testing_results$confusion_LC) == "character") {
      sum_false_classes <- rowSums(model_testing_results[, c("delta_LC",
                                                             "delta_NT",
                                                             "delta_VU",
                                                             "delta_EN",
                                                             "delta_CR")])
    } else {
      sum_false_classes <- rowSums(model_testing_results[, c("delta_0", "delta_1")])
    }
    model_testing_results["total_class_error"] <- sum_false_classes
    sorted_model_testing_results <-
      model_testing_results[order(model_testing_results$total_class_error,
                                  decreasing = FALSE), ]
  } else {
    stop(paste0("Invalid choice rank_by = '",
                rank_by,
                "'. Choose from 'val_acc','val_loss','weighted_error' ,or 'total_class_matches'"))
  }
  return(sorted_model_testing_results)
}

evaluate_iucnn <- function(res) {
  if (res$dropout_rate == 0) {
    warning("No acc-thres-tbl and class-freq calculation. Provide model with dropout_rate > 0 to enable these functions.")
  }
  summary(res)
  plot(res)
  cat_count_out <- get_mc_dropout_cat_counts(res)
  accthres_tbl <- res$accthres_tbl
}

get_weighted_errors <- function(model_testing_results,
                                colname = "confusion_LC",
                                true_index = 1) {
  stat_col <- strsplit(model_testing_results[, colname], "_")
  a <- data.frame(matrix(unlist(stat_col), nrow = length(stat_col), byrow = TRUE))
  weighted_errors <- c()
  for (i in 1:dim(a)[1]) {
    row <- as.numeric(a[i, ])
    weighted_error <- sum(abs((1:dim(a)[2] - true_index)) * row)
    weighted_errors <- c(weighted_errors, weighted_error)
  }
  return(weighted_errors)
}

get_cat_count <- function(target_vector,
                          max_cat = 4, include_NA = FALSE) {
  # count the different categories
  counts <- table(target_vector) # this doens't count NaN
  cats <- as.character(0:max_cat)
  if (include_NA) {
    NA_count <- length(target_vector[is.na(target_vector)])
    counts['NA'] <- NA_count
    cats <- c(cats,'NA')
  }
  mis <- cats[!cats %in% names(counts)]
  plo <- c(counts, rep(0, length(mis)))
  names(plo) <- as.character(c(names(counts), mis))
  # order categories
  plo <- plo[cats]
  return(plo)
}






get_confusion_matrix <- function(best_model) {
  if (typeof(best_model$confusion_LC) == "character") {
    target_cols <- as.character(best_model[, c("confusion_LC",
                                               "confusion_NT",
                                               "confusion_VU",
                                               "confusion_EN",
                                               "confusion_CR")])
    count_strings <- strsplit(target_cols, "_")

    confusion_matrix <- matrix(as.integer(unlist(count_strings)),
                               nrow = length(count_strings),
                               byrow = TRUE)
    confusion_matrix <- data.frame(confusion_matrix,
                                   row.names = c("LC", "NT", "VU", "EN", "CR"))
    names(confusion_matrix) <- c("LC", "NT", "VU", "EN", "CR")
  } else {
    target_cols <- as.character(best_model[, c("confusion_0", "confusion_1")])
    count_strings <- strsplit(target_cols, "_")

    confusion_matrix <- matrix(as.integer(unlist(count_strings)),
                               nrow = length(count_strings), byrow = TRUE)
    confusion_matrix <- data.frame(confusion_matrix,
                                   row.names = c("Not Threatened", "Threatened"))
    names(confusion_matrix) <- c("Not threatened", "Threatened")
  }
  return(confusion_matrix)
}


get_mc_dropout_cat_counts <- function(mc_dropout_probs,
                                      label_dict,
                                      mc_dropout,
                                      true_lab = NaN,
                                      nreps = 1000) {

  if (mc_dropout == FALSE) {
    warning("This model contains no MC-dropout predictions for unseen data.
            No sampled_cat_freqs can be calculated for this model.")
    cat_count_all_matrix <- NaN
    true_cat_count <- NaN

  }else{
    nlabs <- length(label_dict)
    if (is.nan(true_lab[1])) {
      true_cat_count <- NaN
    }else{
      true_cat_count <- get_cat_count(true_lab, max_cat = nlabs - 1)
    }
    n_instances <- dim(mc_dropout_probs)[1]
    cat_mcdropout_sample <- c()
    for (i in 1:n_instances) {
      cat_sample <- replicate(nreps, sample(1:nlabs - 1,
                                            size = 1,
                                            prob = mc_dropout_probs[i, ]))
      cat_mcdropout_sample <- c(cat_mcdropout_sample, c(cat_sample))
    }
    cat_mcdropout_sample_matrix <- matrix(cat_mcdropout_sample, nrow = nreps)
    cat_count_all <- c()
    for (row_id in 1:nreps) {
      row <- cat_mcdropout_sample_matrix[row_id, ]
      cat_count_sample <- get_cat_count(row, max_cat = nlabs - 1)
      cat_count_all <- c(cat_count_all, cat_count_sample)
    }
    cat_count_all_matrix <- t(matrix(cat_count_all, ncol = nreps))
  }

  output <- NULL
  output$predicted_class_count <- cat_count_all_matrix
  output$true_class_count <- true_cat_count
  return(output)
}

rnd <- function(x) trunc(x + sign(x) * 0.5)
azizka/IUCNN documentation built on March 29, 2024, 9:38 a.m.