#' Learning the Optimal Rule using the Revere framework
#'
#' Functions used to learn the Optimal Rule given a tmle_task and likelihood,
#' using the Revere framework. Complements 'tmle3_Spec_mopttx_blip_revere' class.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom data.table data.table
#' @importFrom tmle3 tmle3_Spec
#'
#' @export
#'
#' @keywords data
#'
#' @return A optimal rule object inheriting from \code{\link[tmle3]{tmle3_Spec}}
#' with methods for learning the optimal rule. For a full list of the available
#' functionality, see the complete documentation of \code{\link[tmle3]{tmle3_Spec}}.
#'
#' @format An \code{\link[R6]{R6Class}} object inheriting from
#' \code{\link[tmle3]{tmle3_Spec}}.
#'
#'
#' @section Parameters:
#' - \code{tmle_task}: Task object specifying the data and node structure.
#' - \code{tmle_spec}: Spec object of \code{\link[keras]{tmle3}}. Allows for different
#' Specs to use the current class for learning the Optimal Rule.
#' - \code{likelihood}: Likelihood object of \code{\link[keras]{tmle3}}, corresponding
#' to the current estimate of the required parts of the likelihood necessary for the target
#' parameter.
#' - \code{V}: User-specified list of covariates used to define the rule.
#' - \code{options}: Information on all the variables passed to the original Spec.
#'
#' @examples
#' \dontrun{
#' library(sl3)
#' library(tmle3)
#' library(data.table)
#'
#' data("data_bin")
#' data <- data_bin
#'
#' Q_lib <- make_learner_stack("Lrnr_mean", "Lrnr_glm_fast")
#' g_lib <- make_learner_stack("Lrnr_mean", "Lrnr_glm_fast")
#' B_lib <- make_learner_stack("Lrnr_glm_fast", "Lrnr_xgboost")
#'
#' metalearner <- make_learner(Lrnr_nnls)
#' Q_learner <- make_learner(Lrnr_sl, Q_lib, metalearner)
#' g_learner <- make_learner(Lrnr_sl, g_lib, metalearner)
#' B_learner <- make_learner(Lrnr_sl, B_lib, metalearner)
#'
#' learner_list <- list(Y = Q_learner, A = g_learner, B = B_learner)
#'
#' node_list <- list(W = c("W1", "W2", "W3"), A = "A", Y = "Y")
#'
#' tmle_spec <- tmle3_mopttx_blip_revere(
#' V = c("W1", "W2", "W3"),
#' type = "blip1", learners = learner_list, maximize = TRUE,
#' complex = TRUE, realistic = TRUE
#' )
#'}
Optimal_Rule_Revere <- R6Class(
classname = "Optimal_Rule_Revere",
portable = TRUE,
class = TRUE,
inherit = tmle3::tmle3_Spec,
lock_objects = FALSE,
public = list(
initialize = function(tmle_task, tmle_spec, likelihood, V, options,
shift_grid = seq(-1, 1, by = 0.5)) {
private$.tmle_task <- tmle_task
private$.tmle_spec <- tmle_spec
private$.likelihood <- likelihood
private$.blip_type <- options$type
private$.learners <- options$learners
private$.maximize <- options$maximize
private$.realistic <- options$realistic
private$.resource <- options$resource
private$.interpret <- options$interpret
private$.likelihood_override <- options$likelihood_override
private$.reference <- options$reference
private$.shift_grid <- shift_grid
A_vals <- private$.tmle_task$npsem$A$variable_type$levels
#If binary A, force blip1
if(length(A_vals)==2){
private$.blip_type <- "blip1"
}
#Pick reference as the smallest category, if not assigned
if(is.null(private$.reference) & private$.blip_type=="blip1"){
if(is.factor(A_vals)){
A_vals <- as.numeric(levels(A_vals))[A_vals]
}
private$.reference <- min(A_vals)
}
if (missing(V)) {
V <- tmle_task$npsem$W$variables
}
private$.V <- V
},
factor_to_indicators = function(x, x_vals) {
ind_mat <- sapply(x_vals, function(x_val) as.numeric(x_val == x))
colnames(ind_mat) <- x_vals
return(ind_mat)
},
V_data = function(tmle_task, fold = NULL) {
if (is.null(fold)) {
tmle_task$data[, self$V, with = FALSE]
} else {
tmle_task$data[, self$V, with = FALSE][tmle_task$folds[[fold]]$training_set, ]
}
},
DR_full = function(v, indx) {
DR <- data.frame(private$.DR_full[[v]])
return(data.frame(DR[indx, ]))
},
blip_revere_function = function(tmle_task, fold_number) {
#Grab parameters:
likelihood <- self$likelihood
A_vals <- tmle_task$npsem$A$variable_type$levels
ref <- which(A_vals %in% private$.reference)
V <- self$V
# Generate counterfactual tasks for each value of A:
cf_tasks <- lapply(A_vals, function(A_val) {
if (is.character(A_val)) {
A_val <- as.numeric(A_val)
# A_val<-as.factor(A_val)
}
newdata <- data.table(A = A_val)
cf_task <- tmle_task$generate_counterfactual_task(UUIDgenerate(), new_data = newdata)
return(cf_task)
})
# DR A-IPW mapping of blip
A <- tmle_task$get_tmle_node("A")
Y <- tmle_task$get_tmle_node("Y")
A_vals <- tmle_task$npsem$A$variable_type$levels
A_ind <- self$factor_to_indicators(A, A_vals)
Y_mat <- replicate(length(A_vals), Y)
# Use fold_number fits for Q and g:
Q_vals <- sapply(cf_tasks, likelihood$get_likelihood, "Y", fold_number)
g_vals <- sapply(cf_tasks, likelihood$get_likelihood, "A", fold_number)
DR <- (A_ind / g_vals) * (Y_mat - Q_vals) + Q_vals
# Type of pseudo-blip:
blip_type <- self$blip_type
# If there are missing values in Y,
# Blip outcomes will have missing values as well
if (blip_type == "blip1") {
vals <- seq(A_vals)
blip <- lapply(vals,function(val){DR[, val] - DR[, ref]})
blip <- do.call(cbind,blip)
colnames(blip) <- paste(A_vals)
#Correct for binary A
if(length(vals)==2){
blip <- rowSums(blip)
}
} else if (blip_type == "blip2") {
blip <- DR - rowMeans(DR)
colnames(blip) <- paste0("A=",A_vals)
} else if (blip_type == "blip3") {
blip <- DR - (rowMeans(DR) * g_vals)
}
# TO DO: Nicer solutions. Do it one by one, for now
# If there are missing Ys, there will be missing blips- for now drop these rows.
# (otherwise we train on imputed values... )
if (is.null(V)) {
data <- data.table(V = blip, blip = blip)
outcomes <- grep("blip", names(data), value = TRUE)
V <- grep("V", names(data), value = TRUE)
revere_task <- make_sl3_Task(data,
outcome = outcomes, covariates = V,
folds = tmle_task$folds
)
} else {
V <- tmle_task$data[, self$V, with = FALSE]
data <- data.table(V, blip = blip)
outcomes <- grep("blip", names(data), value = TRUE)
revere_task <- make_sl3_Task(data,
outcome = outcomes, covariates = self$V,
folds = tmle_task$folds
)
}
return(revere_task)
},
bound = function(cv_g) {
cv_g[cv_g < 0.01] <- 0.01
cv_g[cv_g > 0.99] <- 0.99
return(cv_g)
},
fit_blip = function() {
#Grab all parameters:
tmle_task <- self$tmle_task
tmle_spec <- self$tmle_spec
likelihood <- self$likelihood
V <- self$V
type <- tmle_spec$options$type
maximize <- tmle_spec$options$maximize
complex <- tmle_spec$options$complex
realistic <- tmle_spec$options$realistic
resource <- tmle_spec$options$resource
interpret <- tmle_spec$options$interpret
learner_list <- tmle_spec$options$learners
# Edit the tmle3 task so it avoids missing values:
if (!is.null(tmle_task$npsem$Y$censoring_node)) {
delta <- tmle_task$npsem$Y$censoring_node$name
# Subset data and nodes:
observed <- tmle_task$get_tmle_node(delta)
data <- tmle_task$get_data()
data <- data[observed]
data <- data[, (ncol(data)) := NULL]
folds <- sl3::subset_folds(tmle_task$folds, which(observed))
# Create node list:
W <- c(tmle_task$.__enclos_env__$private$.npsem$W$variables)
A <- tmle_task$.__enclos_env__$private$.npsem$A$variables
Y <- tmle_task$.__enclos_env__$private$.npsem$Y$variables
node_list <- list(W = W, A = A, Y = Y)
tmle_spec_new <- tmle3_mopttx_blip_revere(
V = V, type = type,
learners = learner_list, maximize = maximize,
complex = complex, realistic = realistic, resource = resource
)
tmle_task_noC <- tmle_spec_new$make_tmle_task(data, node_list = node_list, folds)
} else {
tmle_task_noC <- tmle_task
}
blip_revere_task <- sl3:::sl3_revere_Task$new(self$blip_revere_function,
tmle_task_noC)
blip_fit <- self$blip_library$train(blip_revere_task)
if(interpret){
blip_task <- self$blip_revere_function(tmle_task, fold_number="full")
preds <- blip_fit$predict(blip_task)
blip_fit_interpret <- hal9001::fit_hal(
X = blip_task$X,
Y = preds,
yolo = FALSE,
return_x_basis=TRUE,
return_lasso = TRUE,
reduce_basis = 1/nrow(blip_task$data)
)
}else{
blip_fit_interpret <- NULL
}
private$.blip_fit <- blip_fit
private$.blip_fit_interpret <- blip_fit_interpret
},
rule = function(tmle_task, fold_number = "full") {
#Get paramaters
realistic <- private$.realistic
resource <- private$.resource
likelihood <- self$likelihood
likelihood_override <- private$.likelihood_override
#Get A values
A_vals <- tmle_task$npsem$A$variable_type$levels
### NOTE:
# If there is missing outcome, this will return rules for ALL values.
# This is ok- we don't have missing Ws or As, just Ys (hence, we can get a predicted value).
# This outputs a warning, but that's ok.
blip_task <- self$blip_revere_function(tmle_task, fold_number)
blip_preds <- self$blip_fit$predict_fold(blip_task, fold_number)
#Get dimensions
n <- nrow(tmle_task$data)
# Type of pseudo-blip:
blip_type <- self$blip_type
if (is.list(blip_preds)) {
blip_preds <- unpack_predictions(blip_preds)
}
# flip sign if we're minimizing
if (!private$.maximize) {
blip_preds <- blip_preds * -1
}
# add an extra 0 column for blip1 so that there's always one column per A level
if ((blip_type == "blip1") & length(A_vals)==2) {
blip_preds <- cbind(0, blip_preds)
}
if (realistic) {
# Need to grab the propensity score:
g_task <- tmle_task$get_regression_task("A")
if(!is.null(likelihood_override)){
g_preds <- likelihood_override$get_likelihood(tmle_task, node="A")
g_preds <- unpack_predictions(g_preds)
}else{
g_learner <- likelihood$factor_list[["A"]]$learner
g_preds <- unpack_predictions(g_learner$predict(g_task))
}
min_g <- 0.05
# make unrealistic rules not optimal
g_preds <- normalize_rows(g_preds)
blip_preds[g_preds < min_g] <- -Inf
}
rule_preds <- max.col(blip_preds)
rule_preds <- A_vals[rule_preds]
#User can put resource=1 for binary rule
if( (length(resource)<length(A_vals)) ){
resource <- c( rep(1,length(A_vals)-length(resource)) ,resource)
}
#General resource constraint:
if(sum(resource)<length(A_vals)){
#Rank based on the current rule allocation and resource
oit <- data.frame(percent=table(rule_preds)/n*100)
names(oit) <- c("A_vals","Freq")
oit <- oit[match(A_vals,oit$A_vals),]
row.names(oit) <- NULL
oit$A_vals <- A_vals
oit[is.na(oit$Freq),"Freq"] <- 0
#Start with the constraint:
resource_inter <- cbind.data.frame(A_vals=seq(A_vals),
resource=resource,
n=resource*n,
oit=oit$Freq)
resource_inter <- resource_inter[order((-resource_inter$resource),resource_inter$oit,
decreasing = TRUE),]
#Add ids and rules to blips:
blip_preds_inter <- cbind.data.frame(id=seq(n),
blip_preds)
for(i in 1:nrow(resource_inter)){
inter <- blip_preds_inter
#Start with A that has the most constraints
A_val <- resource_inter[i,"A_vals"]
#Order blips of current A_val
inter <- inter[order(inter[,(A_val+1)],decreasing = TRUE),]
#Get the current rule
rule_preds_inter <- max.col(inter[,-1]) #without id
rule_preds_inter <- cbind.data.frame(id=inter$id,
rule=(rule_preds_inter==A_val))
rule_preds_inter <- rule_preds_inter[rule_preds_inter$rule==TRUE,]
#If no one benefit from this rule, assign to the unconstrained
if(dim(rule_preds_inter)[1]!=0){
rule_preds_inter$seq <- seq(1:nrow(rule_preds_inter))
inter$seq <- rule_preds_inter[match(inter$id, rule_preds_inter$id),"seq"]
max_num_A <- max(rule_preds_inter$seq)
if(max_num_A>resource_inter[i,"n"]){
inter[(inter$seq>resource_inter[i,"n"] & !is.na(inter$seq)),(A_val+1)] <- -Inf
}
blip_preds_inter <- inter[,-ncol(inter)]
}
}
blip_preds_inter <- blip_preds_inter[order(blip_preds_inter$id,decreasing = FALSE),]
blip_preds_fin <- blip_preds_inter[,-1]
rule_preds_resource <- max.col(blip_preds_fin)
rule_preds_resource <- A_vals[rule_preds_resource]
}else{
rule_preds_resource <- rule_preds
}
#Allow resource constrain only on binary treatment for now
#if(length(A_vals) == 2 & resource < 1){
# #TO DO: Note that this doesn't really allow us to rank blip < 0
# max_preds <-apply(blip_preds, 1, max)
# rank_df <- data.table("id" = c(1:length(max_preds)),
# "blip_preds" = max_preds)
# rank_df <- rank_df[order(rank_df[,2],decreasing=TRUE),]
# self$.rank <- rank_df
#Total to get treatment:
# A1 <- sum(rank_df$blip_preds>0)
# A1_constrain <- floor(A1 * resource)
# get_A_id <- rank_df[1:A1_constrain, "id"]
# get_A_id <- get_A_id$id
# rank_df <- rank_df[order(rank_df[,1],decreasing=FALSE),]
# if(is.factor(A_vals)){
# A_vals <- factor(A_vals, ordered = TRUE)
# }
# rule_preds_resource <- rule_preds
# rule_preds_resource[!(rank_df$id %in% get_A_id)] <- min(A_vals[rule_preds])
# rule_preds <- rule_preds_resource
#}
return(rule_preds_resource)
},
# Think carefully as to how this should be done with folds.
rule_stochastic = function(tmle_task, fold_number = "full") {
likelihood <- self$likelihood
shift_grid <- self$shift_grid
A <- tmle_task$get_tmle_node("A")
# Only supports additive shifts for now.
# Generate counterfactual tasks for each delta shift of A:
cf_tasks <- lapply(shift_grid, function(shift) {
newdata <- data.table(A = A + shift)
cf_task <- tmle_task$generate_counterfactual_task(UUIDgenerate(), new_data = newdata)
return(cf_task)
})
Q_vals <- sapply(cf_tasks, likelihood$get_likelihood, "Y", fold_number)
opt_col <- max.col(Q_vals)
opt_A <- Q_vals[cbind(seq_along(opt_col), opt_col)]
private$.opt_delta <- shift_grid[opt_col]
private$.opt_A <- opt_A
private$.Q_vals <- Q_vals
return(opt_A)
}
),
active = list(
tmle_task = function() {
return(private$.tmle_task)
},
tmle_spec = function() {
return(private$.tmle_spec)
},
likelihood = function() {
return(private$.likelihood)
},
V = function() {
return(private$.V)
},
blip_type = function() {
return(private$.blip_type)
},
blip_fit = function() {
return(private$.blip_fit)
},
blip_fit_interpret = function() {
return(private$.blip_fit_interpret)
},
blip_library = function() {
return(private$.learners$B)
},
A_library = function() {
return(private$.learners$A)
},
shift_grid = function() {
return(private$.shift_grid)
},
return_rank = function() {
return(private$.rank)
}
),
private = list(
.tmle_task = NULL,
.tmle_spec = NULL,
.likelihood = NULL,
.V = NULL,
.blip_type = NULL,
.blip_fit = NULL,
.blip_fit_interpret = NULL,
.learners = NULL,
.maximize = NULL,
.realistic = NULL,
.resource = NULL,
.shift_grid = NULL,
.opt_delta = NULL,
.opt_A = NULL,
.Q_vals = NULL,
.rank = NULL,
.interpret = NULL,
.reference = NULL
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.