inst/chapter4/4coupled_subsets.R

## this implements a coupling
## of the sampler in Section 6.2 of Rapidly Mixing Markov Chains: A Comparison of Techniques
## by Venkatesan Guruswami
## sampling uniformly subsets of size k from set of n elements
rm(list=ls())
set.seed(1)
library(couplingsmontecarlo)
graphsettings <- set_theme_chapter4()
library(ggridges)
library(reshape2)
library(dplyr)
library(doParallel)
library(doRNG)
registerDoParallel(cores = detectCores()-2)
## number of elements in total
n <- 10
fullset <- 1:n
## number of elements in subset
k <- 3
## initial state
rinit <- function() 1:k
state_current <- rinit()
## Markov kernel
single_kernel <- function(state){
    u <- runif(1)
    if (u < 0.5){
        return(state)
    } else {
        i <- sample(x = state, size = 1)
        j <- sample(x = setdiff(fullset, state), size = 1)
        return(c(setdiff(state, i), j))
    }
}
## MCMC run for 'nmcmc' iterations
nmcmc <- 1e4
state_history <- matrix(NA, nrow = nmcmc, ncol = k)
state_history[1,] <- state_current
for (imcmc in 2:nmcmc){
    state_current <- single_kernel(state_current)
    state_history[imcmc,] <- state_current
}
##
## trace plot ...
df <- data.frame(iteration = 1:50)
df$label <- apply(state_history[df$iteration,], 1, function(v) paste0("{", paste(sort(v)-1, collapse = ","), "}"))
ggplot(df, aes(y = iteration, x = 0, label = label)) + geom_text() + scale_y_reverse()

## Now, how do we know whether the chains converged?
burnin <- 1e3
table(as.numeric(state_history[burnin:nmcmc,])) / ((nmcmc-burnin+1)*k)
## the frequencies of occurrence of each element look pretty close to uniform ('1/n')
## for small n and k we can actually list all the possible subsets
## number of subsets with k elements out of n
choose(n, k)
all_subsets <- combn(1:n, k, simplify = F)
## all_subsets is a list of 'choose(n, k)' vectors of sorted indices

sorted_state_history <- t(apply(state_history, 1, sort))
freq_subsets <- rep(0, choose(n,k))
freq_subsets <- foreach (isubset = 1:choose(n,k), .combine = c) %dopar% {
    subset <- all_subsets[[isubset]]
    ## count how many times this subset appeared post-burnin
    freq_subset_ <- sum(apply(sorted_state_history[burnin:nmcmc,], 1, function(v) all(v == subset)))
    freq_subset_ / (nmcmc-burnin+1)
}
plot(freq_subsets, ylim = c(0, 2/choose(n,k)))
abline(h = 1/choose(n,k))
## at this point the sampler seems to be doing OK but we still don't know how much were actually
## necessary for burn-in

## function to sample two elements from maximal coupling of uniform
## distributions on subsets of {1,...,n}
max_unif_sets <- function(subset1, subset2){
    s1 <- length(subset1)
    s2 <- length(subset2)
    # original vectors of probabilities
    p1 <- rep(0, n)
    p2 <- rep(0, n)
    p1[subset1] <- 1/s1
    p2[subset2] <- 1/s2
    # decompose into common part
    common_part <- pmin(p1, p2)
    c <- sum(common_part)
    # and residual parts
    r1 <- (p1-common_part)/(1-c)
    r2 <- (p2-common_part)/(1-c)
    # sample pair of indices
    if (runif(1) < c){
        index <- sample(x = 1:n, size = 1, prob = common_part/c)
        return(c(index, index))
    } else {
        index1 <- sample(x = 1:n, size = 1, prob = r1)
        index2 <- sample(x = 1:n, size = 1, prob = r2)
        return(c(index1, index2))
    }
}

## coupled kernel
coupled_kernel <- function(state1, state2){
    u <- runif(1)
    if (u < 0.5){
        return(list(state1 = state1, state2 = state2, identical = FALSE))
    } else {
        ## sample i and iprime from maximal coupling of uniform distributions
        i12 <- max_unif_sets(state1, state2)
        i1 <- i12[1]; i2 <- i12[2]
        ## sample i and iprime from maximal coupling of uniform distributions
        j12 <- max_unif_sets(setdiff(fullset, state1), setdiff(fullset, state2))
        j1 <- j12[1]; j2 <- j12[2]
        return(list(state1 = c(setdiff(state1, i1), j1),
                    state2 = c(setdiff(state2, i2), j2)))
    }
}

## construction of two chains with a lag L
## using single_kernel and coupled_kernel
sample_meetingtime <- function(single_kernel, coupled_kernel, rinit, lag = 1, max_iterations = Inf){
    starttime <- Sys.time()
    # initialize two chains
    state1 <- rinit(); state2 <- rinit()
    # move first chain for 'lag' iterations
    time <- 0
    for (t in 1:lag){
        time <- time + 1
        state1 <- single_kernel(state1)
    }
    # move two chains until meeting (or until max_iterations)
    meetingtime <- Inf
    # two chains could already be identical by chance
    if (all(sort(state1) == sort(state2))) meetingtime <- lag
    while (is.infinite(meetingtime) && (time < max_iterations)){
        time <- time + 1
        # use coupled kernel
        coupledstates <- coupled_kernel(state1, state2)
        state1 <- coupledstates$state1
        state2 <- coupledstates$state2
        # check if meeting has occurred
        if (all(sort(state1) == sort(state2))) meetingtime <- time
    }
    currentime <- Sys.time()
    elapsedtime <- as.numeric(lubridate::as.duration(lubridate::ymd_hms(currentime) - lubridate::ymd_hms(starttime)), "seconds")
    return(list(meetingtime = meetingtime, elapsedtime = elapsedtime))
}

## generate a number of meeting times, for a certain lag
nrep <- 1e3
lag <- 30
meetings <- foreach(irep = 1:nrep) %dorng% sample_meetingtime(single_kernel, coupled_kernel, rinit, lag)
meeting_times <- sapply(meetings, function(x) x$meetingtime)
## plot meeting times
ghist <- qplot(x = meeting_times - lag, geom = "blank") + geom_histogram(aes(y=..density..))
ghist <- ghist + xlab("meeting time - lag")
ghist <- ghist + theme_minimal()
ghist
## compute TV upper bounds
tv_upper_bound_estimates <- function(meeting_times, L, t){
    return(mean(pmax(0,ceiling((meeting_times-L-t)/L))))
}
niter <- 50
upperbounds <- sapply(1:niter, function(t) tv_upper_bound_estimates(unlist(meeting_times), lag, t))
g_tvbounds <- qplot(x = 1:niter, y = upperbounds, geom = "line")
g_tvbounds <- g_tvbounds + ylab("TV upper bounds") + xlab("iteration")
g_tvbounds <- g_tvbounds + theme_minimal() + scale_y_continuous(breaks = (1:10)/10, limits = c(0,1.1))

# pdf("../subsets.meetingtimes.pdf")
ghist
# dev.off()

# pdf("../subsets.tvbounds.pdf")
hist_noplot <- hist(meeting_times - lag, plot = F, nclass = 50, prob = T)
g_tvbounds + geom_ribbon(data=data.frame(x = hist_noplot$mids,
                                         ymin = rep(0, length(hist_noplot$density)),
                                         y = hist_noplot$density/max(hist_noplot$density)),
                         aes(x= x, ymin = ymin, ymax = y, y=NULL), alpha = 0.5, fill = graphsettings$colors[1]) + geom_line()
# dev.off()



hist_noplot <- hist(meeting_times - lag, plot = F, nclass = 30)
xgrid <- c(min(hist_noplot$breaks), hist_noplot$mids, max(hist_noplot$breaks))
densitygrid <- c(0, hist_noplot$density, 0)
g_tvbounds <- qplot(x = 1:niter, y = upperbounds, geom = "line")
g_tvbounds <- g_tvbounds + ylab("TV upper bounds") + xlab("iteration")
g_tvbounds <- g_tvbounds + scale_y_continuous(breaks = c(0,1), limits = c(0,1.1))
g_tvbounds <- g_tvbounds + geom_ribbon(data=data.frame(x = xgrid,
                                                       ymin = rep(0, length(xgrid)),
                                                       y = densitygrid/max(densitygrid)),
                                       aes(x= x, ymin = ymin, ymax = y, y=NULL), alpha = .3, fill = graphsettings$colors[2]) + geom_line()
g_tvbounds <- g_tvbounds + theme(axis.text.x = element_text(size = 20), axis.title.x = element_text(size = 20),
                                 axis.text.y = element_text(size = 20), axis.title.y = element_text(size = 20, angle = 90))
g_tvbounds
pdf("../subsets.tvbounds.pdf", width = 10, height = 5)
print(g_tvbounds)
dev.off()



## this suggest that 20 or 30 steps might be enough
## we can check that by producing independent chains for this many steps
## and check whether we have a good approximation of the posterior
nrep <- 1e4
nmcmc <- 20
parallel_chains <- foreach(irep = 1:nrep, .combine = rbind) %dorng% {
    state_current <- rinit()
    for (imcmc in 1:nmcmc){
        state_current <- single_kernel(state_current)
    }
    state_current
}
## frequencies of each element
table(as.numeric(parallel_chains)) / (nrep*k)
## compute frequencies of each subset
sorted_parallel_chains <- t(apply(parallel_chains, 1, sort))
freq_subsets <- rep(0, choose(n,k))
freq_subsets <- foreach (isubset = 1:choose(n,k), .combine = c) %dopar% {
    subset <- all_subsets[[isubset]]
    ## count how many times this subset appeared post-burnin
    freq_subset_ <- sum(apply(sorted_parallel_chains, 1, function(v) all(v == subset)))
    freq_subset_ / nrep
}
plot(freq_subsets, ylim = c(0, 2/choose(n,k)))
abline(h = 1/choose(n,k))

##
state1 <- rinit(); state2 <- rinit()
# move first chain for 'lag' iterations
time <- 0
for (t in 1:lag){
    time <- time + 1
    state1 <- single_kernel(state1)
}

printsubset <- function(v) paste0("{", paste(sort(v), collapse = ","), "}")
state1 <- c(1,2,3)
state2 <- c(1,5,9)
printsubset(state1)
pierrejacob/couplingsmontecarlo documentation built on July 24, 2020, 11:55 p.m.