R/3_SPEARobject.R

Defines functions new.spear

Documented in new.spear

# Define SPEAR class:
SPEAR <- R6::R6Class("SPEAR",
                           public = list(
                             data = NULL,
                             params = NULL,
                             inits = NULL,
                             options = NULL,
                             
                             # Update with run.spear or run.cv.spear
                             fit = NULL,
                             # Update with get.model
                             model = NULL,
                             
                             # Functions:
                             initialize = function(
                               # data:
                               data = NULL, 
                               response = NULL,
                               use.ordinal = FALSE,
                               # weights:
                               weights.case = NULL,  
                               weights.x = NULL, 
                               weights.y = NULL,
                               # model parameters:
                               num.factors = 5, 
                               inits.type = "pca", 
                               inits.post.mu = NULL,
                               sparsity.upper = 0.5,
                               warm.up = 100, 
                               max.iter = 1000,
                               thres.elbo = 0.01, 
                               thres.count = 5, 
                               thres.factor = 1e-8, 
                               print.out = 100,
                               seed = 123,
                               # cv parameters:
                               num.folds = 5,
                               num.cores = NULL,
                               # coefficients:
                               a0 = NULL, 
                               b0 = NULL, 
                               a1 = NULL, 
                               b1 = NULL,
                               a2 = NULL, 
                               b2 = NULL, 
                               L1 = NULL,
                               L2 = NULL,
                               robust.eps = NULL,
                               quiet = FALSE
                             ) {
                               # called by SPEARobj$new(...)
                               
                               # Start with options:
                               options = list()
                               # quiet - should extra print statements be silenced? Defaults to FALSE
                               options$quiet = quiet
                               # current.response.idx
                               options$current.response.idx <- 1
                               # current.weight.idx
                               options$current.weight.idx <- 1
                               # parallel.method - use parLapply, change if desired to lapply or mclapply
                               options$parallel.method <- "parLapply"
                               
                               # Save options:
                               self$options = options

                               if(!quiet){
                                 cat("----------------------------------------------------------------\n")
                                 cat("SPEAR version 2.0.0   Please direct all questions to Jeremy Gygi\n(jeremy.gygi@yale.edu) or Leying Guan (leying.guan@yale.edu)\n", sep = "")
                                 cat("----------------------------------------------------------------\n")
                                 cat("Generating SPEAR object...\n")
                               }
                               
                               # Data:
                               if(!quiet){cat("$data...\t")}
                               self$data = list()
                               self$add.data(data = data, response = response, name = "train")
                               if(!quiet){cat("Done!\n")}
                               if(!quiet){cat("$params...\t")}
                               # Parameters:
                               params <- list()
                               
                               # Family (gaussian, binomial, ordinal, multinomial)
                               response.types <- sapply(response, function(tmp){
                                 return(class(MultiAssayExperiment::colData(self$data$train)[[tmp]]))
                               })
                               
                               if(length(unique(response.types)) > 1){
                                 stop("ERROR: multiple response require the same class (i.e. all Gaussian, all Multinomial, etc.).\nEnsure that the `class(MultiAssayExperiment::colData(...)[[response]])` is the same for each.")
                               } else {
                                 # only use the first:
                                 response.types <- response.types[1]
                               }
                               if(response.types == "numeric"){
                                 family = "gaussian"
                                 family.encoded = 0
                                 nclasses = rep(2, length(response))
                               } else if(response.types == "factor"){
                                 if(all(sapply(response, function(tmp){
                                   return(length(levels(MultiAssayExperiment::colData(self$data$train)[[tmp]])) == 2)
                                 }))){
                                   family = "binomial"
                                   family.encoded = 1
                                   nclasses = 2
                                 }
                                 else if(use.ordinal == TRUE){
                                   family = "ordinal"
                                   family.encoded = 2
                                   nclasses = length(levels(MultiAssayExperiment::colData(self$data$train)[[response[1]]]))
                                 } else {
                                   family = "multinomial"
                                   family.encoded = 3
                                   nclasses = rep(2, length(levels(MultiAssayExperiment::colData(self$data$train)[[response[1]]])))
                                 }
                               } else {
                                 stop("ERROR: response class not recognized. Must be of type numeric (for Gaussian) or factor (for binomial/ordinal/multinomial)")
                               }
                               # store:
                               params$response = response
                               params$family = family
                               params$family.encoded = family.encoded
                               params$nclasses = nclasses
                               
                               # TODO: potentially run factor exploration here?
                               params$num.factors = num.factors
                               
                               # Number of folds:
                               params$num.folds = num.folds
                               
                               # Number of cores: (for parallelization)
                               params$num.cores = num.cores
                               
                               # Obtain indices that can be used to parse through concatenated assays together:
                               params$assay.indices = list()
                               start.ind = 1
                               for(d in 1:length(self$data$train)){
                                 end.ind = start.ind + nrow(self$data$train[[d]]) - 1
                                 params$assay.indices[[d]] <- start.ind:end.ind
                                 start.ind = end.ind + 1
                               }
                               
                               # Weights:
                               if(is.null(weights.x) & is.null(weights.y)){
                                 weights.x <- c(0, .1, .5, 1, 2)
                                 weights.y <- rep(1, length(weights.x))
                               } else if(is.null(weights.x)){
                                 weights.x <- rep(1, length(weights.y))
                               } else if(is.null(weights.y)){
                                 weights.y = rep(1, length(weights.x))
                               } else if(length(weights.x) != length(weights.y)){
                                 stop("ERROR: lengths of weights.x and weights.y do not match. They need to have the same length.")
                               }
                               params$weights = cbind(weights.x, weights.y)
                               params$weights = params$weights[order(params$weights[,1], decreasing = TRUE),]
                               colnames(params$weights) = c("w.x", "w.y")
                               rownames(params$weights) = 1:nrow(params$weights)
                               if(is.null(weights.case)){
                                 params$weights.case = rep(1, nrow(MultiAssayExperiment::colData(self$data$train)))
                               }else if(length(weights.case)!=nrow(MultiAssayExperiment::colData(self$data$train))){
                                 stop("ERROR: Supplied weights.case needs to match the number of columns in the provided data (number of samples).")
                               }else{
                                 params$weights.case = weights.case
                               }
                               
                               # Seed
                               params$seed = seed
                               
                               # Misc. Parameters:
                               params$inits.type = inits.type
                               params$inits.post.mu = inits.post.mu
                               params$sparsity.upper = sparsity.upper
                               params$warm.up = warm.up
                               params$max.iter = max.iter
                               params$thres.elbo = thres.elbo
                               params$thres.count = thres.count
                               params$thres.factor =thres.factor
                               params$print.out = print.out
                               
                               # Save:
                               self$params <- params
                               
                               
                               # Initial Coefficients:
                               if(!quiet){cat("Done!\n")}
                               if(!quiet){cat("$inits...\t")}
                               inits = list()
                               if(is.null(a0)){inits$a0 = 1e-2}else{inits$a0 = a0}
                               if(is.null(b0)){inits$b0 = 1e-2}else{inits$b0 = b0}
                               if(is.null(a1)){inits$a1 = sqrt(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$a1 = a1}
                               if(is.null(b1)){inits$b1 = sqrt(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$b1 = b1}
                               if(is.null(a2)){inits$a2 = sqrt(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$a2 = a2}
                               if(is.null(b2)){inits$b2 = sqrt(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$b2 = b2}
                               if(is.null(L1)){inits$L1 = nrow(MultiAssayExperiment::colData(self$data$train))/log(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$L1 = L1}
                               if(is.null(L2)){inits$L2 = .1}else{inits$L2 = L2}
                               if(is.null(robust.eps)){inits$robust.eps = 1.0/(nrow(MultiAssayExperiment::colData(self$data$train)))}else{inits$robust.eps = robust.eps}
                               
                               # Random:
                               # -- colors:
                               private$set.color.scheme()
                               
                               # Save:
                               self$params = params
                               self$inits = inits
                               
                               if(!quiet){
                                 cat("Done!\n")
                                 cat("SPEAR object generated!\n\n")
                               }
                               private$print.out(type = "data", remove.formatting = self$options$remove.formatting, quiet = self$options$quiet)
                               cat("\n")
                               private$print.out(type = "help", remove.formatting = self$options$remove.formatting, quiet = self$options$quiet)
                             },
    
    # print function:
    print = function(...){
      private$print.out(type = "data", remove.formatting = self$options$remove.formatting, quiet = self$options$quiet)
      cat("\n")
      private$print.out(type = "help", remove.formatting = self$options$remove.formatting, quiet = self$options$quiet)
      return(invisible(self)) 
    },
    
    # Method functions:
    add.data = add.data,
    assess = assess,
    cv.evaluate = cv.evaluate,
    estimate.num.factors = estimate.num.factors,
    generate.fold.ids = generate.fold.ids,
    get.analyte.scores = get.analyte.scores,
    get.cv.loss = get.cv.loss,
    get.data = get.data,
    get.factor.contributions = get.factor.contributions,
    get.factor.scores = get.factor.scores,
    get.predictions = get.predictions,
    get.signatures = get.signatures,
    get.variance.explained = get.variance.explained,
    plot.cv.loss = plot.cv.loss,
    plot.factor.contributions = plot.factor.contributions,
    plot.factor.scores = plot.factor.scores,
    plot.variance.explained = plot.variance.explained,
    remove.data = remove.data,
    set.response = set.response,
    set.weights = set.weights,
    train.spear = train.spear
    #SPEAR.help = SPEAR.help
                           ), # end public
    private = list(
      check.fit = check.fit,
      check.fold.ids = check.fold.ids,
      create.params = create.params,
      encode.ordinal = encode.ordinal,
      generate.spear.ids = generate.spear.ids,
      get.concatenated.X = get.concatenated.X,
      get.widx.from.method = get.widx.from.method,
      get.Y = get.Y,
      one.hot.encode.multinomial = one.hot.encode.multinomial,
      print.out = print.out,
      spear = spear,
      set.color.scheme = set.color.scheme,
      update.dimnames = update.dimnames
      #impute.z = impute.z,
    )
)

#' Make a SPEARobject. Will return an R6 class SPEARobject used for the "SPEAR" package.
#'@param X Assay matrix.
#'@export
new.spear <- function(
    data = NULL, 
    response = NULL,
    use.ordinal = FALSE,
    # weights:
    weights.case = NULL,  
    weights.x = NULL, 
    weights.y = NULL,
    # model parameters:
    num.factors = 5, 
    inits.type = "pca", 
    inits.post.mu = NULL,
    sparsity.upper = 0.5,
    warm.up = 100, 
    max.iter = 1000,
    thres.elbo = 0.01, 
    thres.count = 5, 
    thres.factor = 1e-8, 
    print.out = 100,
    seed = 123,
    # cv parameters:
    num.folds = 5,
    num.cores = NULL,
    # coefficients:
    a0 = NULL, 
    b0 = NULL, 
    a1 = NULL, 
    b1 = NULL,
    a2 = NULL, 
    b2 = NULL, 
    L1 = NULL,
    L2 = NULL,
    robust.eps = NULL,
    quiet = FALSE
){
  return(SPEAR$new(
    data = data, 
    response = response,
    use.ordinal = use.ordinal,
    # weights:
    weights.case = weights.case,  
    weights.x = weights.x, 
    weights.y = weights.y,
    # model parameters:
    num.factors = num.factors, 
    inits.type = inits.type, 
    inits.post.mu = inits.post.mu,
    sparsity.upper = sparsity.upper,
    warm.up = warm.up, 
    max.iter = max.iter,
    thres.elbo = thres.elbo, 
    thres.count = thres.count, 
    thres.factor = thres.factor, 
    print.out = print.out,
    seed = seed,
    # cv parameters:
    num.folds = num.folds,
    num.cores = num.cores,
    # coefficients:
    a0 = a0, 
    b0 = b0, 
    a1 = a1, 
    b1 = b1,
    a2 = a2, 
    b2 = b2, 
    L1 = L1,
    L2 = L2,
    robust.eps = robust.eps,
    quiet = quiet
  ))
}
jgygi/SPEAR documentation built on July 5, 2023, 5:35 p.m.