R/Lrnr_density_semiparametric.R

#' Density Estimation With Mean Model and Homoscedastic Errors
#'
#' This learner assumes a mean model with homoscedastic errors: Y ~ E(Y|W) + epsilon. E(Y|W) is fit using any mean learner,
#' and then the errors are fit with kernel density estimation.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom assertthat assert_that is.count is.flag
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#'  \code{\link{Lrnr_base}} for documentation on learners.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#'   \item{\code{binomial_learner}}{The learner to wrap.}
#' }
#'
#' @template common_parameters
#'
#' @examples
#' # load example data
#' data(cpp_imputed)
#'
#' # create sl3 task
#' task <- sl3_Task$new(
#'   cpp_imputed,
#'   covariates = c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs"),
#'   outcome = "haz"
#' )
#'
#' # train density hse learner and make predictions
#' lrnr_density_semi <- Lrnr_density_semiparametric$new(
#'   mean_learner = Lrnr_glm$new()
#' )
#' lrnr_density_semi_fit <- lrnr_density_semi$train(task)
#' lrnr_density_semi_pred <- lrnr_density_semi_fit$predict()
Lrnr_density_semiparametric <- R6Class(
  classname = "Lrnr_density_semiparametric",
  inherit = Lrnr_base, portable = TRUE,
  class = TRUE,
  public = list(
    initialize = function(mean_learner = NULL, var_learner = NULL, bw = NULL, ...) {
      params <- args_to_list()

      if (is.null(params$mean_learner)) {
        params$mean_learner <- make_learner(Lrnr_glm_fast)
      }

      super$initialize(params = params, ...)
    },
    sample = function(task, n_samples = 30, fold_number = "full") {
      # TODO: fold
      # method: inverse cdf
      # style: parallel
      mean_fit <- self$fit_object$mean_fit
      var_fit <- self$fit_object$var_fit
      dens_fit <- self$fit_object$dens_fit

      mean_preds <- mean_fit$predict(task)
      if (!is.null(var_fit)) {
        var_preds <- var_fit$predict(task)

        var_preds[var_preds < 0] <- self$fit_object$min_obs_error
        sd_preds <- sqrt(var_preds)
      } else {
        sd_preds <- rep(1, task$nrow)
      }

      samples <- inverse_sample(n_samples, dens_fit)
      obs_samples <- sapply(
        1:task$nrow,
        function(i) (samples + mean_preds[i]) / sd_preds[i]
      )

      return(t(obs_samples))
    },
    get_outcome_range = function(task = NULL, fold_number = "full") {
      # TODO: fold
      mean_fit <- self$fit_object$mean_fit
      dens_fit <- self$fit_object$dens_fit

      if (!is.null(task)) {
        mean_preds <- mean_fit$predict(task)
        minimum <- min(dens_fit$x) + mean_preds
        maximum <- max(dens_fit$x) + mean_preds
        range <- cbind(minimum, maximum)
      } else {
        mean_preds <- mean_fit$predict(self$training_task)
        minimum <- min(dens_fit$x) + min(mean_preds)
        maximum <- max(dens_fit$x) + max(mean_preds)
        range <- c(minimum, maximum)
      }
      return(range)
    }
  ),
  private = list(
    .properties = c("density", "sampling"),
    .train = function(task) {
      mean_learner <- self$params$mean_learner
      var_learner <- self$params$var_learner
      bw <- self$params$bw

      if (is.null(bw)) {
        bw <- "nrd0"
      }
      mean_fit <- mean_learner$train(task)

      # todo: maybe these should be cv errors?
      mean_preds <- mean_fit$predict()
      errors <- task$Y - mean_preds

      if (!is.null(var_learner)) {
        new_columns <- task$add_columns(data.table(squared_error = errors^2))
        se_task <- task$next_in_chain(column_names = new_columns, outcome = "squared_error")
        min_obs_error <- 2 * min(se_task$Y)
        var_fit <- var_learner$train(se_task)
        var_preds <- var_fit$predict()
        var_preds[var_preds < 0] <- min_obs_error
        sd_preds <- sqrt(var_preds)
        errors <- errors / sd_preds
      } else {
        var_fit <- NULL
        min_obs_error <- NA
      }



      dens_fit <- density(errors, bw = bw)
      fit_object <- list(
        mean_fit = mean_fit,
        var_fit = var_fit,
        min_obs_error = min_obs_error,
        dens_fit = dens_fit
      )

      return(fit_object)
    },
    .predict = function(task) {
      mean_fit <- self$fit_object$mean_fit
      var_fit <- self$fit_object$var_fit
      dens_fit <- self$fit_object$dens_fit
      mean_preds <- mean_fit$predict(task)
      errors <- task$Y - mean_preds
      if (!is.null(var_fit)) {
        var_preds <- var_fit$predict(task)

        var_preds[var_preds < 0] <- self$fit_object$min_obs_error
        sd_preds <- sqrt(var_preds)
      } else {
        sd_preds <- rep(1, task$nrow)
      }

      errors <- errors / sd_preds

      dens_preds <- approx(dens_fit$x, dens_fit$y, errors, rule = 2)$y

      dens_preds <- dens_preds / sd_preds

      # dens_preds[is.na(dens_preds)] <- 0
      return(dens_preds)
    }
  )
)
jeremyrcoyle/sl3 documentation built on April 30, 2024, 10:16 p.m.