library(knitr)
knitr::opts_chunk$set(echo = TRUE)
set.seed(7194568)

library(ggplot2)
library(stringr)
library(tidyverse)

library(future)
library(microbenchmark)

library(SuperLearner)
library(hal9001)

Introduction

This document consists of some simple benchmarks for various choices of the hal9001 implementation. The purpose of this document is two-fold:

  1. Compare the computational performance of these methods
  2. Illustrate the use of these different methods
# easily compute MSE
mse <- function(preds, y) {
  mean((preds - y)^2)
}

The Continuous Outcome Case

Standard use-cases with simple data structures

# generate simple test data
n = 100
p = 3
x <- xmat <- matrix(rnorm(n * p), n, p)
y <- sin(x[, 1]) * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2)

test_n = 10000
test_x <- matrix(rnorm(test_n * p), test_n, p)
test_y <- sin(test_x[, 1]) * sin(test_x[, 2]) + rnorm(test_n, mean = 0,
                                                      sd = 0.2)
# glmnet implementation
glmnet_hal_fit <- fit_hal(X = x, Y = y, fit_type = "glmnet")
glmnet_hal_fit$times

# training sample prediction
preds_glmnet <- predict(glmnet_hal_fit, new_data = x)
glmnet_hal_mse <- mse(preds_glmnet, y)

# out-of-bag prediction
oob_preds_glmnet <- predict(glmnet_hal_fit, new_data = test_x)
oob_glmnet_hal_mse <- mse(oob_preds_glmnet, y = test_y)
# lassi implementation
lassi_hal_fit <- fit_hal(X = x, Y = y, fit_type = "lassi")
lassi_hal_fit$times

# training sample prediction
preds_lassi <- predict(lassi_hal_fit, new_data = x)
lassi_hal_mse <- mse(preds_lassi, y)

# out-of-bag prediction
oob_preds_lassi <- predict(lassi_hal_fit, new_data = test_x)
oob_lassi_hal_mse <- mse(oob_preds_lassi, y = test_y)
mb_hal_lassi <- microbenchmark(unit = "s", times = 20,
  hal_fit_lassi <- fit_hal(Y = y, X = x, fit_type = "lassi", yolo = FALSE)
)
summary(mb_hal_lassi)

mb_hal_glmnet <- microbenchmark(unit = "s", times = 20,
  hal_fit_glmnet <- fit_hal(Y = y, X = x, fit_type = "glmnet", yolo = FALSE)
)
summary(mb_hal_glmnet)

# visualize
p_fit_times <- rbind(mb_hal_lassi, mb_hal_glmnet) %>%
  data.frame %>%
  transmute(
    type = ifelse(str_sub(as.character(expr), 9, 14) == "glmnet", "HAL-glmnet",
                  "HAL-lassi"),
    time = time / 1e9
  ) %>%
  group_by(type) %>%
  ggplot(., aes(x = type, y = time, colour = type)) + geom_boxplot() +
    geom_point() + stat_boxplot(geom = 'errorbar') +
    scale_color_manual(values = wes_palette("GrandBudapest")) +
    xlab("") + ylab("time (sec.)") + ggtitle("") +
    theme_bw() + theme(legend.position = "none")
p_fit_times

Advanced use-cases, with the drtmle R package

makeData <- function(n = n) {
    L0 <- data.frame(x.0 = runif(n, -1, 1))
    A0 <- rbinom(n, 1, plogis(L0$x.0 ^ 2))
    L1 <- data.frame(x.1 = L0$x.0 ^ 2 * A0 + runif(n))
    A1 <- rbinom(n, 1, plogis(L0$x.0 * L1$x.1))
    L2 <- rnorm(n, L0$x.0 ^ 2 * A0 * A1 + L1$x.1)
    return(list(L0 = L0, L1 = L1, L2 = L2, A0 = A0, A1 = A1))
}
dat <- makeData(n = 200)

# add drtmle defaults
abar <- c(1, 1)
stratify <- TRUE
mb_hal_lassi_estimateG <- microbenchmark(unit = "s", times = 20,
  hal_lassi_estimateG <- fit_hal(Y = as.numeric(dat$A0 == abar[1]),
                                 X = dat$L0,
                                 fit_type = "lassi",
                                 yolo = FALSE)
)
summary(mb_hal_lassi_estimateG)

mb_hal_glmnet_estimateG <- microbenchmark(unit = "s", times = 20,
  hal_glmnet_estimateG <- fit_hal(Y = as.numeric(dat$A0 == abar[1]),
                                  X = dat$L0,
                                  fit_type = "glmnet",
                                  yolo = FALSE)
)
summary(mb_hal_glmnet_estimateG)

# visualize
p_estimateG_times <- rbind(mb_hal_lassi_estimateG,
                           mb_hal_glmnet_estimateG) %>%
  data.frame %>%
  transmute(
    type = ifelse(str_detect(as.character(expr), "glmnet"), "HAL-glmnet",
                  "HAL-lassi"),
    time = time / 1e9
  ) %>%
  group_by(type) %>%
  ggplot(., aes(x = type, y = time, colour = type)) + geom_boxplot() +
    geom_point() + stat_boxplot(geom = 'errorbar') +
    scale_color_manual(values = wes_palette("GrandBudapest")) +
    xlab("") + ylab("time (sec.)") + ggtitle("") +
    theme_bw() + theme(legend.position = "none")
p_estimateG_times
Y_estQ <- if (stratify) {
  dat$L2[dat$A0 == abar[1] & dat$A1 == abar[2]]
} else {
  dat$L2
}
X_estQ = as.numeric(unlist(ifelse(stratify,
                                  cbind(dat$L0, dat$L1)[dat$A0 == abar[1] &
                                                        dat$A1 == abar[2],],
                                  cbind(dat$L0, dat$L1, dat$A0, dat$A1))))

mb_hal_lassi_estimateQ <- microbenchmark(unit = "s", times = 20,
  hal_lassi_estimateQ <- fit_hal(Y = Y_estQ,
                                 X = X_estQ,
                                 fit_type = "lassi",
                                 yolo = FALSE)
)
summary(mb_hal_lassi_estimateQ)

mb_hal_glmnet_estimateQ <- microbenchmark(unit = "s", times = 20,
  hal_glmnet_estimateQ <- fit_hal(Y = Y_estQ,
                                  X = X_estQ,
                                  fit_type = "glmnet",
                                  yolo = FALSE)
)
summary(mb_hal_glmnet_estimateQ)

# visualize
p_estimateQ_times <- rbind(mb_hal_lassi_estimateQ,
                           mb_hal_glmnet_estimateQ) %>%
  data.frame %>%
  transmute(
    type = ifelse(str_detect(as.character(expr), "glmnet"), "HAL-glmnet",
                  "HAL-lassi"),
    time = time / 1e9
  ) %>%
  group_by(type) %>%
  ggplot(., aes(x = type, y = time, colour = type)) + geom_boxplot() +
    geom_point() + stat_boxplot(geom = 'errorbar') +
    scale_color_manual(values = wes_palette("GrandBudapest")) +
    xlab("") + ylab("time (sec.)") + ggtitle("") +
    theme_bw() + theme(legend.position = "none")
p_estimateQ_times

Using hal9001 with SuperLearner

The hal9001 R package includes a learner wrapper for use with the SuperLearner R package. In order to use this learner, the user need only add "SL.hal9001" to their SL.library input. Here, we are interested in the difference in the performance of SuperLearner when the hal9001 wrapper uses the lassi backend versus the glmnet backend for the LASSO computation. To facilitate this, we create two new SL wrappers from the one included:

SL.hal9001.lassi <- function (..., fit_type = "lassi") {
  SL.hal9001(..., fit_type = fit_type)
}

SL.hal9001.glmnet <- function (..., fit_type = "glmnet") {
  SL.hal9001(..., fit_type = fit_type)
}
mb_hal_lassi_sl_estimateG <- microbenchmark(unit = "s", times = 20,
  hal_lassi_sl_estimateG <- SuperLearner(Y = as.numeric(dat$A0 == abar[1]),
                                         X = dat$L0,
                                         SL.library = c("SL.speedglm",
                                                        "SL.hal9001.lassi")
                                        )
)
summary(mb_hal_lassi_sl_estimateG)

mb_hal_glmnet_sl_estimateG <- microbenchmark(unit = "s", times = 20,
  hal_glmnet_sl_estimateG <- SuperLearner(Y = as.numeric(dat$A0 == abar[1]),
                                          X = dat$L0,
                                          SL.library = c("SL.speedglm",
                                                         "SL.hal9001.glmnet")
                                         )
)
summary(mb_hal_glmnet_sl_estimateG)

# visualize
p_sl_estimateG_times <- rbind(mb_hal_lassi_sl_estimateG,
                              mb_hal_glmnet_sl_estimateG) %>%
  data.frame %>%
  transmute(
    type = ifelse(str_detect(as.character(expr), "glmnet"), "HAL-glmnet",
                  "HAL-lassi"),
    time = time / 1e9
  ) %>%
  group_by(type) %>%
  ggplot(., aes(x = type, y = time, colour = type)) + geom_boxplot() +
    geom_point() + stat_boxplot(geom = 'errorbar') +
    scale_color_manual(values = wes_palette("GrandBudapest")) +
    xlab("") + ylab("time (sec.)") +
    ggtitle("SuperLearner with HAL glmnet v. lassi") +
    theme_bw() + theme(legend.position = "none")
p_sl_estimateG_times

The Binary Outcome Case

# generate simple test data
n = 100
p = 3
x <- xmat <- matrix(rnorm(n * p), n, p)
y_p <- plogis(sin(x[, 1]) * sin(x[, 2]) + rnorm(n, mean = 0, sd = 0.2))
y <- rbinom(n, size = 1, prob = y_p)

glmnet_hal_bin <- fit_hal(Y = y, X = x, fit_type = "glmnet",
                          family = "binomial")
glmnet_hal_bin$times

mb_hal_bin_glmnet <- microbenchmark(unit = "s", times = 20,
  hal_bin_glmnet <- fit_hal(Y = y, X = x,
                            fit_type = "glmnet",
                            family = "binomial",
                            yolo = FALSE)
)
summary(mb_hal_bin_glmnet)


jeremyrcoyle/mangolassi documentation built on Nov. 18, 2023, 6:22 p.m.