R/formula_hal9001.R

Defines functions h `+.formula_hal9001` formula_hal

Documented in formula_hal h

#' HAL Formula: Convert formula or string to `formula_HAL` object.
#'
#' @param formula A `formula_hal9001` object as outputted by \code{h}.
#' @param smoothness_orders A default value for \code{s} if not provided
#'  explicitly to the function \code{h}.
#' @param num_knots A default value for \code{k} if not provided explicitly to
#'  the function \code{h}.
#' @param X Controls inheritance of the variable `X` from parent environment.
#'  When `NULL` (the default), such a variable is inherited.
#'
#' @importFrom stats as.formula
#'
#' @export
formula_hal <- function(formula, smoothness_orders, num_knots, X = NULL) {
  if (is.null(X)) {
    X <- get("X", envir = parent.frame())
  }

  if (!is.null(get0("smoothness_orders", envir = parent.frame())) &&
    missing(smoothness_orders)) {
    smoothness_orders <- get("smoothness_orders", envir = parent.frame())
  }

  if (!is.null(get0("num_knots", envir = parent.frame())) &&
    missing(num_knots)) {
    num_knots <- get("num_knots", envir = parent.frame())
  }
  num_knots <- num_knots
  smoothness_orders <- smoothness_orders

  terms <- as.character(stats::as.formula(formula))
  terms <- terms[length(terms)] # TODO: CHECK
  output <- eval(parse(text = terms))
  return(output)
}

#' HAL Formula addition: Adding formula term object together into a single
#' formula object term.
#'
#' @param x A `formula_hal9001` object as outputted by \code{h}.
#' @param y A `formula_hal9001` object as outputted by \code{h}.
#'
#' @export
`+.formula_hal9001` <- function(x, y) {
  if (length(x$covariates) != length(y$covariates) ||
    length(setdiff(x$covariates, y$covariates)) != 0) {
    stop("Order of `colnames(X)` must be the same for both terms in formula.")
  }
  keep <- !duplicated(c(x$basis_list, y$basis_list))
  formula_term <- paste0(x$formula_term, " + ", y$formula_term)
  out <- list(
    formula_term = formula_term,
    basis_list = c(x$basis_list, y$basis_list)[keep],
    penalty.factors = c(x$penalty.factors, y$penalty.factors)[keep],
    lower.limits = c(x$lower.limits, y$lower.limits)[keep],
    upper.limits = c(x$upper.limits, y$upper.limits)[keep],
    covariates = x$covariates
  )
  class(out) <- "formula_hal9001"
  return(out)
}

#' HAL Formula term: Generate a single term of the HAL basis
#'
#' @param ... Variables for which to generate multivariate interaction basis
#'  function where the variables can be found in a matrix `X` in a parent
#'  environment/frame. Note, just like standard \code{formula} objects, the
#'  variables should not be characters (e.g. do h(W1,W2) not h("W1", "W2"))
#'  h(W1,W2,W3) will generate three-way HAL basis functions between W1, W2, and
#'  W3. It will `not` generate the lower dimensional basis functions.
#' @param k The number of knots for each univariate basis function used to
#'  generate the tensor product basis functions. If a single value then this
#'  value is used for the univariate basis functions for each variable.
#'  Otherwise, this should be a variable named list that specifies for each
#'  variable how many knots points should be used.
#'  `h(W1,W2,W3, k = list(W1 = 3, W2 = 2, W3=1))` is equivalent to first
#'  binning the variables `W1`, `W2` and `W3` into `3`, `2` and `1` unique
#'  values and then calling `h(W1,W2,W3)`. This coarsening of the data ensures
#'  that fewer basis functions are generated, which can lead to substantial
#'  computational speed-ups. If not provided and the variable \code{num_knots}
#'  is in the parent environment, then \code{s} will be set to
#'  \code{num_knots}`.
#' @param s The \code{smoothness_orders} for the basis functions. The possible
#'  values are `0` for piece-wise constant zero-order splines or `1` for
#'  piece-wise linear first-order splines. If not provided and the variable
#'  \code{smoothness_orders} is in the parent environment, then \code{s} will
#'  be set to \code{smoothness_orders}.
#' @param pf A `penalty.factor` value the generated basis functions that is
#'  used by \code{glmnet} in the LASSO penalization procedure. `pf = 1`
#'  (default) is the standard penalization factor used by \code{glmnet} and
#'  `pf = 0` means the generated basis functions are unpenalized.
#' @param monotone Whether the basis functions should enforce monotonicity of
#'  the interaction term. If `\code{s} = 0`, this is monotonicity of the
#'  function, and, if `\code{s} = 1`, this is monotonicity of its derivative
#'  (e.g., enforcing a convex fit). Set `"none"` for no constraints, `"i"` for
#'  a monotone increasing constraint, and `"d"` for a monotone decreasing
#'  constraint. Using `"i"` constrains the basis functions to have positive
#'  coefficients in the fit, and `"d"` constrains the basis functions to have
#'  negative coefficients.
#' @param . Just like with \code{formula}, `.` as in `h(.)` or `h(.,.)` is
#'  treated as a wildcard variable that generates terms using all variables in
#'  the data. The argument \code{.} should be a character vector of variable
#'  names that `.` iterates over. Specifically,
#'  `h(., k=1, . = c("W1", "W2", "W3"))` is equivalent to
#'  `h(W1, k=1) + h(W2, k=1) + h(W3, k=1)`, and
#'  `h(., .,  k=1, . = c("W1", "W2", "W3"))` is equivalent to
#'  `h(W1,W2, k=1) + h(W2,W3, k=1) + h(W1, W3, k=1)`
#' @param dot_args_as_string Whether the arguments `...` are characters or
#'  character vectors and should thus be evaluated directly. When `TRUE`, the
#'  expression h("W1", "W2") can be used.
#' @param X An optional design matrix where the variables given in \code{...}
#'  can be found. Otherwise, `X` is taken from the parent environment.
#'
#' @importFrom stringr str_match str_split str_detect str_remove str_replace str_extract str_match_all
#' @importFrom assertthat assert_that
#'
#' @export
h <- function(..., k = NULL, s = NULL, pf = 1,
              monotone = c("none", "i", "d"), . = NULL,
              dot_args_as_string = FALSE, X = NULL) {
  monotone <- match.arg(monotone)
  if (is.null(X)) {
    # Get design matrix from parent environment
    X <- as.matrix(get("X", envir = parent.frame()))
  }
  if (is.null(.)) {
    . <- colnames(X)
  }

  if (!dot_args_as_string) {
    # Extract names of possibly nonexistant variables (e.g., formula)
    str <- (deparse(substitute(c(...))))
    str <- stringr::str_replace_all(str, " ", "")
    str <- str_match_all(str, "[,(]([^()]+)[,)]")[[1]][, -1]
    var_names <- str_split(str, ",")[[1]]
    # print(str)
    # var_names <- str_match_all(str,"[,(]([^(,]+)[,)]")[[1]][,-1]
    # print( str_match(str,"[,(]([^(,]+)[,)]"))
    #   print( str_match_all(str,"[,]([^(,]+)[,]"))
    # var_names <- c(var_names,str_match_all(str,"[,]([^(,]+)[,]")[[1]][,-1])
    var_names <- stringr::str_replace_all(var_names, " ", "")
  } else {
    var_names <- unlist(list(...))
  }
  formula_term <- paste0("h(", paste0(var_names, collapse = ", "), ")")

  if (is.null(k)) {
    k <- get("num_knots", envir = parent.frame())
    k <- suppressWarnings(k + rep(0, length(var_names))) # recycle
    k <- k[length(var_names)]
  }
  if (is.null(s)) {
    s <- get("smoothness_orders", envir = parent.frame())[1]
  }


  if ("." %in% var_names) {
    var_names_filled <- fill_dots(var_names, . = .)

    all_items <- lapply(var_names_filled, function(var) {
      h(var,
        k = k, s = s, pf = pf, monotone = monotone, . = .,
        dot_args_as_string = TRUE
      )
    })
    basis_all <- unlist(lapply(all_items, function(item) {
      item$basis_list
    }), recursive = F)
    penalty.factors_all <- unlist(lapply(all_items, function(item) {
      item$penalty.factors
    }))
    lower.limits_all <- unlist(lapply(all_items, function(item) {
      item$lower.limits
    }))
    upper.limits_all <- unlist(lapply(all_items, function(item) {
      item$upper.limits
    }))
    all_items <- list(
      formula_term = formula_term, basis_list = basis_all,
      penalty.factors = penalty.factors_all,
      lower.limits = lower.limits_all,
      upper.limits = upper.limits_all,
      covariates = colnames(X)
    )
    class(all_items) <- "formula_hal9001"
    return(all_items)
  }




  # Get corresponding column indices
  col_index <- match(var_names, colnames(X))
  lapply(seq_along(col_index), function(i) {
    var <- var_names[i]
    j <- col_index[i]

    if (!(length(k) == 1)) {
      tryCatch(
        {
          if (var %in% names(k)) {
            k <- unlist(k[var])
          } else {
            k <- unlist(k["."])
          }
        },
        error = function() {
          stop("k must be a variable named list or vector.")
        }
      )
    }
    x <- X[, j]
    bins <- quantile(x, seq(0, 1, length.out = k + 1))
    x <- bins[findInterval(x, bins, all.inside = TRUE)]
    X[, j] <<- x
  })


  basis_list_item <- make_basis_list(
    X[, col_index, drop = FALSE],
    col_index, rep(s, ncol(X))
  )
  penalty.factors <- rep(pf, length(basis_list_item))
  if (monotone == "i") {
    lower.limits <- rep(0, length(basis_list_item))
    upper.limits <- rep(Inf, length(basis_list_item))
  } else if (monotone == "d") {
    lower.limits <- rep(-Inf, length(basis_list_item))
    upper.limits <- rep(0, length(basis_list_item))
  } else {
    lower.limits <- rep(-Inf, length(basis_list_item))
    upper.limits <- rep(Inf, length(basis_list_item))
  }
  out <- list(
    formula_term = formula_term, basis_list = basis_list_item,
    penalty.factors = penalty.factors, lower.limits = lower.limits,
    upper.limits = upper.limits, covariates = colnames(X)
  )
  class(out) <- "formula_hal9001"
  return(out)
}


#' Print formula_hal9001 object
#'
#' @param x A formula_hal9001 object.
#' @param ... Other arguments (ignored).
#'
#' @export
print.formula_hal9001 <- function(x, ...) {
  cat(paste0("A hal9001 formula object of the form: ~ ", x$formula_term))
}

#'
#' @param var_names A \code{character} vector of variable names representing a single type of interaction
# " (e.g. var_names = c("W1", "W2", "W3") encodes three way interactions between W1, W2 and W3.
#' var_names may include the wildcard variable "." in which case the argument `.` must be specified
#' so that all interactions matching the form of var_names are generated.
#' @param . Specification of variables for use in the formula.
#'   This function takes a character vector `var_names` of the form c(name1, name2, ".", name3, ".")
#' with any number of name{int} variables and any number of wild card variables ".".
#' It returns a list of character vectors of the form c(name1, name2, wildcard1, name3, wildcard2)
#' where wildcard1 and wildcard2 are iterated over all possible character names given in the argument `.`.
#' @rdname formula_helpers
fill_dots <- function(var_names, .) {
  index <- which(var_names == ".")
  if (length(index) == 0) {
    return(sort(var_names))
  }
  len <- length(index)
  index <- min(which(var_names == "."))
  all_items <- lapply(., function(var) {
    new_var_names <- var_names
    new_var_names[index] <- var
    out <- fill_dots(new_var_names, .)
    return(out)
  })
  is_nested <- is.list(all_items[[1]])
  while (is_nested) {
    all_items <- unlist(all_items, recursive = FALSE)
    is_nested <- is.list(all_items[[1]])
  }
  # Remove combinations of variable names that have duplicates.
  # This removes generated interactions that include two of the same variable.
  keep <- sapply(all_items, function(item) {
    if (any(duplicated(item))) {
      return(FALSE)
    }
    return(TRUE)
  })
  all_items <- all_items[keep]

  return(unique(all_items))
}

Try the hal9001 package in your browser

Any scripts or data that you put into this service are public.

hal9001 documentation built on Nov. 14, 2023, 5:08 p.m.