sandbox/Lrnr_pkg_condensier_logisfitR6.R

## Allows us to use any learner in sl3 as a bin learner in condensier
## This additional injection works when Lrnr_condensier$new(bin_estimator =
## Lrnr_base) is a learner from sl3 package.
##
## In this, the sl3 learner 'Lrnr_base' object is injected into the wraper class
## below. This R6 object then provides a communication link between the two
## packages (sl3 <-> condensier).

##  NEED TO UPDATE THIS BASE CLASS FROM CONDENSIER TO ADD WEIGHTS

#' sl3 Learner wrapper for condensier
#'
#' This wrapper allows the use of any \code{sl3} Learner as a Learner for
#' \code{condensier}. For details, see the \code{\link[condensier]{fit_density}}
#' function.
#'
#' @docType class
#'
#' @export
#
Lrnr_pkg_condensier_logisfitR6 <- R6Class(
  "Lrnr_pkg_condensier_logisfitR6",
  inherit = condensier::logisfitR6,
  public = list(
    lmclass = NULL,
    fitfunname = NULL,

    initialize = function(sl3_lrnr, ...) {
      assert_that(is(sl3_lrnr, "Lrnr_base"))
      self$lmclass <- paste0("sl3::", class(sl3_lrnr)[1L])
      self$fitfunname <- paste0("sl3::", class(sl3_lrnr)[1L], "$train")
      private$sl3_lrnr <- sl3_lrnr
      self$get_validity
      super$initialize(...)
    },

    fit = function(datsum_obj) {
      verbose <- getOption("sl3.verbose")
      if (verbose) print(paste("calling ", self$fitfunname))
      X_mat <- datsum_obj$getXDT
      Y_vals <- datsum_obj$getY
      # capture weights conditionally
      if (!is.null(wts <- datsum_obj$getweights)) {
        dataDT <- cbind(X_mat, weights = wts, Y = Y_vals)
      } else {
        dataDT <- cbind(X_mat, Y = Y_vals)
      }
      sl3_lrnr <- private$sl3_lrnr
      if (nrow(dataDT) > 0) {
        task <- sl3_Task$new(
          data = dataDT,
          covariates = colnames(X_mat),
          outcome = colnames(dataDT)[ncol(dataDT)],
          weights = if (!is.null(wts)) "weights"
        )
        out <- sl3_lrnr <- try(suppressWarnings(sl3_lrnr$train(task)))
        if (inherits(sl3_lrnr, "try-error") ||
          inherits(sl3_lrnr$fit_object, "try-error")) {
          # if failed, use fall back learner
          if (verbose) {
            message(paste0(
              "learner ", self$fitfunname,
              " failed, trying private$fallback_learner; "
            ))
            if (inherits(sl3_lrnr, "Lrnr_base")) {
              print(sl3_lrnr$name)
            } else {
              print(sl3_lrnr)
            }
          }
          sl3_lrnr <- private$fallback_learner$new()$train(task)
        }
      }
      if (verbose) print(sl3_lrnr)
      return(sl3_lrnr)
    },

    predict = function(datsum_obj, m.fit) {
      verbose <- getOption("sl3.verbose")
      X_mat <- datsum_obj$getXDT
      assert_that(!is.null(X_mat))
      assert_that(!is.null(datsum_obj$subset_idx))
      pAout <- rep.int(NA_real_, datsum_obj$n)
      if (sum(datsum_obj$subset_idx > 0)) {
        new_task <- sl3_Task$new(
          X_mat,
          covariates = colnames(X_mat),
          outcome = NULL
        )
        ## Learner hasn't been trained,
        ## means we are making predictionsin for a last degenerate bin.
        ## These predictions play no role and aren't used.
        if (is(m.fit, "Lrnr_base") && !m.fit$is_trained) {
          pAout[datsum_obj$subset_idx] <- 0.5
        } else {
          pAout[datsum_obj$subset_idx] <- m.fit$predict(new_task)[[1]]
        }
      }
      return(pAout)
    }
  ),
  private = list(
    fallback_learner = Lrnr_glm_fast,
    sl3_lrnr = NULL
  )
)
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.