R/utils.R

Defines functions extract_res extract_val extract_vals sl3.fit make_stack bound

#' Bounds for Q and g
#'
#' Function to bound Q and g estimates with user specified bounds.
#'
#' @param x observed values of either Q or g.
#' @param bds list containing the upper and lower bound for x.
#'
#' @export
#'

bound <- function(x, bds){
  x[x > max(bds)] <- max(bds)
  x[x < min(bds)] <- min(bds)
  x
}

#' Make a sl3 stack
#'
#' Function to make passing algorithms to sl3 easier.
#'
#' @param learner_lists list of algorithms, possibly with additional parameters.
#'
#' @importFrom sl3 make_learner
#'
#

make_stack <- function(learner_lists) {

  learners <- lapply(learner_lists, function(learner_list) {
    if (!is.list(learner_list)) {
      learner_list <- list(learner_list)
    }

    learner <- do.call(sl3::make_learner, learner_list)
    return(learner)
  })

  # generate output Stack object and return model stack
  out <- make_learner(Stack, learners)
  return(out)
}

#' Fit sl3-based Super Learner
#'
#' Function to make Super Learning with \code{sl3} easier.
#'
#' @param task task object of \code{sl3::make_sl3_Task} form.
#' @param library list of \code{sl3} algorithms to be used for estimation.
#' @param metalrn metalearner used by the SuperLearner algorithm.
#' @param loss loss function used in order to get the risk estimate.
#'
#' @importFrom sl3 make_learner Lrnr_sl
#'
#' @export
#

sl3.fit<-function(task,library,metalrn="Lrnr_nnls",loss=loss_squared_error){

  #Create a stack:
  sl_stack <- make_stack(learner_lists = library)

  #Super Learner:
  metalearner <- sl3::make_learner(metalrn)
  sl <- sl3::Lrnr_sl$new(learners = sl_stack, metalearner = metalearner)
  sl_fit <- sl$train(task)

  #Save risk:
  risk<-sl_fit$cv_risk(loss)

  #Save predictions:
  sl_preds <- sl_fit$predict()

  #Save CV fits:
  cv_fit<-sl_fit$fit_object$cv_fit$fit_object$fold_fits

  return(list(pred=sl_preds, risk=risk, coefs=sl_fit$coefficients, sl=sl,
              sl.fit=sl_fit, cv.fit=cv_fit))
}

#########################################################################
# Functions to extract samples that served as validation samples in the
# estimation step for Y, A, QAW, Q1W, Q0W, Blip, Rule, K
#########################################################################

#' Cross-validate and extract validation samples
#'
#' Function to extract and cross-validate validation samples from prediction on all
#' samples for use in CV-TMLE.
#'
#' @param folds user-specified list of folds. It should correspond to an element of \code{origami}.
#' @param split_preds Cross-validated result of \code{cv_split}.
#'
#' @export

extract_vals <- function(folds, split_preds) {

  val_preds <- origami::cross_validate(extract_val, folds, split_preds)$preds
  val_preds <- val_preds[order(val_preds$index), ]

}

#' Extract validation samples
#'
#' Function to extract the validation sets from split based predictions for use in CV-TMLE.
#'
#' @param fold one fold from a list of folds.
#' @param split_preds Cross-validated result of \code{cv_split}.
#'
#' @export

extract_val <- function(fold, split_preds) {

  #Extract fold and corresponding validation samples.
  v <- origami::fold_index()
  valid_idx <- origami::validation()

  #Extract validation samples for fold v for all the values of split_preds
  val_preds <- sapply(split_preds, function(split_pred) {
    split_pred[[v]][valid_idx]
  })

  #Add index to it (aka, which sample is in question?)
  val_preds <- as.data.frame(val_preds)
  val_preds$index <- valid_idx
  val_preds$folds <- rep(v,nrow(val_preds))
  result <- list(preds = val_preds)

  return(result)
}

#' Wrapper for final TMLE results (with gentmle)
#'
#' Function to make extraction of final TMLE results for the mean under the optimal
#' individualized treatment regime more streamlined. In particular, it should provide
#' inference and results for 4 different scenarious, where the exposure is set to both
#' binary possibilities, observed exposure, and the learned optimal individualized rule.
#'
#' @param res results from multiple calls to \code{ruletmle}.
#'
#' @export
#'

extract_res<-function(res){

  #Extract Psi for each rule:
  psi <- data.frame(unlist(lapply(seq_len(4), function(x) {res[[x]]$tmlePsi})))
  row.names(psi)<-c("A=0","A=1","A=A","A=optA")
  names(psi)<-"Psi"

  #Extract SD for each rule:
  sd <- data.frame(unlist(lapply(seq_len(4), function(x) {res[[x]]$tmleSD})))
  row.names(sd)<-c("A=0","A=1","A=A","A=optA")
  names(sd)<-"SD"

  #Extract CI for each rule:
  lower <- data.frame(unlist(lapply(seq_len(4), function(x) {res[[x]]$tmleCI[1]})))
  upper <- data.frame(unlist(lapply(seq_len(4), function(x) {res[[x]]$tmleCI[2]})))
  CI<-cbind.data.frame(lower=lower,upper=upper)
  names(CI)<-c("lower","upper")
  row.names(CI)<-c("A=0","A=1","A=A","A=optA")

  #Extract IC for each rule:
  IC<-lapply(seq_len(4), function(x) {res[[x]]$IC$Dstar_psi})
  IC<-data.frame(do.call("cbind", IC))
  names(IC)<-c("A=0","A=1","A=A","A=optA")

  #Extract rule:
  rule<-lapply(seq_len(4), function(x) {res[[x]]$rule})
  rule<-data.frame(do.call("cbind", rule))
  names(rule)<-c("A=0","A=1","A=A","A=optA")

  #Extract number of steps until convergence:
  steps <- data.frame(unlist(lapply(seq_len(4), function(x) {res[[x]]$steps})))
  row.names(steps)<-c("A=0","A=1","A=A","A=optA")
  names(steps)<-"steps"

  #Extract initial data:
  initialData<-lapply(seq_len(4), function(x) {res[[x]]$initialData})
  names(initialData)<-c("rule0","rule1","ruleA","ruleOpt")

  #Extract final data:
  tmleData<-lapply(seq_len(4), function(x) {res[[x]]$tmleData})
  names(tmleData)<-c("rule0","rule1","ruleA","ruleOpt")

  #Extract all results:
  all<-lapply(seq_len(4), function(x) {res[[x]]$all})
  names(all)<-c("rule0","rule1","ruleA","ruleOpt")

  return(list(tmlePsi=psi,tmleSD=sd,tmleCI=CI,IC=IC,rule=rule,steps=steps,initialData=initialData,
              tmleData=tmleData,all=all))

}
WaverlyWei/optimal-treatment- documentation built on May 6, 2019, 11:24 a.m.