Nothing
utils::globalVariables(c("in_bin", "bin_id"))
#' HAL Conditional Density Estimation in a Cross-validation Fold
#'
#' @details Estimates the conditional density of A|W for a subset of the full
#' set of observations based on the inputted structure of the cross-validation
#' folds. This is a helper function intended to be used to select the optimal
#' value of the penalization parameter for the highly adaptive lasso estimates
#' of the conditional hazard (via \code{\link[origami]{cross_validate}}). The
#'
#' @param fold Object specifying cross-validation folds as generated by a call
#' to \code{\link[origami]{make_folds}}.
#' @param long_data A \code{data.table} or \code{data.frame} object containing
#' the data in long format, as given in \insertRef{diaz2011super}{haldensify},
#' as produced by \code{\link{format_long_hazards}}.
#' @param wts A \code{numeric} vector of observation-level weights, matching in
#' its length the number of records present in the long format data. Default
#' is to weight all observations equally.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#' parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#' HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#' is set to zero, for indicator basis functions.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#' that may be used to control fitting of the HAL regression model. Possible
#' choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#' and \code{return_x_basis}, but this list is not exhaustive. Consult the
#' documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#'
#' @importFrom stats aggregate plogis
#' @importFrom origami training validation fold_index
#' @importFrom assertthat assert_that
#' @importFrom hal9001 fit_hal
#' @importFrom Rdpack reprompt
#'
#' @return A \code{list}, containing density predictions, observations IDs,
#' observation-level weights, and cross-validation indices for conditional
#' density estimation on a single fold of the overall data.
cv_haldensify <- function(fold,
long_data,
wts = rep(1, nrow(long_data)),
lambda_seq = exp(seq(-1, -13, length = 1000L)),
smoothness_orders = 0L,
...) {
# make training and validation folds
train_set <- origami::training(long_data)
valid_set <- origami::validation(long_data)
# subset observation-level weights to the correct size
wts_train <- wts[fold$training_set]
wts_valid <- wts[fold$validation_set]
# fit a HAL regression on the training set
# NOTE: not selecting lambda by CV so no need to pass IDs for fold splitting
fit_hal_args <- list(...)
if (!any(grepl("fit_control", names(fit_hal_args)))) {
fit_hal_args$fit_control <- list(cv_select = FALSE, weights = wts_train)
} else {
fit_hal_args$fit_control$cv_select <- FALSE
fit_hal_args$fit_control$weights <- wts_train
}
fit_hal_args$X <- as.matrix(train_set[, -c(1, 2)])
fit_hal_args$Y <- as.numeric(train_set$in_bin)
fit_hal_args$family <- "binomial"
fit_hal_args$lambda <- lambda_seq
fit_hal_args$smoothness_orders <- smoothness_orders
hal_fit_train <- do.call(hal9001::fit_hal, fit_hal_args)
# get intercept and coefficient fits for this value of lambda from glmnet
alpha_hat <- hal_fit_train$lasso_fit$a0
betas_hat <- hal_fit_train$lasso_fit$beta
coefs_hat <- rbind(alpha_hat, betas_hat)
# make design matrix for validation set manually
pred_x_basis <- hal9001::make_design_matrix(
as.matrix(valid_set[, -c(1, 2)]),
hal_fit_train$basis_list
)
pred_x_basis <- hal9001::apply_copy_map(
pred_x_basis,
hal_fit_train$copy_map
)
pred_x_basis <- cbind(rep(1, nrow(valid_set)), pred_x_basis)
# manually predict along sequence of lambdas
preds_logit <- pred_x_basis %*% coefs_hat
preds <- stats::plogis(as.matrix(preds_logit))
# compute hazard for a given observation by looping over individuals
density_pred_each_obs <- lapply(unique(valid_set$obs_id), function(id) {
# get predictions for the current observation only
hazard_pred_this_obs <- matrix(preds[valid_set$obs_id == id, ],
ncol = length(lambda_seq)
)
# map hazard to density for a single observation and return
density_pred_this_obs <-
map_hazard_to_density(hazard_pred_single_obs = hazard_pred_this_obs)
# output estimated density for the given observation
return(density_pred_this_obs)
})
# aggregate predictions across observations
density_pred <- do.call(rbind, as.list(density_pred_each_obs))
# collapse weights to the observation level
wts_valid_reduced <- stats::aggregate(
wts_valid, list(valid_set$obs_id),
unique
)
colnames(wts_valid_reduced) <- c("id", "weight")
# construct output
out <- list(
preds = density_pred,
ids = wts_valid_reduced$id,
wts = wts_valid_reduced$weight,
fold = origami::fold_index()
)
return(out)
}
###############################################################################
#' Cross-validated HAL Conditional Density Estimation
#'
#' @details Estimation of the conditional density A|W through using the highly
#' adaptive lasso to estimate the conditional hazard of failure in a given
#' bin over the support of A. Cross-validation is used to select the optimal
#' value of the penalization parameters, based on minimization of the weighted
#' log-likelihood loss for a density.
#'
#' @param A The \code{numeric} vector observed values.
#' @param W A \code{data.frame}, \code{matrix}, or similar giving the values of
#' baseline covariates (potential confounders) for the observed units. These
#' make up the conditioning set for the density estimate. For estimation of a
#' marginal density, specify a constant \code{numeric} vector or \code{NULL}.
#' @param wts A \code{numeric} vector of observation-level weights. The default
#' is to weight all observations equally.
#' @param grid_type A \code{character} indicating the strategy to be used in
#' creating bins along the observed support of \code{A}. For bins of equal
#' range, use \code{"equal_range"}; consult the documentation of
#' \code{\link[ggplot2]{cut_interval}} for more information. To ensure each
#' bin has the same number of observations, use \code{"equal_mass"}; consult
#' the documentation of \code{\link[ggplot2]{cut_number}} for details. The
#' default is \code{"equal_range"} since this has been found to provide better
#' performance in simulation experiments; however, both types may be specified
#' (i.e., \code{c("equal_range", "equal_mass")}) together, in which case
#' cross-validation will be used to select the optimal binning strategy.
#' @param n_bins This \code{numeric} value indicates the number(s) of bins into
#' which the support of \code{A} is to be divided. As with \code{grid_type},
#' multiple values may be specified, in which case cross-validation will be
#' used to choose the optimal number of bins. The default sets the candidate
#' choices of the number of bins based on heuristics tested in simulation.
#' @param cv_folds A \code{numeric} indicating the number of cross-validation
#' folds to be used in fitting the sequence of HAL conditional density models.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#' parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}} via
#' its argument \code{lambda}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#' HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#' is set to zero, for indicator basis functions.
#' @param hal_basis_list A \code{list} consisting of a preconstructed set of
#' HAL basis functions, as produced by \code{\link[hal9001]{fit_hal}}. The
#' default of \code{NULL} results in creating such a set of basis functions.
#' When specified, this is passed directly to the HAL model fitted upon the
#' augmented (repeated measures) data structure, resulting in a much lowered
#' computational cost. This is useful, for example, in fitting HAL conditional
#' density estimates with external cross-validation or bootstrap samples.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#' that may be used to control fitting of the HAL regression model. Possible
#' choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#' and \code{return_x_basis}, but this list is not exhaustive. Consult the
#' documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#'
#' @note Parallel evaluation of the cross-validation procedure to select tuning
#' parameters for density estimation may be invoked via the framework exposed
#' in the \pkg{future} ecosystem. Specifically, set \code{\link[future]{plan}}
#' for \code{\link[future.apply]{future_mapply}} to be used internally.
#'
#' @importFrom data.table ":="
#' @importFrom future.apply future_mapply
#' @importFrom hal9001 fit_hal
#'
#' @return Object of class \code{haldensify}, containing a fitted
#' \code{hal9001} object; a vector of break points used in binning \code{A}
#' over its support \code{W}; sizes of the bins used in each fit; the tuning
#' parameters selected by cross-validation; the full sequence (in lambda) of
#' HAL models for the CV-selected number of bins and binning strategy; and
#' the range of \code{A}.
#'
#' @export
#'
#' @examples
#' # simulate data: W ~ U[-4, 4] and A|W ~ N(mu = W, sd = 0.5)
#' set.seed(429153)
#' n_train <- 50
#' w <- runif(n_train, -4, 4)
#' a <- rnorm(n_train, w, 0.5)
#' # learn relationship A|W using HAL-based density estimation procedure
#' haldensify_fit <- haldensify(
#' A = a, W = w, n_bins = 10L, lambda_seq = exp(seq(-1, -10, length = 100)),
#' # the following arguments are passed to hal9001::fit_hal()
#' max_degree = 3, reduce_basis = 1 / sqrt(length(a))
#' )
haldensify <- function(A, W,
wts = rep(1, length(A)),
grid_type = "equal_range",
n_bins = round(c(0.5, 1, 1.5, 2) * sqrt(length(A))),
cv_folds = 5L,
lambda_seq = exp(seq(-1, -13, length = 1000L)),
smoothness_orders = 0L,
hal_basis_list = NULL,
...) {
# capture dot arguments to hal9001::fit_hal()
fit_hal_args <- list(...)
# if W is set to NULL, create a constant conditioning set
# NOTE: this essentially recovers the marginal density of A
if (is.null(W)) {
W <- rep(1, length(A))
}
# run CV-HAL for all combinations of n_bins and grid_type
tune_grid <- expand.grid(
grid_type = grid_type, n_bins = n_bins,
stringsAsFactors = FALSE
)
# run procedure to select tuning parameters via cross-validation
# NOTE: even when the number of bins and discretization technique are fixed,
# this step is still required to produce a CV-selected choice of lambda
select_out <- future.apply::future_mapply(
FUN = fit_haldensify,
grid_type = tune_grid$grid_type,
n_bins = tune_grid$n_bins,
MoreArgs = list(
A = A, W = W, wts = wts,
cv_folds = cv_folds,
lambda_seq = lambda_seq,
smoothness_orders = smoothness_orders,
...
),
SIMPLIFY = FALSE,
future.seed = TRUE
)
# extract n_bins/grid_type index that is empirical loss minimizer
emp_risk_per_lambda <- lapply(select_out, `[[`, "emp_risks")
min_loss_idx <- lapply(emp_risk_per_lambda, which.min)
min_risk <- lapply(emp_risk_per_lambda, min)
cv_selected_params <- tune_grid[which.min(min_risk), , drop = FALSE]
cv_selected_fits <- select_out[[which.min(min_risk)]]
cv_selected_fits$density_pred <- NULL
# re-format input data into long hazards structure
reformatted_output <- format_long_hazards(
A = A, W = W, wts = wts,
grid_type = cv_selected_params$grid_type,
n_bins = cv_selected_params$n_bins
)
long_data <- reformatted_output$data
breakpoints <- reformatted_output$breaks
bin_sizes <- reformatted_output$bin_length
# fit a HAL regression on the full data set across a sequence in lambda
# NOTE: no sample-splitting since there's no need to select among any of the
# tuning parameters -- advantage: simplifies working with re-sampled
# data (bootstrap); disadvantage: non-sample-split nuisance estimates
if (!any(grepl("fit_control", names(fit_hal_args)))) {
fit_hal_args$fit_control <- list(
cv_select = FALSE, weights = as.numeric(long_data$wts), n_folds = 1
)
} else {
fit_hal_args$fit_control$cv_select <- FALSE
fit_hal_args$fit_control$weights <- as.numeric(long_data$wts)
fit_hal_args$fit_control$n_folds <- 1L
}
fit_hal_args$X <- as.matrix(long_data[, -c("obs_id", "in_bin", "wts")])
fit_hal_args$Y <- as.numeric(long_data$in_bin)
fit_hal_args$basis_list <- hal_basis_list
fit_hal_args$family <- "binomial"
fit_hal_args$lambda <- lambda_seq
fit_hal_args$smoothness_orders <- smoothness_orders
hal_fit <- do.call(hal9001::fit_hal, fit_hal_args)
# construct output
out <- list(
hal_fit = hal_fit,
breaks = breakpoints,
bin_sizes = bin_sizes,
range_a = range(A),
grid_type_cvselect = cv_selected_params$grid_type,
n_bins_cvselect = cv_selected_params$n_bins,
cv_tuning_results = cv_selected_fits
)
class(out) <- "haldensify"
return(out)
}
###############################################################################
#' Fit Conditional Density Estimation for a Sequence of HAL Models
#'
#' @details Estimation of the conditional density of A|W via a cross-validated
#' highly adaptive lasso, used to estimate the conditional hazard of failure
#' in a given bin over the support of A.
#'
#' @param A The \code{numeric} vector of observed values.
#' @param W A \code{data.frame}, \code{matrix}, or similar giving the values of
#' baseline covariates (potential confounders) for the observed units. These
#' make up the conditioning set for the conditional density estimate.
#' @param wts A \code{numeric} vector of observation-level weights. The default
#' is to weight all observations equally.
#' @param grid_type A \code{character} indicating the strategy to be used in
#' creating bins along the observed support of \code{A}. For bins of equal
#' range, use \code{"equal_range"}; consult the documentation of
#' \code{\link[ggplot2]{cut_interval}} for more information. To ensure each
#' bin has the same number of observations, use \code{"equal_mass"}; consult
#' the documentation of \code{\link[ggplot2]{cut_number}} for details.
#' @param n_bins This \code{numeric} value indicates the number(s) of bins into
#' which the support of \code{A} is to be divided. As with \code{grid_type},
#' multiple values may be specified, in which case cross-validation will be
#' used to choose the optimal number of bins. The default sets the candidate
#' choices of the number of bins based on heuristics tested in simulation.
#' @param cv_folds A \code{numeric} indicating the number of cross-validation
#' folds to be used in fitting the sequence of HAL conditional density models.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#' parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#' HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#' is set to zero, for indicator basis functions.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#' that may be used to control fitting of the HAL regression model. Possible
#' choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#' and \code{return_x_basis}, but this list is not exhaustive. Consult the
#' documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#'
#' @importFrom data.table ":="
#' @importFrom matrixStats colMeans2
#' @importFrom origami make_folds cross_validate
#'
#' @return A \code{list}, containing density predictions for the sequence of
#' fitted HAL models; the index and value of the L1 regularization parameter
#' minimizing the density loss; and the sequence of empirical risks for the
#' sequence of fitted HAL models.
#'
#' @export
#'
#' @examples
#' # simulate data: W ~ U[-4, 4] and A|W ~ N(mu = W, sd = 0.5)
#' n_train <- 50
#' w <- runif(n_train, -4, 4)
#' a <- rnorm(n_train, w, 0.5)
#' # fit cross-validated HAL-based density estimator of A|W
#' haldensify_cvfit <- fit_haldensify(
#' A = a, W = w, n_bins = 10L, lambda_seq = exp(seq(-1, -10, length = 100)),
#' # the following arguments are passed to hal9001::fit_hal()
#' max_degree = 3, reduce_basis = 1 / sqrt(length(a))
#' )
fit_haldensify <- function(A, W,
wts = rep(1, length(A)),
grid_type = "equal_range",
n_bins = round(c(0.5, 1, 1.5, 2) * sqrt(length(A))),
cv_folds = 5L,
lambda_seq = exp(seq(-1, -13, length = 1000L)),
smoothness_orders = 0L,
...) {
# capture dot arguments for reference
dot_args <- list(...)
# re-format input data into long hazards structure
reformatted_output <- format_long_hazards(
A = A, W = W, wts = wts,
grid_type = grid_type, n_bins = n_bins
)
long_data <- reformatted_output$data
bin_sizes <- reformatted_output$bin_length
# extract weights from long format data structure
wts_long <- long_data$wts
long_data[, wts := NULL]
# make folds with origami
folds <- origami::make_folds(long_data,
V = cv_folds,
cluster_ids = long_data$obs_id
)
# call cross_validate on cv_density function
haldensity <- origami::cross_validate(
cv_fun = cv_haldensify,
folds = folds,
long_data = long_data,
wts = wts_long,
lambda_seq = lambda_seq,
smoothness_orders = smoothness_orders,
...,
use_future = FALSE,
.combine = FALSE
)
# re-organize output cross-validation procedure
density_pred_unscaled <- do.call(rbind, as.list(haldensity$preds))
# re-scale predictions by multiplying by bin width for each failure bin
density_pred_scaled <- apply(density_pred_unscaled, 2, function(x) {
pred <- x / bin_sizes[long_data[in_bin == 1, bin_id]]
return(pred)
})
obs_wts <- do.call(c, as.list(haldensity$wts))
# compute loss for the given individual
density_loss <- apply(density_pred_scaled, 2, function(x) {
pred_weighted <- x * obs_wts
loss_weighted <- -log(pred_weighted)
return(loss_weighted)
})
# take column means to have average loss across sequence of lambdas
emp_risks_density_loss <- matrixStats::colMeans2(density_loss)
# find minimizer of loss in lambda sequence
lambda_loss_min_idx <- which.min(emp_risks_density_loss)
lambda_loss_min <- lambda_seq[lambda_loss_min_idx]
# return loss minimizer in lambda, Pn losses, and all density estimates
out <- list(
lambda_loss_min_idx = lambda_loss_min_idx,
lambda_loss_min = lambda_loss_min,
emp_risks = emp_risks_density_loss,
density_pred = density_pred_scaled,
lambda_seq = lambda_seq
)
return(out)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.