R/gt_snmf.R

Defines functions extract_cross_entropy gt_snmf

Documented in gt_snmf

#' Run SNMF from R in tidypopgen
#'
#' @details This is a wrapper for the function snmf from R package LEA.
#'
#' @param x a `gen_tibble` or a character giving the path to the input geno file
#' @param k an integer giving the number of clusters
#' @param project one of "continue", "new", and "force": "continue" stores files
#'   in the current project, "new" creates a new project, and "force" stores
#'   results in the current project even if the .geno input file has been
#'   altered,
#' @param n_runs the number of runs for each k value (defaults to 1)
#' @param alpha numeric snmf regularization parameter. See LEA::snmf for details
#' @param tolerance numeric value of tolerance (default 0.00001)
#' @param entropy boolean indicating whether to estimate cross-entropy
#' @param percentage numeric value indicating percentage of masked genotypes,
#'   ranging between 0 and 1, to be used when entropy = TRUE
#' @param I number of SNPs for initialising the snmf algorithm
#' @param iterations numeric integer for maximum iterations (default 200)
#' @param ploidy the ploidy of the input data (defaults to 2)
#' @param seed the seed for the random number generator
#' @return an object of class `gt_admix` consisting of a list with the following
#'   elements:
#' - `k` the number of clusters
#' - `Q` a matrix with the admixture proportions
#' - `P` a matrix with the allele frequencies
#' - `log` a log of the output generated by ADMIXTURE (usually printed
#'    on the screen when running from the command line)
#' - `cv` the masked cross-entropy (if `entropy` is TRUE)
#' - `loglik` the log likelihood of the model
#' - `id` the id column of the input `gen_tibble` (if applicable)
#' - `group` the group column of the input `gen_tibble` (if applicable)
#' @export
#' @examplesIf rlang::is_installed("LEA")
#' # run the example only if we have the package installed
#' example_gt <- load_example_gt("gen_tbl")
#'
#' # To run SNMF on a gen_tibble:
#' example_gt %>% gt_snmf(
#'   k = 1:3, project = "force", entropy = TRUE,
#'   percentage = 0.5, n_runs = 1, seed = 1, alpha = 100
#' )
gt_snmf <- function(
    x,
    k,
    project = "continue",
    n_runs = 1,
    alpha,
    tolerance = 0.00001,
    entropy = FALSE,
    percentage = 0.05,
    I, # nolint
    iterations = 200,
    ploidy = 2,
    seed = -1) {
  # add seed check again!!!
  if (!is.null(seed) && length(seed) != n_runs) {
    stop("'seed' should be a vector of length 'n_runs'")
  }

  # if required install LEA
  if (!requireNamespace("LEA", quietly = TRUE)) {
    stop(
      "to use this function, first install package 'adegenet' with\n",
      "utils::install.packages('LEA'')"
    )
  }

  if (inherits(x, "gen_tbl")) {
    input_file <- gt_as_geno_lea(x)
    # expand path to file to be full path
    input_file <- normalizePath(input_file)
    out_file <- sub(".geno", "", input_file)
    file_name <- sub(".geno", "", basename(input_file))
  } else if (inherits(x, "character")) {
    if (!file.exists(x)) {
      stop("The file ", x, " does not exist")
    }
    # check whether the file ends in .geno
    if (!grepl(".geno$", x)) {
      stop("The input file must be a .geno file")
    }
    input_file <- x
    out_file <- sub(".geno", "", input_file)
    file_name <- sub(".geno", "", basename(input_file))
  } else if (!inherits(x, "character")) {
    stop(paste(
      "x must be a gen_tibble or a character giving the path to the",
      "input geno file"
    ))
  }

  # cast k as an integer
  k <- as.integer(k)

  # initialise list to store results
  adm_list <- list(
    k = integer(),
    Q = list(),
    P = list(),
    log = list(),
    loglik = numeric(),
    G = list()
  )
  class(adm_list) <- c("gt_admix", "list")

  if (entropy) {
    snmf_res <- utils::capture.output(LEA::snmf(
      input.file = input_file,
      K = k,
      project = project,
      repetitions = n_runs,
      alpha = alpha,
      tolerance = tolerance,
      entropy = entropy,
      percentage = percentage,
      I = I,
      iterations = iterations,
      ploidy = ploidy,
      seed = seed
    ))
  } else {
    snmf_res <- utils::capture.output(LEA::snmf(
      input.file = input_file,
      K = k,
      project = project,
      repetitions = n_runs,
      alpha = alpha,
      tolerance = tolerance,
      I = I,
      iterations = iterations,
      ploidy = ploidy,
      seed = seed
    ))
  }

  # loop over values of k and number of repeats
  index <- 1
  for (this_k in as.integer(k)) {
    for (this_rep in seq_len(n_runs)) {
      adm_list$k[index] <- this_k
      adm_list$Q[[index]] <- q_matrix(utils::read.table(
        paste0(
          out_file,
          ".snmf/K",
          this_k,
          "/run",
          this_rep,
          "/",
          file_name,
          "_r",
          this_rep,
          ".",
          this_k,
          ".Q"
        ),
        header = FALSE
      ))
      adm_list$G[[index]] <- q_matrix(utils::read.table(
        paste0(
          out_file,
          ".snmf/K",
          this_k,
          "/run",
          this_rep,
          "/",
          file_name,
          "_r",
          this_rep,
          ".",
          this_k,
          ".G"
        ),
        header = FALSE
      ))
      index <- index + 1
    }
  }

  # add log
  adm_list$log <- snmf_res

  # add entropy to cv slot
  if (entropy) {
    # extract value from line with cross-Entropy (number after :)
    entropy <- extract_cross_entropy(snmf_res)
    adm_list$cv <- entropy$cross_entropy_masked
  }

  # add metadata if x is a gen_tibble
  if (inherits(x, "gen_tbl")) {
    adm_list$id <- x$id
    # if it is grouped, add the group
    if (inherits(x, "grouped_gen_tbl")) {
      adm_list$group <- x[[dplyr::group_vars(x)]]
    }
  }

  # add info on algorithm
  adm_list$algorithm <- "SNMF"

  return(adm_list)
}


# Internal function for extracting cross-entropy from the log output
extract_cross_entropy <- function(log_text) {
  # Initialize an empty tibble
  results <- tibble(
    K = integer(),
    repetition = integer(),
    cross_entropy_all = numeric(),
    cross_entropy_masked = numeric()
  )

  # Loop through the text and extract data
  for (i in seq_along(log_text)) {
    line <- log_text[i]

    # Match the "sNMF K = x repetition y" line
    if (grepl("sNMF K = \\d+  repetition \\d+", line)) {
      # Extract K and repetition
      matches <- regmatches(
        line,
        regexec("sNMF K = (\\d+)  repetition (\\d+)", line)
      ) # nolint
      K <- as.integer(matches[[1]][2]) # nolint
      repetition <- as.integer(matches[[1]][3])
    }

    # Match the "Cross-Entropy (all data):" line
    if (grepl("Cross-Entropy \\(all data\\):", line)) {
      matches <- regmatches(
        line,
        regexec("Cross-Entropy \\(all data\\):\\s+([0-9.]+)", line)
      ) # nolint
      cross_entropy_all <- as.numeric(matches[[1]][2])
    }

    # Match the "Cross-Entropy (masked data):" line
    if (grepl("Cross-Entropy \\(masked data\\):", line)) {
      matches <- regmatches(
        line,
        regexec("Cross-Entropy \\(masked data\\):\\s+([0-9.]+)", line)
      ) # nolint
      cross_entropy_masked <- as.numeric(matches[[1]][2])

      # Add the collected data to the tibble
      results <- add_row(
        results,
        K = K,
        repetition = repetition,
        cross_entropy_all = cross_entropy_all,
        cross_entropy_masked = cross_entropy_masked
      )
    }
  }

  return(results) # nolint
}

Try the tidypopgen package in your browser

Any scripts or data that you put into this service are public.

tidypopgen documentation built on Aug. 28, 2025, 1:08 a.m.