R/policy_eval.R

Defines functions policy_eval_fold policy_eval_cross policy_eval_type policy_eval

Documented in policy_eval

#' Policy Evaluation
#'
#' \code{policy_eval()} is used to estimate the value of a given fixed policy
#' or a data adaptive policy (e.g. a policy learned from the data).
#' @param policy_data Policy data object created by [policy_data()].
#' @param policy Policy object created by [policy_def()].
#' @param policy_learn Policy learner object created by [policy_learn()].
#' @param g_models List of action probability models/g-models for each stage
#' created by [g_empir()], [g_glm()], [g_rf()], [g_sl()] or similar functions.
#' Only used for evaluation if \code{g_functions} is \code{NULL}.
#' If a single model is provided and \code{g_full_history} is \code{FALSE},
#' a single g-model is fitted across all stages. If \code{g_full_history} is
#' \code{TRUE} the model is reused at every stage.
#' @param q_models Outcome regression models/Q-models created by
#' [q_glm()], [q_rf()], [q_sl()] or similar functions.
#' Only used for evaluation if \code{q_functions} is \code{NULL}.
#' If a single model is provided, the model is reused at every stage.
#' @param g_functions Fitted g-model objects, see [nuisance_functions].
#' Preferably, use \code{g_models}.
#' @param q_functions Fitted Q-model objects, see [nuisance_functions].
#' Only valid if the Q-functions are fitted using the same policy.
#' Preferably, use \code{q_models}.
#' @param g_full_history If TRUE, the full history is used to fit each g-model.
#' If FALSE, the state/Markov type history is used to fit each g-model.
#' @param q_full_history Similar to g_full_history.
#' @param save_g_functions If TRUE, the fitted g-functions are saved.
#' @param save_q_functions Similar to save_g_functions.
#' @param M Number of folds for cross-fitting.
#' @param type Type of evaluation (dr/doubly robust, ipw/inverse propensity
#' weighting, or/outcome regression).
#' @param future_args Arguments passed to [future.apply::future_apply()].
#' @param name Character string.
#' @param object,x,y Objects of class "policy_eval".
#' @param labels Name(s) of the estimate(s).
#' @param paired \code{TRUE} indicates that the estimates are based on
#' the same data sample.
#' @param ... Additional arguments.
#' @returns \code{policy_eval()} returns an object of class "policy_eval".
#' The object is a list containing the following elements:
#' \item{\code{value_estimate}}{Numeric. The estimated value of the policy.}
#' \item{\code{type}}{Character string. The type of evaluation ("dr", "ipw",
#' "or").}
#' \item{\code{IC}}{Numeric vector. Estimated influence curve associated with
#' the value estimate.}
#' \item{\code{value_estimate_ipw}}{(only if \code{type = "dr"}) Numeric.
#' The estimated value of the policy based on inverse probability weighting.}
#' \item{\code{value_estimate_or}}{(only if \code{type = "dr"}) Numeric.
#' The estimated value of the policy based on outcome regression.}
#' \item{\code{id}}{Character vector. The IDs of the observations.}
#' \item{\code{policy_actions}}{[data.table] with keys id and stage. Actions
#' associated with the policy for every observation and stage.}
#' \item{\code{policy_object}}{(only if \code{policy = NULL} and \code{M = 1})
#' The policy object returned by \code{policy_learn}, see [policy_learn].}
#' \item{\code{g_functions}}{(only if \code{M = 1}) The
#' fitted g-functions. Object of class "nuisance_functions".}
#' \item{\code{g_values}}{The fitted g-function values.}
#' \item{\code{q_functions}}{(only if \code{M = 1}) The
#' fitted Q-functions. Object of class "nuisance_functions".}
#' \item{\code{q_values}}{The fitted Q-function values.}
#' \item{\code{cross_fits}}{(only if \code{M > 1}) List containing the
#' "policy_eval" object for every (validation) fold.}
#' \item{\code{folds}}{(only if \code{M > 1}) The (validation) folds used
#' for cross-fitting.}
#' @section S3 generics:
#' The following S3 generic functions are available for an object of
#' class \code{policy_eval}:
#' \describe{
#' \item{[get_g_functions()]}{ Extract the fitted g-functions.}
#' \item{[get_q_functions()]}{ Extract the fitted Q-functions.}
#' \item{[get_policy()]}{ Extract the fitted policy object.}
#' \item{[get_policy_functions()]}{ Extract the fitted policy function for
#'                                 a given stage.}
#' \item{[get_policy_actions()]}{ Extract the (fitted) policy actions.}ps
#' \item{[plot.policy_eval()]}{Plot diagnostics.}
#' }
#' @seealso [lava::IC], [lava::estimate.default].
#' @details
#' Each observation has the sequential form
#' \deqn{O= {B, U_1, X_1, A_1, ..., U_K, X_K, A_K, U_{K+1}},}
#' for a possibly stochastic number of stages K.
#' \itemize{
#'  \item \eqn{B} is a vector of baseline covariates.
#'  \item \eqn{U_k} is the reward at stage k (not influenced by the action \eqn{A_k}).
#'  \item \eqn{X_k} is a vector of state covariates summarizing the state at stage k.
#'  \item \eqn{A_k} is the categorical action within the action set \eqn{\mathcal{A}} at stage k.
#' }
#' The utility is given by the sum of the rewards, i.e.,
#' \eqn{U = \sum_{k = 1}^{K+1} U_k}.
#'
#' A policy is a set of functions
#' \deqn{d = \{d_1, ..., d_K\},}
#' where \eqn{d_k} for \eqn{k\in \{1, ..., K\}} maps \eqn{\{B, X_1, A_1, ..., A_{k-1}, X_k\}} into the
#' action set.
#'
#' Recursively define the Q-models (\code{q_models}):
#' \deqn{Q^d_K(h_K, a_K) = E[U|H_K = h_K, A_K = a_K]}
#' \deqn{Q^d_k(h_k, a_k) = E[Q_{k+1}(H_{k+1}, d_{k+1}(B,X_1, A_1,...,X_{k+1}))|H_k = h_k, A_k = a_k].}
#' If \code{q_full_history = TRUE},
#' \eqn{H_k = \{B, X_1, A_1, ..., A_{k-1}, X_k\}}, and if
#' \code{q_full_history = FALSE}, \eqn{H_k = \{B, X_k\}}.
#'
#' The g-models (\code{g_models}) are defined as
#' \deqn{g_k(h_k, a_k) = P(A_k = a_k|H_k = h_k).}
#' If \code{g_full_history = TRUE},
#' \eqn{H_k = \{B, X_1, A_1, ..., A_{k-1}, X_k\}}, and if
#' \code{g_full_history = FALSE}, \eqn{H_k = \{B, X_k\}}.
#' Furthermore, if \code{g_full_history = FALSE} and \code{g_models} is a
#' single model, it is assumed that \eqn{g_1(h_1, a_1) = ... = g_K(h_K, a_K)}.
#'
#' If \code{type = "or"} \code{policy_eval} returns the empirical estimates of
#' the value (\code{value_estimate}):
#' \deqn{E[Q^d_1(H_1, d_1(...))]}
#' for an appropriate input \eqn{...} to the policy.
#'
#' If \code{type = "ipw"} \code{policy_eval} returns the empirical estimates of
#' the value (\code{value_estimate}) and score (\code{IC}):
#' \deqn{E[(\prod_{k=1}^K I\{A_k = d_k(...)\} g_k(H_k, A_k)^{-1}) U].}
#' \deqn{(\prod_{k=1}^K I\{A_k = d_k(...)\} g_k(H_k, A_k)^{-1}) U - E[(\prod_{k=1}^K I\{A_k = d_k(...)\} g_k(H_k, A_k)^{-1}) U].}
#'
#' If \code{type = "dr"} \code{policy_eval} returns the empirical estimates of
#' the value (\code{value_estimate}) and influence curve (\code{IC}):
#' \deqn{E[Z^d_1],}
#' \deqn{Z^d_1 - E[Z^d_1],}
#' where
#' \deqn{
#' Z^d_1 = Q^d_1(H_1 , d_1(...)) + \sum_{r = 1}^K \prod_{j = 1}^{r}
#' \frac{I\{A_j = d_j(...)\}}{g_{j}(H_j, A_j)}
#' \{Q_{r+1}^d(H_{r+1} , d_{r+1}(...)) - Q_{r}^d(H_r , d_r(...))\}.
#' }
#' @references
#' van der Laan, Mark J., and Alexander R. Luedtke. "Targeted learning of the
#' mean outcome under an optimal dynamic treatment rule." Journal of causal
#' inference 3.1 (2015): 61-95. \doi{10.1515/jci-2013-0022}\cr
#' \cr
#' Tsiatis, Anastasios A., et al. Dynamic treatment regimes: Statistical methods
#'for precision medicine. Chapman and Hall/CRC, 2019. \doi{10.1201/9780429192692}.
#' @export
#' @examples
#' library("polle")
#' ### Single stage:
#' d1 <- sim_single_stage(5e2, seed=1)
#' pd1 <- policy_data(d1, action="A", covariates=list("Z", "B", "L"), utility="U")
#' pd1
#'
#' # defining a static policy (A=1):
#' pl1 <- policy_def(1)
#'
#' # evaluating the policy:
#' pe1 <- policy_eval(policy_data = pd1,
#'                    policy = pl1,
#'                    g_models = g_glm(),
#'                    q_models = q_glm(),
#'                    name = "A=1 (glm)")
#'
#' # summarizing the estimated value of the policy:
#' # (equivalent to summary(pe1)):
#' pe1
#' coef(pe1) # value coefficient
#' sqrt(vcov(pe1)) # value standard error
#'
#' # getting the g-function and Q-function values:
#' head(predict(get_g_functions(pe1), pd1))
#' head(predict(get_q_functions(pe1), pd1))
#'
#' # getting the fitted influence curve (IC) for the value:
#' head(IC(pe1))
#'
#' # evaluating the policy using random forest nuisance models:
#' set.seed(1)
#' pe1_rf <- policy_eval(policy_data = pd1,
#'                       policy = pl1,
#'                       g_models = g_rf(),
#'                       q_models = q_rf(),
#'                       name = "A=1 (rf)")
#'
#' # merging the two estimates (equivalent to pe1 + pe1_rf):
#' (est1 <- merge(pe1, pe1_rf))
#' coef(est1)
#' head(IC(est1))
#'
#' ### Two stages:
#' d2 <- sim_two_stage(5e2, seed=1)
#' pd2 <- policy_data(d2,
#'                    action = c("A_1", "A_2"),
#'                    covariates = list(L = c("L_1", "L_2"),
#'                                      C = c("C_1", "C_2")),
#'                    utility = c("U_1", "U_2", "U_3"))
#' pd2
#'
#' # defining a policy learner based on cross-fitted doubly robust Q-learning:
#' pl2 <- policy_learn(type = "drql",
#'                     control = control_drql(qv_models = list(q_glm(~C_1),
#'                                                             q_glm(~C_1+C_2))),
#'                     full_history = TRUE,
#'                     L = 2) # number of folds for cross-fitting
#'
#' # evaluating the policy learner using 2-fold cross fitting:
#' pe2 <- policy_eval(type = "dr",
#'                    policy_data = pd2,
#'                    policy_learn = pl2,
#'                    q_models = q_glm(),
#'                    g_models = g_glm(),
#'                    M = 2, # number of folds for cross-fitting
#'                    name = "drql")
#' # summarizing the estimated value of the policy:
#' pe2
#'
#' # getting the cross-fitted policy actions:
#' head(get_policy_actions(pe2))
policy_eval <- function(policy_data,
                        policy = NULL, policy_learn = NULL,
                        g_functions = NULL, g_models = g_glm(), g_full_history = FALSE, save_g_functions = TRUE,
                        q_functions = NULL, q_models = q_glm(), q_full_history = FALSE, save_q_functions = TRUE,
                        type = "dr",
                        M = 1, future_args = list(future.seed=TRUE),
                        name = NULL
                        ) {
  args <- as.list(environment())
  args[["policy_data"]] <- NULL
  args[["M"]] <- NULL
  args[["future_args"]] <- NULL
  args[["name"]] <- NULL

  # input checks:
  if (!inherits(policy_data, what = "policy_data"))
    stop("policy_data must be of inherited class 'policy_data'.")
  if (!is.null(policy)){
    if (!inherits(policy, what = "policy"))
      stop("policy must be of inherited class 'policy'.")
  }
  if ((is.null(policy) & is.null(policy_learn)) |
      (!is.null(policy_learn) & !is.null(policy)))
    stop("Provide either policy or policy_learn.")
  if (is.null(policy) & !is.null(policy_learn)){
    if (!inherits(policy_learn, what = "policy_learn"))
      stop("policy_learn must be of inherited class 'policy_learn'.")
  }
  if (!is.null(g_functions)){
    if(!(inherits(g_functions, "g_functions")))
      stop("g_functions must be of class 'g_functions'.")
  }
  if (!(is.logical(g_full_history) & (length(g_full_history) == 1)))
    stop("g_full_history must be TRUE or FALSE")
  if (!is.null(q_functions)){
    if(!(inherits(q_functions, "q_functions")))
      stop("q-functions must be of class 'q_functions'.")
  }
  if (!(is.logical(q_full_history) & (length(q_full_history) == 1)))
    stop("q_full_history must be TRUE or FALSE")
  if (!(is.numeric(M) & (length(M) == 1)))
    stop("M must be an integer greater than 0.")
  if (!(M %% 1 == 0))
    stop("M must be an integer greater than 0.")
  if (M<=0)
    stop("M must be an integer greater than 0.")
  if (!is.list(future_args))
    stop("future_args must be a list.")
  if (!is.null(name)){
    name <- as.character(name)
    if (length(name) != 1)
      stop("name must be a character string.")
  }


  if (M > 1){
    val <- policy_eval_cross(args = args,
                                    policy_data = policy_data,
                                    M = M,
                                    future_args = future_args)
  } else {
    args[["train_policy_data"]] <- policy_data
    args[["valid_policy_data"]] <- policy_data
    val <- do.call(what = policy_eval_type, args = args)
  }
  if (is.null(name)){
    if (!is.null(policy))
      val$name <- attr(policy, "name")
    else
      val$name <- attr(policy_learn, "name")
  } else
    val$name <- name

  return(val)
}

policy_eval_type <- function(type,
                             train_policy_data,
                             valid_policy_data,
                             policy, policy_learn,
                             g_models, g_functions, g_full_history, save_g_functions,
                             q_models, q_functions, q_full_history, save_q_functions){

  type <- tolower(type)
  if (length(type) != 1)
    stop("type must be a character string.")

  if (type %in% c("dr", "aipw")){
    type <- "dr"
  } else if (type %in% c("ipw")){
    type <- "ipw"
  } else if (type %in% c("or", "q")) {
    type <- "or"
  } else{
    stop("type must be either 'dr', 'ipw' or  'or'.")
  }

  # fitting the g-functions, the q-functions and the policy (functions):
  fits <- fit_functions(policy_data = train_policy_data,
                        type = type,
                        policy = policy, policy_learn = policy_learn,
                        g_models = g_models, g_functions = g_functions, g_full_history = g_full_history,
                        q_models = q_models, q_functions = q_functions, q_full_history = q_full_history)

  # getting the fitted policy and associated actions:
  if (is.null(policy)){
    policy <- get_policy(getElement(fits, "policy_object"))
  }

  # calculating the doubly robust score and value estimate:
  g_functions <- getElement(fits, "g_functions")
  q_functions <- getElement(fits, "q_functions")
  value_object <- value(type = type,
                        policy_data = valid_policy_data,
                        policy = policy,
                        g_functions = g_functions,
                        q_functions = q_functions)
  g_values <- getElement(value_object, "g_values")
  q_values <- getElement(value_object, "q_values")

  # setting g-functions output:
  if (save_g_functions != TRUE){
    g_functions <- NULL
  }
  # setting Q-functions output:
  if(save_q_functions != TRUE){
    q_functions <- NULL
  }

  out <- list(
    value_estimate = getElement(value_object, "value_estimate"),
    type = type,
    IC = getElement(value_object, "IC"),
    value_estimate_ipw = getElement(value_object, "value_estimate_ipw"),
    value_estimate_or = getElement(value_object, "value_estimate_or"),
    id = get_id(valid_policy_data),
    policy_actions = getElement(value_object, "policy_actions"),
    policy_object = getElement(fits, "policy_object"),
    g_functions = g_functions,
    g_values = g_values,
    q_functions = q_functions,
    q_values = q_values
  )
  out <- Filter(Negate(is.null), out)

  class(out) <- c("policy_eval")
  return(out)
}

policy_eval_cross <- function(args,
                              policy_data,
                              M,
                              future_args){
  n <- get_n(policy_data)
  id <- get_id(policy_data)

  # setting up the folds
  folds <- split(sample(1:n, n), rep(1:M, length.out = n))
  folds <- lapply(folds, sort)
  names(folds) <- paste("fold_", 1:M, sep = "")

  prog <- progressor(along = folds)
  cross_args <- append(list(X = folds,
                            FUN = policy_eval_fold,
                            policy_data = policy_data,
                            args = args,
                            prog = prog),
                       future_args)

  # cross fitting the policy evaluation using the folds:
  cross_fits <- do.call(what = future.apply::future_lapply, cross_args)

  # collecting ids:
  id <- unlist(lapply(cross_fits, function(x) getElement(x, "id")), use.names = FALSE)

  # collecting the value estimates:
  n <- unlist(lapply(cross_fits, function(x) length(getElement(x, "id"))))
  value_estimate <- unlist(lapply(cross_fits, function(x) getElement(x, "value_estimate")))
  value_estimate <- sum((n / sum(n)) * value_estimate)

  # collecting the IC decompositions:
  IC <- unlist(lapply(cross_fits, function(x) getElement(x, "IC")), use.names = FALSE)

  # collecting the IPW value estimates (only if type == "dr")
  value_estimate_ipw <- unlist(lapply(cross_fits, function(x) getElement(x, "value_estimate_ipw")))
  if (!is.null(value_estimate_ipw)){
    value_estimate_ipw <- sum((n / sum(n)) * value_estimate_ipw)
  }

  # collecting the OR value estimates (only if type = "dr")
  value_estimate_or <- unlist(lapply(cross_fits, function(x) getElement(x, "value_estimate_or")))
  if (!is.null(value_estimate_or)){
    value_estimate_or <- sum((n / sum(n)) * value_estimate_or)
  }

  # collecting the policy actions
  policy_actions <- lapply(cross_fits, function(x) getElement(x, "policy_actions"))
  policy_actions <- rbindlist(policy_actions)
  setkey(policy_actions, "id", "stage")

  # collecting the g- and Q-values:
  g_values <- lapply(cross_fits, function(x) getElement(x, "g_values"))
  null_g_values <- unlist(lapply(g_values, is.null))
  if (!all(null_g_values)){
    g_values <- data.table::rbindlist(g_values)
    setkey(g_values, "id", "stage")
  } else {
    g_values <- NULL
  }

  q_values <- lapply(cross_fits, function(x) getElement(x, "q_values"))
  null_q_values <- unlist(lapply(q_values, is.null))
  if (!all(null_q_values)){
    q_values <- data.table::rbindlist(q_values)
    setkey(q_values, "id", "stage")
  } else {
    q_values <- NULL
  }

  # sorting via the IDs:
  IC <- IC[order(id)]
  id <- id[order(id)]

  out <- list(value_estimate = value_estimate,
              type = getElement(args, "type"),
              IC = IC,
              value_estimate_ipw = value_estimate_ipw,
              value_estimate_or = value_estimate_or,
              id = id,
              policy_actions = policy_actions,
              g_values = g_values,
              q_values = q_values,
              cross_fits = cross_fits,
              folds = folds
  )

  out <- Filter(Negate(is.null), out)

  class(out) <- c("policy_eval")
  return(out)
}

policy_eval_fold <- function(fold,
                             policy_data,
                             args,
                             prog
){

  K <- get_K(policy_data)
  id <- get_id(policy_data)

  train_id <- id[-fold]
  validation_id <- id[fold]

  # training data:
  train_policy_data <- subset_id(policy_data, train_id)
  if (get_K(train_policy_data) != K) stop("The number of stages varies accross the training folds.")

  # validation data:
  valid_policy_data <- subset_id(policy_data, validation_id)
  if (get_K(valid_policy_data) != K) stop("The number of stages varies accross the validation folds.")

  eval_args <- append(args, list(valid_policy_data = valid_policy_data,
                                 train_policy_data = train_policy_data))

  out <- do.call(what = "policy_eval_type", args = eval_args)

  # progress:
  prog()

  return(out)
}

Try the polle package in your browser

Any scripts or data that you put into this service are public.

polle documentation built on May 29, 2024, 1:15 a.m.