Nothing
#' @title R6 class for prediction models
#' @description Interface for statistical and machine learning models to be used
#' for nuisance model estimation in targeted learning.
#'
#' The following list provides an overview of constructors for many commonly
#' used models.
#'
#' Regression and classification: [learner_glm], [learner_gam], [learner_grf],
#' [learner_hal], [learner_glmnet_cv], [learner_svm], [learner_xgboost],
#' [learner_mars] \cr
#' Regression: [learner_isoreg] \cr
#' Classification: [learner_naivebayes] \cr
#' Ensemble (super learner): [learner_sl]
#' @param data data.frame
#' @author Klaus Kähler Holst, Benedikt Sommer
#' @examples
#' data(iris)
#' rf <- function(formula, ...) {
#' learner$new(formula,
#' info = "grf::probability_forest",
#' estimate = function(x, y, ...) {
#' grf::probability_forest(X = x, Y = y, ...)
#' },
#' predict = function(object, newdata) {
#' predict(object, newdata)$predictions
#' },
#' estimate.args = list(...)
#' )
#' }
#'
#' args <- expand.list(
#' num.trees = c(100, 200), mtry = 1:3,
#' formula = c(Species ~ ., Species ~ Sepal.Length + Sepal.Width)
#' )
#' models <- lapply(args, function(par) do.call(rf, par))
#'
#' x <- models[[1]]$clone()
#' x$estimate(iris)
#' predict(x, newdata = head(iris))
#'
#' \donttest{
#' # Reduce Ex. timing
#' a <- targeted::cv(models, data = iris)
#' cbind(coef(a), attr(args, "table"))
#' }
#'
#' # defining learner via function with arguments y (response)
#' # and x (design matrix)
#' f1 <- learner$new(
#' estimate = function(y, x) lm.fit(x = x, y = y),
#' predict = function(object, newdata) newdata %*% object$coefficients
#' )
#' # defining the learner via arguments formula and data
#' f2 <- learner$new(
#' estimate = function(formula, data, ...) glm(formula, data, ...)
#' )
#' # generic learner defined from function (predict method derived per default
#' # from stats::predict
#' f3 <- learner$new(
#' estimate = function(dt, ...) {
#' lm(y ~ x, data = dt)
#' }
#' )
#' @export
learner <- R6::R6Class("learner", # nolint
public = list(
#' @field info Optional information/name of the model
info = NULL,
#' @description
#' Create a new prediction model object
#' @param formula formula specifying outcome and design matrix
#' @param estimate function for fitting the model. This must be a function
#' with response, 'y', and design matrix, 'x'. Alternatively, a function
#' with a formula and data argument. See the examples section.
#' @param predict prediction function (must be a function of model
#' object, 'object', and new design matrix, 'newdata')
#' @param info optional description of the model
#' @param predict.args optional arguments to prediction function
#' @param estimate.args optional arguments to estimate function
#' @param specials optional specials terms (weights, offset,
#' id, subset, ...) passed on to [targeted::design]
#' @param formula.keep.specials if TRUE then special terms defined by
#' `specials` will be removed from the formula before it is being passed to
#' the estimate print.function()
#' @param intercept (logical) include intercept in design matrix
initialize = function(formula = NULL,
estimate,
predict = stats::predict,
predict.args = NULL,
estimate.args = NULL,
info = NULL,
specials = c(),
formula.keep.specials = FALSE,
intercept = FALSE
) {
estimate <- add_dots(estimate)
private$des.args <- list(specials = specials, intercept = intercept)
fit_formula <- "formula" %in% formalArgs(estimate)
fit_data_arg <- "data" %in% formalArgs(estimate)
private$init.estimate <- estimate
private$init.predict <- predict
private$estimate.args <- estimate.args
no_formula <- is.null(formula)
if (!no_formula && is.character(formula) || is.function(formula)) {
no_formula <- TRUE
}
if (no_formula) {
private$fitfun <- function(...) {
args <- private$update_args(private$estimate.args, ...)
return(do.call(private$init.estimate, args))
}
private$predfun <- function(...) {
args <- private$update_args(predict.args, ...)
return(do.call(private$init.predict, args))
}
} else {
if (fit_formula) { # Formula in arguments of estimation procedure
private$fitfun <- function(data, ...) {
des <- do.call(
targeted::design,
c(list(formula = self$formula,
data = data,
design.matrix = FALSE),
private$des.args
)
)
args <- private$update_args(private$estimate.args, ...) #
form <- self$formula
if (!private$formula.keep.specials) form <- des$formula
args <- c(
args, list(formula = form, data = data)
)
if (length(des$specials) > 0) {
args <- c(args, des[des$specials])
}
return(structure(do.call(private$init.estimate, args),
design = summary(des)
))
}
} else {
# Formula automatically processed into design matrix & response
private$fitfun <- function(data, ...) {
xx <- do.call(
targeted::design,
c(list(formula = self$formula, data = data), private$des.args)
)
args <- private$update_args(private$estimate.args, ...)
args <- c(list(x = xx$x, y = xx$y), args)
if (length(xx$specials) > 0) {
args <- c(args, xx[xx$specials])
}
return(structure(do.call(private$init.estimate, args),
design = summary(xx)
))
}
}
private$predfun <- function(object, data, ...) {
if (no_formula) {
predict_args_call <- private$update_args(predict.args, ...)
args <- c(list(object, newdata = data), predict_args_call)
} else {
args <- list(...)
des <- update(attr(object, "design"), data)
for (s in des$specials) {
if (is.null(args[[s]])) args[[s]] <- des[[s]]
}
predict_args_call <- predict.args
predict_args_call[names(args)] <- args
newdata <- data
if (!fit_formula) {
newdata <- model.matrix(des)
}
args <- c(list(object,
newdata = newdata
), predict_args_call)
}
return(do.call(private$init.predict, args))
}
}
private$.formula <- formula
private$formula.keep.specials <- formula.keep.specials
self$info <- info
private$init <- list(
estimate.args = estimate.args,
predict.args = predict.args,
estimate = estimate,
predict = predict,
specials = specials,
intercept = intercept
)
},
#' @description
#' Estimation method
#' @param ... Additional arguments to estimation method
#' @param store Logical determining if estimated model should be
#' stored inside the class.
estimate = function(data, ..., store = TRUE) {
res <- private$fitfun(data, ...)
if (store) private$fitted <- res
return(invisible(res))
},
#' @description
#' Prediction method
#' @param newdata data.frame
#' @param ... Additional arguments to prediction method
#' @param object Optional model fit object
predict = function(newdata, ..., object = NULL) {
if (is.null(object)) object <- private$fitted
if (is.null(object)) stop("Provide estimated model object")
return(private$predfun(object, newdata, ...))
},
#' @description
#' Update formula
#' @param formula formula or character which defines the new response
update = function(formula) {
if (is.character(formula)) {
if (grepl("~", formula)) {
formula <- as.formula(formula)
} else {
formula <- reformulate(as.character(private$.formula)[3], formula)
}
}
private$.formula <- formula
environment(private$fitfun)$formula <- formula
environment(private$fitfun)$self <- self
return(invisible(formula))
},
#' @description
#' Print method
print = function() {
learner_print(self, private)
return(invisible())
},
#' @description
#' Summary method to provide more extensive information than
#' [learner$print()][learner].
#' @return summarized_learner object, which is a list with the following
#' elements:
#' \describe{
#' \item{info}{description of the learner}
#' \item{formula}{formula specifying outcome and design matrix}
#' \item{estimate}{function for fitting the model}
#' \item{estimate.args}{arguments to estimate function}
#' \item{predict}{function for making predictions from fitted model}
#' \item{predict.args}{arguments to predict function}
#' \item{specials}{provided special terms}
#' \item{intercept}{include intercept in design matrix}
#' }
#' @examples
#' lr <- learner_glm(y ~ x, family = "nb")
#' lr$summary()
#'
#' lr_sum <- lr$summary() # store returned summary in new object
#' names(lr_sum)
#' print(lr_sum)
summary = function() {
obj <- structure(
c(list(formula = self$formula, info = self$info), private$init),
class = "summarized_learner"
)
return(obj)
},
#' @description
#' Extract response from data
#' @param eval when FALSE return the untransformed outcome
#' (i.e., return 'a' if formula defined as I(a==1) ~ ...)
#' @param ... additional arguments to [targeted::design]
response = function(data, eval = TRUE, ...) {
if (eval) {
return(self$design(data = data, ..., design.matrix = FALSE)$y)
}
if (is.null(self$formula)) return(NULL)
newf <- update(self$formula, ~1)
return(data[, all.vars(newf), drop = TRUE])
},
#' @description
#' Generate [targeted::design] object (design matrix and response) from data
#' @param ... additional arguments to [targeted::design]
design = function(data, ...) {
args <- c(private$des.args, list(data = data))
args[...names()] <- list(...)
return(do.call(design, c(list(self$formula), args)))
},
#' @description
#' Get options
#' @param arg name of option to get value of
opt = function(arg) {
return(private$estimate.args[[arg]])
}
),
active = list(
#' @field clear Remove fitted model from the learner object
clear = function() invisible(private$fitted <- NULL),
#' @field fit Return estimated model object.
fit = function(value) {
if (missing(value)) return(private$fitted)
else private$fitted <- NULL
},
#' @field formula Return model formula. Use [learner$update()][learner] to
#' update the formula.
formula = function() {
private$.formula
}
),
private = list(
# @field des.args Arguments for targeted::design
des.args = NULL,
# @field estimate.args Arguments for estimate method
estimate.args = NULL,
# @field init.estimate Original estimate method supplied at initialization
init.estimate = NULL,
# @field init.predict Original predict method supplied at initialization
init.predict = NULL,
# @field predfun Prediction method
predfun = NULL,
# @field fitfun Estimation method
fitfun = NULL,
# @field fitted Fitted model object
fitted = NULL,
# @field .formula Model formula object // uses dot as a pre-fix to allow
# using formula as an active binding
.formula = NULL,
# @field formula.keep.specials if TRUE then special terms defined by
# `specials` will be removed from the formula before it is being passed to
# the estimate print.function()
formula.keep.specials = NULL,
# @field init Information on the initialized model
init = NULL,
# When x$clone(deep=TRUE) is called, the deep_clone gets invoked once for
# each field, with the name and value.
deep_clone = function(name, value) {
if (name == "fitfun") {
env <- list2env(
as.list.environment(environment(value),
all.names = TRUE
),
parent = globalenv()
)
environment(value) <- env
return(value)
} else {
# For everything else, just return it. This results in a shallow
# copy of s3.
return(value)
}
},
# Utility to update list of arguments with ellipsis
# @param args list or NULL
update_args = function(args, ...) {
if (is.null(args)) args <- list() # because predict.args = NULL by default
dots <- list(...)
# update args for unnamed list of arguments
if (length(dots) > 0 && is.null(names(dots))) {
args <- c(args, dots)
} else {
args[names(dots)] <- dots
}
return(args)
}
)
)
#' @export
estimate.learner <- function(x, ...) {
return(x$estimate(...))
}
#' @export
predict.learner <- function(object, ...) {
return(object$predict(...))
}
format_fit_predict_args <- function(args) {
if (length(args) == 0) return(" ")
funs <- c(is.numeric, is.character, is.integer, is.logical, is.null)
# print family attribute of family objects instead of printing only that the
# argument is of class <family>
mask <- unlist(lapply(args, \(x) inherits(x, "family")))
vals <- lapply(args[mask], \(x) x$family)
args[mask] <- vals
mask <- unlist(lapply(args, \(x) any(sapply(funs, \(f) f(x)))))
args_class <- paste0("<", lapply(args, \(x) class(x)[[1]]), ">")
args[!mask] <- args_class[!mask]
return(paste0(names(args), "=", args, collapse =", "))
}
learner_print <- function(self, private) {
cat_ruler(" learner object ", 10)
if (!is.null(self$info)) {
cat(self$info, "\n\n")
}
cat(
"Estimate arguments:",
format_fit_predict_args(private$init$estimate.args),
"\nPredict arguments:",
format_fit_predict_args(private$init$predict.args),
"\nFormula:",
capture.output(print(self$formula)),
"\n"
)
if (!is.null(private$fitted)) {
cat_ruler("\u2500", 18)
fit <- self$fit
attr(fit, "design") <- NULL
if (!is.atomic(fit) && !is.null(fit$call)) fit$call <- substitute()
if (!is.atomic(fit) && !is.null(attributes(fit)$call)) {
attributes(fit$call) <- substitute()
}
cat(capture.output(print(fit)), sep ="\n")
}
return(invisible())
}
#' @export
print.summarized_learner <- function(x, ...) {
cat_ruler(" learner object ", 10)
if (!is.null(x$info)) {
cat(x$info, "\n\n")
}
cat(
"formula:",
capture.output(print(x$formula)),
"\nestimate:", paste0(names(formals(x$estimate)), collapse = ", "),
"\nestimate.args:",
format_fit_predict_args(x$estimate.args),
"\npredict:", paste0(names(formals(x$predict)), collapse = ", "),
"\npredict.args:",
format_fit_predict_args(x$predict.args),
"\nspecials:", paste(x$specials, collapse = ", "),
"\n"
)
}
#' @title R6 class for prediction models
#' @description Replaced by [learner]
#' @export
ml_model <- R6Class("ml_model",
inherit = learner,
public = list(
#' @description Create a new prediction model object
#' @param ... deprecated
initialize = function(...) {
rlang::warn(paste0(
"targeted::ml_model is deprecated and will ",
"be removed in targeted v0.7.0. Use targeted::learner instead.")
)
super$initialize(...)
}
)
)
#' @export
estimate.ml_model <- function(x, ...) {
rlang::warn(paste0(
"targeted::ml_model is deprecated and will ",
"be removed in targeted v0.7.0. Use targeted::learner instead.")
)
return(x$estimate(...))
}
#' @export
predict.ml_model <- function(object, ...) {
rlang::warn(paste0(
"targeted::ml_model is deprecated and will ",
"be removed in targeted v0.7.0. Use targeted::learner instead.")
)
return(object$predict(...))
}
#' ML model
#'
#' Wrapper for ml_model
#' @export
#' @param formula formula
#' @param model model (sl, rf, pf, glm, ...)
#' @param ... additional arguments to model object
ML <- function(formula, model="glm", ...) {
stop(
"targeted::ML has been removed in targeted 0.6. ",
"Please use the targeted::learner_ functions instead."
)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.