knitr::opts_chunk$set(fig.width=12, fig.height=8, fig.path='Figs/', warning=FALSE, message=FALSE) knitr::opts_knit$set(root.dir="../") options(width = 250)
myround <- function(x, digits=1) { if(digits < 1) stop("This is intended for the case digits >= 1.") if(length(digits) > 1) { digits <- digits[1] warning("Using only digits[1]") } tmp <- sprintf(paste("%.", digits, "f", sep=""), x) # deal with "-0.00" case zero <- paste0("0.", paste(rep("0", digits), collapse="")) tmp[tmp == paste0("-", zero)] <- zero tmp } cap <- function(x) { s <- strsplit(x, " ")[[1]] paste(toupper(substring(s, 1,1)), substring(s, 2), sep="", collapse=" ") }
library(MLlibrary) library(dplyr) library(purrr) library(reshape2) library(ggplot2) library(grid) # for unit library(scales) # for brewer_pal library(knitr) # library(xlsx) library(readxl) library(tidyr) library(xtable) library(zoo) THRESHOLD <- 0.2 # all_names <- c('niger_pastoral', 'niger_agricultural', 'tanzania_2008', 'tanzania_2010', 'tanzania_2012', 'ghana_pe', 'mexico', 'south_africa_w1', 'south_africa_w2', 'south_africa_w3', 'iraq', 'brazil') # all_names <- c('niger_pastoral', 'niger_agricultural', 'tanzania_2008') all_names <- c('niger_pastoral_tuned', 'niger_agricultural_tuned', 'brazil_tuned', 'ghana_tuned_pe', 'iraq_tuned', 'mexico_tuned') countries <- strsplit(all_names, '_') %>% purrr::map(first) %>% unique() %>% purrr::map(cap) # pmt_names <- c('niger_pastoral_pmt', 'niger_agricultural_pmt', 'ghana') # pmt_names <- c('niger_pastoral_pmt_tuned', 'niger_agricultural_pmt_tuned', 'ghana_tuned') pmt_names <- c('niger_pastoral_pmt_tuned', 'niger_agricultural_pmt_tuned', 'ghana_tuned') renames <- list( niger_pastoral='niger_1', niger_pastoral_pmt='niger_1_pmt', niger_agricultural='niger_2', niger_agricultural_pmt='niger_2_pmt', ghana_pe='ghana', ghana='ghana_pmt', south_africa_w1='sa_08', south_africa_w2='sa_10', south_africa_w3='sa_12', tanzania_2008='tnz_08', tanzania_2010='tnz_10', tanzania_2012='tnz_12', iraq='iraq', mexico='mexico', brazil='brazil' ) tuned_names <- renames names(tuned_names) <- map_chr(names(renames), ~paste(., 'tuned', sep='_')) tuned_names$ghana_tuned_pe <- 'ghana' renames <- c(renames, tuned_names) clean_name <- function(name) { new_name <- renames[[name]] if (!is.null(new_name)) { return(new_name) } else { return(name) } }
table_stats <- function(tables) { lapply(names(tables), function(name) { df <- tables[[name]] value_name <- colnames(df)[[2]] df$dataset <- name # reshape::cast(df, dataset ~ method, value=value_name) tidyr::spread(df, method, value) }) }
ds_stats <- lapply(c(all_names, pmt_names), function(name) { df <- load_dataset(name) row_count <- nrow(df) col_count <- ncol(df) data.frame(dataset=clean_name(name), N=row_count, K=col_count) }) ds_stats <- bind_rows(ds_stats) %>% arrange(N)
get_reaches <- function(ds_names) { reaches <- lapply(ds_names, function(name) { output <- load_validation_models(name) reach_by_pct_targeted(output, threshold=THRESHOLD) }) names(reaches) <- sapply(ds_names, clean_name) reaches } get_reach_table <- function(reaches, subset_models=FALSE) { tables <- lapply(reaches, table_stat) combined <- combine_tables(tables) if (!subset_models) { combined %>% select(dataset, N, K, ols, enet, tuned_forest, opf, ensemble) %>% rename(ols_plus_forest=opf) %>% rename(forest=tuned_forest) } else { combined %>% select(dataset, N, K, ols_25, ensemble_25) } } get_budget_table <- function(reaches, subset_models=FALSE) { if (subset_models) base <- 'ols_25' else base <- 'ols' tables <- lapply(reaches, budget_change, base=base) combine_tables(tables) } combine_tables <- function(tables) { table_stats(tables) %>% bind_rows() %>% merge(ds_stats, by='dataset') %>% select(dataset, N, K, ols, everything()) %>% arrange(N) } difference_table <- function(reaches, subset_models=FALSE) { reach_table <- get_reach_table(reaches, subset_models) if (!subset_models) { reach_differences <- reach_table %>% mutate(reach_improvement=ensemble-ols) %>% mutate(relative_reach_improvement=(ensemble-ols)/ols) %>% select(N, K, dataset, reach_improvement, relative_reach_improvement) budget_table <- get_budget_table(reaches) %>% mutate(budget_reduction=-1 * ensemble) %>% select(dataset, budget_reduction) } else { reach_differences <- reach_table %>% mutate(reach_improvement=ensemble_25-ols_25) %>% mutate(relative_reach_improvement=(ensemble_25-ols_25)/ols_25) %>% select(N, K, dataset, reach_improvement, relative_reach_improvement) budget_table <- get_budget_table(reaches, subset_models) %>% mutate(budget_reduction=-1 * ensemble_25) %>% select(dataset, budget_reduction) } merge(reach_differences, budget_table, by='dataset') %>% arrange(N) } plot_reach_vs_pct_targeted <- function(dsname, threshold=DEFAULT_THRESHOLDS, target=NULL) { zoomtheme <- theme(legend.position="none", axis.line=element_blank(),axis.text.x=element_blank(), axis.text.y=element_blank(),axis.ticks=element_blank(), axis.title.x=element_blank(),axis.title.y=element_blank(), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), panel.background = element_rect(color='gray', fill="white"), plot.margin = unit(c(-6,0,-6,-6),"mm"), strip.background=element_blank(), strip.text=element_blank()) output <- load_validation_models(dsname) country_reach <- reach_by_pct_targeted(output, threshold=threshold) to_plot <- filter(country_reach, method %in% c('ols', 'ensemble')) p <- ggplot(to_plot, aes(x=pct_targeted, y=value, color=method)) + geom_line() + facet_wrap(~ threshold) + scale_color_brewer(type='qual', palette=2) + ylab('true poor reached (as pct of popuation)') + xlab('pct targeted') if (length(threshold)==1) { ols_reach <- filter(ungroup(to_plot), method=='ols') ols_value <- value_at_pct(ols_reach)[[1]] ensemble_reach <- filter(ungroup(to_plot), method=='ensemble') ensemble_value <- value_at_pct(ensemble_reach)[[1]] budget <- filter(budget_change(to_plot), method=='ensemble')[, 3] ensemble_pct_targeted <- (threshold + budget * threshold)[[1]] xmin <- threshold + .005 xmax <- .9 width <- xmax - xmin ymin <- .1 ymax <- ols_value + .005 height <- ymax - ymin mag <- 35 p.zoom <- p + coord_cartesian(xlim = c(threshold-width/mag, threshold+width/mag), ylim=c(ols_value-height/mag, ols_value+height/mag)) + geom_segment(x=threshold, xend=threshold, y=ensemble_value, yend=ols_value, color='black') + geom_segment(y=ols_value, yend=ols_value, xend=threshold, x=ensemble_pct_targeted, color='black') + annotate("text", label='Delta~budget', x=threshold-(threshold-ensemble_pct_targeted)/2, y=ols_value-.0005, parse=TRUE) + annotate("text", label='Delta~reach', x=threshold-.001, y=ensemble_value-(ensemble_value-ols_value)/2, parse=TRUE, angle=90) + zoomtheme g <- ggplotGrob(p.zoom) p <- p + annotation_custom(g, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) } p <- p + geom_vline(xintercept=.4, linetype='longdash', color='black') p + theme_bw() } reaches <- get_reaches(all_names) reacht <- get_reach_table(reaches) difft <- difference_table(reaches)
country_codes <- c( 'Niger', 'Tanzania', 'South Africa', 'Ghana', 'Iraq', 'Mexico', 'Brazil' ) country_info <- setNames( c( 'explore', 'explore', 'holdout', 'explore', 'holdout', 'explore', 'holdout' ), country_codes ) map_dat <- map_data('world') country_dat <- filter(map_dat, region %in% country_codes) %>% mutate(type=country_info[region]) ggplot() + geom_polygon(aes(long,lat, group=group), fill="grey65", data=map_dat) + theme_bw() + theme(axis.text = element_blank(), axis.title=element_blank()) + geom_polygon(data=country_dat, aes(long, lat, group=group, fill=type))
ds_info <- read.xlsx('analyses/datasets_used_20151220.xlsx', 1, stringsAsFactors=FALSE) %>% arrange(N) to_print <- data.frame( Country=ds_info$Country, Year=ds_info$Year, Survey=ds_info$Survey, '< $1.90'=ds_info$Poverty.Ratio.at..1.90.PPP, '< $3.10'=ds_info$Poverty.Ratio.at..3.10.PPP, '< NPL'=ds_info$Poverty.Ratio.at.National.Poverty.Line, N=as.integer(ds_info$N), K=as.integer(ds_info$K), check.names=FALSE, stringsAsFactors=FALSE ) to_print[1, 3] <- 'LSMS' to_print[2, 3] <- 'LSMS' to_print <- to_print[1:(nrow(to_print) - 1), ] to_print[is.na(to_print)] <- '' print(xtable(to_print), comment=F, include.rownames=FALSE, scalebox=.85)
Key:
After an exploratory phase, we chose to focus on a small number of methods.
method_df <- data.frame( key=c('OLS', 'enet', 'forest', 'OLS + RF', 'ensemble'), method=c('OLS', 'elastic net', 'random forest', 'OLS with random forest on residuals', 'ensemble'), notes=c( 'Baseline', 'Popular linear method for predictions.', 'Popular nonlinear method. Performed well in exploratory phase.', 'Simple way to combine OLS and forests. Performed well in exploratory phase.', 'Most popular way to optimize predictions.' )) print(xtable(method_df, align = c('l', 'l', 'p{3cm}', 'p{5cm}')), comment=F, include.rownames=FALSE)
We think of poverty targeting as a classification problem and focus on two metrics of success:
plot_reach_vs_pct_targeted('ghana_tuned', threshold=.4)
# df <- melt(reacht, id=c('dataset', 'N', 'K', 'country')) df <- melt(reacht, id=c('dataset', 'N', 'K')) df <- mutate(df, value=value / THRESHOLD) df$dataset <- factor(df$dataset, levels=ds_stats$dataset) ggplot(df, aes(ymax=value, y=value, upper=value, middle=value, x=dataset, fill=variable)) + geom_boxplot(position=position_dodge(width=.72), width=.7, lwd=.1, fatten=10, lower=0) + scale_fill_manual(values=c('#9ecae1', '#a1d99b','#74c476','#41ab5d','#238b45','#005a32')) + ylab('reach') + # coord_cartesian(ylim=c(.16, .4)) + theme_bw() + theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())
df <- rename(difft, reach=relative_reach_improvement, budget=budget_reduction) df <- melt(df, id=c('dataset', 'N', 'K')) df <- filter(df, variable != 'reach_improvement') df$dataset <- factor(df$dataset, levels=ds_stats$dataset) # df <- filter(df, variable=='reach') ggplot(df, aes(y=value, x=dataset, fill=variable)) + geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') + scale_fill_manual(values=c('#33a02c', '#fcbba1')) + ylab(expression(frac(ensemble - ols, ols))) + theme_bw() + theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())
# reacht_25 <- get_reach_table(reaches, TRUE) # difft_25 <- difference_table(reaches, TRUE) # # orig_diff <- select(difft, dataset=dataset, orig=relative_reach_improvement) # difft2 <- merge(orig_diff, difft_25, by='dataset') # # df <- rename(difft2, original=orig, subset_25=relative_reach_improvement) # df <- melt(df, id=c('dataset', 'N', 'K')) # df <- filter(df, variable %in% c('original', 'subset_25')) # df$dataset <- factor(df$dataset, levels=ds_stats$dataset) # ggplot(df, aes(y=value, x=dataset, fill=variable)) + # geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') + # scale_fill_manual(values=c('#33a02c', '#fcbba1')) + # ylab(expression(frac(ensemble_25 - ols_25, ols_25))) + # theme_bw() + # theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())
reaches_pmt <- get_reaches(pmt_names) reacht_pmt <- get_reach_table(reaches_pmt) difft_pmt <- difference_table(reaches_pmt) #FIXME use dataset mapping to get correct sort order idx <- purrr::map(difft$dataset, ~which(grepl(., difft_pmt$dataset))) %>% purrr::flatten() %>% as.integer() baseline_df <- difft[idx, ] %>% arrange(N) df <- select(difft_pmt, dataset, N, K, pmt=relative_reach_improvement) df$baseline <- baseline_df$relative_reach_improvement df <- melt(df, id=c('dataset', 'N', 'K')) df$dataset <- factor(df$dataset, levels=ds_stats$dataset) ggplot(df, aes(y=value, x=dataset, fill=variable)) + geom_bar(position=position_dodge(width=0.5), width=.4, stat='identity') + scale_fill_manual(values=c('#33a02c', '#a1d99b')) + ylab(expression(frac(ensemble - ols, ols))) + theme_bw() + theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())
# country_list <- ds_info$Country[1:(length(ds_info$Country)-1)] # country_list <- na.locf(country_list) # if (length(country_list) == nrow(reacht)) { # reacht$country <- country_list # } else { # reacht$country <- reacht$dataset # } # ggplot(reacht, aes(y=enet-ols, x=N / K, label=country)) + # geom_point(size=4) + # scale_x_log10() + # xlab('log(N/k)') + # geom_text(nudge_x = .02, nudge_y=.0005, check_overlap=TRUE) + # theme_bw()
# forests <- lapply(all_names, function(name) { # df <- load_dataset(name) # df <- df[order(df[, TARGET_VARIABLE]), ] # df$X <- NULL # output <- load_validation_models(name) %>% # filter(method=='forest') %>% # arrange(true) # tol <- .0001 # merged <- df # if (nrow(merged) == nrow(output)) { # if (all(abs(merged[, TARGET_VARIABLE] - output$true) < tol)) { # merged[, TARGET_VARIABLE] <- output$predicted # } # } else { # merged[, TARGET_VARIABLE] <- output$predicted[match(df[, TARGET_VARIABLE], output$true)] # } # model <- fit_ols(merged) # rsq <- summary(model)$r.squared # data.frame(dataset=clean_name(name), N=nrow(df), K=ncol(df), rsq=rsq) # }) # forests <- rbind_all(forests) # forests$dataset <- factor(forests$dataset, levels=ds_stats$dataset) # ggplot(forests, aes(y=rsq, x=dataset)) + # geom_bar(stat='identity') + # guides(fill=FALSE) + # ylab(expression(r^2)) + # coord_cartesian(ylim=c(0, 1)) + # ylim(0, 1) + # theme_bw() + # theme(panel.grid.major.x=element_blank(), panel.grid.minor.x=element_blank())
TODO[Jack]
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.