knitr::opts_chunk$set(echo = TRUE)
#out1 <- out
 out$name <- rownames(out)
library(stringr)
library(ggplot2)
out
names <- c("glm-", "glm.inter", "gam_2", "gam_4", "xgboost_3", "xgboost_4", "hal9001")
new_out <- do.call(rbind, lapply(names, get_data, out = out, type = c("IPW", "plugin") ))


print(plot_sim(new_out, column = "risks"))
get_data <- function(name,  out, type = c("plugin", "IPW")) {
  if(name == "subst") {
    index_plugin_glm <-grep(name, out$name, value = F)
  }
  if("plugin" %in% type) {
      index_plugin_glm <- intersect(grep(name, out$name, value = F), grep("plugin", out$name, value = F) )

  out_sub <- out[index_plugin_glm,]
  out_sub$degree <- as.numeric(str_match(out_sub$name, "k=([0-9]+)")[,2])
  out_sub$est <-  sprintf("%s (%s)", name ,"Plugin")
  dat1 <- out_sub
  }

  if("IPW" %in% type) {
  index_plugin_glm <- intersect(grep(name, out$name, value = F), grep("IPW", out$name, value = F) )
  out_sub <- out[index_plugin_glm,]
  out_sub$degree <- as.numeric(str_match(out_sub$name, "k=([0-9]+)")[,2])
  out_sub$est <- sprintf("%s (%s)", name ,"IPW")
  dat2 <- out_sub
  }

  if(length(type) ==2) {
    dat <-rbind(dat1,dat2)
  }
  else if("IPW" %in% type) {
    dat <- dat2
  } else {
    dat <- dat1
  }

  lrnr_name <- stringr::str_match(dat$est, "(.+) \\(")[,2]
  type_name <- stringr::str_match(dat$est, "\\((.*)\\)")[,2]
  dat$Type <- type_name
  dat$Estimator <- lrnr_name
  attr(dat, "len") <- index_plugin_glm
  return(dat)
}





plot_sim <- function(dat, column = "risk", baseline_risk = 0, scale = F) {
   index_plugin_glm <- attr(dat, "len") 


  if(scale) {
    dat[,column] <- dat[,column] / baseline_risk

  }

  g<- ggplot(dat, aes_string(x="degree", y = column, color = "Estimator")) + geom_point() + geom_line(aes(linetype = Type)) + scale_x_continuous( breaks =  seq(1,length(index_plugin_glm)+1,1)) + xlab("Size of Sieve")  + ylab("Risk")  + scale_linetype(name =  "Risk Function") #+geom_hline(aes(yintercept=baseline_risk), colour = "black")

g <- g + theme(axis.text.x = element_text(size = 13),         axis.text.y = element_text(size = 17),         axis.title.x = element_text(size = 13),         axis.title.y = element_text(size = 18),         legend.text =  element_text(size = 10),         legend.title =  element_text(size = 10)) + scale_y_continuous(n.breaks = 10) +  theme_bw() 


  return(g)
}


Larsvanderlaan/npcausalML documentation built on July 30, 2023, 4:32 p.m.