library(doParallel)
library(ggplot2)
library(gridExtra)
registerDoParallel(cores = detectCores())
rm(list = ls())
theme_set(theme_bw())
theme_update(axis.text.x = element_text(size = 20),
axis.text.y = element_text(size = 20),
axis.title.x = element_text(size = 25, margin=margin(20,0,0,0)),
axis.title.y = element_text(size = 25, angle = 90, margin = margin(0,20,0,0)),
legend.text = element_text(size = 20),
legend.title = element_text(size = 20),
title = element_text(size = 30),
strip.text = element_text(size = 25),
strip.background = element_rect(fill="white"),
panel.spacing = unit(2, "lines"),
legend.position = "bottom")
set.seed(11)
prefix = ""
source("gamma_normal_functions.R")
n = 100
# gen_obs_data = function(n){rgamma(n,10,5)}
# obs = gen_obs_data(n)
# save(obs, file = paste0(prefix,"gamma_obs.Rdata"))
load(file = paste0(prefix,"gamma_obs.Rdata"))
sort_obs = sort(obs)
#Choose m to be a multiple of n (or the other way around)
M = 500
N = c(1,5,10,20,50,100,500,1000) #This is what we refer to as k in the paper
m = c(10,20,50,100,300,1000,5000,10000)
filename <- paste0(prefix,"gamma_optim_n",n,"_varying_k_m.RData")
t = proc.time()
mewe_k_m = foreach(rep = 1:M) %dorng% {
#Store the data
mewe_mu_store = matrix(0,length(N),length(m))
mewe_sigma_store = matrix(0,length(N),length(m))
mewe_time_store = matrix(0,length(N),length(m))
count_evaluations_store = matrix(0,length(N),length(m))
#Generate all the randomness needed.
randomness = t(sapply(1:max(N), function(k) target$generate_randomness(max(m))))
#For each data set size m, find the MEWE using different numbers of sets of randomness N to approximate the expectation.
for(i in 1:length(m)){
#Subset all the sets of randomness to be of the required size (m), and sort the subsetted randomness.
use_randomness = randomness[,1:(m[i])]
sort_randomness = t(apply(use_randomness, MARGIN = 1, FUN = function(x) sort(x)))
#Make sure the observed and synthetic data are of the same length
if(m[i] > n){
sort_obs_mult = rep(sort_obs, each = m[i]/n)
} else{
sort_obs_mult = sort_obs
}
if(m[i] <= n){
sort_randomness_mult = t(apply(sort_randomness, MARGIN = 1, FUN = function(x) rep(x, each = n/m[i])))
} else{
sort_randomness_mult = sort_randomness
}
for(j in 1:length(N)){
#Subset data again to the required number of sets of randomness
if(N[j]==1){
sort_randomness_sub = t(as.matrix(sort_randomness_mult[1:N[j],]))
} else{
sort_randomness_sub = sort_randomness_mult[1:N[j],]
}
#Define the objective function defining the MEWE (choose number of sets of randomness).
mewe_objective = function(theta){
wass_dists = apply(sort_randomness_sub, MARGIN = 1 , function(x) metricL1(sort_obs_mult,(theta[2]*x+theta[1])))
out = mean(wass_dists)
return(out)
}
#Optimize the objective to find the MEWE.
init = c(runif(1,-4,4),runif(1,0.1,6))
optim_time = proc.time()
out = optim(init,mewe_objective)#,lower=c(-Inf,0.01),upper=c(Inf,Inf)) #Takes longer with specified domain.
optim_time = proc.time() - optim_time
mewe = out$par
objective_evals = out$count
#Store the results
mewe_mu_store[j,i] = mewe[1]
mewe_sigma_store[j,i] = mewe[2]
mewe_time_store[j,i] = optim_time[3]
count_evaluations_store[j,i] = objective_evals[1]
}
}
output = list(mewe_mu_store,mewe_sigma_store,mewe_time_store,count_evaluations_store)
return(output)
}
t = proc.time() - t
#Use huge m to find MWE
mm = 10^8
randomness= target$generate_randomness(mm)
sort_randomness = sort(randomness)
sort_obs_mult = rep(sort_obs, each = mm/n)
mewe_objective = function(theta){
out = metricL1(sort_obs_mult,(theta[2]*sort_randomness+theta[1]))
return(out)
}
init = c(runif(1,-4,4),runif(1,0.1,6))
optim_time = proc.time()
mwe = optim(init,mewe_objective)$par
optim_time = proc.time() - optim_time
save(mwe,mewe_k_m,file = filename)
load(filename)
df_mwek = data.frame(mu = mwe[1], sigma = mwe[2], k=1)
df_mwem = data.frame(mu = mwe[1], sigma = mwe[2], m=1)
mewe = lapply(1:length(m),
function(i) lapply(1:length(N),
function(j) t(sapply(1:M,
function(k) c(mewe_k_m[[k]][[1]][j,i],mewe_k_m[[k]][[2]][j,i])))))
df_mewe_list = list()
for(i in 1:length(m)){
dummy = lapply(1:length(N), function(k) t(cbind(mewe[[i]][[k]],rep(k,M))))
dummy = matrix(unlist(dummy), ncol = 3, byrow = TRUE)
dummy = data.frame(dummy)
names(dummy) = c("mu","sigma","k")
df_mewe_list[[i]] = assign(paste("df_mewe_m",m[i],sep=""),dummy)
}
# ## Plot all levels of k for different m
# #On the same axes
# g = list()
# for(i in 1:length(m)){
# g[[i]] <- ggplot(data = df_mewe_list[[i]], aes(x = mu, y = sigma, colour = k, group = k)) + ylim(0, 1.5) + xlim(1,3)
# g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
# g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10))
# g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
# g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("m = ",m[i],sep=""))
# g[[i]] <- g[[i]] + geom_point(data = df_mwek, aes(x = mu, y = sigma), color="black")
# }
# do.call(grid.arrange, c(g, ncol=4))
# #On scaled axes
# g = list()
# for(i in 1:length(m)){
# g[[i]] <- ggplot(data = df_mewe_list[[i]] %>% filter(sigma> 0.2), aes(x = mu, y = sigma, colour = k, group = k)) #+ ylim(0, 1.5) + xlim(1,3)
# g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
# g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10))
# g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
# g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("m = ",m[i],sep=""))
# g[[i]] <- g[[i]] + geom_point(data = df_mwek, aes(x = mu, y = sigma), color="black")
# }
# do.call(grid.arrange, c(g, ncol=4))
#On scaled axes, only a subset of m values
msubset = c(10,100,1000,10000)
subset = c(1,4,6,8)
g = list()
for(i in 1:4){
g[[i]] <- ggplot(data = df_mewe_list[[subset[i]]] %>% filter(sigma> 0.2), aes(x = mu, y = sigma, colour = k, group = k)) #+ ylim(0, 1.5) + xlim(1,3)
g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10,hjust=0.5))
g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("m = ",msubset[i],sep=""))
#g[[i]] <- g[[i]] + geom_point(data = df_mwek, aes(x = mu, y = sigma), color="black")
g[[i]] <- g[[i]] + geom_vline(xintercept=df_mwek$mu) + geom_hline(yintercept=df_mwek$sigma)
}
png(filename = paste0(prefix,"gamma_n100_all.k_some.m.png"))
do.call(grid.arrange, c(g, ncol=2))
dev.off()
### Stratify on k instead
mewe_k = lapply(1:length(N),
function(j) lapply(1:length(m),
function(i) t(sapply(1:M,
function(k) c(mewe_k_m[[k]][[1]][j,i],mewe_k_m[[k]][[2]][j,i])))))
df_mewe_k_list = list()
for(i in 1:length(N)){
dummy = lapply(1:length(m), function(k) t(cbind(mewe_k[[i]][[k]],rep(k,M))))
dummy = matrix(unlist(dummy), ncol = 3, byrow = TRUE)
dummy = data.frame(dummy)
names(dummy) = c("mu","sigma","m")
df_mewe_k_list[[i]] = assign(paste("df_mewe_k",N[i],sep=""),dummy)
}
# ### Plot all levels of m for different k
# #On the same axes
# g = list()
# for(i in 1:length(N)){
# g[[i]] <- ggplot(data = df_mewe_k_list[[i]] %>% filter(sigma> 0.2), aes(x = mu, y = sigma, colour = m, group = m)) + ylim(0.2, 1.4) + xlim(1.2,2.6)
# g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
# g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10))
# g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
# g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("k = ",N[i],sep=""))
# g[[i]] <- g[[i]] + geom_point(data = df_mwem, aes(x = mu, y = sigma), color="black")
# }
# do.call(grid.arrange, c(g, ncol=4))
# #On scaled axes
# g = list()
# for(i in 1:length(N)){
# g[[i]] <- ggplot(data = df_mewe_k_list[[i]] %>% filter(sigma > 0.2), aes(x = mu, y = sigma, colour = m, group = m)) #+ ylim(0.2, 1.4) + xlim(1.2,2.6)
# g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
# g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10))
# g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
# g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("k = ",N[i],sep=""))
# g[[i]] <- g[[i]] + geom_point(data = df_mwem, aes(x = mu, y = sigma), color="black")
# }
# do.call(grid.arrange, c(g, ncol=4))
#On scaled axes, subset of k
ksubset = c(1,10,100,1000)
subset = c(1,3,6,8)
g = list()
for(i in 1:length(N)){
g[[i]] <- ggplot(data = df_mewe_k_list[[subset[i]]] %>% filter(sigma > 0.2), aes(x = mu, y = sigma, colour = m, group = m)) #+ ylim(0.2, 1.4) + xlim(1.2,2.6)
g[[i]] <- g[[i]] + geom_point(alpha = 0.5)
g[[i]] <- g[[i]] + scale_colour_gradient2(midpoint = 5) + theme(legend.position = "none", plot.title = element_text(size=10,hjust=0.5))
g[[i]] <- g[[i]] + theme(axis.title.x=element_text(size=14), axis.text.x=element_text(size=12),axis.title.y=element_text(size=14), axis.text.y=element_text(size=12))
g[[i]] <- g[[i]] + xlab(expression(gamma)) + ylab(expression(sigma)) + ggtitle(paste("k = ",ksubset[i],sep=""))
#g[[i]] <- g[[i]] + geom_point(data = df_mwem, aes(x = mu, y = sigma), color="black")
g[[i]] <- g[[i]] + geom_vline(xintercept=df_mwek$mu) + geom_hline(yintercept=df_mwek$sigma)
}
png(filename = paste0(prefix,"gamma_n100_all.m_some.k.png"))
do.call(grid.arrange, c(g, ncol=2))
dev.off()
# #Runtimes
# average_runtimes = matrix(0,length(N),length(m))
# for(i in 1:M){
# average_runtimes = mewe_k_m[[i]][[3]] + average_runtimes
# }
# average_runtimes = 1/M*average_runtimes
#
# #Fixed k, runtime as function of m
# plot(log(m),log(average_runtimes[4,]))
# plot(log(m),log(average_runtimes[8,]))
# plot(m,average_runtimes[8,])
#
#
# #Fixed m, runtime as function of k
# plot(log(N),log(average_runtimes[,4]))
# plot(log(N),log(average_runtimes[,8]))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.