knitr::opts_chunk$set(echo = TRUE)
library(future)
plan(multicore)
ns <- c( 1500)


output <- lapply(ns, function(n) {
  out <- estCATE(n,  list( Lrnr_glm$new(), Lrnr_rpart$new(),   Lrnr_gam$new(), Lrnr_earth$new(),   Lrnr_xgboost$new(max_depth = 5, verbosity = 0)) , 20, sim.CATE, list_of_sieves_uni, positivity = F, hard = F)
  out$n <- n
  out
})
output <- do.call(rbind, output)
#write.csv(output,"CATEsims_posTRUEhardTRUE.csv")
library(ggplot2)
dt <- as.data.table(output) #fread("CATEsims_posFALSEhardFALSE_first.csv")
dt$lrnr <- gsub(".+glm.+", "lm", dt$lrnr )
dt$lrnr <- gsub(".+earth.+", "earth", dt$lrnr )
dt$lrnr <- gsub(".+rpart.+", "regTree", dt$lrnr )
dt$lrnr <- gsub(".+ranger.+", "randomForest", dt$lrnr )
dt$lrnr <- gsub("(.+xgboost_20_1_)", "xgboost_depth=", dt$lrnr )

dt_long <- unique(rbind(data.table(n = dt$n, MSE = dt$cateplugin_adaptive, lrnr = dt$lrnr, method = "SievePluginAdapt"),
                        data.table(n = dt$n, MSE = dt$cateplugin_adaptive, lrnr = dt$lrnr, method = "SievePluginAdapttmle"),
                        data.table(n = dt$n, MSE = dt$cateplugin_oracle, lrnr = dt$lrnr, method = "SievePluginOracle"),
      data.table(n = dt$n, MSE = dt$cateonestep, lrnr = dt$lrnr, method = "OneStep"),
      data.table(n = dt$n, MSE = dt$cateonestep_oracle, lrnr = dt$lrnr, method = "OneStepOracle"),
      data.table(n = dt$n, MSE = dt$causalforest, lrnr = "randomForest", method = "causalForest")

))
tmp <- do.call(rbind, lapply(unique(dt$lrnr) , function(lrnr) {
  data.table(n = dt$n, MSE = dt$subst, lrnr = lrnr, method = "SubstitutionEnsemble")
}))
dt_long <- rbind(dt_long, tmp)
dt_long <- unique(dt_long)
dt_long <- dt_long[order(dt_long$lrnr)]
dt_long <- dt_long[dt_long$n >= 250]
dt_long$MSE <- sqrt(dt_long$MSE)
 dt_long
ggplot(dt_long, aes(x=n, y = MSE, group = method, color = method)) + geom_line() + scale_y_log10() + facet_wrap(. ~  lrnr,   scale = "free")
dt <- fread("CATEsims_posFALSEhardFALSE_first.csv")

ns <- dt$n

library(ggplot2)
dt <- dt[dt$n !=100]
ggplot(dt, aes(x = n)) + geom_line(aes(y = cateplugin)) + facet_wrap(. ~  lrnr,  ncol = 2)
library(future)
plan(multicore)
ns <- c(100, 250, 500, 750, 1000, 5000, 10000)

output <- lapply(ns, function(n) {
  out <- estCATE(n,  lrnr_stack, 100, sim.CATE, list_of_sieves_uni, positivity = TRUE, hard = TRUE)
  out$n <- n
  out
})
output <- do.call(rbind, output)
write.csv(output,"CATEsims_posTRUEhardTRUE.csv")


Larsvanderlaan/npcausalML documentation built on July 30, 2023, 4:32 p.m.