knitr::opts_chunk$set(fig.width=12, fig.height=8, fig.path='Figs/',
                      warning=FALSE, message=FALSE)
knitr::opts_knit$set(root.dir="../")
options(width = 250)
myround <- function(x, digits=1) {
  if(digits < 1) stop("This is intended for the case digits >= 1.")
  if(length(digits) > 1) {
    digits <- digits[1]
    warning("Using only digits[1]")
  }
  tmp <- sprintf(paste("%.", digits, "f", sep=""), x)
  # deal with "-0.00" case
  zero <- paste0("0.", paste(rep("0", digits), collapse=""))
  tmp[tmp == paste0("-", zero)] <- zero
  tmp
}

cap <- function(x) {
  s <- strsplit(x, " ")[[1]]
  paste(toupper(substring(s, 1,1)), substring(s, 2),
      sep="", collapse=" ")
}
library(MLlibrary)
library(dplyr)
library(purrr)
library(reshape2)
library(ggplot2)
library(grid) # for unit
library(scales) # for brewer_pal
library(knitr)
# library(xlsx)
library(readxl)
library(tidyr)
library(xtable)
library(zoo)

THRESHOLD <- 0.2
# all_names <- c('niger_pastoral', 'niger_agricultural', 'tanzania_2008', 'tanzania_2010', 'tanzania_2012', 'ghana_pe', 'mexico', 'south_africa_w1', 'south_africa_w2', 'south_africa_w3', 'iraq', 'brazil')
# all_names <- c('niger_pastoral', 'niger_agricultural', 'tanzania_2008')
all_names <- c('niger_pastoral_tuned', 'niger_agricultural_tuned', 'brazil_tuned', 'ghana_tuned_pe', 'iraq_tuned', 'mexico_tuned')
countries <- strsplit(all_names, '_') %>% 
  purrr::map(first) %>%
  unique() %>%
  purrr::map(cap)

# pmt_names <- c('niger_pastoral_pmt', 'niger_agricultural_pmt', 'ghana')
# pmt_names <- c('niger_pastoral_pmt_tuned', 'niger_agricultural_pmt_tuned', 'ghana_tuned')
pmt_names <- c('niger_pastoral_pmt_tuned', 'niger_agricultural_pmt_tuned', 'ghana_tuned')
renames <- list(
  niger_pastoral='niger_1',
  niger_pastoral_pmt='niger_1_pmt',
  niger_agricultural='niger_2',
  niger_agricultural_pmt='niger_2_pmt',
  ghana_pe='ghana',
  ghana='ghana_pmt',
  south_africa_w1='sa_08',
  south_africa_w2='sa_10',
  south_africa_w3='sa_12',
  tanzania_2008='tnz_08',
  tanzania_2010='tnz_10',
  tanzania_2012='tnz_12',
  iraq='iraq',
  mexico='mexico',
  brazil='brazil'
  )

tuned_names <- renames
names(tuned_names) <- map_chr(names(renames), ~paste(., 'tuned', sep='_'))
tuned_names$ghana_tuned_pe <- 'ghana'
renames <- c(renames, tuned_names)

clean_name <- function(name) {
  new_name <- renames[[name]]
  if (!is.null(new_name)) {
    return(new_name)
  }
  else {
    return(name)
  }
}
table_stats <- function(tables) {
  lapply(names(tables), function(name) {
    df <- tables[[name]]
    value_name <- colnames(df)[[2]]
    df$dataset <- name
    # reshape::cast(df, dataset ~ method, value=value_name)
    tidyr::spread(df, method, value)
  })
}
ds_stats <- lapply(c(all_names, pmt_names), function(name) {
  df <- load_dataset(name)
  row_count <- nrow(df)
  col_count <- ncol(df)
  data.frame(dataset=clean_name(name), N=row_count, K=col_count)
})
ds_stats <- bind_rows(ds_stats) %>% arrange(N)
get_reaches <- function(ds_names) {
  reaches <- lapply(ds_names, function(name) {
    output <- load_validation_models(name)
    reach_by_pct_targeted(output, threshold=THRESHOLD)
  })
  names(reaches) <- sapply(ds_names, clean_name)
  reaches 
}

get_reach_table <- function(reaches, subset_models=FALSE) {
  tables <- lapply(reaches, table_stat)
  combined <- combine_tables(tables)
  if (!subset_models) {
    combined %>%
      select(dataset, N, K, ols, enet, tuned_forest, opf, ensemble) %>%
      rename(ols_plus_forest=opf) %>%
      rename(forest=tuned_forest)
  } else {
    combined %>%
      select(dataset, N, K, ols_25, ensemble_25)
  }
}


get_budget_table <- function(reaches, subset_models=FALSE) {
  if (subset_models) base <- 'ols_25' else base <- 'ols'
  tables <- lapply(reaches, budget_change, base=base)
  combine_tables(tables)
}

combine_tables <- function(tables) {
  table_stats(tables) %>%
    bind_rows() %>%
    merge(ds_stats, by='dataset') %>%
      select(dataset, N, K, ols, everything()) %>%
      arrange(N)
}

difference_table <- function(reaches, subset_models=FALSE) {
  reach_table <- get_reach_table(reaches, subset_models)
  if (!subset_models) {
    reach_differences <- reach_table %>%
      mutate(reach_improvement=ensemble-ols) %>%
      mutate(relative_reach_improvement=(ensemble-ols)/ols) %>%
      select(N, K, dataset, reach_improvement, relative_reach_improvement)
    budget_table <- get_budget_table(reaches) %>%
      mutate(budget_reduction=-1 * ensemble) %>%
      select(dataset, budget_reduction)
  } else {
    reach_differences <- reach_table %>%
      mutate(reach_improvement=ensemble_25-ols_25) %>%
      mutate(relative_reach_improvement=(ensemble_25-ols_25)/ols_25) %>%
      select(N, K, dataset, reach_improvement, relative_reach_improvement)
    budget_table <- get_budget_table(reaches, subset_models) %>%
      mutate(budget_reduction=-1 * ensemble_25) %>%
      select(dataset, budget_reduction)
  }
  merge(reach_differences, budget_table, by='dataset') %>%
    arrange(N)
}

plot_reach_vs_pct_targeted <- function(dsname, threshold=DEFAULT_THRESHOLDS, target=NULL) {
  zoomtheme <- theme(legend.position="none", axis.line=element_blank(),axis.text.x=element_blank(),
                   axis.text.y=element_blank(),axis.ticks=element_blank(),
                   axis.title.x=element_blank(),axis.title.y=element_blank(),
                   panel.grid.major = element_blank(), panel.grid.minor = element_blank(), 
                   panel.background = element_rect(color='gray', fill="white"),
                   plot.margin = unit(c(-6,0,-6,-6),"mm"),
                   strip.background=element_blank(),
                   strip.text=element_blank())

  output <- load_validation_models(dsname)
  country_reach <- reach_by_pct_targeted(output, threshold=threshold)
  to_plot <- filter(country_reach, method %in% c('ols', 'ensemble'))
  p <- ggplot(to_plot, aes(x=pct_targeted, y=value, color=method)) +
    geom_line() +
    facet_wrap(~ threshold) +
    scale_color_brewer(type='qual', palette=2) +
    ylab('true poor reached (as pct of popuation)') +
    xlab('pct targeted')
  if (length(threshold)==1) {
    ols_reach <- filter(ungroup(to_plot), method=='ols')
    ols_value <- value_at_pct(ols_reach)[[1]]
    ensemble_reach <- filter(ungroup(to_plot), method=='ensemble')
    ensemble_value <- value_at_pct(ensemble_reach)[[1]]
    budget <- filter(budget_change(to_plot), method=='ensemble')[, 3]
    ensemble_pct_targeted <- (threshold + budget * threshold)[[1]]
    xmin <- threshold + .005
    xmax <- .9
    width <- xmax - xmin
    ymin <- .1
    ymax <- ols_value + .005
    height <- ymax - ymin
    mag <- 35
    p.zoom <- p + 
      coord_cartesian(xlim = c(threshold-width/mag, threshold+width/mag), ylim=c(ols_value-height/mag, ols_value+height/mag)) +
      geom_segment(x=threshold, xend=threshold, y=ensemble_value, yend=ols_value, color='black') +
      geom_segment(y=ols_value, yend=ols_value, xend=threshold, x=ensemble_pct_targeted, color='black') +
      annotate("text", label='Delta~budget', x=threshold-(threshold-ensemble_pct_targeted)/2, y=ols_value-.0005, parse=TRUE) +
      annotate("text", label='Delta~reach', x=threshold-.001, y=ensemble_value-(ensemble_value-ols_value)/2, parse=TRUE, angle=90) +
      zoomtheme
    g <- ggplotGrob(p.zoom)
    p <- p + annotation_custom(g, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
  }
  p <- p + geom_vline(xintercept=.4, linetype='longdash', color='black')
  p + theme_bw()
}

reaches <- get_reaches(all_names)
reacht <- get_reach_table(reaches)
difft  <- difference_table(reaches)

Targeting

Current Practice

Research Design

Datasets

country_codes <- c(
  'Niger',
  'Tanzania',
  'South Africa',
  'Ghana',
  'Iraq',
  'Mexico',
  'Brazil'
)
country_info <- setNames(
  c(
    'explore',
    'explore',
    'holdout',
    'explore',
    'holdout',
    'explore',
    'holdout'
  ),
  country_codes
)
map_dat <- map_data('world')
country_dat <- filter(map_dat, region %in% country_codes) %>% mutate(type=country_info[region])
ggplot() +
  geom_polygon(aes(long,lat, group=group), fill="grey65", data=map_dat) +
  theme_bw() +
  theme(axis.text = element_blank(), axis.title=element_blank()) +
  geom_polygon(data=country_dat, aes(long, lat, group=group, fill=type))

ds_info <- read.xlsx('analyses/datasets_used_20151220.xlsx', 1, stringsAsFactors=FALSE) %>%
  arrange(N)
to_print <- data.frame(
  Country=ds_info$Country,
  Year=ds_info$Year,
  Survey=ds_info$Survey,
  '< $1.90'=ds_info$Poverty.Ratio.at..1.90.PPP,
  '< $3.10'=ds_info$Poverty.Ratio.at..3.10.PPP,
  '< NPL'=ds_info$Poverty.Ratio.at.National.Poverty.Line,
  N=as.integer(ds_info$N),
  K=as.integer(ds_info$K),
  check.names=FALSE,
  stringsAsFactors=FALSE
)
to_print[1, 3] <- 'LSMS'
to_print[2, 3] <- 'LSMS'
to_print <- to_print[1:(nrow(to_print) - 1), ]
to_print[is.na(to_print)] <- ''
print(xtable(to_print),
      comment=F,
      include.rownames=FALSE,
      scalebox=.85)

Key:

Methods

After an exploratory phase, we chose to focus on a small number of methods.

method_df <- data.frame(
  key=c('OLS', 'enet', 'forest', 'OLS + RF', 'ensemble'),
  method=c('OLS', 'elastic net', 'random forest', 'OLS with random forest on residuals', 'ensemble'),
  notes=c(
    'Baseline',
    'Popular linear method for predictions.',
    'Popular nonlinear method. Performed well in exploratory phase.',
    'Simple way to combine OLS and forests. Performed well in exploratory phase.',
    'Most popular way to optimize predictions.'
  ))
print(xtable(method_df, align = c('l', 'l', 'p{3cm}', 'p{5cm}')),
      comment=F,
      include.rownames=FALSE)

Metrics

We think of poverty targeting as a classification problem and focus on two metrics of success:

plot_reach_vs_pct_targeted('ghana_tuned', threshold=.4) 

Results

# df <- melt(reacht, id=c('dataset', 'N', 'K', 'country'))
df <- melt(reacht, id=c('dataset', 'N', 'K'))
df <- mutate(df, value=value / THRESHOLD)
df$dataset <- factor(df$dataset, levels=ds_stats$dataset)
ggplot(df, aes(ymax=value, y=value, upper=value, middle=value, x=dataset, fill=variable)) + 
  geom_boxplot(position=position_dodge(width=.72), width=.7, lwd=.1, fatten=10, lower=0) + 
  scale_fill_manual(values=c('#9ecae1', '#a1d99b','#74c476','#41ab5d','#238b45','#005a32')) +
  ylab('reach') +
  # coord_cartesian(ylim=c(.16, .4)) +
  theme_bw() + 
  theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())

Ensembles outperform OLS

df <- rename(difft, reach=relative_reach_improvement, budget=budget_reduction)
df <- melt(df, id=c('dataset', 'N', 'K'))
df <- filter(df, variable != 'reach_improvement')
df$dataset <- factor(df$dataset, levels=ds_stats$dataset)
# df <- filter(df, variable=='reach')
ggplot(df, aes(y=value, x=dataset, fill=variable)) + 
  geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') + 
  scale_fill_manual(values=c('#33a02c', '#fcbba1')) +
  ylab(expression(frac(ensemble - ols, ols))) +
  theme_bw() + 
  theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())

Results are similar on 25 feature subsets

# reacht_25 <- get_reach_table(reaches, TRUE)
# difft_25 <- difference_table(reaches, TRUE)
# 
# orig_diff <- select(difft, dataset=dataset, orig=relative_reach_improvement)
# difft2 <- merge(orig_diff, difft_25, by='dataset')
# 
# df <- rename(difft2, original=orig, subset_25=relative_reach_improvement)
# df <- melt(df, id=c('dataset', 'N', 'K'))
# df <- filter(df, variable %in% c('original', 'subset_25'))
# df$dataset <- factor(df$dataset, levels=ds_stats$dataset)
# ggplot(df, aes(y=value, x=dataset, fill=variable)) +
#   geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') +
#   scale_fill_manual(values=c('#33a02c', '#fcbba1')) +
#   ylab(expression(frac(ensemble_25 - ols_25, ols_25))) +
#   theme_bw() +
#   theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())

Ensemble performs better on real PMTs

reaches_pmt <- get_reaches(pmt_names)
reacht_pmt <- get_reach_table(reaches_pmt)
difft_pmt <- difference_table(reaches_pmt)

#FIXME use dataset mapping to get correct sort order
idx <- purrr::map(difft$dataset, ~which(grepl(., difft_pmt$dataset))) %>%
  purrr::flatten() %>%
  as.integer()
baseline_df <- difft[idx, ] %>%
  arrange(N)
df <- select(difft_pmt, dataset, N, K, pmt=relative_reach_improvement)
df$baseline <- baseline_df$relative_reach_improvement
df <- melt(df, id=c('dataset', 'N', 'K'))
df$dataset <- factor(df$dataset, levels=ds_stats$dataset)
ggplot(df, aes(y=value, x=dataset, fill=variable)) +
  geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') +
  scale_fill_manual(values=c('#33a02c', '#a1d99b')) +
  ylab(expression(frac(ensemble - ols, ols))) +
  theme_bw() +
  theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())

Benefit from regularizing OLS decreases as N / K increases

# country_list <- ds_info$Country[1:(length(ds_info$Country)-1)]
# country_list <- na.locf(country_list)
# if (length(country_list) == nrow(reacht)) {
#   reacht$country <- country_list
# } else {
#   reacht$country <- reacht$dataset
# }
# ggplot(reacht, aes(y=enet-ols, x=N / K, label=country)) + 
#   geom_point(size=4) +
#   scale_x_log10() +
#   xlab('log(N/k)') +
#   geom_text(nudge_x = .02, nudge_y=.0005, check_overlap=TRUE) +
#   theme_bw()

Random forests are well-approximated by OLS

# forests <- lapply(all_names, function(name) {
#   df <- load_dataset(name)
#   df <- df[order(df[, TARGET_VARIABLE]), ]
#   df$X <- NULL
#   output <- load_validation_models(name) %>%
#     filter(method=='forest') %>%
#     arrange(true)
#   tol <- .0001
#   merged <- df
#   if (nrow(merged) == nrow(output)) {
#     if (all(abs(merged[, TARGET_VARIABLE] - output$true) < tol)) {
#       merged[, TARGET_VARIABLE] <- output$predicted
#     }
#   } else {
#     merged[, TARGET_VARIABLE] <- output$predicted[match(df[, TARGET_VARIABLE], output$true)]
#   }
#   model <- fit_ols(merged)
#   rsq <- summary(model)$r.squared
#   data.frame(dataset=clean_name(name), N=nrow(df), K=ncol(df), rsq=rsq)
# })
# forests <- rbind_all(forests)
# forests$dataset <- factor(forests$dataset, levels=ds_stats$dataset)
# ggplot(forests, aes(y=rsq, x=dataset)) + 
#   geom_bar(stat='identity') +
#   guides(fill=FALSE) +
#   ylab(expression(r^2)) +
#   coord_cartesian(ylim=c(0, 1)) +
#   ylim(0, 1) + 
#   theme_bw() + 
#   theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())

Improvements are meaningful

TODO[Jack]



ml-e/ML-library documentation built on May 23, 2019, 2:03 a.m.