#' HAL Conditional Density Estimation in a Cross-validation Fold
#' @details Estimates the conditional density of A|W for a subset of the full
#'  set of observations based on the inputted structure of the cross-validation
#'  folds. This is a helper function intended to be used to select the optimal
#'  value of the penalization parameter for the highly adaptive lasso estimates
#'  of the conditional hazard (via \code{\link[origami]{cross_validate}}). The
#' @param fold Object specifying cross-validation folds as generated by a call
#'  to \code{\link[origami]{make_folds}}.
#' @param long_data A \code{data.table} or \code{data.frame} object containing
#'  the data in long format, as given in \insertRef{diaz2011super}{haldensify},
#'  as produced by \code{\link{format_long_hazards}}.
#' @param wts A \code{numeric} vector of observation-level weights, matching in
#'  its length the number of records present in the long format data. Default
#'  is to weight all observations equally.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#'  parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#'  HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#'  is set to zero, for indicator basis functions.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#'  that may be used to control fitting of the HAL regression model. Possible
#'  choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#'  and \code{return_x_basis}, but this list is not exhaustive. Consult the
#'  documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#' @importFrom stats aggregate plogis
#' @importFrom origami training validation fold_index
#' @importFrom assertthat assert_that
#' @importFrom hal9001 fit_hal
#' @importFrom Rdpack reprompt
#' @return A \code{list}, containing density predictions, observations IDs,
#'  observation-level weights, and cross-validation indices for conditional
#'  density estimation on a single fold of the overall data.
cv_haldensify <- function(fold,
                          wts = rep(1, nrow(long_data)),
                          lambda_seq = exp(seq(-1, -13, length = 1000L)),
                          smoothness_orders = 0L,
                          ...) {
  # make training and validation folds
  train_set <- origami::training(long_data)
  valid_set <- origami::validation(long_data)

  # subset observation-level weights to the correct size
  wts_train <- wts[fold$training_set]
  wts_valid <- wts[fold$validation_set]

  # fit a HAL regression on the training set
  # NOTE: not selecting lambda by CV so no need to pass IDs for fold splitting
  fit_hal_args <- list(...)
  if (!any(grepl("fit_control", names(fit_hal_args)))) {
    fit_hal_args$fit_control <- list(cv_select = FALSE, weights = wts_train)
  } else {
    fit_hal_args$fit_control$cv_select <- FALSE
    fit_hal_args$fit_control$weights <- wts_train
  fit_hal_args$X <- as.matrix(train_set[, -c(1, 2)])
  fit_hal_args$Y <- as.numeric(train_set$in_bin)
  fit_hal_args$family <- "binomial"
  fit_hal_args$lambda <- lambda_seq
  fit_hal_args$smoothness_orders <- smoothness_orders
  hal_fit_train <- do.call(hal9001::fit_hal, fit_hal_args)

  # get intercept and coefficient fits for this value of lambda from glmnet
  alpha_hat <- hal_fit_train$lasso_fit$a0
  betas_hat <- hal_fit_train$lasso_fit$beta
  coefs_hat <- rbind(alpha_hat, betas_hat)

  # make design matrix for validation set manually
  pred_x_basis <- hal9001::make_design_matrix(
    as.matrix(valid_set[, -c(1, 2)]),
  pred_x_basis <- hal9001::apply_copy_map(
  pred_x_basis <- cbind(rep(1, nrow(valid_set)), pred_x_basis)

  # manually predict along sequence of lambdas
  preds_logit <- pred_x_basis %*% coefs_hat
  preds <- stats::plogis(as.matrix(preds_logit))

  # compute hazard for a given observation by looping over individuals
  density_pred_each_obs <- lapply(unique(valid_set$obs_id), function(id) {
    # get predictions for the current observation only
    hazard_pred_this_obs <- matrix(preds[valid_set$obs_id == id, ],
      ncol = length(lambda_seq)

    # map hazard to density for a single observation and return
    density_pred_this_obs <-
      map_hazard_to_density(hazard_pred_single_obs = hazard_pred_this_obs)

    # output estimated density for the given observation

  # aggregate predictions across observations
  density_pred <- do.call(rbind, as.list(density_pred_each_obs))

  # collapse weights to the observation level
  wts_valid_reduced <- stats::aggregate(
    wts_valid, list(valid_set$obs_id),
  colnames(wts_valid_reduced) <- c("id", "weight")

  # construct output
  out <- list(
    preds = density_pred,
    ids = wts_valid_reduced$id,
    wts = wts_valid_reduced$weight,
    fold = origami::fold_index()


#' Cross-validated HAL Conditional Density Estimation
#' @details Estimation of the conditional density A|W through using the highly
#'  adaptive lasso to estimate the conditional hazard of failure in a given
#'  bin over the support of A. Cross-validation is used to select the optimal
#'  value of the penalization parameters, based on minimization of the weighted
#'  log-likelihood loss for a density.
#' @param A The \code{numeric} vector observed values.
#' @param W A \code{data.frame}, \code{matrix}, or similar giving the values of
#'  baseline covariates (potential confounders) for the observed units. These
#'  make up the conditioning set for the density estimate. For estimation of a
#'  marginal density, specify a constant \code{numeric} vector or \code{NULL}.
#' @param wts A \code{numeric} vector of observation-level weights. The default
#'  is to weight all observations equally.
#' @param grid_type A \code{character} indicating the strategy to be used in
#'  creating bins along the observed support of \code{A}. For bins of equal
#'  range, use \code{"equal_range"}; consult the documentation of
#'  \code{\link[ggplot2]{cut_interval}} for more information. To ensure each
#'  bin has the same number of observations, use \code{"equal_mass"}; consult
#'  the documentation of \code{\link[ggplot2]{cut_number}} for details. The
#'  default is \code{"equal_range"} since this has been found to provide better
#'  performance in simulation experiments; however, both types may be specified
#'  (i.e., \code{c("equal_range", "equal_mass")}) together, in which case
#'  cross-validation will be used to select the optimal binning strategy.
#' @param n_bins This \code{numeric} value indicates the number(s) of bins into
#'  which the support of \code{A} is to be divided. As with \code{grid_type},
#'  multiple values may be specified, in which case cross-validation will be
#'  used to choose the optimal number of bins. The default sets the candidate
#'  choices of the number of bins based on heuristics tested in simulation.
#' @param cv_folds A \code{numeric} indicating the number of cross-validation
#'  folds to be used in fitting the sequence of HAL conditional density models.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#'  parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}} via
#'  its argument \code{lambda}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#'  HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#'  is set to zero, for indicator basis functions.
#' @param hal_basis_list A \code{list} consisting of a preconstructed set of
#'  HAL basis functions, as produced by \code{\link[hal9001]{fit_hal}}. The
#'  default of \code{NULL} results in creating such a set of basis functions.
#'  When specified, this is passed directly to the HAL model fitted upon the
#'  augmented (repeated measures) data structure, resulting in a much lowered
#'  computational cost. This is useful, for example, in fitting HAL conditional
#'  density estimates with external cross-validation or bootstrap samples.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#'  that may be used to control fitting of the HAL regression model. Possible
#'  choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#'  and \code{return_x_basis}, but this list is not exhaustive. Consult the
#'  documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#' @note Parallel evaluation of the cross-validation procedure to select tuning
#'  parameters for density estimation may be invoked via the framework exposed
#'  in the \pkg{future} ecosystem. Specifically, set \code{\link[future]{plan}}
#'  for \code{\link[future.apply]{future_mapply}} to be used internally.
#' @importFrom data.table ":="
#' @importFrom future.apply future_mapply
#' @importFrom hal9001 fit_hal
#' @return Object of class \code{haldensify}, containing a fitted
#'  \code{hal9001} object; a vector of break points used in binning \code{A}
#'  over its support \code{W}; sizes of the bins used in each fit; the tuning
#'  parameters selected by cross-validation; the full sequence (in lambda) of
#'  HAL models for the CV-selected number of bins and binning strategy; and
#'  the range of \code{A}.
#' @export
#' @examples
#' # simulate data: W ~ U[-4, 4] and A|W ~ N(mu = W, sd = 0.5)
#' set.seed(429153)
#' n_train <- 50
#' w <- runif(n_train, -4, 4)
#' a <- rnorm(n_train, w, 0.5)
#' # learn relationship A|W using HAL-based density estimation procedure
#' haldensify_fit <- haldensify(
#'   A = a, W = w, n_bins = 10L, lambda_seq = exp(seq(-1, -10, length = 100)),
#'   # the following arguments are passed to hal9001::fit_hal()
#'   max_degree = 3, reduce_basis = 1 / sqrt(length(a))
#' )
haldensify <- function(A, W,
                       wts = rep(1, length(A)),
                       grid_type = "equal_range",
                       n_bins = round(c(0.5, 1, 1.5, 2) * sqrt(length(A))),
                       cv_folds = 5L,
                       lambda_seq = exp(seq(-1, -13, length = 1000L)),
                       smoothness_orders = 0L,
                       hal_basis_list = NULL,
                       ...) {
  # capture dot arguments to hal9001::fit_hal()
  fit_hal_args <- list(...)

  # if W is set to NULL, create a constant conditioning set
  # NOTE: this essentially recovers the marginal density of A
  if (is.null(W)) {
    W <- rep(1, length(A))

  # run CV-HAL for all combinations of n_bins and grid_type
  tune_grid <- expand.grid(
    grid_type = grid_type, n_bins = n_bins,
    stringsAsFactors = FALSE

  # run procedure to select tuning parameters via cross-validation
  # NOTE: even when the number of bins and discretization technique are fixed,
  #       this step is still required to produce a CV-selected choice of lambda
  select_out <- future.apply::future_mapply(
    FUN = fit_haldensify,
    grid_type = tune_grid$grid_type,
    n_bins = tune_grid$n_bins,
    MoreArgs = list(
      A = A, W = W, wts = wts,
      cv_folds = cv_folds,
      lambda_seq = lambda_seq,
      smoothness_orders = smoothness_orders,
    future.seed = TRUE

  # extract n_bins/grid_type index that is empirical loss minimizer
  emp_risk_per_lambda <- lapply(select_out, `[[`, "emp_risks")
  min_loss_idx <- lapply(emp_risk_per_lambda, which.min)
  min_risk <- lapply(emp_risk_per_lambda, min)
  cv_selected_params <- tune_grid[which.min(min_risk), , drop = FALSE]
  cv_selected_fits <- select_out[[which.min(min_risk)]]
  cv_selected_fits$density_pred <- NULL

  # re-format input data into long hazards structure
  reformatted_output <- format_long_hazards(
    A = A, W = W, wts = wts,
    grid_type = cv_selected_params$grid_type,
    n_bins = cv_selected_params$n_bins
  long_data <- reformatted_output$data
  breakpoints <- reformatted_output$breaks
  bin_sizes <- reformatted_output$bin_length

  # fit a HAL regression on the full data set across a sequence in lambda
  # NOTE: no sample-splitting since there's no need to select among any of the
  #       tuning parameters -- advantage: simplifies working with re-sampled
  #       data (bootstrap); disadvantage: non-sample-split nuisance estimates
  if (!any(grepl("fit_control", names(fit_hal_args)))) {
    fit_hal_args$fit_control <- list(
      cv_select = FALSE, weights = as.numeric(long_data$wts), n_folds = 1
  } else {
    fit_hal_args$fit_control$cv_select <- FALSE
    fit_hal_args$fit_control$weights <- as.numeric(long_data$wts)
    fit_hal_args$fit_control$n_folds <- 1L
  fit_hal_args$X <- as.matrix(long_data[, -c("obs_id", "in_bin", "wts")])
  fit_hal_args$Y <- as.numeric(long_data$in_bin)
  fit_hal_args$basis_list <- hal_basis_list
  fit_hal_args$family <- "binomial"
  fit_hal_args$lambda <- lambda_seq
  fit_hal_args$smoothness_orders <- smoothness_orders
  hal_fit <- do.call(hal9001::fit_hal, fit_hal_args)

  # construct output
  out <- list(
    hal_fit = hal_fit,
    breaks = breakpoints,
    bin_sizes = bin_sizes,
    range_a = range(A),
    grid_type_cvselect = cv_selected_params$grid_type,
    n_bins_cvselect = cv_selected_params$n_bins,
    cv_tuning_results = cv_selected_fits
  class(out) <- "haldensify"


#' Fit Conditional Density Estimation for a Sequence of HAL Models
#' @details Estimation of the conditional density of A|W via a cross-validated
#'  highly adaptive lasso, used to estimate the conditional hazard of failure
#'  in a given bin over the support of A.
#' @param A The \code{numeric} vector of observed values.
#' @param W A \code{data.frame}, \code{matrix}, or similar giving the values of
#'  baseline covariates (potential confounders) for the observed units. These
#'  make up the conditioning set for the conditional density estimate.
#' @param wts A \code{numeric} vector of observation-level weights. The default
#'  is to weight all observations equally.
#' @param grid_type A \code{character} indicating the strategy to be used in
#'  creating bins along the observed support of \code{A}. For bins of equal
#'  range, use \code{"equal_range"}; consult the documentation of
#'  \code{\link[ggplot2]{cut_interval}} for more information. To ensure each
#'  bin has the same number of observations, use \code{"equal_mass"}; consult
#'  the documentation of \code{\link[ggplot2]{cut_number}} for details.
#' @param n_bins This \code{numeric} value indicates the number(s) of bins into
#'  which the support of \code{A} is to be divided. As with \code{grid_type},
#'  multiple values may be specified, in which case cross-validation will be
#'  used to choose the optimal number of bins. The default sets the candidate
#'  choices of the number of bins based on heuristics tested in simulation.
#' @param cv_folds A \code{numeric} indicating the number of cross-validation
#'  folds to be used in fitting the sequence of HAL conditional density models.
#' @param lambda_seq A \code{numeric} sequence of values of the regularization
#'  parameter of Lasso regression; passed to \code{\link[hal9001]{fit_hal}}.
#' @param smoothness_orders A \code{integer} indicating the smoothness of the
#'  HAL basis functions; passed to \code{\link[hal9001]{fit_hal}}. The default
#'  is set to zero, for indicator basis functions.
#' @param ... Additional (optional) arguments of \code{\link[hal9001]{fit_hal}}
#'  that may be used to control fitting of the HAL regression model. Possible
#'  choices include \code{use_min}, \code{reduce_basis}, \code{return_lasso},
#'  and \code{return_x_basis}, but this list is not exhaustive. Consult the
#'  documentation of \code{\link[hal9001]{fit_hal}} for complete details.
#' @importFrom data.table ":="
#' @importFrom matrixStats colMeans2
#' @importFrom origami make_folds cross_validate
#' @return A \code{list}, containing density predictions for the sequence of
#'  fitted HAL models; the index and value of the L1 regularization parameter
#'  minimizing the density loss; and the sequence of empirical risks for the
#'  sequence of fitted HAL models.
#' @export
#' @examples
#' # simulate data: W ~ U[-4, 4] and A|W ~ N(mu = W, sd = 0.5)
#' n_train <- 50
#' w <- runif(n_train, -4, 4)
#' a <- rnorm(n_train, w, 0.5)
#' # fit cross-validated HAL-based density estimator of A|W
#' haldensify_cvfit <- fit_haldensify(
#'   A = a, W = w, n_bins = 10L, lambda_seq = exp(seq(-1, -10, length = 100)),
#'   # the following arguments are passed to hal9001::fit_hal()
#'   max_degree = 3, reduce_basis = 1 / sqrt(length(a))
#' )
fit_haldensify <- function(A, W,
                           wts = rep(1, length(A)),
                           grid_type = "equal_range",
                           n_bins = round(c(0.5, 1, 1.5, 2) * sqrt(length(A))),
                           cv_folds = 5L,
                           lambda_seq = exp(seq(-1, -13, length = 1000L)),
                           smoothness_orders = 0L,
                           ...) {
  # capture dot arguments for reference
  dot_args <- list(...)

  # re-format input data into long hazards structure
  reformatted_output <- format_long_hazards(
    A = A, W = W, wts = wts,
    grid_type = grid_type, n_bins = n_bins
  long_data <- reformatted_output$data
  bin_sizes <- reformatted_output$bin_length

  # extract weights from long format data structure
  wts_long <- long_data$wts
  long_data[, wts := NULL]

  # make folds with origami
  folds <- origami::make_folds(long_data,
    V = cv_folds,
    cluster_ids = long_data$obs_id

  # call cross_validate on cv_density function
  haldensity <- origami::cross_validate(
    cv_fun = cv_haldensify,
    folds = folds,
    long_data = long_data,
    wts = wts_long,
    lambda_seq = lambda_seq,
    smoothness_orders = smoothness_orders,
    use_future = FALSE,
    .combine = FALSE

  # re-organize output cross-validation procedure
  density_pred_unscaled <- do.call(rbind, as.list(haldensity$preds))

  # re-scale predictions by multiplying by bin width for each failure bin
  density_pred_scaled <- apply(density_pred_unscaled, 2, function(x) {
    pred <- x / bin_sizes[long_data[in_bin == 1, bin_id]]
  obs_wts <- do.call(c, as.list(haldensity$wts))

  # compute loss for the given individual
  density_loss <- apply(density_pred_scaled, 2, function(x) {
    pred_weighted <- x * obs_wts
    loss_weighted <- -log(pred_weighted)

  # take column means to have average loss across sequence of lambdas
  emp_risks_density_loss <- matrixStats::colMeans2(density_loss)

  # find minimizer of loss in lambda sequence
  lambda_loss_min_idx <- which.min(emp_risks_density_loss)
  lambda_loss_min <- lambda_seq[lambda_loss_min_idx]

  # return loss minimizer in lambda, Pn losses, and all density estimates
  out <- list(
    lambda_loss_min_idx = lambda_loss_min_idx,
    lambda_loss_min = lambda_loss_min,
    emp_risks = emp_risks_density_loss,
    density_pred = density_pred_scaled,
    lambda_seq = lambda_seq
