#' Define a Machine Learning Task
#'
#' An increasingly thick wrapper around a \code{\link[data.table]{data.table}}
#' containing the data for a prediction task. This contains metadata about the
#' particular machine learning problem, including which variables are to be
#' used as covariates and outcomes.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom assertthat assert_that
#' @importFrom origami make_folds
#' @importFrom uuid UUIDgenerate
#' @importFrom digest digest
#' @importFrom data.table as.data.table data.table setcolorder setDT setnames ":="
#'
#' @export
#'
#' @keywords data
#'
#' @return \code{sl3_Task} object
#'
#' @format \code{\link{R6Class}} object.
#'
#' @template sl3_Task_extra
sl3_Task <- R6Class(
classname = "sl3_Task",
portable = TRUE,
class = TRUE,
public = list(
initialize = function(data, covariates, outcome = NULL,
outcome_type = NULL, outcome_levels = NULL,
id = NULL, weights = NULL, offset = NULL,
time = NULL, nodes = NULL, column_names = NULL,
row_index = NULL, folds = NULL, flag = TRUE,
drop_missing_outcome = FALSE) {
# generate node list from other arguments if not explicitly specified
if (is.null(nodes)) {
nodes <- list(
covariates = covariates, outcome = outcome, id = id,
weights = weights, offset = offset, time = time
)
}
# generate column name mapping if not specified
all_nodes <- unlist(nodes)
# get column names from data (and check data class in the process)
if (inherits(data, "data.frame")) {
data_names <- copy(names(data))
} else if (inherits(data, "Shared_Data")) {
data_names <- copy(data$column_names)
} else {
stop(sprintf("Data of class %s not supported", class(data)[[1]]))
}
if (is.null(column_names)) {
column_names <- data_names
names(column_names) <- column_names
}
# verify nodes are contained in column map
missing_cols <- setdiff(all_nodes, names(column_names))
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Couldn't find %s",
paste(missing_cols, collapse = " ")
)
)
# verify referenced columns are actually in data
referenced_columns <- column_names[all_nodes]
missing_cols <- setdiff(referenced_columns, data_names)
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Data doesn't contain referenced columns %s",
paste(missing_cols, collapse = " ")
)
)
# process data
if (inherits(data, "Shared_Data")) {
# we already have a Shared_Data object, so just store it
# we don't do processing because we assume it's already been done
private$.shared_data <- data
} else {
# we have some other data object, so construct a Shared_Data object
# and store it (this will copy the data)
# process characters and missings
processed <- process_data(data,
nodes,
column_names = column_names,
flag = flag, drop_missing_outcome = drop_missing_outcome
)
data <- processed$data
nodes <- processed$nodes
column_names <- processed$column_names
# process_data copies, so don't copy again here
private$.shared_data <- Shared_Data$new(data, force_copy = FALSE)
}
# store final nodes and column names
private$.nodes <- nodes
private$.column_names <- column_names
# process outcome type
if (is.null(outcome_type)) {
if (!is.null(nodes$outcome)) {
outcome_type <- variable_type(x = self$Y)
} else {
outcome_type <- variable_type("none")
}
} else {
if ("family" %in% class(outcome_type)) {
outcome_type <- outcome_type$family
}
if (is.character(outcome_type)) {
if (outcome_type == "binary") {
outcome_type <- "binomial"
}
if (outcome_type == "gaussian") {
outcome_type <- "continuous"
}
if (outcome_type == "multinomial") {
outcome_type <- "categorical"
}
allowed_types <- c(
"binomial", "categorical", "continuous", "multivariate", "none",
"quasibinomial"
)
if ((!outcome_type %in% allowed_types) | length(outcome_type) > 1) {
if (length(outcome_type) > 1) {
outcome_type <- paste(c(outcome_type), collapse = ", ")
}
stop(cat(sprintf(
"The supplied outcome_type %s is not supported.\n", outcome_type
)))
}
outcome_type <- variable_type(
type = outcome_type, levels = outcome_levels, x = self$Y
)
}
if (!"Variable_Type" %in% class(outcome_type)) {
if (length(outcome_type) > 1) {
outcome_type <- paste(c(outcome_type), collapse = ", ")
}
stop(cat(sprintf(
"The supplied outcome_type %s is not supported.\n", outcome_type
)))
}
}
private$.outcome_type <- outcome_type
# process row_index
private$.row_index <- row_index
private$.folds <- folds
# assign uuid using digest
private$.uuid <- digest::digest(self$data)
invisible(self)
},
add_interactions = function(interactions, warn_on_existing = TRUE) {
## ----------------------------------------------------------------------
## Add columns with interactions (by reference) to input design matrix
## (data.table). Used for training / predicting.
## returns the names of the added columns
## ----------------------------------------------------------------------
prod.DT <- function(x) {
y <- x[[1]]
for (i in 2:ncol(x)) {
y <- y * x[[i]]
}
return(y)
}
old_names <- self$column_names
interaction_names <- names(interactions)
if (is.null(interaction_names)) {
interaction_names <- sapply(interactions, paste0, collapse = "_")
}
is_new <- !(interaction_names %in% old_names)
if (any(!is_new)) {
warning(
"The following interactions already exist:",
paste0(interaction_names[!is_new], collapse = ", ")
)
}
interaction_data <- lapply(interactions[is_new], function(int) {
# check if interaction terms numeric
int_numeric <- sapply(int, function(i) is.numeric(self$X[[i]]))
if (all(int_numeric)) {
d_int <- data.table::data.table(self$X[, prod.DT(.SD), .SD = int])
data.table::setnames(d_int, paste0(int, collapse = "_"))
return(d_int)
} else {
# match interaction terms to X
Xmatch <- lapply(int, function(i) {
grep(i, colnames(self$X), value = TRUE)
})
Xint <- as.list(data.table::as.data.table(t(expand.grid(Xmatch))))
d_Xint <- lapply(Xint, function(Xint) {
self$X[, prod.DT(.SD), .SD = Xint]
})
data.table::setDT(d_Xint)
data.table::setnames(d_Xint, sapply(Xint, paste0, collapse = "_"))
no_Xint <- rowSums(d_Xint) == 0 # happens when we omit 1 factor level
if (any(int_numeric)) {
d_Xint$other <- rep(0, nrow(d_Xint))
d_Xint[no_Xint, "other"] <- 1
if (any(int_numeric)) {
# need to take the product if we have a numeric covariate
d_Xint[no_Xint, "other"] <- prod.DT(data.table::data.table(
rep(1, sum(no_Xint)),
self$X[no_Xint, names(which(int_numeric)), with = FALSE]
))
}
other_name <- paste0("other.", paste0(int, collapse = "_"))
colnames(d_Xint)[ncol(d_Xint)] <- other_name
}
return(d_Xint)
}
})
interaction_names <- unlist(lapply(interaction_data, colnames))
interaction_data <- data.table::data.table(
do.call(cbind, interaction_data)
)
data.table::setnames(interaction_data, interaction_names)
interaction_cols <- self$add_columns(interaction_data,
column_uuid = NULL
)
new_covariates <- c(self$nodes$covariates, interaction_names)
return(self$next_in_chain(
covariates = new_covariates,
column_names = interaction_cols
))
},
add_columns = function(new_data, column_uuid = uuid::UUIDgenerate()) {
if (is.numeric(private$.row_index)) {
new_col_map <- private$.shared_data$add_columns(
new_data, column_uuid,
as.integer(private$.row_index)
)
} else {
new_col_map <- private$.shared_data$add_columns(
new_data, column_uuid,
private$.row_index
)
}
column_names <- private$.column_names
column_names[names(new_col_map)] <- new_col_map
# return an updated column_names map
return(column_names)
},
next_in_chain = function(covariates = NULL, outcome = NULL, id = NULL,
weights = NULL, offset = NULL, time = NULL,
folds = NULL, column_names = NULL,
new_nodes = NULL, ...) {
if (is.null(new_nodes)) {
new_nodes <- self$nodes
if (!is.null(covariates)) {
new_nodes$covariates <- covariates
}
if (!is.null(outcome)) {
new_nodes$outcome <- outcome
}
if (!missing(id)) {
new_nodes$id <- id
}
if (!missing(weights)) {
new_nodes$weights <- weights
}
if (!missing(offset)) {
new_nodes$offset <- offset
}
if (!missing(time)) {
new_nodes$time <- time
}
}
if (is.null(column_names)) {
column_names <- private$.column_names
}
if (is.null(folds)) {
folds <- private$.folds
}
all_nodes <- unlist(new_nodes)
# verify nodes are contained in dataset
missing_cols <- setdiff(all_nodes, names(column_names))
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Couldn't find %s",
paste(missing_cols, collapse = " ")
)
)
new_task <- self$clone()
if ((is.null(new_nodes$outcome) &&
is.null(self$nodes$outcome)) ||
all(new_nodes$outcome == self$nodes$outcome)) {
# if we have the same outcome, transfer outcome properties
new_outcome_type <- self$outcome_type
} else {
# otherwise, let the new task guess
new_outcome_type <- NULL
}
new_task$initialize(
private$.shared_data,
nodes = new_nodes,
folds = folds,
column_names = column_names,
row_index = private$.row_index,
outcome_type = new_outcome_type,
...
)
return(new_task)
},
subset_task = function(row_index, drop_folds = FALSE) {
if (is.logical(row_index)) {
row_index <- which(row_index)
}
old_row_index <- private$.row_index
if (!is.null(old_row_index)) {
# index into the logical rows of this task
row_index <- old_row_index[row_index]
}
must_reindex <- any(duplicated(row_index))
if (must_reindex) {
new_shared_data <- private$.shared_data$clone()
new_shared_data$reindex(row_index)
row_index <- seq_along(row_index)
} else {
new_shared_data <- private$.shared_data
}
new_task <- self$clone()
if (drop_folds) {
new_folds <- NULL
} else {
if (must_reindex) {
stop("subset indices have copies, this requires dropping folds.")
}
new_folds <- subset_folds(private$.folds, row_index)
}
new_task$initialize(
new_shared_data,
nodes = private$.nodes,
folds = new_folds,
column_names = private$.column_names,
row_index = row_index,
outcome_type = self$outcome_type
)
return(new_task)
},
get_data = function(rows = NULL, columns, expand_factors = FALSE) {
if (missing(rows)) {
rows <- private$.row_index
}
true_columns <- unlist(private$.column_names[columns])
subset <- private$.shared_data$get_data(rows, true_columns)
if (ncol(subset) > 0) {
data.table::setnames(subset, true_columns, columns)
}
if (expand_factors) {
subset <- dt_expand_factors(subset)
}
return(subset)
},
has_node = function(node_name) {
node_var <- private$.nodes[[node_name]]
return(!is.null(node_var))
},
get_node = function(node_name, generator_fun = NULL,
expand_factors = FALSE) {
if (missing(generator_fun)) {
generator_fun <- function(node_name, n) {
stop(sprintf("Node %s not specified", node_name))
}
}
node_var <- private$.nodes[[node_name]]
if (is.null(node_var)) {
return(generator_fun(node_name, self$nrow))
} else {
data_col <- self$get_data(, node_var, expand_factors)
if (ncol(data_col) == 1) {
return(unlist(data_col, use.names = FALSE))
} else {
return(data_col)
}
}
},
offset_transformed = function(link_fun = NULL, for_prediction = FALSE) {
if (self$has_node("offset")) {
offset <- self$offset
# transform if sl3.transform.offset is true and link_fun was provided
if (getOption("sl3.transform.offset") && !is.null(link_fun)) {
offset <- link_fun(offset)
}
} else {
# if task has no offset, return NULL or a zero offset as is appropriate
stop("Trained with offsets but predict method called on task without.")
}
return(offset)
},
print = function() {
cat(sprintf("An sl3 Task with %d obs and these nodes:\n", self$nrow))
print(self$nodes)
},
revere_fold_task = function(fold_number) {
return(self)
},
get_folds = function(n = self$nrow, fold_fun = origami::folds_vfold,
cluster_ids = NULL, strata_ids = NULL, ...) {
### get_folds calls origami::make_folds.
### ... are arguments to be passed to fold_fun.
args <- list(...)
# incorporate other arguments
args$n <- n
args$fold_fun <- fold_fun
args$cluster_ids <- cluster_ids
args$strata_ids <- strata_ids
# specify clustered CV (this is the default when id in task)
if (self$has_node("id") & is.null(args$cluster_ids)) {
args$cluster_ids <- self$id
}
# specify stratified CV (this is the default when outcome is discrete)
if (self$outcome_type$type %in% c("binomial", "categorical") &
is.null(args$strata_ids)) {
args$strata_ids <- self$Y
}
# do not consider stratified CV if it prevents clustered CV
if (!is.null(args$strata_ids) & !is.null(args$cluster_ids)) {
clusters_nested <- all(
rowSums(table(args$cluster_ids, args$strata_ids) > 0) == 1
)
if (!clusters_nested) args$strata_ids <- NULL
}
# make folds
folds <- do.call(origami::make_folds, args)
return(folds)
}
),
active = list(
internal_data = function() {
return(private$.shared_data)
},
data = function() {
all_nodes <- unique(unlist(private$.nodes))
return(self$get_data(, all_nodes))
},
nrow = function() {
if (is.null(private$.row_index)) {
return(private$.shared_data$nrow)
} else {
return(length(private$.row_index))
}
},
nodes = function() {
return(private$.nodes)
},
X = function() {
covariates <- private$.nodes$covariates
X_dt <- self$get_data(, covariates, expand_factors = TRUE)
return(X_dt)
},
X_intercept = function() {
# returns X matrix with manually generated intercept column
X_dt <- self$X
if (ncol(X_dt) == 0) {
intercept <- rep(1, self$nrow)
X_dt <- self$data[, list(intercept = intercept)]
} else {
old_ncol <- ncol(X_dt)
X_dt[, intercept := 1]
# make intercept first column
data.table::setcolorder(X_dt, c(old_ncol + 1, seq_len(old_ncol)))
}
return(X_dt)
},
Y = function() {
return(self$get_node("outcome"))
},
offset = function() {
return(self$get_node("offset"))
},
weights = function() {
return(self$get_node("weights", function(node_var, n) {
rep(1, n)
}))
},
id = function() {
return(self$get_node("id", function(node_var, n) {
seq_len(n)
}))
},
time = function() {
return(self$get_node("time", function(node_var, n) {
if (self$has_node("id")) {
stop("times requested but not specified for this task")
} else {
self$row_index
}
}))
},
folds = function(new_folds) {
if (!missing(new_folds)) {
private$.folds <- new_folds
} else if (is.numeric(private$.folds) | is.null(private$.folds)) {
# call get_folds
if (is.numeric(private$.folds)) {
new_folds <- self$get_folds(V = private$.folds)
} else {
new_folds <- self$get_folds()
}
private$.folds <- new_folds
}
return(private$.folds)
},
uuid = function() {
return(private$.uuid)
},
column_names = function() {
return(private$.column_names)
},
outcome_type = function() {
return(private$.outcome_type)
},
row_index = function() {
return(private$.row_index)
}
),
private = list(
.shared_data = NULL,
.nodes = NULL,
.X = NULL,
.folds = NULL,
.uuid = NULL,
.column_names = NULL,
.row_index = NULL,
.outcome_type = NULL
)
)
#' @export
`[.sl3_Task` <- function(x, i = NULL, j = NULL, ...) {
return(x$subset_task(i))
}
#' @param ... Passes all arguments to the constructor. See documentation for
#' Constructor below.
#'
#' @rdname sl3_Task
#'
#' @export
make_sl3_Task <- sl3_Task$new
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.