R/wlasso.R

Defines functions wlasso

Documented in wlasso

#' Weighted LASSO prediction models for complex survey data
#'
#'@description This function allows as to fit LASSO prediction (linear or logistic) models to complex survey data, considering sampling weights in the estimation process and selects the lambda that minimizes the error based on different replicating weights methods.
#'
#' @param data A data frame with information about the response variable and covariates, as well as sampling weights and strata and cluster indicators. It could be \code{NULL} if the sampling design is indicated in the \code{design} argument.
#' @param col.y A numeric value indicating the number of the column in which information on the response variable can be found or a character string indicating the name of that column.
#' @param col.x A numeric vector indicating the numbers of the columns in which information on the covariates can be found or a vector of character strings indicating the names of these columns.
#' @param cluster A character string indicating the name of the column with cluster identifiers. It could be \code{NULL} if the sampling design is indicated in the \code{design} argument.
#' @param strata A character string indicating the name of the column with strata identifiers. It could be \code{NULL} if the sampling design is indicated in the \code{design} argument.
#' @param weights A character string indicating the name of the column with sampling weights. It could be \code{NULL} if the sampling design is indicated in the \code{design} argument.
#' @param design An object of class \code{survey.design} generated by \code{survey::svydesign()}. It could be \code{NULL} if information about \code{cluster}, \code{strata}, \code{weights} and \code{data} are given.
#' @param family A character string indicating the family to fit LASSO models. Choose between \code{gaussian} (to fit linear models) or \code{binomial} (for logistic models).
#' @param lambda.grid A numeric vector indicating a grid for penalization parameters. The default option is \code{lambda.grid = NULL}, which considers the default grid selected by the function \code{glmnet::glmnet()}.
#' @param method A character string indicating the method to be applied to define replicate weights. Choose between one of these: \code{JKn}, \code{dCV}, \code{bootstrap}, \code{subbootstrap}, \code{BRR}, \code{split}, \code{extrapolation}.
#' @param k A numeric value indicating the number of folds to be defined. Default is \code{k=10}. Only applies for the \code{dCV} method.
#' @param R A numeric value indicating the number of times the sample is partitioned. Default is \code{R=1}. Only applies for \code{dCV}, \code{split} or \code{extrapolation} methods.
#' @param B A numeric value indicating the number of bootstrap resamples. Default is \code{B=200}. Only applies for \code{bootstrap} and  \code{subbootstrap} methods.
#' @param dCV.sw.test A logical value indicating the method for estimating the error for \code{dCV} method. \code{FALSE}, (the default option) estimates the error for each test set and defines the cross-validated error based on the average strategy. Option \code{TRUE} estimates the cross-validated error based on the pooling strategy
#' @param train.prob A numeric value between 0 and 1, indicating the proportion of clusters (for the method \code{split}) or strata (for the method \code{extrapolation}) to be set in the training sets. Default is \code{train.prob = 0.7}. Only applies for \code{split} and \code{extrapolation} methods.
#' @param method.split A character string indicating the way in which replicate weights should be defined in the \code{split} method. Choose one of the following: \code{dCV}, \code{bootstrap} or \code{subbootstrap}. Only applies for \code{split} method.
#' @param print.rw A logical value. If \code{TRUE}, the data set with the replicate weights is saved in the output object. Default \code{print.rw=FALSE}.
#'
#' @importFrom graphics abline mtext
#' @importFrom stats as.formula coef predict runif
#'
#' @return The output object of the function \code{wlasso()} is an object of class \code{wlasso}. This object is a list containing 4 or 5 elements, depending on the value set to the argument \code{print.rw}. Below we describe the contents of these elements:
#' - `lambda`: A list containing information of two elements:
#'   - `grid`: A numeric vector indicating all the values considered for the tuning parameter.
#'   - `min`: A numeric value indicating the value of the tuning parameter that minimizes the average error (i.e., selected optimal tuning parameter).
#' - `error`: A list containing information of two elements:
#'   - `average`: A numeric vector indicating the average error corresponding to each tuning parameter.
#'   - `all`: A numeric matrix indicating the error of each test set for each tuning parameter.
#' - `model`: A list containing information of two elements in relation to the fitted models. Note that all these models are fitted considering the whole data set (and not uniquely the training sets).
#'   - `grid`: A list with the information about the models fitted for each of the tuning parameters considered (i.e., all the values in the \code{lambda$grid} object):
#'     - `a0`: a numeric vector of model intercepts across the whole grid of tuning parameters (hence, of the same length as \code{lambda$grid}).
#'     - `beta`: a matrix of regression coefficients corresponding to all the considered covariates across the whole grid of tuning parameters (the number of rows is equal to the number of covariates considered and the number of columns to the length of \code{lambda$grid}).
#'     - `df`: a numeric vector of the degrees of freedom (i.e., the number of coefficients different from zero) across the whole grid of tuning parameters  (hence, of the same length as \code{lambda$grid}).
#'   - `min`: A list with the information about the model fitted considering uniquely the tuning parameter that minimizes the error in the training models (i.e., the optimal tuning parameter selected between the elements in \code{lambda$grid}):
#'     - `a0`: a numeric value indicating the intercept value of the selected model.
#'     - `beta`: a matrix of regression coefficients corresponding to all the considered covariates for the selected tuning parameters (the number of rows is equal to the number of covariates considered and the number of columns is one).
#'     - `df`: a numeric value indicating the degrees of freedom (i.e., the number of coefficients different from zero) of the selected model.
#' - `data.rw`: A data frame containing the original data set and the replicate weights added to define training and test sets. Only included in the output object if \code{print.rw=TRUE}.
#' - `call`: an object containing the information about the way in which the function has been run.
#' @export
#'
#' @examples
#' data(simdata_lasso_binomial)
#' mcv <- wlasso(data = simdata_lasso_binomial,
#'               col.y = "y", col.x = 1:50,
#'               family = "binomial",
#'               cluster = "cluster", strata = "strata", weights = "weights",
#'               method = "dCV", k=10, R=1)
#'
#' # Or equivalently:
#' \donttest{
#' mydesign <- survey::svydesign(ids=~cluster, strata = ~strata, weights = ~weights,
#'                               nest = TRUE, data = simdata_lasso_binomial)
#' mcv <- wlasso(col.y = "y", col.x = 1:50, design = mydesign,
#'               family = "binomial",
#'               method = "dCV", k=10, R=1)
#' }


wlasso <- function(data = NULL, col.y = NULL, col.x = NULL,
                   cluster = NULL, strata = NULL, weights = NULL, design = NULL,
                   family = c("gaussian", "binomial"),
                   lambda.grid = NULL,
                   method = c("dCV", "JKn", "bootstrap", "subbootstrap", "BRR", "split", "extrapolation"),
                   k = 10, R = 1, B = 200,
                   dCV.sw.test = FALSE,
                   train.prob = 0.7, method.split = c("dCV", "bootstrap", "subbootstrap"),
                   print.rw = FALSE){

  # Stops and messages:
  if(is.null(data) & is.null(design)){stop("Information about either the data set ('data') or the sampling design ('design') needed.")}

  if(method == "split"){
    if(is.null(train.prob)){stop("Selected replicate weights method: 'split'.\nPlease, set a value between 0 and 1 for the argument 'train.prob'.")}
    if(train.prob < 0 | train.prob > 1){stop("Selected replicate weights method: 'split'.\nPlease, set a value between 0 and 1 for the argument 'train.prob'.")}
    if(length(method.split)!=1){stop("Selected replicate weights method: 'split'.\nPlease, set a valid method for the argument 'method.split'. Choose between: 'dCV', 'bootstrap' or 'subbootstrap'.")}
  }

  if(method == "extrapolation"){
    if(is.null(train.prob)){stop("Selected replicate weights method: 'extrapolation'.\nPlease, set a value between 0 and 1 for the argument 'train.prob'.")}
    if(train.prob < 0 | train.prob > 1){stop("Selected replicate weights method: 'extrapolation'.\nPlease, set a value between 0 and 1 for the argument 'train.prob'.")}
  }

  if(method %in% c("JKn", "bootstrap", "subbootstrap", "BRR")){
    if(R!=1){message("Selected method:", method,". For this method, R = 1. Thus, the argument R =",R, "has been ignored.")}
  }

  if(method %in% c("dCV", "split", "extrapolation")){
    if(R != round(R)){stop("The argument 'R' must be an integer greater or equal to 1. R=",R," is not an integer.\nPlease, set a valid value for 'R' or skip the argument to select the default option R=1.")}
    if(R < 1){stop("The argument 'R' must be an integer greater or equal to 1. R=",R," lower than 1.\nPlease, set a valid value for 'R' or skip the argument to select the default option R=1.")}
  }

  if(method != "dCV"){
    if(!is.null(k) & k!=10){message("Selected method:", method,". The argument k =",k, "is not needed and, hence, has been ignored.")}
  }

  if(method == "dCV"){
    if(k != round(k)){stop("The argument 'k' must be an integer. k=",k," is not an integer.\nPlease, set a valid value for 'k' or skip the argument to select the default option k=10.")}
    if(k < 1){stop("The argument 'k' must be a positive integer. k=",k," is not a positive integer.\nPlease, set a valid value for 'k' or skip the argument to select the default option k=10.")}
  }

  if(!(method %in% c("bootstrap", "subbootstrap"))){
    if(!is.null(B) & B!=200){message("Selected method:", method,". The argument B=",B, " is not needed and, hence, has been ignored.")}
  }

  if(method %in% c("bootstrap", "subbootstrap")){
    if(B != round(B)){stop("The argument 'B' must be an integer. B=",B," is not an integer.\nPlease, set a valid value for 'B' or skip the argument to select the default option B=200.")}
    if(B < 1){stop("The argument 'B' must be a positive integer. B=",B," is not a positive integer.\nPlease, set a valid value for 'B' or skip the argument to select the default option B=200.")}
  }



  # Step 0: Notation
  if(!is.null(design)){
    cluster <- as.character(design$call$id[2])
    if(cluster == "1" || cluster == "0"){
      cluster <- NULL
    }
    strata <- as.character(design$call$strata[2])
    weights <- as.character(design$call$weights[2])
    data <- get(design$call$data)
  }


  # Step 1: Generate replicate weights based on the method
  newdata <- replicate.weights(data = data, method = method,
                               cluster = cluster, strata = strata, weights = weights,
                               k = k, R = R, B = B,
                               train.prob = train.prob, method.split = method.split,
                               rw.test = TRUE, dCV.sw.test = dCV.sw.test)


  # Step 2: if is.null(lambda.grid), then initialize it
  if(is.null(lambda.grid)){
    model.orig <- glmnet::glmnet(y = as.numeric(newdata[,col.y]),
                                 x = as.matrix(newdata[,col.x]),
                                 weights = as.numeric(newdata[,weights]),
                                 family = family)
    lambda.grid <- model.orig$lambda
  } else {
    model.orig <- glmnet::glmnet(y = as.numeric(newdata[,col.y]),
                                 x = as.matrix(newdata[,col.x]),
                                 weights = as.numeric(newdata[,weights]),
                                 family = family,
                                 lambda = lambda.grid)
  }

  # Step 3: Fit the training models and estimate yhat for units in the sample
  rwtraincols <- grep("_train", colnames(newdata))
  l.yhat <- list()

  for(col.w in rwtraincols){

    model <- glmnet::glmnet(y = as.numeric(newdata[,col.y]),
                            x = as.matrix(newdata[,col.x]),
                            weights = as.numeric(newdata[,col.w]),
                            lambda = lambda.grid,
                            family = family)

    # Sample yhat
    yhat <- predict(model, newx=as.matrix(newdata[,col.x]), type = "response")
    l.yhat[[length(l.yhat) + 1]] <- yhat
    names(l.yhat)[[length(l.yhat)]] <- paste0("yhat_", colnames(newdata)[col.w])

  }

  # Step 4: estimate the error in the test sets
  error <- error.f(data = newdata, l.yhat = l.yhat,
                   method = method, cv.error.ind = dCV.sw.test,
                   R = R, k = k, B = B,
                   col.y = col.y, family = family, weights = weights)
  mean.error <- apply(error, 2, mean)

  lambda.min <- lambda.grid[which.min(mean.error)]

  model <- glmnet::glmnet(y = data[,col.y],
                          x = as.matrix(data[,col.x]),
                          weights = data[,weights],
                          lambda = lambda.min,
                          family = family)

  result <- list()
  result$lambda <- list(grid = lambda.grid,
                        min = lambda.min)
  result$error <- list(average = mean.error,
                       all = error)
  result$model <- list()
  result$model$grid <- list(a0 = model.orig$a0,
                            beta = model.orig$beta,
                            df = model.orig$df)
  result$model$min <- list(a0 = model$a0,
                           beta = model$beta,
                           df = model$df)
  result$call <- match.call()

  if(print.rw == TRUE){result$data.rw <- newdata}

  class(result) <- "wlasso"

  return(result)

}

Try the svyVarSel package in your browser

Any scripts or data that you put into this service are public.

svyVarSel documentation built on Oct. 15, 2024, 5:06 p.m.