R/model_prediction.R

Defines functions predictMPRModel

Documented in predictMPRModel

#' predictMPRModel
#'
#' @param model An \code{MPRModel} object. This is typically returned from a
#'   call to \code{\link{fitMPRModel}}.
#' @param data The data.frame/matrix that the model will be applied to.
#' @param ... Other arguments to be passed to the method-specific prediction
#'   function.
#'
#' @return The result from the method-specific predict function.
#' @export
predictMPRModel <- function(model, data, ...) {
  checkNA(data)
  checkMatrixOrDF(data)
  if (S3Class(model) != "MPRModel") {
    stop("predictMPRModel: model must be of class MPRModel\n")
  }
  type <- model$modelType
  method <- model$modelMethod

  predictFunctionLookup <- list(
    "binary" = list(
      "glmnet" = predictMPRModelglmnet,
      "biglasso" = predictMPRModelbiglasso,
      "bart" = predictMPRModelBinaryBART,
      "rf" = predictMPRModelBinaryRF
    ),
    "survival" = list(
      "glmnet" = predictMPRModelglmnet,
      "biglasso" = predictMPRModelbiglasso
    ),
    "continuous" = list(
      "glmnet" = predictMPRModelglmnet,
      "biglasso" = predictMPRModelbiglasso,
      "bart" = predictMPRModelContinuousBART,
      "rf" = predictMPRModelContinuousRF
    )
  )

  predictResult <- predictFunctionLookup[[type]][[method]](model, data, ...)
  predictResult
}
marioni-group/MethylPipeR documentation built on Oct. 10, 2024, 3:32 p.m.