R/Lrnr_density_hse.R

#' Density Estimation With Mean Model and Homoscedastic Errors
#'
#' This learner assumes a mean model with homoscedastic errors: Y ~ E(Y|W) + epsilon. E(Y|W) is fit using any mean learner,
#' and then the errors are fit with kernel density estimation.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom assertthat assert_that is.count is.flag
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#'  \code{\link{Lrnr_base}} for documentation on learners.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#'   \item{\code{binomial_learner}}{The learner to wrap.}
#' }
#'
#' @template common_parameters
#'
#' @examples
#' # load example data
#' data(cpp_imputed)
#'
#' # create sl3 task
#' task <- sl3_Task$new(
#'   cpp_imputed,
#'   covariates = c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs"),
#'   outcome = "haz"
#' )
#'
#' # train density hse learner and make predictions
#' lrnr_density_hse <- Lrnr_density_hse$new(mean_learner = Lrnr_glm$new())
#' fit_density_hse <- lrnr_density_hse$train(task)
#' preds_density_hse <- fit_density_hse$predict()
Lrnr_density_hse <- R6Class(
  classname = "Lrnr_density_hse",
  inherit = Lrnr_base, portable = TRUE,
  class = TRUE,
  public = list(
    initialize = function(mean_learner = NULL, ...) {
      if (is.null(mean_learner)) {
        mean_learner <- make_learner(Lrnr_glm_fast)
      }
      params <- list(mean_learner = mean_learner, ...)
      super$initialize(params = params, ...)
    }
  ),
  private = list(
    .properties = c("density"),
    .train = function(task) {
      mean_learner <- self$params$mean_learner
      mean_fit <- mean_learner$train(task)

      # TODO: maybe these should be cv errors?
      mean_preds <- mean_fit$predict()
      errors <- task$Y - mean_preds
      dens_fit <- density(errors)
      fit_object <- list(mean_fit = mean_fit, dens_fit = dens_fit)
      return(fit_object)
    },
    .predict = function(task) {
      mean_fit <- self$fit_object$mean_fit
      dens_fit <- self$fit_object$dens_fit
      mean_preds <- mean_fit$predict(task)
      errors <- task$Y - mean_preds
      dens_preds <- approx(dens_fit$x, dens_fit$y, errors, rule = 2)$y
      # dens_preds[is.na(dens_preds)] <- 0
      return(dens_preds)
    },
    .required_packages = c()
  )
)
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.