R/predict.R

Defines functions predict.cre

Documented in predict.cre

#' @title
#' Predict individual treatment effect via causal rule ensemble
#'
#' @description
#' Predicts individual treatment effect via causal rule ensemble algorithm.
#'
#' @param object A `cre` object from running the CRE function.
#' @param X A covariate matrix (or data.frame)
#' @param ... Additional arguments passed to customize the prediction.
#'
#' @return
#' An array with the estimated Individual Treatment Effects
#'
#' @export
#'
predict.cre <- function(object, X, ...) {
  if (is.null(object$rules)){
    ite_pred <- rep(object$CATE$Estimate[1], times = nrow(X))
  } else {
    rules_matrix <- generate_rules_matrix(X, object$rules)
    rownames(object$CATE) <- object$CATE$Rule
    ite_pred <- rules_matrix %*% as.matrix(object$CATE[2:nrow(object$CATE),]["Estimate"])
                + object$CATE$Estimate[1]
  }
  return(ite_pred)
}

Try the CRE package in your browser

Any scripts or data that you put into this service are public.

CRE documentation built on Oct. 19, 2024, 5:07 p.m.