R/feasible.r

Defines functions feasible

Documented in feasible

feasible <- function(X, Anodes = NULL, Ynodes = NULL, Lnodes = NULL, Cnodes = NULL,
                     abar = NULL,
                     alpha = 0.95, grid.size = 0.5, tol = 1e-2,
                     left.boundary = NULL, right.boundary = NULL,
                     screen = FALSE, survival = FALSE,
                     d.method = c("hazardbinning", "binning", "parametric", "hal_density"),
                     verbose = TRUE, ...) {
  
  #
  if (!is.data.frame(X)) stop("'X' must be a data frame")
  if (is.null(Anodes) || !is.character(Anodes)) stop("'Anodes' must be a character vector")
  if (is.null(Ynodes) || !is.character(Ynodes)) stop("'Ynodes' must be a character vector")
  if (!is.null(Lnodes) && !is.character(Lnodes)) stop("'Lnodes' must be a character vector or NULL")
  if (!is.null(Cnodes) && !is.character(Cnodes)) stop("'Cnodes' must be a character vector or NULL")
  if (is.null(abar)) stop("Please provide values for 'abar'")
  if (!is.numeric(abar)) stop("'abar' must be numeric")
  if (!is.numeric(alpha) || alpha <= 0 || alpha >= 1) stop("'alpha' must be a number in (0, 1)")
  if (!is.null(grid.size) && grid.size <= 0) stop("'grid.size' must be a positive number")
  if (!is.numeric(tol) || tol < 0) stop("'tol' must be a non-negative number")
  if (!is.null(left.boundary) && !is.numeric(left.boundary)) stop("'left.boundary' must be numeric or NULL")
  if (!is.null(right.boundary) && !is.numeric(right.boundary)) stop("'right.boundary' must be numeric or NULL")
  if (!is.logical(screen)) stop("'screen' must be logical (TRUE or FALSE)")
  if (!is.logical(survival)) stop("'survival' must be logical (TRUE or FALSE)")
  if (!is.logical(verbose)) stop("'verbose' must be logical (TRUE or FALSE)")
  if (!all(c(Ynodes, Lnodes, Anodes, Cnodes) %in% colnames(X))) {
    missing_vars <- setdiff(c(Ynodes, Lnodes, Anodes, Cnodes), colnames(X))
    stop("The following nodes are missing from X: ", paste(missing_vars, collapse = ", "))
  }
  if (min(which(colnames(X) %in% Ynodes)) < min(which(colnames(X) %in% Anodes))) {
    stop("Ynodes occur before Anodes. Check that you didn't specify pre-intervention variables as outcomes.")
  }
  if (max(which(colnames(X) %in% Anodes)) == ncol(X)) {
    stop("Anodes or Cnodes should not be in the last column of X.")
  }
  #
  d.method <- match.arg(d.method)
  #
  times <- length(Anodes)
  # 
  if (is.null(grid.size)) {
    if (is.matrix(abar)) {stop("If abar is a matrix, please specify grid.size.")}
    query_abar <- sort(unique(abar))
  } else {
    X_Avals <- unlist(X[, Anodes], use.names = FALSE)
    
    grid_min <- ifelse(
      is.null(left.boundary),
      min(c(X_Avals, abar), na.rm = TRUE),
      left.boundary
    )
    
    grid_max <- ifelse(
      is.null(right.boundary),
      max(c(X_Avals, abar), na.rm = TRUE),
      right.boundary
    )
    if (is.na(grid_min) || is.na(grid_max)) {
      stop("grid_min or grid_max is NA. Check for missing values in Anodes or abar.")
    }
    if (grid_min >= grid_max) {
      stop("left.boundary must be less than right.boundary.")
    }
    query_abar <- seq(from = grid_min, to = grid_max, by = grid.size)
    query_abar <- query_abar[!sapply(query_abar, function(x) any(abs(x - abar) < tol))]
    query_abar <- sort(unique(c(query_abar, abar)))
  }
  
  if (!is.matrix(abar)) {
    abar <- matrix(abar, ncol = times, nrow = length(abar), byrow = FALSE)
  }
  
  # Define bin cuts
  cuts <- (head(query_abar, -1) + tail(query_abar, -1)) / 2
  cuts <- c(cuts[1] - mean(diff(query_abar)), cuts, cuts[length(cuts)] + mean(diff(query_abar)))
  bin_length <- diff(cuts)
  
  #
  gdf <- get(d.method)
  #
  dots <- list(...)
  has_SL <- "SL.library" %in% names(dots)
  SL.library <- if (has_SL) dots$SL.library else NULL
  
  if (has_SL && d.method == "hal_density") {
    stop("Please omit 'SL.library' when using d.method = 'hal_density' (it is not supported for haldensify).")
  }
  gdf_formals <- names(formals(gdf))
  dot_names <- names(dots)
  dots_named <- if (is.null(dot_names)) {
    list()
  } else {
    dots[dot_names != "" & !is.na(dot_names)]
  }
  
  #
  has_dots <- "..." %in% gdf_formals
  
  if (has_dots) {
    dots_gdf <- dots_named
    ignored  <- character(0L)
  } else {
    keep     <- intersect(names(dots_named), gdf_formals)
    dots_gdf <- dots_named[keep]
    ignored  <- setdiff(names(dots_named), keep)
  }
  
  if (length(ignored) && isTRUE(verbose)) {
    warning("Ignoring arguments not used by '", d.method, "': ",
            paste(ignored, collapse = ", "))
  }
  #
  Aform.n <- make.model.formulas(X = X, Anodes = Anodes, Cnodes = Cnodes, Ynodes = Ynodes, survival = survival)
  if (screen) {
    Aform.n <- model.formulas.update(Aform.n$model.names, X, pw = verbose)
    Aform.n <- Aform.n$Anames
  } else {
    Aform.n <- Aform.n$model.names$Anames
  }
  Aform.d <- paste0(Anodes, "~1")
  #
  g.preds <- do.call(
    gdf,
    c(
      dots_gdf,
      list(
        form.n  = Aform.n,
        form.d = Aform.d,
        X        = X,
        Anodes   = Anodes,
        abar     = query_abar,
        verbose  = verbose
      )
    )
  )
  g.preds.n <- g.preds[[1]]
  
  # Normalize conditional densities
  den_reg <- lapply(seq_len(times), function(tt) {
    dens.t <- g.preds.n[[tt]]
    dens.norm <- t(apply(dens.t, 1, function(row) row / sum(row * bin_length)))
    colnames(dens.norm) <- NULL
    dens.norm
  })
  
  # Determine density threshold (f_alpha)
  falphas <- lapply(seq_len(times), function(tt) {
    dens.t <- den_reg[[tt]]
    apply(dens.t, 1, function(row) {
      cumsum_vals <- cumsum(sort(row))
      index_f <- order(row)[which(cumsum_vals > 1 - alpha)[1]]
      row[index_f]
    })
  })
  # Construct feasible intervention matrix
  low_matrix <- lapply(seq_len(times), function(tt) {
    den.t <- den_reg[[tt]]    # m x n matrix
    falpha <- falphas[[tt]]   # length-m vector
    sweep(den.t, 1, falpha, `<`)
  })
  # Construct feasible intervention matrix
  feasible_matrix <- lapply(seq_len(times), function(tt) {
    den.t <- den_reg[[tt]]
    falpha <- falphas[[tt]]
    n_bins <- ncol(den.t)
    
    # Function to replace a row vector according to falpha threshold
    replace_row <- function(row, threshold) {
      if (all(is.na(row)) || is.na(threshold)) {
        return(rep(NA_real_, n_bins))
      }
      
      below <- row < threshold
      above <- !below & !is.na(row)
      
      # Early exit: if no below or no above, return as is
      if (!any(below) || !any(above)) {
        return(query_abar)
      }
      
      # Closest match per below index
      idx_below <- which(below)
      idx_above <- which(above)
      closest_idx <- sapply(idx_below, function(idx) {
        idx_above[which.min(abs(idx - idx_above))]
      })
      
      # Replace
      result <- seq_len(n_bins)
      result[idx_below] <- closest_idx
      query_abar[result]
    }
    
    result_matrix <- t(mapply(replace_row, split(den.t, row(den.t)), falpha))
    result_matrix[, findInterval(abar[, tt], cuts), drop = FALSE]
  })
  
  
  # Summarize low-density overlap and mean feasible values
  summary <- lapply(seq_len(times), function(tt) {
    i <- (den_reg[[tt]] < falphas[[tt]])[, findInterval(abar[, tt], cuts), drop = FALSE]
    j <- feasible_matrix[[tt]]
    data.frame(
      time = tt,
      Strategy = seq_along(abar[, tt]),
      Abar = abar[, tt],
      Feasible = colMeans(j, na.rm = TRUE),
      Low = colMeans(i, na.rm = TRUE)
    )
  })
  summary_dat <- do.call(rbind, summary)
  rownames(summary_dat) <- NULL
  
  # Create list of feasible values per target abar
  strategy_list <- vector("list", length = nrow(abar))
  for (i in seq_len(nrow(abar))) {
    strategy_list[[i]] <- do.call(cbind, lapply(feasible_matrix, function(mat) mat[, i]))
    colnames(strategy_list[[i]]) <- abar[i, ]
  }
  
  obj <- list(
    feasible   = strategy_list,
    low_matrix = low_matrix
  )
  attr(obj, "summary") <- summary_dat
  class(obj) <- "feasible"
  
  return(obj)
}

Try the CICI package in your browser

Any scripts or data that you put into this service are public.

CICI documentation built on April 7, 2026, 5:08 p.m.