policy_learn

knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library("data.table")
library("polle")

This vignette is a guide to policy_learn() and some of the associated S3 methods. The purpose of policy_learn is to specify a policy learning algorithm and estimate an optimal policy. For details on the methodology, see the associated paper [@nordland2023policy].

We consider a fixed two-stage problem as a general setup and simulate data using sim_two_stage() and create a policy_data object using policy_data():

d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd

Specifying and applying a policy learner

policy_learn() specify a policy learning algorithm via the type argument: Q-learning (ql), doubly robust Q-learning (drql), doubly robust blip learning (blip), policy tree learning (ptl), and outcome weighted learning (owl).

Because each policy learning type has varying control arguments, these are passed as a list using the control argument. To help the user set the required control arguments and to provide documentation, each type has a helper function control_type() which sets the default control arguments and overwrite values if supplied by the user.

As an example we specify a doubly robust blip learner:

pl_blip <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  )
)

For details on the implementation, see Algorithm 3 in [@nordland2023policy]. The only required control argument for blip learning is a model input. The blip_models argument expects a q_model. In this case we input a simple linear model as implemented in q_glm.

The output of policy_learn() is again a function:

pl_blip

In order to apply the policy learner we need to input a policy_data object and nuisance models g_models and q_models for computing the doubly robust score.

(po_blip <- pl_blip(
  pd,
  g_models = list(g_glm(), g_glm()),
  q_models = list(q_glm(), q_glm())
 ))

Cross-fitting the doubly robust score

Like policy_eval() is it possible to cross-fit the doubly robust score used as input to the policy model. The number of folds for the cross-fitting procedure is provided via the L argument. As default, the cross-fitted nuisance models are not saved. The cross-fitted nuisance models can be saved via the save_cross_fit_models argument:

pl_blip_cross <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  L = 2,
  save_cross_fit_models = TRUE
)
po_blip_cross <- pl_blip_cross(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

From a user perspective, nothing has changed. However, the policy object now contains each of the cross-fitted nuisance models:

po_blip_cross$g_functions_cf

Realistic policy learning

Realistic policy learning is implemented for types ql, drql, blip and ptl (for a binary action set). The alpha argument sets the probability threshold for defining the realistic action set. For implementation details, see Algorithm 5 in [@nordland2023policy]. Here we set a 5\% restriction:

pl_blip_alpha <- policy_learn(
  type = "blip",
  control = control_blip(
    blip_models = q_glm(formula = ~ BB + L + C)
  ),
  alpha = 0.05,
  L = 2
)
po_blip_alpha <- pl_blip_alpha(
   pd,
   g_models = list(g_glm(), g_glm()),
   q_models = list(q_glm(), q_glm())
 )

The policy object now lists the alpha level as well as the g-model used to define the realistic action set:

po_blip_alpha$alpha
po_blip_alpha$g_functions

Implementation/Simulation and get_policy_functions()

A policy function is great for evaluating a given policy or even implementing or simulating from a single-stage policy. However, the function is not useful for implementing or simulating from a learned multi-stage policy. To access the policy function for each stage we use get_policy_functions(). In this case we get the second stage policy function:

pf_blip <- get_policy_functions(po_blip, stage = 2)

The stage specific policy requires a data.table with named columns as input and returns a character vector with the recommended actions:

pf_blip(
  H = data.table(BB = c("group2", "group1"),
                 L = c(1, 0),
                 C = c(1, 2))
)

Policy objects and get_policy()

Applying the policy learner returns a policy_object containing all of the components needed to specify the learned policy. In this the only component of the policy is a model for the blip function:

po_blip$blip_functions$stage_1$blip_model

To access and apply the policy itself use get_policy(), which behaves as a policy meaning that we can apply to any (suitable) policy_data object to get the policy actions:

get_policy(po_blip)(pd) |> head(4)

SessionInfo

sessionInfo()

References



Try the polle package in your browser

Any scripts or data that you put into this service are public.

polle documentation built on May 29, 2024, 1:15 a.m.