R/visual.R

Defines functions arm_CI gg_color_hue plot_arms

Documented in arm_CI plot_arms

#' Plot mab_sim Output
#'
#' @param arms A list of objects of class "arm".
#' @param history A dataframe giving the history of the simulation.
#' @param display_part Indeces of the arms to display the partitions for.
#' @param CI_index Index of the arm to display confidence interval for.
#'
#' @export
#' @import ggplot2
#' @include class_defn.R
plot_arms <- function(arms,history,display_part, CI_index){

  approx_funs <- lapply(arms,merge_arm)
  num_arms <- length(arms)
  x <- seq(-1,1,length = 200)

  ribbon <- data.frame()
  ribbon_colr <- gg_color_hue(num_arms,alpha = 0.2)
  arm <- arms[[CI_index]]
  for(t in 1:length(x)){
    l <- arm_CI(x[t],arm)
    ribbon <- rbind(ribbon, data.frame(x = x[t], ymin = l[1], ymax = l[2]))
  }

  pl <- ggplot() + geom_ribbon(data = ribbon, mapping = aes(x = x, ymin = ymin, ymax = ymax),
                                col = ribbon_colr[CI_index], fill = ribbon_colr[CI_index])

  data <- data.frame()
  for(i in 1:num_arms){ # Data needs to be in long format for ggplot.
    df_true <- data.frame(x = x, y = reward_funs[[i]](x), Arm = paste("Arm",i, sep = " "), Type = "True")
    df_approx <- data.frame(x = x, y = approx_funs[[i]](x), Arm = paste("Arm", i , sep = " "), Type = "Approx")
    data <- rbind(data,df_true,df_approx)
  }
  pl <- pl + geom_line(data = data, mapping = aes(x = x, y = y, color = Arm, linetype = Type))

  point_data <- history[c("x","a")]
  point_colr <- gg_color_hue(num_arms,alpha = 0.1)
  point_data <- cbind(point_data, y = min(data$y) - 0.2*point_data$a, colr = point_colr[point_data$a])
  pl <- pl + geom_point(data = point_data, mapping = aes(x = x, y = y), col = point_data$colr)

  # Partition DataFrame
  partition_lines <- data.frame()
  part_colr <- gg_color_hue(num_arms,alpha = 0.5)
  for(i in display_part){
    part <- arms[[i]]@partition
    if (nrow(part) == 1) next
    vlines <- part[,2]; vlines <- vlines[-length(vlines)]
    df <- data.frame(vlines = vlines,colr = part_colr[i])
    partition_lines <- rbind(partition_lines,df)
  }
  pl <- pl + geom_vline(xintercept = partition_lines$vlines, col = partition_lines$colr, linetype = "dashed")

  pl <- pl + ggtitle(paste("#Steps =",nrow(history),sep = " "))
  return(pl)

}

gg_color_hue <- function(n,alpha) {
  hues = seq(15, 375, length = n + 1)
  grDevices::hcl(h = hues, l = 65, c = 100, alpha = alpha)[1:n]
}

#' Confidence Interval for Arm Approximation (variance known)
#'
#' @param x The context observation.
#' @param arm An object of S4 class "arm".
#'
#' @return A vector of length two giving the confidence interval.
#' @export
#'
arm_CI <- function(x,arm){

  part <- arm@partition
  index <- allocate(x,part)
  x_map <- remap(x,part[index,])
  dis <- arm@distributions[[index]]
  basis <- legendre_basis(length(dis@beta) - 1)

  var <- dis@b / (dis@a - 1)
  covar <- var * dis@Sigma
  phi <- expand(x_map,basis)
  sum_var <- t(phi)%*%covar%*%phi
  mean <- t(phi)%*%dis@beta
  out <- c(mean - 1.96*sqrt(sum_var),mean + 1.96*sqrt(sum_var))
  return(out)

}
dfcorbin/MABsim documentation built on April 26, 2020, 8:26 a.m.