knitr::opts_chunk$set(echo = TRUE, dev = "ragg_png")

Notes: - There is a "suffix" for output files so that one can test w/o overwritting

library(abind)
library(gsubfn)
library(devtools)
library(CausalGrid)
library(causalTree)
library(doSNOW)
library(foreach)
library(xtable)
library(stringr)
library(glmnet)
library(ranger)
library(gridExtra)
library(ggplot2)
library(rprojroot)
root_dir <- rprojroot::find_package_root_file()
source(paste0(root_dir,"/project/ct_utils.R"))
source(paste0(root_dir,"/tests/dgps.R"))
export_dir = paste0(root_dir,"/project/") #can be turned off below
sim_rdata_fname = "project/sim.RData"
log_file = "project/log.txt"
tbl_export_path = paste0(export_dir, "tables/")
fig_export_path = paste0(export_dir, "figs/")
suffix = "-temp"
#Sim parameters
n_parallel = getOption("cl.cores", default=20) #1 5 20; PARALLEL
S = 1000 #3 1000, n_parallel TODO: Is this just 1 in CT!?
# num.trees=100 # 100 below

Ns = c(500, 1000) #c(500, 1000)
N_test = 8000
good_features = list(c(T, F), 
                     c(T, T, rep(F, 8)), 
                     c(rep(T, 4), rep(F, 16))
                     #,c(T,T, F)
                     )
D = length(good_features)
Ks = sapply(good_features, length)
#Estimation config
b = 4
nfolds = 5
minsize=25

NN = length(Ns)
NIters = NN*D*S


#Execution config
my_seed = 1337
set.seed(my_seed)
rf.num.threads = 1 #NULL will multi-treatd, doesn't seem to help much with small data

est_names_ct = c("Causal Tree (CT)", "Causal Tree - Adaptive (CT-A)", "CT-A - Matching CT size (CT-A:M)"
                 #,"Causal Tree (CT CV2)", "Causal Tree - Adaptive (CT-A CV2)"
                 )
est_names_cg = c("Causal Grid (CG)", "CG + Linear Controls (CG-X)", "CG + RF Controls (CG-RF)", 
                 "CG - Matching CT size (CG:M)", "CG-X - Matching CT size (CG-X:M)", "CG-RF - Matching CT size (CG-RF:M)"
              )
small_table_idx = c(1,4,5,6)
app_table_idx = c(7,8,9)
est_names = c(est_names_ct, est_names_cg)

Helper functions

sim_data <- function(n=500, design=1) {
  if(design<=3) return(AI_sim(n, design))
  return(XOR2_sim(n))
}

get_iter <- function(d, N_i, s) {
  (d-1)*NN*S + (N_i-1)*S + (s-1) + 1
}

get_dgp_idx <- function(d, N_i) {
  (d-1)*NN + (N_i-1) + 1
}

yX_data <- function(y, X) {
  yX = cbind(y, X)
  colnames(yX) = c("Y", paste("X", 1:ncol(X), sep=""))
  yX = as.data.frame(yX)
}

sim_ct_fit <- function(y, X, w, tr_sample, honest=FALSE, cv_folds=nfolds, complexity="CV") {
  yX = yX_data(y, X)
  set.seed(my_seed)
  fit = ct_cv_tree("Y~.", data=yX, treatment=w, index_tr=tr_sample, minsize=minsize, split.Bucket=TRUE, bucketNum=b, xval=cv_folds, split.Honest=honest, cv.Honest=honest, complexity=complexity)
  attr(fit$terms, ".Environment") <- NULL #save space other captures environment
  return(fit)
}

sim_ct_predict_te <- function(obj, X_te) {
  colnames(X_te) = paste("X", 1:ncol(X_te), sep="")
  return(predict(obj, newdata=as.data.frame(X_te), type="vector"))
}

sim_ct_vectors <- function(ct_fit, mask, X_te, tau_te) {
  nl = num_cells(ct_fit)
  nsplits_by_dim = ct_nsplits_by_dim(ct_fit, ncol(X_te))
  ngood = sum(nsplits_by_dim[mask])
  ntot = sum(nsplits_by_dim)
  mse = mean((tau_te - sim_ct_predict_te(ct_fit, X_te))^2)
  return(c(nl, mse, ngood, ntot))
}

sim_cg_fit <- function(y, X, w, tr_sample, verbosity=0, do_rf=FALSE, num.threads=rf.num.threads, num.trees=100, doCV=FALSE, incl_comp_in_pick=FALSE, ...) {
  set.seed(my_seed)
  if(do_rf) {
    return(fit_estimate_partition(y, X, d=w, tr_split=tr_sample, cv_folds=nfolds, verbosity=verbosity, 
                                  min_size=2*minsize, max_splits=10, bucket_min_d_var=TRUE, bucket_min_n=2*b, 
                                  ctrl_method=grid_rf(num.threads=num.threads, num.trees=num.trees), 
                                  nsplits_k_warn_limit=NA, bump_complexity=list(doCV=doCV, incl_comp_in_pick=incl_comp_in_pick), ...))
  }
  else {
    return(fit_estimate_partition(y, X, d=w, tr_split=tr_sample, cv_folds=nfolds, verbosity=verbosity, 
                                  min_size=2*minsize, max_splits=10, bucket_min_d_var=TRUE, bucket_min_n=2*b, 
                                  nsplits_k_warn_limit=NA, bump_complexity=list(doCV=doCV, incl_comp_in_pick=incl_comp_in_pick), ...))
  }
}

save_space_cg_rf <- function(grid_rf_fit) {
  grid_rf_fit$est_plan$rf_y_fit <-grid_rf_fit$est_plan$rf_y_xfit <- grid_rf_fit$est_plan$rf_d_fit <- grid_rf_fit$est_plan$rf_d_xfit <- NULL
  grid_rf_fit
}

sim_cg_vectors <- function(grid_fit, mask, X_te, tau_te) {
  nsplits = grid_fit$partition$nsplits_by_dim
  preds = predict(grid_fit, new_X=X_te)
  cg_mse = mean((preds - tau_te)^2)
  return(c(num_cells(grid_fit), cg_mse, sum(nsplits[mask]), sum(nsplits)))
}

ct_ndim_splits <- function(obj, K) {
  return(sum(ct_nsplits_by_dim(obj, K)>0))
}

cg_ndim_splits <- function(obj) {
  return(sum(obj$partition$nsplits_by_dim>0))
}

Generate Data

data1 = list()
data2 = list()
for(d in 1:D) {
  for(N_i in 1:NN) {
    N = Ns[N_i]
    for(s in 1:S){
      iter = get_iter(d, N_i, s)
      data1[[iter]] = sim_data(n=2*N, design=d)
      data2[[iter]] = sim_data(n=N_test, design=d)
    }
  }
}
`%doChange%` = if(n_parallel>1) `%dopar%` else `%do%` #so I can just change variable without changing other code for single-threaded
if(n_parallel>1) {
  if(file.exists(log_file)) file.remove(log_file)
  #cl <- makeCluster(n_parallel, outfile=log_file)
  cl <- makeCluster(n_parallel)
  registerDoSNOW(cl)
}

Eval CT

outer_results_ct = list()
n_ct_group = length(est_names_ct)
nl_ct_group = array(dim=c(D*NN, n_ct_group, S))
mse_ct_group = array(dim=c(D*NN, n_ct_group, S))
ngood_ct_group = array(dim=c(D*NN, n_ct_group, S))
ntot_ct_group = array(dim=c(D*NN, n_ct_group, S))
bar_length = if(n_parallel>1) NN*D else S*NN*D
pb = utils::txtProgressBar(0, bar_length, style = 1)
run=1

for(d in 1:D) {
  for(N_i in 1:NN) {
    dgp_i = get_dgp_idx(d, N_i)
    if(n_parallel>1) {
      utils::setTxtProgressBar(pb, run)
      run = run+1
    }
    #results_s = list() # To debug foreach
    #for(s in 1:S){  # To debug foreach
    results_s = foreach(s=1:S, d1=lapply(1:S, function(s1) data1[[get_iter(d, N_i, s1)]]), d2=lapply(1:S, function(s2) data2[[get_iter(d, N_i, s2)]]), .packages=c("gsubfn","causalTree"), .errorhandling = "pass", .noexport=c("data1", "data2"), .inorder=FALSE, .export=c("num_cells","num_cells.rpart")) %doChange% {  #for some reason we have to export some functions manually !?
      if(n_parallel==1) {
        utils::setTxtProgressBar(pb, run)
        run = run+1
      }
      iter = get_iter(d, N_i, s)
      #d1 = data1[[iter]] # To debug foreach
      #d2 = data2[[iter]] # To debug foreach
      list[y, X, w, tau] = d1
      list[y_te, X_te, w_te, tau_te] = d2
      tr_sample = c(rep(TRUE, round(nrow(X)/2)), rep(FALSE, N-round(nrow(X)/2)))

      ct_h5 = sim_ct_fit(y, X, w, tr_sample, honest=T, cv_folds=5)
      nl_ct_h5 = num_cells(ct_h5)
      ct_a5 = sim_ct_fit(y, X, w, tr_sample, honest=F, cv_folds=5)
      ct_a5_m = sim_ct_fit(y, X, w, tr_sample, honest=F, cv_folds=5, complexity=nl_ct_h5)
      #ct_h2 = sim_ct_fit(y, X, w, tr_sample, honest=T, cv_folds=2)
      #ct_a2 = sim_ct_fit(y, X, w, tr_sample, honest=F, cv_folds=2)
      fits = list(ct_h5, ct_a5, ct_a5_m) #,ct_h2, ct_a2 
      res = list(s, fits)
      #results_s[[s]] = res #P to debug foreach
      res
    }

    #recombine data
    for(res_s in 1:S){
      res = results_s[[res_s]]
      list[s, fits] = res
      iter = get_iter(d, N_i, s)
      #if(inherits(res, "error")) {
      #}
      outer_results_ct[[iter]] = res
      list[y_te, X_te, w_te, tau_te] = data2[[iter]]

      for(f_i in 1:length(fits)) {
        fit = fits[[f_i]]
        list[nl_f, mse_f, ngood_f, ntot_f] = sim_ct_vectors(fit, good_features[[d]], X_te, tau_te)
        nl_ct_group[dgp_i, f_i, s] = nl_f
        mse_ct_group[dgp_i, f_i, s] = mse_f
        ngood_ct_group[dgp_i, f_i, s] = ngood_f
        ntot_ct_group[dgp_i, f_i, s] = ntot_f
      }
    }
  }
}
close(pb)
# # use just the CV5 ones
# nl_ct_group_orig = nl_ct_group
# mse_ct_group_orig = mse_ct_group
# ngood_ct_group_orig = ngood_ct_group
# ntot_ct_group_orig = ntot_ct_group
# nl_ct_group = nl_ct_group_orig[,3:4,]
# mse_ct_group = mse_ct_group_orig[,3:4,]
# ngood_ct_group = ngood_ct_group_orig[,3:4,]
# ntot_ct_group = ntot_ct_group_orig[,3:4,]
# est_names_ct = c("Causal Tree (CT)", "Causal Tree - Adaptive (CT-A)")
# est_names = c(est_names_ct, est_names_cg)
# n_ct_group = length(est_names_ct)
mse_ct_group_mean = apply(mse_ct_group, c(1,2), mean)
mse_ct_group_mean

Eval CG

t1 = Sys.time()
cat(paste("Start time: ",t1,"\n"))
bar_length = if(n_parallel>1) NN*D else S*NN*D
pb = utils::txtProgressBar(0, bar_length, style = 1)
run=1
outer_results_cg = list()
n_cg_group = length(est_names_cg)
nl_cg_group = array(dim=c(D*NN, n_cg_group, S))
mse_cg_group = array(dim=c(D*NN, n_cg_group, S))
ngood_cg_group = array(dim=c(D*NN, n_cg_group, S))
ntot_cg_group = array(dim=c(D*NN, n_cg_group, S))
#n_errors = matrix(0, nrow=D, ncol=NN)

for(d in 1:D) {
  for(N_i in 1:NN) {
    dgp_i = get_dgp_idx(d, N_i)
    if(n_parallel>1) {
      utils::setTxtProgressBar(pb, run)
      run = run+1
    }
    #results_s = list() # To debug foreach
    #for(s in 1:S){  # To debug foreach
    results_s = foreach(s=1:S, d1=lapply(1:S, function(s1) data1[[get_iter(d, N_i, s1)]]), d2=lapply(1:S, function(s2) data2[[get_iter(d, N_i, s2)]]), .packages=c("gsubfn","ranger", "CausalGrid"), .errorhandling = "pass", .inorder=FALSE) %doChange% {
      if(n_parallel==1) {
        utils::setTxtProgressBar(pb, run)
        run = run+1
      }
      N = Ns[N_i]
      iter = get_iter(d, N_i, s)
      #d1 = data1[[iter]] # To debug foreach
      #d2 = data2[[iter]] # To debug foreach
      list[y, X, w, tau] = d1
      list[y_te, X_te, w_te, tau_te] = d2
      tr_sample = c(rep(TRUE, N), rep(FALSE, N))


      grid_a_fit <- sim_cg_fit(y, X, w, tr_sample)
      grid_a_LassoCV_fit <- sim_cg_fit(y, X, w, tr_sample, ctrl_method="LassoCV")
      grid_a_RF_fit <- sim_cg_fit(y, X, w, tr_sample, do_rf=TRUE)

      ct_h_nl_s = nl_ct_group[dgp_i, 1, s]
      grid_a_hm_fit <- change_complexity(grid_a_fit, y, X, d=w, which.min(abs(ct_h_nl_s - (grid_a_fit$complexity_seq + 1))))
      grid_a_LassoCV_hm_fit <- change_complexity(grid_a_LassoCV_fit, y, X, d=w, which.min(abs(ct_h_nl_s - (grid_a_LassoCV_fit$complexity_seq + 1))))
      grid_a_RF_hm_fit <- change_complexity(grid_a_RF_fit, y, X, d=w, which.min(abs(ct_h_nl_s - (grid_a_RF_fit$complexity_seq + 1))))

      grid_a_RF_fit = save_space_cg_rf(grid_a_RF_fit)

      fits_base = list(grid_a_fit, grid_a_LassoCV_fit, grid_a_RF_fit)
      fits_match = list(grid_a_hm_fit, grid_a_LassoCV_hm_fit, grid_a_RF_hm_fit)
      res_mat = matrix(nrow=0, ncol=4)
      for(fit in c(fits_base, fits_match)) {
        res_mat = rbind(res_mat, sim_cg_vectors(fit, good_features[[d]], X_te, tau_te))
      }

      res = list(s, fits_base, res_mat)
      #results_s[[s]] = res #P to debug foreach
      res
    }

    #recombine data
    for(res_s in 1:S) {
      res = results_s[[res_s]]
      list[s, fits_base, res_mat] = res
      iter = get_iter(d, N_i, s)
      #if(inherits(res, "error")) {
      #  cat(paste("Error: d=", d, "N_i=", N_i, "s=", s, "\n"))
      #  n_errors[d,N_i] = n_errors[d,N_i] + 1
      #  next
      #}
      for(r in 1:nrow(res_mat)) {
        nl_cg_group[dgp_i, r, s] = res_mat[r,1]
        mse_cg_group[dgp_i, r, s] = res_mat[r,2]
        ngood_cg_group[dgp_i, r, s] = res_mat[r,3]
        ntot_cg_group[dgp_i, r, s] = res_mat[r,4]
      }
    }
    outer_results_cg[[dgp_i]] = results_s
  }
}

t2 = Sys.time() #can us as.numeric(t1) to convert to seconds
td = t2-t1
close(pb)
cat(paste("Total time: ",format(as.numeric(td))," ", attr(td,"units"),"\n"))

#if(sum(n_errors)>0){
#  cat("N errors")
#  print(n_errors)
#}
if(n_parallel>1) stopCluster(cl)

Collect results

nl_group = abind(nl_ct_group, nl_cg_group, along=2)
mse_group = abind(mse_ct_group, mse_cg_group, along=2)
ngood_group = abind(ngood_ct_group, ngood_cg_group, along=2)
ntot_group = abind(ntot_ct_group, ntot_cg_group, along=2)

Potentially save raw results to file

if(F){
  save(S, Ns, D, data1, data2, 
       nl_group, mse_group, ngood_group, ntot_group,
       outer_results_ct, outer_results_cg,
       file = sim_rdata_fname)
}
if(F){
  load(sim_rdata_fname)
}

Output Results

ratio_good_group = ngood_group/ntot_group
ratio_good_group[is.na(ratio_good_group)] = 1

nl_group_mean = apply(nl_group, c(1,2), mean)
mse_group_mean = apply(mse_group, c(1,2), mean)
ngood_group_mean = apply(ngood_group, c(1,2), mean)
ntot_group_mean = apply(ntot_group, c(1,2), mean)
pct_good_group_mean = apply(ratio_good_group, c(1,2), mean)
compose_table <- function(res_array) {
  res_array = t(res_array)
  colnames(res_array) = rep(c("N=500", "N=1000"), D)
  rownames(res_array) = est_names
  res_array
}
n_cells_comp <- compose_table(nl_group_mean)
mse_comp <- compose_table(mse_group_mean)
ratio_good_comp <- compose_table(pct_good_group_mean)
cat("n_cells_comp\n")
print(n_cells_comp)
cat("mse_comp\n")
print(mse_comp)
cat("ratio_good_comp\n")
print(ratio_good_comp)

benefit = (mse_comp[1,] - mse_comp[5,])/mse_comp[1,]
cat(paste("benefit for CT-X=[", min(benefit), ",", max(benefit), "]\n"))
print(benefit)
benefit = (mse_comp[1,] - mse_comp[6,])/mse_comp[1,]
cat(paste("benefit for CT-RF=[", min(benefit), ",", max(benefit), "]\n"))
print(benefit)

Write out tex tables

fmt_table <- function(xtbl, o_fname) {
  capt_ret <- capture.output(file_cont <- print(xtbl, floating=F, comment = F))
  #cat(file_cont, file=paste0(o_fname,"_2"))
  file_cont = paste0(str_sub(file_cont, end=28+2*D), paste0(paste0("&\\multicolumn{2}{c}{Design ", 1:D, "}", collapse=""),"\\\\ \n"), str_sub(file_cont, start=29+2*D))
  cat(file_cont, file=o_fname)
}

if(T){ #Output tables
  fmt_table(xtable(n_cells_comp[small_table_idx,], digits=2), paste0(tbl_export_path, "n_cells",suffix,".tex"))
  fmt_table(xtable(mse_comp[small_table_idx,], digits=3), paste0(tbl_export_path, "mse",suffix,".tex"))
  fmt_table(xtable(ratio_good_comp[small_table_idx,]), paste0(tbl_export_path, "ratio_good",suffix,".tex"))
  fmt_table(xtable(n_cells_comp[app_table_idx,], digits=2), paste0(tbl_export_path, "n_cells_app",suffix,".tex"))
  fmt_table(xtable(mse_comp[app_table_idx,], digits=3), paste0(tbl_export_path, "mse_app",suffix,".tex"))
  fmt_table(xtable(ratio_good_comp[app_table_idx,]), paste0(tbl_export_path, "ratio_good_app",suffix,".tex"))
}

Output examples

ct_sim_plot <- function(d, N_i, s, honest=TRUE) {
  iter = get_iter(d, N_i, s)
  list[y, X, w, tau] = data1[[iter]]
  X_range = get_X_range(X)
  if(honest)
    plt = plot_2D_partition.rpart(ct_h_fit_models[[iter]], X_range=X_range)
  else
    plt = plot_2D_partition.rpart(ct_a_fit_models[[iter]], X_range=X_range)
  return(plt + ggtitle("Causal Tree") + labs(fill = "tau(X)"))
}
cg_sim_plot <- function(d, N_i, s, lasso=FALSE) {
  iter = get_iter(d, N_i, s)
  list[y, X, w, tau] = data1[[iter]]
  N = Ns[N_i]
  X_range = get_X_range(X)
  if(lasso)
    grid_fit = cg_a_LassoCV_fit_models[[iter]]
  else
    grid_fit = cg_a_fit_models[[iter]]
  grid_a_m_fit <- change_complexity(grid_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_fit$complexity_seq + 1))))
  plt = plot_2D_partition.estimated_partition(grid_a_m_fit, c("X1", "X2"))
  return(plt + ggtitle("Causal Tree") + labs(fill = "tau(X)"))
}
ct_cg_plot <- function(d, N_i, s, ct_honest=TRUE, cg_lasso=FALSE) {
  ct_plt = ct_sim_plot(d, N_i, s, honest=ct_honest)
  cg_plt = cg_sim_plot(d, N_i, s, lasso=cg_lasso)
  grid.arrange(ct_plt + ggtitle("Causal Tree"), cg_plt + ggtitle("Causal Grid"), ncol=2)
}
if(F) {
  #Pick which one to show. 
  n_dim_splits = matrix(0,nrow=D*NN*S, ncol=6)
  for(d in 1:D) {
    k = Ks[d]
    for(N_i in 1:NN) {
      for(s in 1:S) {
        iter = get_iter(d, N_i, s)
        N = Ns[N_i]
        list[y, X, w, tau] = data1[[iter]]
        #list[y_te, X_te, w_te, tau_te] = data2[[iter]]
        tr_sample = c(rep(TRUE, N), rep(FALSE, N))

        grid_a_fit <- cg_a_fit_models[[iter]]
        grid_a_LassoCV_fit <- cg_a_LassoCV_fit_models[[iter]]
        ct_h_nl_s = nl_ct_group[dgp_i, 1, s]
        grid_a_m_fit <- change_complexity(grid_a_fit, y, X, d=w, which.min(abs(ct_h_nl_s - (grid_a_fit$complexity_seq + 1))))
        grid_a_LassoCV_m_fit <- change_complexity(grid_a_LassoCV_fit, y, X, d=w, which.min(abs(ct_h_nl_s - (grid_a_LassoCV_fit$complexity_seq + 1))))

        n_dim_splits[iter, 1] = ct_ndim_splits(ct_h_fit_models[[iter]], k)
        n_dim_splits[iter, 2] = ct_ndim_splits(ct_a_fit_models[[iter]], k)
        n_dim_splits[iter, 3] = cg_ndim_splits(grid_a_fit)
        n_dim_splits[iter, 4] = cg_ndim_splits(grid_a_LassoCV_fit)
        n_dim_splits[iter, 5] = cg_ndim_splits(grid_a_m_fit)
        n_dim_splits[iter, 6] = cg_ndim_splits(grid_a_LassoCV_m_fit)
      }
    }
  }
  #const_mask = results_ct_h[,1]==2 & results_ct_h[,4]>=4 & n_dim_splits[,1]==2 & n_dim_splits[,5]==2 #cg: normal
  #grph_pick = cbind(results_ct_h[const_mask,c(1,2,3, 4)],results_cg_a_m[const_mask,c(4)])
  const_mask = results_ct_h[,1]==2 & results_ct_h[,4]>=4 & n_dim_splits[,1]==2 & n_dim_splits[,6]==2 #cg: Lasso
  grph_pick = cbind(results_ct_h[const_mask,c(1,2,3, 4)],results_cg_a_LassoCV_m[const_mask,c(4)])
  grph_pick
  ct_cg_plot(2, 1, 57, ct_honest=TRUE, cg_lasso=TRUE)
}
cg_table <- function(obj, digits=3){
  stats = obj$cell_stats$stats[c("param_est")]
  colnames(stats) <- c("Est.")
  tbl = cbind(get_desc_df(obj$partition, do_str=TRUE, drop_unsplit=TRUE, digits=digits), stats)
}

ggplot_to_pdf <-function(plt_obj, filename) {
  pdf(file=filename)
  print(plt_obj)
  dev.off()
}
if(F){
  d=2 #so we can hve a 2d model
  N_i=1
  s=57
  iter = get_iter(d, N_i, s)
  list[y, X, w, tau] = data1[[iter]]
  X_range = get_X_range(X)

  ct_m = ct_h_fit_models[[iter]]
  ct_m_desc = ct_desc(ct_m)
  cat(ct_m_desc, file=paste0(tbl_export_path, "ct_ex_2d",suffix,".tex"), sep="\n")
  ct_plt = plot_2D_partition.rpart(ct_m, X_range=X_range) + labs(fill = "tau(X)")
  print(ct_plt + ggtitle("Causal Tree"))

  grid_fit = cg_a_LassoCV_fit_models[[iter]] #cg_a_fit_models[[iter]]
  cg_m <- change_complexity(grid_fit, y, X, d=w, which.min(abs(ct_h_nl[iter] - (grid_fit$complexity_seq + 1))))
  print(cg_m)
  cg_m_tbl = cg_table(cg_m)
  temp <- capture.output(cg_tbl_file_cont <- print(xtable(cg_m_tbl, digits=3), floating=F, comment = F))
  cat(cg_tbl_file_cont, file=paste0(tbl_export_path, "cg_ex_2d",suffix,".tex"))
  cg_plt = plot_2D_partition.estimated_partition(cg_m, c("X1", "X2")) + labs(fill = "tau(X)")
  print(cg_plt + ggtitle("Causal Grid"))

  ggplot_to_pdf(ct_plt + ggtitle("Causal Tree"), paste0(fig_export_path, "ct_ex_2d",suffix,".pdf"))
  ggplot_to_pdf(cg_plt + ggtitle("Causal Grid"), paste0(fig_export_path, "cg_ex_2d",suffix,".pdf"))

  pdf(file=paste0(fig_export_path, "cg_ct_ex_2d",suffix,".pdf"), width=8, height=4)
  grid.arrange(ct_plt + ggtitle("Causal Tree"), cg_plt + ggtitle("Causal Grid"), ncol=2)
  dev.off()
}


microsoft/CausalGrid documentation built on Aug. 25, 2022, 9:30 a.m.