R/step_2.R

Defines functions get_vt2 vt2_ctree vt2_classtree vt2_rtree vt2_lasso

Documented in get_vt2 vt2_classtree vt2_ctree vt2_lasso vt2_rtree

#' Estimate the CATE using the Lasso for Step 2
#'
#' @inheritParams tunevt
#' @inheritParams get_mnpp
#' @param theta lasso penalty parameter (\code{lambda})
#'
#' @importFrom glmnet glmnet
#' @importFrom stats predict coef
#'
#' @family VT Step 2 functions
#'
#' @return a list of length 3 containing the following elements:
#'   \item{mod}{an object of class \code{glmnet}. See
#'     \code{\link[glmnet]{glmnet}}.}
#'   \item{coefficients}{coefficients associated with the penalty parameter
#'     \code{theta}.}
#'   \item{fitted.values}{predicted values associated with the penalty parameter
#'     \code{theta}.}
#'
vt2_lasso <- function(z, data, Trt, Y, theta) {

  keep_x_fit <- !(names(data) %in% c(Y, Trt))
  X <- data.matrix(subset(data, select = keep_x_fit))

  mod <- glmnet(
    x = X,
    y = data.matrix(z)
  )
  beta <- coef(mod, s = theta)
  preds <- drop(predict(mod, s = theta, newx = X))

  return(list(
    mod = mod, coefficients = beta, fitted.values = preds, theta = theta
  ))
}


#' Estimate the CATE using a regression tree for Step 2
#'
#' @inheritParams tunevt
#' @inheritParams get_mnpp
#' @param theta tree complexity parameter (\code{cp})
#'
#' @importFrom rpart rpart
#' @importFrom stats predict coef
#'
#' @family VT Step 2 functions
#'
#' @return an object of class \code{rpart}. See
#'   \code{\link[rpart]{rpart.object}}.
#'
vt2_rtree <- function(z, data, Trt, Y, theta) {

  keep_x_fit <- !(names(data) %in% c(Y, Trt))
  mod <- rpart::rpart(
    z ~ ., data = subset(data, select = keep_x_fit),
    method = "anova",
    cp = theta
  )

  return(mod)

}

#' Estimate the CATE using a classification tree for Step 2
#'
#' @inheritParams tunevt
#' @inheritParams get_mnpp
#' @param theta tree complexity parameter (\code{cp})
#'
#' @importFrom rpart rpart
#' @importFrom stats predict coef
#'
#' @family VT Step 2 functions
#'
#' @return an object of class \code{rpart}. See
#'   \code{\link[rpart]{rpart.object}}.
#'
vt2_classtree <- function(z, data, Trt, Y, theta, threshold) {

  z_class <- ifelse(z > threshold, 1, 0)

  keep_x_fit <- !(names(data) %in% c(Y, Trt))
  mod <- rpart::rpart(
    z_class ~ ., data = subset(data, select = keep_x_fit),
    method = "class",
    cp = theta
  )
  return(mod)
}


#' Estimate the CATE using a conditional inference tree for Step 2
#'
#' @inheritParams tunevt
#' @inheritParams get_mnpp
#' @param theta the value of the test statistic that must be exceeded in order
#'   to implement a split (\code{mincriterion})
#'
#' @importFrom party ctree ctree_control
#' @importFrom stats predict coef
#'
#' @family VT Step 2 functions
#'
#' @return An object of class \code{BinaryTree-class}. See
#'   \code{\link[party]{BinaryTree-class}}.
#'
vt2_ctree <- function(z, data, Trt, Y, theta) {

  data$z <- z
  keep_x_fit <- !(names(data) %in% c(Y, Trt))
  mod <- party::ctree(
    z ~ ., data = subset(data, select = keep_x_fit),
    controls = party::ctree_control(
      mincriterion = theta,
      testtype = "Teststatistic")
  )

  return(mod)

}


#' Get the appropriate Step 2 estimation function associated with a method
#'
#' @inheritParams tunevt
#'
#' @return a function that fits a model for the CATE through Step 2 of Virtual
#'   Twins
#'
get_vt2 <- function(step2) {
  if (step2 == "lasso") {
    f <- vt2_lasso
  } else if (step2 == "rtree") {
    f <- vt2_rtree
  } else if (step2 == "ctree") {
    f <- vt2_ctree
  } else if (step2 == "classtree") {
    f <- vt2_classtree
  } else {
    stop("Invalid argument to 'step2'. Accepts 'lasso', 'rtree', 'classtree', and 'ctree'.")
  }
  return(f)
}

Try the tehtuner package in your browser

Any scripts or data that you put into this service are public.

tehtuner documentation built on April 3, 2023, 5:16 p.m.