Nothing
#' Spectrally Deconfounded Additive Models
#'
#' Estimate high-dimensional additive models using spectral deconfounding \insertCite{scheidegger2023spectral}{SDModels}.
#' The covariates are expanded into B-spline basis functions. A spectral
#' transformation is used to remove bias arising from hidden confounding and
#' a group lasso objective is minimized to enforce component-wise sparsity.
#' Optimal number of basis functions per component and sparsity penalty are
#' chosen by cross validation.
#'@references
#' \insertAllCited{}
#' @author Cyrill Scheidegger
#' @param formula Object of class \code{formula} or describing the model to fit
#' of the form \code{y ~ x1 + x2 + ...} where \code{y} is a numeric response and
#' \code{x1, x2, ...} are vectors of covariates. Interactions are not supported.
#' @param data Training data of class \code{data.frame} containing the variables in the model.
#' @param x Matrix of covariates, alternative to \code{formula} and \code{data}.
#' @param y Vector of responses, alternative to \code{formula} and \code{data}.
#' @param Q_type Type of deconfounding, one of 'trim', 'pca', 'no_deconfounding'.
#' 'trim' corresponds to the Trim transform \insertCite{Cevid2020SpectralModels}{SDModels}
#' as implemented in the Doubly debiased lasso \insertCite{Guo2022DoublyConfounding}{SDModels},
#' 'pca' to the PCA transformation\insertCite{Paul2008PreconditioningProblems}{SDModels}.
#' See \code{\link{get_Q}}.
#' @param trim_quantile Quantile for Trim transform,
#' only needed for trim, see \code{\link{get_Q}}.
#' @param q_hat Assumed confounding dimension, only needed for pca,
#' see \code{\link{get_Q}}.
#' @param nfolds The number of folds for cross-validation. Default is 5.
#' @param cv_method The method for selecting the regularization parameter during cross-validation.
#' One of "min" (minimum cv-loss) and "1se" (one-standard-error rule) Default is "1se".
#' @param n_K The number of candidate values for the number of basis functions for B-splines. Default is 4.
#' @param n_lambda1 The number of candidate values for the regularization parameter in the initial cross-validation step. Default is 10.
#' @param n_lambda2 The number of candidate values for the regularization parameter in the second stage of cross-validation
#' (once the optimal number of basis function K is decided, a second stage of cross-validation for the regularization parameter
#' is performed on a finer grid). Default is 20.
#' @param Q_scale Should data be scaled to estimate the spectral transformation?
#' Default is \code{TRUE} to not reduce the signal of high variance covariates.
#' @param ind_lin A vector of indices specifying which covariates to model linearly (i.e. not expanded into basis function).
#' Default is `NULL`.
#' @param mc.cores Number of cores to use for parallel processing, if \code{mc.cores > 1}
#' the cross validation is parallelized. Default is `1`. (only supported for unix)
#' @param verbose If \code{TRUE} fitting information is shown.
#' @return An object of class `SDAM` containing the following elements:
#' \item{X}{The original design matrix.}
#' \item{p}{The number of covariates in `X`.}
#' \item{intercept}{The intercept term of the fitted model.}
#' \item{K}{A vector of the number of basis functions for each covariate,
#' where 1 corresponds to a linear term. The entries of the vector will mostly by
#' the same, but some entries might be lower if the corresponding component of
#' X contains only few unique values.}
#' \item{breaks}{A list of breakpoints used for the B-splines. Used to reconstruct the B-spline basis functions.}
#' \item{coefs}{A list of coefficients for the B-spline basis functions for each component.}
#' \item{active}{A vector of active covariates that contribute to the model.}
#' @seealso \code{\link{get_Q}}, \code{\link{predict.SDAM}}, \code{\link{varImp.SDAM}},
#' \code{\link{predict_individual_fj}}, \code{\link{partDependence}}
#' @examples
#' set.seed(1)
#' X <- matrix(rnorm(10 * 5), ncol = 5)
#' Y <- sin(X[, 1]) - X[, 2] + rnorm(10)
#' model <- SDAM(x = X, y = Y, Q_type = "trim", trim_quantile = 0.5, nfold = 2, n_K = 1)
#'
#'
#' \donttest{
#' library(HDclassif)
#' data(wine)
#' names(wine) <- c("class", "alcohol", "malicAcid", "ash", "alcalinityAsh", "magnesium",
#' "totPhenols", "flavanoids", "nonFlavPhenols", "proanthocyanins",
#' "colIntens", "hue", "OD", "proline")
#' wine <- log(wine)
#'
#' # estimate model
#' # do not use class in the model and restrict proline to be linear
#' model <- SDAM(alcohol ~ -class + ., wine, ind_lin = "proline", nfold = 3)
#'
#' # extract variable importance
#' varImp(model)
#'
#' # most important variable
#' mostImp <- names(which.max(varImp(model)))
#' mostImp
#'
#' # predict for individual Xj
#' predJ <- predict_individual_fj(object = model, j = mostImp)
#' plot(wine[, mostImp], predJ,
#' xlab = paste0("log ", mostImp), ylab = "log alcohol")
#'
#' # partial dependece
#' plot(partDependence(model, mostImp))
#'
#' # predict
#' predict(model, newdata = wine[42, ])
#'
#' ## alternative function call
#' mod_none <- SDAM(x = as.matrix(wine[1:10, -c(1, 2)]), y = wine$alcohol[1:10],
#' Q_type = "no_deconfounding", nfolds = 2, n_K = 4,
#' n_lambda1 = 4, n_lambda2 = 8)
#' }
#'
#' @export
SDAM <- function(formula = NULL, data = NULL, x = NULL, y = NULL,
Q_type = "trim", trim_quantile = 0.5, q_hat = 0, nfolds = 5,
cv_method = "1se", n_K = 4, n_lambda1 = 10, n_lambda2 = 20,
Q_scale = TRUE, ind_lin = NULL, mc.cores = 1, verbose = TRUE){
input_data <- data.handler(formula = formula, data = data, x = x, y = y)
X <- input_data$X
Y <- input_data$Y
n <- NROW(X)
p <- NCOL(X)
if(n != length(Y)) stop('X and Y must have the same number of observations')
if(!is.numeric(nfolds) || nfolds < 2 || nfolds > n) stop('nfolds must be an integer between 2 and n')
if(!is.numeric(n_K) || n_K < 1) stop('n_K must be a positive integer')
if(!is.numeric(n_lambda1) || n_lambda1 < 1) stop('n_lambda1 must be a positive integer')
if(!is.numeric(n_lambda2) || n_lambda2 < 1) stop('n_lambda2 must be a positive integer')
if(!is.numeric(mc.cores) || mc.cores < 1) stop('mc.cores must be a positive integer')
if(!is.null(ind_lin)){
if(!is.numeric(ind_lin)){
if(!is.character(ind_lin)) stop("ind_lin must either contain integers or variable names")
ind_lin <- which(colnames(data.frame(X)) %in% ind_lin)
}
if((min(ind_lin) < 1) || max(ind_lin) > p) stop("ind_lin must contain covariates in the data in [1, p]")
}
gprLassoControl <- grplasso::grpl.control(save.x = FALSE, save.y = FALSE, trace = 0)
# create vector of candidate values for K
# intuition: candidate values for K should be between K0 = 4 and 10*n^0.2
K.up <- round(10*n^0.2)
vK <- unique(round(seq(4, K.up, length.out = n_K)))
# spectral transformation
Qf <- get_Qf(X, type = Q_type, trim_quantile = trim_quantile, q_hat = q_hat,
gpu = FALSE, scaling = Q_scale)
QY <- Qf(Y)
# get number of unique elements in each column of X
n_unique_X <- apply(X, 2, function(x){length(unique(x))})
# Generate the design and model parameters for every K in vK
lmodK <- list()
for (i in 1:length(vK)){
K <- vK[i]
# effective number of basis functions for each Xj, j = 1,..., p
# K_eff[j] can be at most equal to the number of unique values of Xj
# K_eff[j] is set to 1 for all j in ind_lin
# K_eff[j] is set to 1 if K_eff[j]<=3, since B-spline needs 4 basis functions
K_eff <- rep(K, p)
K_eff[ind_lin] <- 1
K_eff <- pmin(K_eff, n_unique_X)
K_eff[K_eff < 4] <- 1
# first column is intercept
#B <- cbind(rep(1, n), matrix(nrow = n, ncol = sum(K_eff)))
B <- matrix(1, nrow = n, ncol = 1)
Rlist <- list()
lbreaks <- list()
# variable grouping, intercept not penalized gets NA
#index <- c(NA, rep(1:p, times = K_eff))
index <- NA
for (j in 1:p){
# number of breaks is number of basis functions minus order (4 by default) + 2
if(K_eff[j] >= 4){
breaks <- quantile(X[,j], probs=seq(0, 1, length.out = K_eff[j]-2))
breaks <- unique(breaks)
K_eff[j] <- length(breaks) + 2
lbreaks[[j]] <- breaks
Bj <- Bbasis(X[,j], breaks = breaks)
}
else{
lbreaks[[j]] <- NULL
Bj <- X[, j]
}
Rj.inv <- solve(chol(1/n*t(Bj) %*% Bj))
index <- c(index, rep(j, K_eff[j]))
B <- cbind(B, Bj %*% Rj.inv)
#B[, index == j & !is.na(index)] <- Bj %*% Rj.inv
Rlist[[j]] <- Rj.inv
}
QB <- Qf(B)
# calculate maximal lambda
lambdamax <- grplasso::lambdamax(QB, QY, index = index, model = grplasso::LinReg(),
center = FALSE, standardize = FALSE)
# lambdas for cross validation
lambda <- exp(seq(log(lambdamax), log(lambdamax/1000), length.out = n_lambda1))
lmodK[[i]] <- list(Rlist = Rlist, lbreaks = lbreaks, index = index, B = B,
QB = QB, lambda = lambda, K = K, K_eff = K_eff)
}
# generate folds for CV
ind <- sample(rep(1:nfolds, length.out = n), replace = FALSE)
# calculates mse on fold l and for a listK which has the form of a lmodK[[i]]
mse_fold_K <- function(l, listK){
test <- ind == l
# use capture.output to supress the output form grplasso
# use suppressWarnings to igrnore the warnings "Penalization not adjusted to non-penalized predictors"
# which we are aware of.
suppressWarnings(
mod <- grplasso::grplasso(listK$QB[!test, ], QY[!test], index = listK$index,
lambda = listK$lambda, model = grplasso::LinReg(),
center = FALSE, standardize = FALSE,
control = gprLassoControl)
)
QYpred <- predict(mod, newdata = listK$QB[test, ])
mse <- apply(QYpred, 2, function(y){mean((y - QY[test])^2)})
return(mse)
}
mse_fold <- function(l){
MSEl <- lapply(lmodK, function(listK){mse_fold_K(l, listK)})
return(unname(do.call(rbind, MSEl)))
}
if(verbose) print("Initial cross-validation")
if(mc.cores == 1){
MSES <- pbapply::pblapply(1:nfolds, mse_fold)
} else {
MSES <- parallel::mclapply(1:nfolds, mse_fold, mc.cores = mc.cores)
}
# aggregate MSEs over folds
MSES.agg <- Reduce("+", MSES) / nfolds
ind.min <- which(MSES.agg == min(MSES.agg), arr.ind = TRUE)
K.min <- vK[ind.min[1]]
lambda.min <- lmodK[[ind.min[1]]]$lambda[ind.min[2]]
# refit model for K.min and find best value for lambda in the neighborhood of lambda.min
modK.min <- lmodK[[ind.min[1]]]
modK.min$lambda <- exp(seq(log(lambda.min * 10), log(lambda.min/10),
length.out = n_lambda2))
if(verbose) print("Second stage cross-validation")
if(mc.cores == 1){
MSES1 <- pbapply::pblapply(1:nfolds, mse_fold_K, listK = modK.min)
} else {
MSES1 <- parallel::mclapply(1:nfolds, mse_fold_K, listK = modK.min,
mc.cores = mc.cores)
}
MSES1 <- do.call(rbind, MSES1)
MSE1.agg <- apply(MSES1, 2, mean)
se.agg <- 1/sqrt(nfolds) * apply(MSES1, 2, sd)
ind.min1 <- which.min(MSE1.agg)
if(cv_method == "min"){
lambdastar <- modK.min$lambda[ind.min1]
} else {
if(cv_method != "1se"){
warning("CV method not implemented. Taking '1se'.")
}
lambdastar <- max(modK.min$lambda[MSE1.agg <= MSE1.agg[ind.min1]+se.agg[ind.min1]])
}
## fit model on full data with K.min and lambdastar
suppressWarnings(
mod <- grplasso::grplasso(modK.min$QB, QY, index = modK.min$index,
lambda = lambdastar, model = grplasso::LinReg(),
center = FALSE, standardize = FALSE,
control = gprLassoControl)
)
# transform back to original scale
lcoef <- list()
active <- numeric()
running_ind <- 1
index <- modK.min$index
Rlist <- modK.min$Rlist
for(j in 1:p){
cj <- mod$coefficients[index == j & !is.na(index)]
if(sum(cj^2) != 0){
active <- c(active, j)
# transform back
lcoef[[j]] <- Rlist[[j]] %*% cj
}
}
intercept <- mod$coefficients[1]
lreturn <- list()
# original covariates
lreturn$X <- X
lreturn$p <- NCOL(X)
lreturn$var_names = colnames(data.frame(X))
# intercept
lreturn$intercept <-intercept
# number of basis functions for each component
lreturn$K <- modK.min$K_eff
# list of breaks of B-spline basis
lreturn$breaks <- modK.min$lbreaks
# list of coefficients
lreturn$coefs <- lcoef
# estimated active set
lreturn$active <- active
class(lreturn) <- "SDAM"
return(lreturn)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.