#' Plot mab_sim Output
#'
#' @param arms A list of objects of class "arm".
#' @param history A dataframe giving the history of the simulation.
#' @param display_part Indeces of the arms to display the partitions for.
#' @param CI_index Index of the arm to display confidence interval for.
#'
#' @export
#' @import ggplot2
#' @include class_defn.R
plot_arms <- function(arms,history,display_part, CI_index){
approx_funs <- lapply(arms,merge_arm)
num_arms <- length(arms)
x <- seq(-1,1,length = 200)
ribbon <- data.frame()
ribbon_colr <- gg_color_hue(num_arms,alpha = 0.2)
arm <- arms[[CI_index]]
for(t in 1:length(x)){
l <- arm_CI(x[t],arm)
ribbon <- rbind(ribbon, data.frame(x = x[t], ymin = l[1], ymax = l[2]))
}
pl <- ggplot() + geom_ribbon(data = ribbon, mapping = aes(x = x, ymin = ymin, ymax = ymax),
col = ribbon_colr[CI_index], fill = ribbon_colr[CI_index])
data <- data.frame()
for(i in 1:num_arms){ # Data needs to be in long format for ggplot.
df_true <- data.frame(x = x, y = reward_funs[[i]](x), Arm = paste("Arm",i, sep = " "), Type = "True")
df_approx <- data.frame(x = x, y = approx_funs[[i]](x), Arm = paste("Arm", i , sep = " "), Type = "Approx")
data <- rbind(data,df_true,df_approx)
}
pl <- pl + geom_line(data = data, mapping = aes(x = x, y = y, color = Arm, linetype = Type))
point_data <- history[c("x","a")]
point_colr <- gg_color_hue(num_arms,alpha = 0.1)
point_data <- cbind(point_data, y = min(data$y) - 0.2*point_data$a, colr = point_colr[point_data$a])
pl <- pl + geom_point(data = point_data, mapping = aes(x = x, y = y), col = point_data$colr)
# Partition DataFrame
partition_lines <- data.frame()
part_colr <- gg_color_hue(num_arms,alpha = 0.5)
for(i in display_part){
part <- arms[[i]]@partition
if (nrow(part) == 1) next
vlines <- part[,2]; vlines <- vlines[-length(vlines)]
df <- data.frame(vlines = vlines,colr = part_colr[i])
partition_lines <- rbind(partition_lines,df)
}
pl <- pl + geom_vline(xintercept = partition_lines$vlines, col = partition_lines$colr, linetype = "dashed")
pl <- pl + ggtitle(paste("#Steps =",nrow(history),sep = " "))
return(pl)
}
gg_color_hue <- function(n,alpha) {
hues = seq(15, 375, length = n + 1)
grDevices::hcl(h = hues, l = 65, c = 100, alpha = alpha)[1:n]
}
#' Confidence Interval for Arm Approximation (variance known)
#'
#' @param x The context observation.
#' @param arm An object of S4 class "arm".
#'
#' @return A vector of length two giving the confidence interval.
#' @export
#'
arm_CI <- function(x,arm){
part <- arm@partition
index <- allocate(x,part)
x_map <- remap(x,part[index,])
dis <- arm@distributions[[index]]
basis <- legendre_basis(length(dis@beta) - 1)
var <- dis@b / (dis@a - 1)
covar <- var * dis@Sigma
phi <- expand(x_map,basis)
sum_var <- t(phi)%*%covar%*%phi
mean <- t(phi)%*%dis@beta
out <- c(mean - 1.96*sqrt(sum_var),mean + 1.96*sqrt(sum_var))
return(out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.