#' Get info about the best lambda for [glmnet::glmnet()].
#'
#' Take a prepped [recipes::recipe()] and return the necessary info for choosing
#' a lambda for lasso regression using [glmnet::cv.glmnet()].
#'
#' @param prepped_rec A prepped [recipes::recipe()]. From this will be extracted
#' the `x` and `y` for [glmnet::cv.glmnet()].
#' @inheritParams train_gbm
#' @param lambda_user A numeric vector of lambdas to investigate. Whether this
#' is specified or not, the lambda sequence generated by [glmnet::cv.glmnet()]
#' is used and the stats returned for `lambda_user` are actually the stats for
#' the closest lambdas in this sequence generated by [glmnet::cv.glmnet()].
#' This argument is not compulsory (the default `NULL` is fine).
#' @inheritParams glmnet::cv.glmnet
#'
#' @return A list with the following elements:
#' * `best_lambda`: The lambda resulting in the best CV performance.
#' * `best_metric_mn`: The mean CV metric score with `best_lambda`.
#' * `best_metric_se`: The standard error in the CV scores with `best_lambda`.
#' * `lambda_1se`: The largest lambda resulting in a score within a standard
#' error of the best score.
#' * `metric_mn_1se`: The mean CV metric score with `lambda_1se`.
#' * `metric_se_1se`: The standard error in the CV metric score with
#' `lambda_1se`.
#' * `n_lambda`: The `nlambda` in the call to [glmnet::glmnet()].
#' * `lambda_min_ratio`: The `lambda.min.ratio` in the call to
#' [glmnet::glmnet()].
#' * `type_measure`: The `type.measure` in the call to [glmnet::glmnet()].
#' * `family`: The `family` in the call to [glmnet::glmnet()].
#' * `lambda_user`: The best performing `lambda` passed into `lambda_user` (so
#' if you pass many lambdas in `lambda_user`, you get 1 out). This will be
#' chosen according to the `selection_method`.
#' * `lambda_user_metric_mn`: The mean CV metric score with `lambda_user`.
#' * `lambda_user_metric_se`: The standard error in the CV metric score with
#' `lambda_user`.
#'
#' @noRd
cvg_lambda_help <- function(prepped_rec, outcome, metric, selection_method,
foldid, lambda_user, n_cores) {
checked_args <- argchk_cvg_lambda_help(
prepped_rec = prepped_rec,
outcome = outcome,
metric = metric,
selection_method = selection_method,
foldid = foldid,
lambda_user = lambda_user,
n_cores = n_cores
)
c(juiced_rec, x, y, fam, metric) %<-% checked_args[
c("juiced_rec", "x", "y", "fam", "metric")
]
# Parallel setup -------------------------------------------------------------
doFuture::registerDoFuture()
old_plan <- future::plan(future::multisession, workers = n_cores)
on.exit(future::plan(old_plan), add = TRUE)
# Main body ------------------------------------------------------------------
n_obs <- nrow(x)
n_vars <- ncol(x)
n_lambda <- 100
lambda_min_ratio <- dplyr::if_else(n_obs < n_vars, 0.01, 0.0001)
type_measure <- dplyr::case_when(
fam == "gaussian" ~ dplyr::if_else(metric[1] == "rmse", "mse", metric[1]),
TRUE ~ "deviance"
) %>%
dplyr::if_else(. == "rmse", "mse", .)
found_min <- FALSE
while ((!found_min) && (n_lambda <= 1600)) {
cvg_obj <- suppressMessages(
glmnet::cv.glmnet(
x, y,
parallel = n_cores > 1,
family = fam,
nlambda = n_lambda,
lambda.min.ratio = lambda_min_ratio,
type.measure = type_measure,
foldid = foldid
)
)
found_min <- (length(cvg_obj$lambda) < n_lambda) ||
(which.min(cvg_obj$cvm) != length(cvg_obj$cvm))
if (fam %in% c("binomial", "multinomial")) {
break
} else if (!found_min) { # unlikely to ever be needed
n_lambda <- 2 * n_lambda
lambda_min_ratio <- lambda_min_ratio / 10
}
}
best_index <- which.min(cvg_obj$cvm)
index_1se <- match(cvg_obj$lambda.1se, cvg_obj$lambda)
out <- list(
best_lambda = cvg_obj$lambda.min,
best_metric_mn = cvg_obj$cvm[best_index],
best_metric_se = cvg_obj$cvsd[best_index],
lambda_1se = cvg_obj$lambda.1se,
metric_mn_1se = cvg_obj$cvm[index_1se],
metric_se_1se = cvg_obj$cvsd[index_1se],
n_lambda = n_lambda,
lambda_min_ratio = lambda_min_ratio,
type_measure = type_measure,
family = fam
)
out_extra <- list(
lambda_user = NA_real_,
lambda_user_metric_mn = NA_real_,
lambda_user_metric_se = NA_real_
)
if (!is.null(lambda_user)) {
closest_indices <- purrr::map_dbl(
lambda_user,
~ which.min(abs(cvg_obj$lambda - .))
) %>%
unique()
metric_direction <- get_metric_direction(metric)
transformed_cvm <- cvg_obj$cvm
if (metric_direction == "maximize") transformed_cvm <- -transformed_cvm
index <- closest_indices[which.min(transformed_cvm[closest_indices])]
if (selection_method == "Breiman") {
good_indices <- closest_indices[
dplyr::between(
cvg_obj$cvm[closest_indices],
cvg_obj$cvm[index],
cvg_obj$cvm[index] + cvg_obj$cvsd[index]
)
]
index <- min(good_indices)
}
out_extra <- list(
lambda_user = DescTools::Closest(lambda_user, cvg_obj$lambda[index]),
lambda_user_metric_mn = cvg_obj$cvm[index],
lambda_user_metric_se = cvg_obj$cvsd[index]
)
}
out <- c(out, out_extra)
out
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.