## Following the guide of the SL3 package
## Requires:
## - tidyverse
## - sl3
## - R.utils
## - OnlineSuperLearner
## - devtools
## Load the OSL package
devtools::load_all(".")
## Dont ask for input while plotting
devAskNewPage(ask = FALSE)
######################
## Helper functions ##
######################
## Function to create a formula notation
generate_fomula <- function(dependent, independent) {
form <- lapply(dependent, function(dep) {
s <-dep
first <- TRUE
for (dep_in in independent) {
if(dep_in == dep) next
if(first) s <- paste(s, dep_in, sep = ' ~ ')
else s <- paste(s, dep_in, sep = ' + ')
first <- FALSE
}
first <- TRUE
formula(s)
})
names(form) <- dependent
form
}
## Function to create a formula notation for a set of variables
generate_formulae <- function(W, A, Y){
## Generate W Formulae
W_form <- generate_fomula(W, W)
A_form <- generate_fomula(W, W)
Y_form <- generate_fomula(Y, c(W,A))
list(W=W_form, A=A_form, Y=Y_form)
}
######################
set.seed(49753)
if (!("devtools" %in% rownames(installed.packages()))) { install.packages('devtools') }
if (!("sl3" %in% rownames(installed.packages()))) { install.packages('sl3') }
library(sl3)
library(tidyverse)
## Load the cpp data
data(cpp)
cpp <- cpp %>%
dplyr::filter(!is.na(haz)) %>%
mutate_all(funs(replace(., is.na(.), 0)))
## Define the covariates
W <-c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs", "sexn")
A <- c()
Y <- c("haz")
## Generate the corresponding formulae and relevantvariables
formulae <- generate_formulae(W,A,Y)
relevantVariables <- list()
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$apgar1, family = 'gaussian'))
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$apgar5, family = 'gaussian'))
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$parity, family = 'gaussian'))
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$gagebrth, family = 'gaussian'))
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$meducyrs, family = 'gaussian'))
## relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$W$sexn, family = 'gaussian'))
relevantVariables <- append(relevantVariables, RelevantVariable$new(formula = formulae$Y$haz, family = 'gaussian'))
## Define a list of algorithms to use
algos <- list()
algos <- append(algos, list(list(algorithm = "ML.XGBoost",
algorithm_params = list(alpha = c(0.3, 0.4, 0.5)),
params = list(nbins = c(30, 40, 50), online = TRUE))))
algos <- append(algos, list(list(algorithm = "condensier::speedglmR6",
params = list(nbins = c(5, 10, 15), online = FALSE))))
## General settings
log <- R.utils::Arguments$getVerbose(-1, timestamp=TRUE)
nb_iter <- 10
training_set_size <- nrow(cpp)
## Fit the online SuperLearner
osl <- OnlineSuperLearner::fit.OnlineSuperLearner(
formulae = relevantVariables,
data = cpp,
algorithms = algos,
verbose = log,
test_set_size = 5 + (3 * 3 + 3),
initial_data_size = training_set_size / 2,
max_iterations = nb_iter,
mini_batch_size = (training_set_size / 2) / nb_iter
)
browser()
## With the fitted OSL we can now sample and perform predictions
preds <- sampledata(osl, newdata = cpp, relevantVariables, plot = TRUE)
preds$osl.estimator %>% print
preds$dosl.estimator %>% print
## With the fitted OSL we can now sample and perform predictions
preds <- predict(osl, newdata = cpp, relevantVariables, plot = TRUE)
preds$osl.estimator %>% print
preds$dosl.estimator %>% print
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.