get_class_weight: Estimate frequency of different classes

View source: R/preprocess.R

get_class_weightR Documentation

Estimate frequency of different classes

Description

Count number of nucleotides for each class and use as estimation for relation of class distribution. Outputs list of class relations. Can be used as input for class_weigth in train_model function.

Usage

get_class_weight(
  path,
  vocabulary_label = NULL,
  format = "fasta",
  file_proportion = 1,
  train_type = "label_folder",
  named_list = FALSE,
  csv_path = NULL
)

Arguments

path

Path to training data. If train_type is label_folder, should be a vector or list where each entry corresponds to a class (list elements can be directories and/or individual files). If train_type is not label_folder, can be a single directory or file or a list of directories and/or files.

vocabulary_label

Character vector of possible targets. Targets outside vocabulary_label will get discarded.

format

File format, either "fasta" or "fastq".

file_proportion

Proportion of files to randomly sample for estimating class distributions.

train_type

Either "lm", "lm_rds", "masked_lm" for language model; "label_header", "label_folder", "label_csv", "label_rds" for classification or "dummy_gen".

  • Language model is trained to predict character(s) in a sequence.

  • "label_header"/"label_folder"/"label_csv" are trained to predict a corresponding class given a sequence as input.

  • If "label_header", class will be read from fasta headers.

  • If "label_folder", class will be read from folder, i.e. all files in one folder must belong to the same class.

  • If "label_csv", targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv file

    file label_1 label_2
    "a.fasta" 1 0
  • If "label_rds", generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model with multiple inputs.

  • If "lm_rds", generator will iterate over set of .rds files and will split tensor according to target_len argument (targets are last target_len nucleotides of each sequence).

  • If "dummy_gen", generator creates random data once and repeatedly feeds these to model.

  • If "masked_lm", generator maskes some parts of the input. See masked_lm argument for details.

named_list

Whether to give class weight list names ⁠"0", "1", ...⁠ or not.

csv_path

If train_type = "label_csv", path to csv file containing labels.

Value

A list of numeric values (class weights).

Examples


# create dummy data
path_1 <- tempfile()
path_2 <- tempfile()

for (current_path in c(path_1, path_2)) {
  
  dir.create(current_path)
  # create twice as much data for first class
  num_files <- ifelse(current_path == path_1, 6, 3)
  create_dummy_data(file_path = current_path,
                    num_files = num_files,
                    seq_length = 10,
                    num_seq = 5,
                    vocabulary = c("a", "c", "g", "t"))
}


class_weight <- get_class_weight(
  path = c(path_1, path_2),
  vocabulary_label = c("A", "B"),
  format = "fasta",
  file_proportion = 1,
  train_type = "label_folder",
  csv_path = NULL)

class_weight


GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.