library(duvernaygeomechmodel)
set.seed(1)

get_model_type <- function(model_dir){
  if (str_detect(model_dir,'rf')){
    model_type = 'rf'
  } else if (str_detect(model_dir,'glm')){
    model_type = 'glm'
  } else {

    model_type = 'mars'
  }
  model_type
}

get_model_target <- function(model_dir){
  if (str_detect(model_dir,'ym')){
    target = 'ym'
  } else if (str_detect(model_dir,'pr')){
    target = 'pr'
  } else {
    target = 'br'
  }
  target
}

read_results <- function(model_dir){
  results = read_tsv(str_c(model_dir,'/model_results.txt'), 
                   skip_empty_rows = TRUE, col_names=FALSE)

  train_perf = str_split(results[17,], " ")
  train_mae = as.numeric(train_perf[[1]][1])
  train_rmse  = as.numeric(train_perf[[1]][2])
  train_bias = as.numeric(str_replace(results[2,],"\\[1\\] ",""))
  train_var = as.numeric(str_replace(results[4,],"\\[1\\] ",""))

  test_perf = str_split(results[20,], " ")
  test_mae = as.numeric(test_perf[[1]][1])
  test_rmse  = as.numeric(test_perf[[1]][2])
  test_bias = as.numeric(str_replace(results[6,],"\\[1\\] ",""))
  test_var = as.numeric(str_replace(results[8,],"\\[1\\] ",""))

  list(train_mae=train_mae, train_rmse=train_rmse, train_bias=train_bias,
       train_var=train_var, test_mae=test_mae, test_rmse=test_rmse,
       test_bias=test_bias, test_var=test_var)
}

Iterate through models

summary_dir = '../output/statistical_model_summaries/'
output_dir = '../output/statistical_models/'

comp_tbl = tibble()
comp_resamp = tibble()
for (model_dir in list.dirs(output_dir, recursive = FALSE)){

  if (str_replace(model_dir, '../output/statistical_models//', '') == "wave") next

  model_type = get_model_type(model_dir)

  target = get_model_target(model_dir)

  feat_set = str_replace(model_dir, '../output/statistical_models//', '') %>%
    str_replace(., str_c('_',model_type),"") %>%
    str_replace(., str_c('_',target),"")

  res = read_results(model_dir)

  full_target = case_when(
    target == 'ym' ~ 'Young\'s Modulus',
    target == 'pr' ~ 'Poisson\'s Ratio',
    TRUE ~ 'Brittleness'
  )

  train = tibble(
    model_type=str_to_upper(model_type), target=full_target, feat_set, set='train',
    mae=res$train_mae, rmse=res$train_rmse, bias=res$train_bias, var=res$train_var
  )

  test = tibble(
    model_type=str_to_upper(model_type), target=full_target, feat_set, set='test',
    mae=res$test_mae, rmse=res$test_rmse, bias=res$test_bias, var=res$test_var
  )

  comp_tbl = bind_rows(comp_tbl, bind_rows(train,test))

  resample_obj = read_rds(str_c(model_dir,'/resample_results.rds'))$resample_obj

  trn_df = resample_obj$measures.train %>%
    mutate(set='train')

  tst_df = resample_obj$measures.test %>%
    mutate(set='test')

  resamp_df = bind_rows(trn_df, tst_df) %>%
    mutate(model_type = str_to_upper(model_type)) %>%
    mutate(target = target) %>%
    mutate(feat_set = feat_set)

  comp_resamp = bind_rows(comp_resamp, resamp_df)
}

# set factors for better plotting
comp_tbl$set = factor(
  comp_tbl$set, 
  levels=c("train","test")
  )

comp_tbl$feat_set = factor(
  comp_tbl$feat_set, 
  levels = c("vp_rho", "vp_vs_rho", "dev_vp_rho", 
             "dev_vp_vs_rho", "dev_conf_vp_rho",
             "dev_conf_vp_vs_rho"
             )
  )

comp_resamp$set = factor(
  comp_resamp$set, 
  levels=c("train","test")
  )

comp_resamp$feat_set = factor(
  comp_resamp$feat_set, 
  levels = c("vp_rho", "vp_vs_rho", "dev_vp_rho", 
             "dev_vp_vs_rho", "dev_conf_vp_rho",
             "dev_conf_vp_vs_rho"
             )
  )

Get Wave Equation Results and add to comp_resamp

wave_results = read_rds(paste0(output_dir,"/wave/resample_results.rds"))

comp_resamp = bind_rows(comp_resamp, wave_results %>% select(-correction))

mean_ym_correction = wave_results %>%
  filter(target == "ym") %>%
  summarize_at("correction", mean) %>%
  pull(correction)

mean_pr_correction = wave_results %>%
  filter(target == "pr") %>%
  summarize_at("correction", mean) %>%
  pull(correction)

Plot dot plots of model performance. Four plots with feature_set + model_type and result

ggplot(comp_tbl) +
  geom_point(aes(x=model_type, y=mae, fill=feat_set, color=feat_set, shape=set), size=3) +
  scale_color_brewer(type='qual', palette = 'Paired') +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  scale_shape_manual(values=c(21, 24))+
  facet_wrap(. ~ target, scales ='free') +
  xlab("") +
  ylab("MAE") +
  ggsave(str_c(summary_dir,'summary_mae.pdf'), width=7, height=5)

ggplot(comp_tbl) +
  geom_point(aes(x=model_type, y=rmse, fill=feat_set, color=feat_set, shape=set), size=3) +
  scale_color_brewer(type='qual', palette = 'Paired') +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  scale_shape_manual(values=c(21, 24))+
  facet_wrap(. ~ target, scales ='free') +
  xlab("") +
  ylab("RMSE") +
  ggsave(str_c(summary_dir,'summary_rmse.pdf'), width=7, height=5)

ggplot(comp_tbl %>% filter(!(target == 'Brittleness' & bias > 0.03))) +
  geom_point(aes(x=model_type, y=bias, fill=feat_set, color=feat_set, shape=set), size=3) +
  scale_color_brewer(type='qual', palette = 'Paired') +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  scale_shape_manual(values=c(21, 24))+
  facet_wrap(. ~ target, scales ='free') +
  xlab("") +
  ylab("Bias") +
  ggsave(str_c(summary_dir,'summary_bias.pdf'), width=7, height=5)

ggplot(comp_tbl %>% filter(!(target == 'Brittleness' & bias > 0.03))) +
  geom_point(aes(x=model_type, y=var, fill=feat_set, color=feat_set, shape=set), size=3) +
  scale_color_brewer(type='qual', palette = 'Paired') +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  scale_shape_manual(values=c(21, 24))+
  facet_wrap(. ~ target, scales ='free') +
  xlab("") +
  ylab("Variance") +
  ggsave(str_c(summary_dir,'summary_variance.pdf'), width=7, height=5)

Plot boxplot of resampling results

ggplot(comp_resamp %>% filter(target == 'ym')) +
  geom_boxplot(aes(y=mae, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  xlab("") +
  coord_cartesian(ylim=c(0,7)) +
  ylab('MAE') +
  ggtitle('(a) Young\'s Modulus (GPa)') +
  ggsave(str_c(summary_dir,'ym_mae.pdf'), width=6, height=6)

ggplot(comp_resamp %>% filter(target == 'pr')) +
  geom_boxplot(aes(y=mae, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  xlab("") +
  coord_cartesian(ylim=c(0,0.05)) +
  ylab('MAE') +
  ggtitle('(b) Poisson\'s Ratio') +
  ggsave(str_c(summary_dir,'pr_mae.pdf'), width=6, height=6)

ggplot(comp_resamp %>% filter(target == 'br')) +
  geom_boxplot(aes(y=mae, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  coord_cartesian(ylim=c(0,0.2)) +
  xlab("") +
  ylab('MAE') +
  ggtitle('(c) Brittleness') +
  ggsave(str_c(summary_dir,'br_mae.pdf'), width=6, height=6)

ggplot(comp_resamp %>% filter(target == 'ym')) +
  geom_boxplot(aes(y=rmse, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  xlab("") +
  coord_cartesian(ylim=c(0,7)) +
  ylab('RSME') +
  ggtitle('(a) Young\'s Modulus (GPa)') +
  ggsave(str_c(summary_dir,'ym_rmse.pdf'), width=6, height=6)

ggplot(comp_resamp %>% filter(target == 'pr')) +
  geom_boxplot(aes(y=rmse, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  xlab("") +
  coord_cartesian(ylim=c(0,0.05)) +
  ylab('RSME') +
  ggtitle('(b) Poisson\'s Ratio') +
  ggsave(str_c(summary_dir,'pr_rmse.pdf'), width=6, height=6)

ggplot(comp_resamp %>% filter(target == 'br')) +
  geom_boxplot(aes(y=rmse, x=model_type, fill=feat_set), outlier.size=0.25) +
  scale_fill_brewer(type='qual', palette = 'Paired') +
  facet_wrap(. ~ set, scales = 'fixed') +
  coord_cartesian(ylim=c(0,0.2)) +
  xlab("") +
  ylab('RSME') +
  ggtitle('(c) Brittleness') +
  ggsave(str_c(summary_dir,'br_rmse.pdf'), width=5, height=5)

Load importance / interaction results

comp_int_imp = tibble()
for (model_dir in list.dirs(output_dir, recursive = FALSE)){

  # ignore wave equation
  if (str_replace(model_dir, '../output/statistical_models//', '') == "wave") next

  model_type = get_model_type(model_dir)

  target = get_model_target(model_dir)

  feat_set = str_replace(model_dir, '../output/statistical_models//', '') %>%
    str_replace(., str_c('_',model_type),"") %>%
    str_replace(., str_c('_',target),"")

  full_target = case_when(
    target == 'ym' ~ 'Young\'s Modulus',
    target == 'pr' ~ 'Poisson\'s Ratio',
    TRUE ~ 'Brittleness'
  )

  imp_int = read_csv(str_c(model_dir,'/imp_int_res.csv')) %>%
    janitor::clean_names() %>%
    mutate(model_type) %>%
    mutate(target) %>%
    mutate(feat_set) %>%
    mutate(full_target)

  comp_int_imp = bind_rows(comp_int_imp, imp_int)  
}

comp_int_imp$feat_set = factor(
  comp_int_imp$feat_set, 
  levels = c("vp_rho", "vp_vs_rho", "dev_vp_rho", 
             "dev_vp_vs_rho", "dev_conf_vp_rho",
             "dev_conf_vp_vs_rho"
             )
  )

comp_int_imp$target = factor(
  comp_int_imp$target, 
  levels = c("ym", "pr", "br")
  )

Summarize importance and interactions for a single plot

ggplot(comp_int_imp %>% filter(target == 'ym')) +
  geom_point(aes(y=importance, x=feature, color=feat_set, shape=model_type), size=2) +
  geom_errorbar(aes(x=feature, ymin=importance_05, ymax=importance_95, color=feat_set, width=0.1)) +
  xlab("") +
  ggtitle('Young\'s Modulus') +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  scale_y_log10() +
  ylab("Permuation Importance Ratio") +
  scale_color_brewer(type='qual', palette = 'Paired') +
  ggsave(str_c(summary_dir,'ym_imp.pdf'), width=4, height=5)

ggplot(comp_int_imp %>% filter(target == 'pr')) +
  geom_point(aes(y=importance, x=feature, color=feat_set, shape=model_type), size=2) +
  geom_errorbar(aes(x=feature, ymin=importance_05, ymax=importance_95, color=feat_set, width=0.1)) +
  xlab("") +
  ggtitle('Poisson\'s Ratio') +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  scale_y_log10() +
  ylab("Permuation Importance Ratio") +
  scale_color_brewer(type='qual', palette = 'Paired') +
  ggsave(str_c(summary_dir,'pr_imp.pdf'), width=4, height=5)

ggplot(comp_int_imp %>% filter(target == 'br')) +
  geom_point(aes(y=importance, x=feature, color=feat_set, shape=model_type), size=2) +
  geom_errorbar(aes(x=feature, ymin=importance_05, ymax=importance_95, color=feat_set, width=0.1)) +
  xlab("") +
  ggtitle('Brittleness') +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  scale_y_log10() +
  ylab("Permuation Importance Ratio") +
  scale_color_brewer(type='qual', palette = 'Paired') +
  ggsave(str_c(summary_dir,'br_imp.pdf'), width=4, height=5)

ggplot(comp_int_imp) +
  geom_point(aes(y=importance, x=feature, color=feat_set, shape=model_type), size=2) +
  geom_errorbar(aes(x=feature, ymin=importance_05, ymax=importance_95, color=feat_set, width=0.1)) +
  xlab("") +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  scale_y_log10() +
  ylab("Permuation Importance Ratio") +
  facet_wrap(. ~ target, scales='fixed') + 
  scale_color_brewer(type='qual', palette = 'Paired') +
  ggsave(str_c(summary_dir,'importance.pdf'), width=7, height=5)

ggplot(comp_int_imp) +
  geom_point(aes(y=interaction, x=feature, color=feat_set, shape=model_type), size=2) +
  scale_y_continuous(labels = function(x) format(x, scientific = TRUE, digits=3)) +
  xlab("") +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  ylab("Interaction") +
  facet_wrap(. ~ target, scales='fixed') + 
  scale_color_brewer(type='qual', palette = 'Paired') +
  ggsave(str_c(summary_dir,'interaction.pdf'), width=7, height=5)


ScottHMcKean/duvernay_geomech_model documentation built on Sept. 12, 2022, 10:08 a.m.