knitr::opts_chunk$set(echo = TRUE)
library(future) plan(multicore) ns <- c( 1500) output <- lapply(ns, function(n) { out <- estCATE(n, list( Lrnr_glm$new(), Lrnr_rpart$new(), Lrnr_gam$new(), Lrnr_earth$new(), Lrnr_xgboost$new(max_depth = 5, verbosity = 0)) , 20, sim.CATE, list_of_sieves_uni, positivity = F, hard = F) out$n <- n out }) output <- do.call(rbind, output) #write.csv(output,"CATEsims_posTRUEhardTRUE.csv")
library(ggplot2) dt <- as.data.table(output) #fread("CATEsims_posFALSEhardFALSE_first.csv") dt$lrnr <- gsub(".+glm.+", "lm", dt$lrnr ) dt$lrnr <- gsub(".+earth.+", "earth", dt$lrnr ) dt$lrnr <- gsub(".+rpart.+", "regTree", dt$lrnr ) dt$lrnr <- gsub(".+ranger.+", "randomForest", dt$lrnr ) dt$lrnr <- gsub("(.+xgboost_20_1_)", "xgboost_depth=", dt$lrnr ) dt_long <- unique(rbind(data.table(n = dt$n, MSE = dt$cateplugin_adaptive, lrnr = dt$lrnr, method = "SievePluginAdapt"), data.table(n = dt$n, MSE = dt$cateplugin_adaptive, lrnr = dt$lrnr, method = "SievePluginAdapttmle"), data.table(n = dt$n, MSE = dt$cateplugin_oracle, lrnr = dt$lrnr, method = "SievePluginOracle"), data.table(n = dt$n, MSE = dt$cateonestep, lrnr = dt$lrnr, method = "OneStep"), data.table(n = dt$n, MSE = dt$cateonestep_oracle, lrnr = dt$lrnr, method = "OneStepOracle"), data.table(n = dt$n, MSE = dt$causalforest, lrnr = "randomForest", method = "causalForest") )) tmp <- do.call(rbind, lapply(unique(dt$lrnr) , function(lrnr) { data.table(n = dt$n, MSE = dt$subst, lrnr = lrnr, method = "SubstitutionEnsemble") })) dt_long <- rbind(dt_long, tmp) dt_long <- unique(dt_long) dt_long <- dt_long[order(dt_long$lrnr)] dt_long <- dt_long[dt_long$n >= 250] dt_long$MSE <- sqrt(dt_long$MSE) dt_long ggplot(dt_long, aes(x=n, y = MSE, group = method, color = method)) + geom_line() + scale_y_log10() + facet_wrap(. ~ lrnr, scale = "free")
dt <- fread("CATEsims_posFALSEhardFALSE_first.csv") ns <- dt$n library(ggplot2) dt <- dt[dt$n !=100] ggplot(dt, aes(x = n)) + geom_line(aes(y = cateplugin)) + facet_wrap(. ~ lrnr, ncol = 2)
library(future) plan(multicore) ns <- c(100, 250, 500, 750, 1000, 5000, 10000) output <- lapply(ns, function(n) { out <- estCATE(n, lrnr_stack, 100, sim.CATE, list_of_sieves_uni, positivity = TRUE, hard = TRUE) out$n <- n out }) output <- do.call(rbind, output) write.csv(output,"CATEsims_posTRUEhardTRUE.csv")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.