knitr::opts_chunk$set(fig.width=12, fig.height=8, fig.path='Figs/',
                      warning=FALSE, message=FALSE)
knitr::opts_knit$set(root.dir="../")
options(width = 250)
library(MLlibrary)
library(dplyr)
library(purrr)

THRESHOLD <- 0.4
preval_names <- c('niger_pastoral', 'niger_agricultural', 'tanzania_2008', 'tanzania_2010', 'tanzania_2012', 'ghana_pe', 'mexico')
val_names <- c('south_africa_w1', 'south_africa_w2', 'south_africa_w3', 'iraq', 'brazil')
table_stats <- function(tables) {
  lapply(names(tables), function(name) {
    df <- tables[[name]]
    value_name <- colnames(df)[[2]]
    df$dataset <- name
    reshape::cast(df, dataset ~ method, value=value_name)
  })
}
ds_stats <- lapply(c(preval_names, val_names), function(name) {
  df <- load_dataset(name)
  row_count <- nrow(df)
  col_count <- ncol(df)
  data.frame(dataset=name, N=row_count, K=col_count)
})
ds_stats <- bind_rows(ds_stats)
get_reaches <- function(ds_names) {
  reaches <- lapply(ds_names, function(name) {
    output <- load_validation_models(name)
    reach_by_pct_targeted(output, threshold=THRESHOLD)
  })
  names(reaches) <- ds_names
  reaches
}

get_reach_table <- function(reaches) {
  tables <- lapply(reaches, table_stat)
  combine_tables(tables)
}

get_budget_table <- function(reaches) {
  tables <- lapply(reaches, budget_change)
  combine_tables(tables)
}

combine_tables <- function(tables) {
  table_stats(tables) %>%
    bind_rows() %>%
    merge(ds_stats, by='dataset') %>%
    select(dataset, N, K, ols, everything()) %>%
    arrange(N)
}

difference_table <- function(reaches) {
  reach_table <- get_reach_table(reaches)
  reach_differences <- reach_table %>%
    mutate(reach_improvement=ensemble-ols) %>%
    mutate(relative_reach_improvement=(ensemble-ols)/ols) %>%
    select(N, K, dataset, reach_improvement, relative_reach_improvement)
  budget_table <- get_budget_table(reaches) %>%
    mutate(budget_reduction=-1 * ensemble) %>%
    select(dataset, budget_reduction)
  merge(reach_differences, budget_table, by='dataset') %>%
    arrange(N)
}

Pre-validation

preval_reaches <- get_reaches(preval_names)
preval_table <- get_reach_table(preval_reaches)
print(preval_table)
print(difference_table(preval_reaches))

Validation

val_reaches <- get_reaches(val_names)
val_table <- get_reach_table(val_reaches)
print(val_table)
print(difference_table(val_reaches))


ml-e/ML-library documentation built on May 23, 2019, 2:03 a.m.