R/ModelGenerator.R

Defines functions ModelGenerator

Documented in ModelGenerator

#' Generates an ensemble of neural network models.
#' 
#' \code{\link{ModelGenerator}} generates an ensemble of neural network models
#' each trained to classify cellular phenotypes using the reference data set.
#'
#' @param R Reference data set returned by \code{\link{GetTrainingData_HPCA}}
#' @param N Number of neural networks to train. Default is 1.
#' @param num.cores Number of cores to use for parallel computing. Default is 1.
#' @param verbose if TRUE, code will report outputs. Default is TRUE.
#' @param hidden Number of hidden layers in the neural network. Default is 1.
#' @param set.seed If TRUE, seed is set to ensure reproducibility of these results. Default is TRUE.
#' @param seed if set.seed is TRUE, the seed can be set. Default is 42.
#' @seealso [SignacFast()] for a function that uses the models generated by this function.
#' @return A list, each containing N neural network models
#' @export
#' @examples
#' \dontrun{
#' # download training data set from GitHub
#' Ref = GetTrainingData_HPCA()
#' 
#' # train a stack of 1,800 neural network models
#' Models = ModelGenerator(R = Ref, N = 100, num.cores = 4)
#' 
#' # save models
#' save(Models, file = "models.rda")
#' }
ModelGenerator <- function(R, N = 1, num.cores = 1, verbose = TRUE, hidden = 1, set.seed = TRUE, seed = '42')
{
  if (verbose)
  {
    cat(" ..........  Entry in ModelGenerator \n");
    ta = proc.time()[3];
    cat(" ..........  Running ModelGenerator on input reference dataset :\n");
    cat("             classification problems = ", length(R$Reference), "\n", sep = "");
    cat("             total number of classifiers in ensemble = ", N * length(R$Reference), "\n", sep = "");
  }
  
  res = lapply(R$Reference, function(x){
    # train a neural network (N times)
    if (set.seed)
    {
      RNGkind("L'Ecuyer-CMRG")
      set.seed(seed = seed)
    }
    out = suppressWarnings(pbmcapply::pbmclapply(1:N, function(y) {
      model.fit = neuralnet::neuralnet(celltypes~.,hidden=hidden,data=x, act.fct = 'logistic', linear.output = FALSE)
      model.fit$generalized.weights <- NULL
      model.fit$covariate <- NULL
      model.fit$data <- NULL
      model.fit$response <- NULL
      model.fit$startweights <- NULL
      model.fit$call <- NULL
      model.fit$net.result <- NULL
      model.fit$exclude <- NULL
      model.fit$err.fct <- NULL
      model.fit$result.matrix <- NULL
      return(model.fit)
    }, mc.cores = num.cores))
    outs = list(
      genes = colnames(x)[-ncol(x)],
      classifiers = out
    )
    return(outs)
  })
  if (verbose) {
    tb = proc.time()[3] - ta;
    cat("\n ..........  Exit ModelGenerator.\n");
    cat("             Execution time = ", tb, " s.\n", sep = "");
  }
  return(res)
}

Try the SignacX package in your browser

Any scripts or data that you put into this service are public.

SignacX documentation built on Nov. 18, 2021, 5:07 p.m.