calibrate: Calibrate high-dimensional Cox models

View source: R/4_1_calibrate.R

calibrateR Documentation

Calibrate high-dimensional Cox models

Description

Calibrate high-dimensional Cox models

Usage

calibrate(
  x,
  time,
  event,
  model.type = c("lasso", "alasso", "flasso", "enet", "aenet", "mcp", "mnet", "scad",
    "snet"),
  alpha,
  lambda,
  pen.factor = NULL,
  gamma,
  lambda1,
  lambda2,
  method = c("fitting", "bootstrap", "cv", "repeated.cv"),
  boot.times = NULL,
  nfolds = NULL,
  rep.times = NULL,
  pred.at,
  ngroup = 5,
  seed = 1001,
  trace = TRUE
)

Arguments

x

Matrix of training data used for fitting the model; on which to run the calibration.

time

Survival time. Must be of the same length with the number of rows as x.

event

Status indicator, normally 0 = alive, 1 = dead. Must be of the same length with the number of rows as x.

model.type

Model type to calibrate. Could be one of "lasso", "alasso", "flasso", "enet", "aenet", "mcp", "mnet", "scad", or "snet".

alpha

Value of the elastic-net mixing parameter alpha for enet, aenet, mnet, and snet models. For lasso, alasso, mcp, and scad models, please set alpha = 1. alpha=1: lasso (l1) penalty; alpha=0: ridge (l2) penalty. Note that for mnet and snet models, alpha can be set to very close to 0 but not 0 exactly.

lambda

Value of the penalty parameter lambda to use in the model fits on the resampled data. From the Cox model you have built.

pen.factor

Penalty factors to apply to each coefficient. From the built adaptive lasso or adaptive elastic-net model.

gamma

Value of the model parameter gamma for MCP/SCAD/Mnet/Snet models.

lambda1

Value of the penalty parameter lambda1 for fused lasso model.

lambda2

Value of the penalty parameter lambda2 for fused lasso model.

method

Calibration method. Options including "fitting", "bootstrap", "cv", and "repeated.cv".

boot.times

Number of repetitions for bootstrap.

nfolds

Number of folds for cross-validation and repeated cross-validation.

rep.times

Number of repeated times for repeated cross-validation.

pred.at

Time point at which calibration should take place.

ngroup

Number of groups to be formed for calibration.

seed

A random seed for resampling.

trace

Logical. Output the calibration progress or not. Default is TRUE.

Examples

data("smart")
x <- as.matrix(smart[, -c(1, 2)])
time <- smart$TEVENT
event <- smart$EVENT
y <- survival::Surv(time, event)

# Fit Cox model with lasso penalty
fit <- fit_lasso(x, y, nfolds = 5, rule = "lambda.1se", seed = 11)

# Model calibration by fitting the original data directly
cal.fitting <- calibrate(
  x, time, event,
  model.type = "lasso",
  alpha = 1, lambda = fit$lambda,
  method = "fitting",
  pred.at = 365 * 9, ngroup = 5,
  seed = 1010
)

# Model calibration by 5-fold cross-validation
cal.cv <- calibrate(
  x, time, event,
  model.type = "lasso",
  alpha = 1, lambda = fit$lambda,
  method = "cv", nfolds = 5,
  pred.at = 365 * 9, ngroup = 5,
  seed = 1010
)

print(cal.fitting)
summary(cal.fitting)
plot(cal.fitting)

print(cal.cv)
summary(cal.cv)
plot(cal.cv)

# # Test fused lasso, SCAD, and Mnet models
# data(smart)
# x = as.matrix(smart[, -c(1, 2)])[1:500, ]
# time = smart$TEVENT[1:500]
# event = smart$EVENT[1:500]
# y = survival::Surv(time, event)
#
# set.seed(1010)
# cal.fitting = calibrate(
#   x, time, event, model.type = "flasso",
#   lambda1 = 5, lambda2 = 2,
#   method = "fitting",
#   pred.at = 365 * 9, ngroup = 5,
#   seed = 1010)
#
# cal.boot = calibrate(
#   x, time, event, model.type = "scad",
#   gamma = 3.7, alpha = 1, lambda = 0.03,
#   method = "bootstrap", boot.times = 10,
#   pred.at = 365 * 9, ngroup = 5,
#   seed = 1010)
#
# cal.cv = calibrate(
#   x, time, event, model.type = "mnet",
#   gamma = 3, alpha = 0.3, lambda = 0.03,
#   method = "cv", nfolds = 5,
#   pred.at = 365 * 9, ngroup = 5,
#   seed = 1010)
#
# cal.repcv = calibrate(
#   x, time, event, model.type = "flasso",
#   lambda1 = 5, lambda2 = 2,
#   method = "repeated.cv", nfolds = 5, rep.times = 3,
#   pred.at = 365 * 9, ngroup = 5,
#   seed = 1010)
#
# print(cal.fitting)
# summary(cal.fitting)
# plot(cal.fitting)
#
# print(cal.boot)
# summary(cal.boot)
# plot(cal.boot)
#
# print(cal.cv)
# summary(cal.cv)
# plot(cal.cv)
#
# print(cal.repcv)
# summary(cal.repcv)
# plot(cal.repcv)

road2stat/hdnom documentation built on March 14, 2024, 11:10 p.m.