R/proportional_hazards-glmnet.R

Defines functions survival_prob_coxnet get_missings_coxnet survival_time_coxnet multi_predict_coxnet_linear_pred multi_predict._coxnet predict_raw._coxnet predict_linear_pred._coxnet predict_survival._coxnet predict._coxnet coxnet_prepare_x print._coxnet check_dots_coxnet check_strata_remaining check_intercept_model drop_strata remove_strata get_strata_glmnet check_strata_nterms has_strata coxnet_train

Documented in coxnet_train survival_prob_coxnet survival_time_coxnet

#' Wrapper for glmnet for censored
#'
#' Not to be used directly by users.
#'
#' @details
#' This wrapper translates from formula interface to glmnet's matrix due to how
#' stratification can be specified. glmnet requires that the _response_ is
#' stratified via [glmnet::stratifySurv()]. censored allows specification via a
#' [survival::strata()] term on the right-hand side of the formula. The formula
#' is used to generate the stratification information needed for stratifying the
#' response. The formula without the strata term is used for generating the
#' model matrix for glmnet.
#'
#' The wrapper retains the original formula and the pre-processing elements
#' including the training data to allow for predictions from the fitted model.
#'
#' @param formula The model formula.
#' @param data The data.
#' @inheritParams glmnet::glmnet
#' @param ... additional parameters passed to glmnet::glmnet.
#' @param call The call passed to [rlang::abort()].
#'
#' @return A fitted `glmnet` model.
#' @export
#' @keywords internal
#' @examples
#' coxnet_mod <- coxnet_train(Surv(time, status) ~ age + sex, data = lung)
coxnet_train <- function(formula,
                         data,
                         alpha = 1,
                         lambda = NULL,
                         weights = NULL,
                         ...,
                         call = caller_env()) {
  dots <- rlang::quos(...)
  check_dots_coxnet(dots, call = call)

  encoding_info <-
    parsnip::get_encoding("proportional_hazards") %>%
    dplyr::filter(mode == "censored regression", engine == "glmnet")

  indicators <- encoding_info %>% dplyr::pull(predictor_indicators)
  remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept)

  formula_without_strata <- remove_strata(formula, data, call = call)

  data_obj <- parsnip::.convert_form_to_xy_fit(
    formula = formula_without_strata,
    data = data,
    composition = "matrix",
    indicators = indicators,
    remove_intercept = remove_intercept
  )

  if (has_strata(formula, data)) {
    check_strata_nterms(formula, data, call = call)
    strata <- get_strata_glmnet(formula, data)
    data_obj$y <- glmnet::stratifySurv(data_obj$y, strata = strata)
  }

  fit <- glmnet::glmnet(
    data_obj$x,
    data_obj$y,
    family = "cox",
    alpha = alpha,
    lambda = lambda,
    weights = weights,
    ...
  )

  # TODO: remove offset from data_obj?
  res <- list(
    fit = fit,
    preproc = data_obj
  )
  class(res) <- "coxnet"
  res
}

has_strata <- function(formula, data) {
  mod_terms <- stats::terms(formula, specials = "strata", data = data)
  !is.null(attr(mod_terms, "specials")$strata)
}

# glmnet only allows one strata column so we require that there's only one term
check_strata_nterms <- function(formula, data, call = caller_env()) {
  mod_terms <- stats::terms(formula, specials = "strata", data = data)
  strata_terms <- attr(mod_terms, "specials")$strata
  if (length(strata_terms) > 1) {
    rlang::abort(
      c(
        "There can only be a single 'strata' term specified using the `strata()` function.",
        i = "It can contain multiple strata columns, e.g., ` ~ x + strata(s1, s2)`."
      ),
      call = call
    )
  }
  invisible(formula)
}

get_strata_glmnet <- function(formula, data, na.action = stats::na.omit) {
  mod_terms <- stats::terms(formula, specials = "strata", data = data)
  mod_terms <- stats::delete.response(mod_terms)
  mod_frame <- stats::model.frame(mod_terms, data, na.action = na.action)

  strata_ind <- attr(mod_terms, "specials")$strata
  strata <- purrr::pluck(mod_frame, strata_ind)

  strata
}

remove_strata <- function(formula, data, call = rlang::caller_env()) {
  if (!has_strata(formula, data)) {
    return(formula)
  }

  rhs <- formula[[3]]
  formula[[3]] <- rhs %>%
    drop_strata() %>%
    check_intercept_model(call = call) %>%
    check_strata_remaining(call = call)
  formula
}

# strata() must be part of a sequence of `+` calls
# only drop correct usage of strata so we can check for incorrect usage in
# its own function
drop_strata <- function(expr, in_plus = TRUE) {
  if (is_call(expr, "+", n = 2) && in_plus) {
    lhs <- drop_strata(expr[[2]], in_plus = in_plus)
    rhs <- drop_strata(expr[[3]], in_plus = in_plus)
    if (is_call(lhs, "strata")) {
      rhs
    } else if (is_call(rhs, "strata")) {
      lhs
    } else {
      rlang::call2("+", lhs, rhs)
    }
  } else if (is_call(expr)) {
    expr[-1] <- map(as.list(expr[-1]), drop_strata, in_plus = FALSE)
    expr
  } else {
    expr
  }
}

check_intercept_model <- function(expr, call = caller_env()) {
  if (expr == rlang::sym("1") | is_call(expr, "strata")) {
    abort(
      "The Cox model does not contain an intercept, please add a predictor.",
      call = call
    )
  }
  expr
}

check_strata_remaining <- function(expr, call = rlang::caller_env()) {
  if (is_call(expr, "strata")) {
    abort(
      c(
        "Stratification must be nested under a chain of `+` calls.",
        i = "# Good: ~ x1 + x2 + strata(s)",
        i = "# Bad: ~ x1 + (x2 + strata(s))"
      ),
      call = call
    )
  } else if (is_call(expr)) {
    #lapply() instead of map() to avoid map() reporting the index of where it errors
    expr[-1] <- lapply(as.list(expr[-1]), check_strata_remaining, call = call)
    expr
  } else {
    expr
  }
}

check_dots_coxnet <- function(x, call = caller_env()) {
  bad_args <- c("subset", "contrasts", "offset", "family")
  bad_names <- names(x) %in% bad_args
  if (any(bad_names)) {
    rlang::abort(
      glue::glue(
        "These argument(s) cannot be used to create the model: ",
        glue::glue_collapse(glue::glue("`{names(x)[bad_names]}`"), sep = ", ")
      ),
      call = call
    )
  }
  invisible(NULL)
}

#' @export
print._coxnet <- function(x, ...) {
  cat("parsnip model object\n\n")
  cat("Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n")

  if (inherits(x$fit$fit, "try-error")) {
    cat("Model fit failed with error:\n", x$fit, "\n")
  } else {
    print(x$fit$fit, ...)
    cat("The training data has been saved for prediction.\n")
  }
  invisible(x)
}


# prediction --------------------------------------------------------------

coxnet_prepare_x <- function(new_data, object) {
  went_through_formula_interface <- !is.null(object$preproc$coxnet)

  if (went_through_formula_interface) {
    new_x <- parsnip::.convert_form_to_xy_new(
      object$preproc$coxnet,
      new_data,
      composition = "matrix"
    )$x
  } else {
    new_x <- new_data[, object$preproc$x_var, drop = FALSE] %>%
      as.matrix()
  }

  new_x
}

# notes adapted from parsnip:

# glmnet call stack for censored regression using `predict` when object has
# classes "_coxnet" and "model_fit":
#
#  predict()
#   predict._coxnet(penalty = NULL)   <-- checks and sets penalty
#    predict.model_fit()              <-- checks for extra vars in ...
#     predict_survival()
#      predict_survival._coxnet()
#       predict_survival.model_fit()
#        survival_prob_coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "linear_pred")` when object has
# classes "_coxnet" and "model_fit":
#
# 	multi_predict()
#    multi_predict._coxnet(penalty = NULL)
#      predict._coxnet(multi = TRUE)          <-- checks and sets penalty
#       predict.model_fit()                   <-- checks for extra vars in ...
#        predict_raw()
#         predict_raw._coxnet()
#          predict_raw.model_fit(opts = list(s = penalty))
#           predict.coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "survival")` when object has
# classes "_coxnet" and "model_fit":
#
# 	multi_predict()
#    multi_predict._coxnet(penalty = NULL)
#      predict._coxnet(multi = TRUE)          <-- checks and sets penalty
#       predict.model_fit()                  <-- checks for extra vars in ...
#        predict_survival()
#         predict_survival._coxnet()
#          predict_survival.model_fit()
#           survival_prob_coxnet()

#' @export
predict._coxnet <-
  function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
    # See discussion in https://github.com/tidymodels/parsnip/issues/195
    if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
      penalty <- object$spec$args$penalty
    }

    object$spec$args$penalty <- parsnip::.check_glmnet_penalty_predict(penalty, object, multi)

    object$spec <- eval_args(object$spec)
    predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
  }

#' @export
predict_survival._coxnet <- function(object, new_data, ...) {
  object$spec <- eval_args(object$spec)
  NextMethod()
}

#' @export
predict_linear_pred._coxnet <- function(object,
                                        new_data,
                                        ...,
                                        increasing = TRUE) {
  res <- NextMethod()
  if (increasing) {
    # For consistency with other models, we want the lp to increase with
    # time. For this, we change the sign
    res <- -res
  }
  res
}

#' @export
predict_raw._coxnet <- function(object, new_data, opts = list(), ...) {
  object$spec <- eval_args(object$spec)
  opts$s <- object$spec$args$penalty
  NextMethod()
}


# multi_predict -----------------------------------------------------------

#' @export
multi_predict._coxnet <- function(object,
                                  new_data,
                                  type = NULL,
                                  penalty = NULL,
                                  ...) {
  dots <- list(...)

  if (any(names(dots) == "newdata")) {
    rlang::abort("Please use `new_data` instead of `newdata`.")
  }

  object$spec <- eval_args(object$spec)

  if (is.null(penalty)) {
    # See discussion in https://github.com/tidymodels/parsnip/issues/195
    if (!is.null(object$spec$args$penalty)) {
      penalty <- object$spec$args$penalty
    } else {
      penalty <- object$fit$lambda
    }
  }

  if (type == "linear_pred") {
    pred <- multi_predict_coxnet_linear_pred(
      object,
      new_data = new_data,
      opts = dots,
      penalty = penalty
    )
  } else {
    pred <- predict(
      object,
      new_data = new_data,
      type = type,
      ...,
      penalty = penalty,
      multi = TRUE
    )
  }

  pred
}

multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {

  if ("increasing" %in% names(opts)) {
    increasing <- opts$increasing
    opts$increasing <- NULL
  } else {
    increasing <- TRUE
  }

  pred <- predict(
    object,
    new_data = new_data,
    type = "raw",
    opts = opts,
    penalty = penalty,
    multi = TRUE
  )

  if (increasing) {
    # For consistency with other models, we want the lp to increase with
    # time. For this, we change the sign
    pred <- -pred
  }

  # post-processing into nested tibble
  param_key <- tibble(group = colnames(pred), penalty = penalty)
  pred <- pred %>%
    as_tibble() %>%
    dplyr::mutate(.row = seq_len(nrow(pred))) %>%
    tidyr::pivot_longer(
      -.row,
      names_to = "group",
      values_to = ".pred_linear_pred"
    )
  if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
    pred <- dplyr::inner_join(param_key, pred, by = "group", multiple = "all")
  } else {
    pred <- dplyr::inner_join(param_key, pred, by = "group")
  }
  pred <- pred %>%
    dplyr::select(-group) %>%
    dplyr::arrange(.row, penalty) %>%
    tidyr::nest(.pred = c(-.row)) %>%
    dplyr::select(-.row)
}


# prediction: time --------------------------------------------------------

#' A wrapper for survival times with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param penalty Penalty value(s).
#' @param ... Options to pass to [survival::survfit()].
#' @return A vector.
#' @keywords internal
#' @export
#' @examples
#' cox_mod <- proportional_hazards(penalty = 0.1) %>%
#'   set_engine("glmnet") %>%
#'   fit(Surv(time, status) ~ ., data = lung)
#' survival_time_coxnet(cox_mod, new_data = lung[1:3, ], penalty = 0.1)
survival_time_coxnet <- function(object, new_data, penalty = NULL, ...) {
  new_x <- coxnet_prepare_x(new_data, object)

  went_through_formula_interface <- !is.null(object$preproc$coxnet)
  if (went_through_formula_interface &&
    has_strata(object$formula, object$training_data)) {
    new_strata <- get_strata_glmnet(
      object$formula,
      data = new_data,
      na.action = stats::na.pass
    )
  } else {
    new_strata <- NULL
  }

  missings_in_new_data <- get_missings_coxnet(new_x, new_strata)
  if (!is.null(missings_in_new_data)) {
    n_total <- nrow(new_data)
    n_missing <- length(missings_in_new_data)
    all_missing <- n_missing == n_total
    if (all_missing) {
      ret <- rep(NA, n_missing)
      return(ret)
    }
    new_x <- new_x[-missings_in_new_data, , drop = FALSE]
    new_strata <- new_strata[-missings_in_new_data]
  }

  y <- survival::survfit(
    object$fit$fit,
    newx = new_x,
    newstrata = new_strata,
    s = penalty,
    x = object$training_data$x,
    y = object$training_data$y,
    weights = object$preproc$coxnet$weights,
    na.action = stats::na.exclude,
    ...
  )

  tabs <- summary(y)$table
  if (is.matrix(tabs)) {
    colnames(tabs) <- gsub("[[:punct:]]", "", colnames(tabs))
    res <- unname(tabs[, "rmean"])
  } else {
    names(tabs) <- gsub("[[:punct:]]", "", names(tabs))
    res <- unname(tabs["rmean"])
  }
  if (!is.null(missings_in_new_data)) {
    index_with_na <- rep(NA, n_total)
    index_with_na[-missings_in_new_data] <- seq_along(res)
    res <- res[index_with_na]
  }
  res
}


get_missings_coxnet <- function(new_x, new_strata) {
  missings_logical <- apply(cbind(new_x, new_strata), MARGIN = 1, anyNA)
  if (!any(missings_logical)) {
    return(NULL)
  }
  which(missings_logical)
}

# prediction: survival ----------------------------------------------------


#' A wrapper for survival probabilities with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param eval_time A vector of integers for prediction times.
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
#' @param output One of "surv" or "haz".
#' @param penalty Penalty value(s).
#' @param ... Options to pass to [survival::survfit()].
#' @return A tibble with a list column of nested tibbles.
#' @keywords internal
#' @export
#' @examples
#' cox_mod <- proportional_hazards(penalty = 0.1) %>%
#'   set_engine("glmnet") %>%
#'   fit(Surv(time, status) ~ ., data = lung)
#' survival_prob_coxnet(cox_mod, new_data = lung[1:3, ], eval_time = 300)
survival_prob_coxnet <- function(object,
                                 new_data,
                                 eval_time,
                                 time = deprecated(),
                                 output = "surv",
                                 penalty = NULL,
                                 ...) {
  if (lifecycle::is_present(time)) {
    lifecycle::deprecate_warn(
      "0.2.0",
      "survival_prob_coxnet(time)",
      "survival_prob_coxnet(eval_time)"
    )
    eval_time <- time
  }

  if (is.null(penalty)) {
    penalty <- object$spec$args$penalty
  }

  output <- match.arg(output, c("surv", "haz"))
  multi <- length(penalty) > 1

  new_x <- coxnet_prepare_x(new_data, object)

  went_through_formula_interface <- !is.null(object$preproc$coxnet)
  if (went_through_formula_interface &&
    has_strata(object$formula, object$training_data)) {
    new_strata <- get_strata_glmnet(
      object$formula,
      data = new_data,
      na.action = stats::na.pass
    )
  } else {
    new_strata <- NULL
  }

  n_obs <- nrow(new_data)
  missings_in_new_data <- get_missings_coxnet(new_x, new_strata)

  if (!is.null(missings_in_new_data)) {
    n_missing <- length(missings_in_new_data)
    all_missing <- n_missing == n_obs
    if (all_missing) {
      ret <- predict_survival_na(eval_time, interval = "none")
      ret <- tibble(.pred = rep(list(ret), n_missing))
      return(ret)
    }
    new_x <- new_x[-missings_in_new_data, , drop = FALSE]
    new_strata <- new_strata[-missings_in_new_data]
  }

  y <- survival::survfit(
    object$fit$fit,
    newx = new_x,
    newstrata = new_strata,
    s = penalty,
    x = object$training_data$x,
    y = object$training_data$y,
    weights = object$preproc$coxnet$weights,
    na.action = na.exclude,
    ...
  )

  if (multi) {
    res_patched <- purrr::map(
      y,
      survfit_summary_to_patched_tibble,
      index_missing = missings_in_new_data,
      eval_time = eval_time,
      n_obs = n_obs
    )
    res <- tibble::tibble(
      penalty = penalty,
      res_patched = res_patched
    ) %>%
      tidyr::unnest(cols = res_patched) %>%
      keep_cols(output, keep_penalty = TRUE) %>%
      tidyr::nest(.pred = c(-.row)) %>%
      dplyr::select(-.row)
  } else {
    res <- survfit_summary_to_patched_tibble(
      y,
      index_missing = missings_in_new_data,
      eval_time = eval_time,
      n_obs = n_obs
    ) %>%
      keep_cols(output) %>%
      tidyr::nest(.pred = c(-.row)) %>%
      dplyr::select(-.row)
  }

  res
}

Try the censored package in your browser

Any scripts or data that you put into this service are public.

censored documentation built on April 14, 2023, 12:30 a.m.