R/gpb.Predictor.R

#' @importFrom methods is
#' @importFrom R6 R6Class
#' @importFrom utils read.delim
Predictor <- R6::R6Class(

  classname = "gpb.Predictor",
  cloneable = FALSE,
  public = list(

    # Finalize will free up the handles
    finalize = function() {

      # Check the need for freeing handle
      if (private$need_free_handle) {

        .Call(
          LGBM_BoosterFree_R
          , private$handle
        )
        private$handle <- NULL

      }

      return(invisible(NULL))

    },

    # Initialize will create a starter model
    initialize = function(modelfile, params = list()) {
      private$params <- gpb.params2str(params = params)
      # Create new lgb handle
      handle <- NULL

      # Check if handle is a character
      if (is.character(modelfile)) {

        # Create handle on it
        handle <- .Call(
          LGBM_BoosterCreateFromModelfile_R
          , modelfile
        )
        private$need_free_handle <- TRUE

      } else if (methods::is(modelfile, "gpb.Booster.handle")) {

        # Check if model file is a booster handle already
        handle <- modelfile
        private$need_free_handle <- FALSE

      } else {

        stop("gpb.Predictor: modelfile must be either a character filename or an gpb.Booster.handle")

      }

      # Override class and store it
      class(handle) <- "gpb.Booster.handle"
      private$handle <- handle

      return(invisible(NULL))

    },

    # Get current iteration
    current_iter = function() {

      cur_iter <- 0L
      .Call(
        LGBM_BoosterGetCurrentIteration_R
        , private$handle
        , cur_iter
      )
      return(cur_iter)

    },

    # Predict from data
    predict = function(data,
                       start_iteration = NULL,
                       num_iteration = NULL,
                       rawscore = FALSE,
                       predleaf = FALSE,
                       predcontrib = FALSE,
                       header = FALSE,
                       reshape = FALSE) {

      # Check if number of iterations is existing - if not, then set it to -1 (use all)
      if (is.null(num_iteration)) {
        num_iteration <- -1L
      }
      # Check if start iterations is existing - if not, then set it to 0 (start from the first iteration)
      if (is.null(start_iteration)) {
        start_iteration <- 0L
      }

      num_row <- 0L

      # Check if data is a file name and not a matrix
      if (identical(class(data), "character") && length(data) == 1L) {

        # Data is a filename, create a temporary file with a "gpboost_" pattern in it
        tmp_filename <- tempfile(pattern = "gpboost_")
        on.exit(unlink(tmp_filename), add = TRUE)

        # Predict from temporary file
        .Call(
          LGBM_BoosterPredictForFile_R
          , private$handle
          , data
          , as.integer(header)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
          , as.integer(start_iteration)
          , as.integer(num_iteration)
          , private$params
          , tmp_filename
        )

        # Get predictions from file
        preds <- utils::read.delim(tmp_filename, header = FALSE, sep = "\t")
        num_row <- nrow(preds)
        preds <- as.vector(t(preds))

      } else {

        # Not a file, we need to predict from R object
        num_row <- nrow(data)

        npred <- 0L

        # Check number of predictions to do
        .Call(
          LGBM_BoosterCalcNumPredict_R
          , private$handle
          , as.integer(num_row)
          , as.integer(rawscore)
          , as.integer(predleaf)
          , as.integer(predcontrib)
          , as.integer(start_iteration)
          , as.integer(num_iteration)
          , npred
        )

        # Pre-allocate empty vector
        preds <- numeric(npred)

        # Check if data is a matrix
        if (is.matrix(data)) {
          # this if() prevents the memory and computational costs
          # of converting something that is already "double" to "double"
          if (storage.mode(data) != "double") {
            storage.mode(data) <- "double"
          }
          .Call(
            LGBM_BoosterPredictForMat_R
            , private$handle
            , data
            , as.integer(nrow(data))
            , as.integer(ncol(data))
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
            , as.integer(start_iteration)
            , as.integer(num_iteration)
            , private$params
            , preds
          )

        } else if (methods::is(data, "dgCMatrix")) {
          if (length(data@p) > 2147483647L) {
            stop("Cannot support large CSC matrix")
          }
          # Check if data is a dgCMatrix (sparse matrix, column compressed format)
          .Call(
            LGBM_BoosterPredictForCSC_R
            , private$handle
            , data@p
            , data@i
            , data@x
            , length(data@p)
            , length(data@x)
            , nrow(data)
            , as.integer(rawscore)
            , as.integer(predleaf)
            , as.integer(predcontrib)
            , as.integer(start_iteration)
            , as.integer(num_iteration)
            , private$params
            , preds
          )

        } else {

          stop("predict: cannot predict on data of class ", sQuote(class(data)))

        }
      }

      # Check if number of rows is strange (not a multiple of the dataset rows)
      if (length(preds) %% num_row != 0L) {
        stop(
          "predict: prediction length "
          , sQuote(length(preds))
          , " is not a multiple of nrows(data): "
          , sQuote(num_row)
        )
      }

      # Get number of cases per row
      npred_per_case <- length(preds) / num_row


      # Data reshaping

      if (predleaf | predcontrib) {

        # Predict leaves only, reshaping is mandatory
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)

      } else if (reshape && npred_per_case > 1L) {

        # Predict with data reshaping
        preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)

      }

      return(preds)

    }

  ),
  private = list(
    handle = NULL
    , need_free_handle = FALSE
    , params = ""
  )
)

Try the gpboost package in your browser

Any scripts or data that you put into this service are public.

gpboost documentation built on Oct. 24, 2023, 9:09 a.m.