Nothing
#--- PPI ALL OLS ---------------------------------------------------------------
#' PPI "All" OLS
#'
#' @description
#' Helper function for PPI "All" for OLS estimation
#'
#' @details
#' Another look at statistical inference with machine learning-imputed data
#' (Gronsbell et al., 2026) \doi{10.48550/arXiv.2411.19908}
#'
#' @param X_l (matrix): n x p matrix of covariates in the labeled data.
#'
#' @param Y_l (vector): n-vector of labeled outcomes.
#'
#' @param f_l (vector): n-vector of predictions in the labeled data.
#'
#' @param X_u (matrix): N x p matrix of covariates in the unlabeled data.
#'
#' @param f_u (vector): N-vector of predictions in the unlabeled data.
#'
#' @param w_l (ndarray, optional): Sample weights for the labeled data set.
#' Defaults to a vector of ones.
#'
#' @param w_u (ndarray, optional): Sample weights for the unlabeled
#' data set. Defaults to a vector of ones.
#'
#' @return (list): A list containing the following:
#'
#' \describe{
#' \item{est}{(vector): vector of PPI OLS regression coefficient
#' estimates.}
#' \item{se}{(vector): vector of standard errors of the coefficients.}
#' \item{rectifier_est}{(vector): vector of the rectifier OLS
#' regression coefficient estimates.}
#' }
#'
#' @examples
#'
#' dat <- simdat()
#'
#' form <- Y - f ~ X1
#'
#' X_l <- model.matrix(form, data = dat[dat$set_label == "labeled", ])
#'
#' Y_l <- dat[dat$set_label == "labeled", all.vars(form)[1]] |>
#'
#' matrix(ncol = 1)
#'
#' f_l <- dat[dat$set_label == "labeled", all.vars(form)[2]] |>
#'
#' matrix(ncol = 1)
#'
#' X_u <- model.matrix(form, data = dat[dat$set_label == "unlabeled", ])
#'
#' f_u <- dat[dat$set_label == "unlabeled", all.vars(form)[2]] |>
#'
#' matrix(ncol = 1)
#'
#' ppi_a_ols(X_l, Y_l, f_l, X_u, f_u)
#'
#' @import stats
#'
#' @export
ppi_a_ols <- function(
X_l,
Y_l,
f_l,
X_u,
f_u,
w_l = NULL,
w_u = NULL) {
f_a <- rbind(f_l, f_u)
X_a <- rbind(X_l, X_u)
#-- Sample Sizes
n <- NROW(X_l)
p <- NCOL(X_l)
N <- NROW(X_u)
n_total <- length(f_a)
#- 1. Prediction-Powered Estimator
theta_tilde_f <- solve(crossprod(X_a)) %*% t(X_a) %*% f_a
delta_hat_f <- solve(crossprod(X_l)) %*% t(X_l) %*% (f_l - Y_l)
theta_hat_pp <- theta_tilde_f - delta_hat_f
#- 2. Meat and Bread for Imputed Estimate
Sigma_tilde <- crossprod(X_a) / n_total
M_tilde <- vapply(seq_len(n_total),
function(i) {
(c(f_a[i] - crossprod(X_a[i, ], theta_tilde_f)))^2 *
tcrossprod(X_a[i, ])
},
FUN.VALUE = matrix(0, nrow = p^2)) |>
rowMeans() |>
matrix(nrow = p)
iSigma_tilde <- solve(Sigma_tilde)
#- 3. Sandwich Variance Estimator for Imputed Estimate
V_tilde <- iSigma_tilde %*% M_tilde %*% iSigma_tilde
#- 4. Meat and Bread for Empirical Rectifier
Sigma <- crossprod(X_l) / n
M <- vapply(seq_len(n),
function(i) {
(c(f_l[i] - Y_l[i] - crossprod(X_l[i, ], delta_hat_f)))^2 *
tcrossprod(X_l[i, ])
},
FUN.VALUE = matrix(0, nrow = p^2)) |>
rowMeans() |>
matrix(nrow = p)
iSigma <- solve(Sigma)
#- 5. Sandwich Variance Estimator for Empirical Rectifier
V <- iSigma %*% M %*% iSigma
#- 6. Standard Error Estimates
se <- sqrt(diag(V) / n + diag(V_tilde) / n_total)
#- Output
return(list(est = theta_hat_pp, se = se, rectifier_est = delta_hat_f))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.