R/glmnet-helpers.R

Defines functions cvg_lambda_help

#' Get info about the best lambda for [glmnet::glmnet()].
#'
#' Take a prepped [recipes::recipe()] and return the necessary info for choosing
#' a lambda for lasso regression using [glmnet::cv.glmnet()].
#'
#' @param prepped_rec A prepped [recipes::recipe()]. From this will be extracted
#'   the `x` and `y` for [glmnet::cv.glmnet()].
#' @inheritParams train_gbm
#' @param lambda_user A numeric vector of lambdas to investigate. Whether this
#'   is specified or not, the lambda sequence generated by [glmnet::cv.glmnet()]
#'   is used and the stats returned for `lambda_user` are actually the stats for
#'   the closest lambdas in this sequence generated by [glmnet::cv.glmnet()].
#'   This argument is not compulsory (the default `NULL` is fine).
#' @inheritParams glmnet::cv.glmnet
#'
#' @return A list with the following elements:
#' * `best_lambda`: The lambda resulting in the best CV performance.
#' * `best_metric_mn`: The mean CV metric score with `best_lambda`.
#' * `best_metric_se`: The standard error in the CV scores with `best_lambda`.
#' * `lambda_1se`: The largest lambda resulting in a score within a standard
#'   error of the best score.
#' * `metric_mn_1se`: The mean CV metric score with `lambda_1se`.
#' * `metric_se_1se`: The standard error in the CV metric score with
#'   `lambda_1se`.
#' * `n_lambda`: The `nlambda` in the call to [glmnet::glmnet()].
#' * `lambda_min_ratio`: The `lambda.min.ratio` in the call to
#'   [glmnet::glmnet()].
#' * `type_measure`: The `type.measure` in the call to [glmnet::glmnet()].
#' * `family`: The `family` in the call to [glmnet::glmnet()].
#' * `lambda_user`: The best performing `lambda` passed into `lambda_user` (so
#'   if you pass many lambdas in `lambda_user`, you get 1 out). This will be
#'   chosen according to the `selection_method`.
#' * `lambda_user_metric_mn`: The mean CV metric score with `lambda_user`.
#' * `lambda_user_metric_se`: The standard error in the CV metric score with
#'   `lambda_user`.
#'
#' @noRd
cvg_lambda_help <- function(prepped_rec, outcome, metric, selection_method,
                            foldid, lambda_user, n_cores) {
  checked_args <- argchk_cvg_lambda_help(
    prepped_rec = prepped_rec,
    outcome = outcome,
    metric = metric,
    selection_method = selection_method,
    foldid = foldid,
    lambda_user = lambda_user,
    n_cores = n_cores
  )
  c(juiced_rec, x, y, fam, metric) %<-% checked_args[
    c("juiced_rec", "x", "y", "fam", "metric")
  ]

  # Parallel setup -------------------------------------------------------------
  doFuture::registerDoFuture()
  old_plan <- future::plan(future::multisession, workers = n_cores)
  on.exit(future::plan(old_plan), add = TRUE)

  # Main body ------------------------------------------------------------------
  n_obs <- nrow(x)
  n_vars <- ncol(x)
  n_lambda <- 100
  lambda_min_ratio <- dplyr::if_else(n_obs < n_vars, 0.01, 0.0001)
  type_measure <- dplyr::case_when(
    fam == "gaussian" ~ dplyr::if_else(metric[1] == "rmse", "mse", metric[1]),
    TRUE ~ "deviance"
  ) %>%
    dplyr::if_else(. == "rmse", "mse", .)
  found_min <- FALSE
  while ((!found_min) && (n_lambda <= 1600)) {
    cvg_obj <- suppressMessages(
      glmnet::cv.glmnet(
        x, y,
        parallel = n_cores > 1,
        family = fam,
        nlambda = n_lambda,
        lambda.min.ratio = lambda_min_ratio,
        type.measure = type_measure,
        foldid = foldid
      )
    )
    found_min <- (length(cvg_obj$lambda) < n_lambda) ||
      (which.min(cvg_obj$cvm) != length(cvg_obj$cvm))
    if (fam %in% c("binomial", "multinomial")) {
      break
    } else if (!found_min) { # unlikely to ever be needed
      n_lambda <- 2 * n_lambda
      lambda_min_ratio <- lambda_min_ratio / 10
    }
  }
  best_index <- which.min(cvg_obj$cvm)
  index_1se <- match(cvg_obj$lambda.1se, cvg_obj$lambda)
  out <- list(
    best_lambda = cvg_obj$lambda.min,
    best_metric_mn = cvg_obj$cvm[best_index],
    best_metric_se = cvg_obj$cvsd[best_index],
    lambda_1se = cvg_obj$lambda.1se,
    metric_mn_1se = cvg_obj$cvm[index_1se],
    metric_se_1se = cvg_obj$cvsd[index_1se],
    n_lambda = n_lambda,
    lambda_min_ratio = lambda_min_ratio,
    type_measure = type_measure,
    family = fam
  )
  out_extra <- list(
    lambda_user = NA_real_,
    lambda_user_metric_mn = NA_real_,
    lambda_user_metric_se = NA_real_
  )
  if (!is.null(lambda_user)) {
    closest_indices <- purrr::map_dbl(
      lambda_user,
      ~ which.min(abs(cvg_obj$lambda - .))
    ) %>%
      unique()
    metric_direction <- get_metric_direction(metric)
    transformed_cvm <- cvg_obj$cvm
    if (metric_direction == "maximize") transformed_cvm <- -transformed_cvm
    index <- closest_indices[which.min(transformed_cvm[closest_indices])]
    if (selection_method == "Breiman") {
      good_indices <- closest_indices[
        dplyr::between(
          cvg_obj$cvm[closest_indices],
          cvg_obj$cvm[index],
          cvg_obj$cvm[index] + cvg_obj$cvsd[index]
        )
      ]
      index <- min(good_indices)
    }
    out_extra <- list(
      lambda_user = DescTools::Closest(lambda_user, cvg_obj$lambda[index]),
      lambda_user_metric_mn = cvg_obj$cvm[index],
      lambda_user_metric_se = cvg_obj$cvsd[index]
    )
  }
  out <- c(out, out_extra)
  out
}
mirvie/mirmodels documentation built on Jan. 14, 2022, 11:12 a.m.