R/generic_mab.R

Defines functions new_MAB sim_bandit

Documented in new_MAB sim_bandit

## THIS SOURCE FILE CONTAINS FUNCTIONS AND S4 CLASSES ASSOCIATED WITH 
## CUSTOMISABLE MAB SIMULATIONS.


#' S4 Class to represent output of Generic MAB Algorithm
#'
#' @slot posteriors A list containing posterior parameter values: beta, covar, a and b.
#' @slot trajectory A dataframe containing information on the decision trajectory.
#' @slot basis A list of basis functions.
#' @slot means A list of mean-reward functions.
#' @slot bas_type The type of basis functions used.
#'
#' @export
methods::setClass("MAB",
                  slots = c(
                    posteriors = "list",
                    trajectory = "data.frame",
                    basis = "list",
                    means = "list",
                    bas_type = "character"
                  )
)


#' Generate a new object of S4 class MAB
#'
#' @param posteriors A list containing posterior parameter values: beta, covar, a and b.
#' @param trajectory A dataframe containing information on the decision trajectory.
#' @param basis A list of basis functions.
#' @param means A list of mean-reward functions.
#' @param bas_type The type of basis functions used.
#'
#' @return An object of the S4 class MAB.
#' @export
new_MAB <- function(posteriors,trajectory,basis,means,bas_type){
  obj <- methods::new("MAB",
                      posteriors = posteriors,
                      trajectory = trajectory,
                      basis = basis,
                      means = means,
                      bas_type = bas_type
  )
  return(obj)
}



# CREATE A PLOT METHOD TO COMPARE THE REAL LEVERS AGAINST THE POSTERIOR APPROXIMATIONS
#' @export
#' @importFrom graphics plot lines points legend
methods::setMethod("plot","MAB",
  function(x){
   traj <- x@trajectory
   post <- x@posteriors
   basis <- x@basis
   means <- x@means
   bas_type <- x@bas_type
   
   if (bas_type == "poly") x_vals <- seq(-1,1,length=200)
   else x_vals <- seq(0,1,length=200)
   
   phi <- c()
   for (i in 1:length(x_vals)){
     phi <- rbind(phi,
                  t(expand(x_vals[i],basis))
                  )
   }
   
   y_post <- c()
   for (i in 1:length(post)){
     coeff <- as.matrix(post[[i]]$beta)
     y_post <- rbind(y_post,
                     t(coeff) %*% t(phi[,1:length(coeff)])
     )
   }
   
   y_true <- c()
   for (i in 1:length(means)){ # Compute the true y values from the mean-reward functions.
     y_true <- rbind(y_true,
                     means[[i]](x_vals))
   }
   
   ylims <- c(min(c(y_true,y_post)),max(c(y_true,y_post)))
   colr <- grDevices::rainbow(nrow(y_true)+1)
   plot(traj$x,traj$reward,ylim = ylims, col = grDevices::rainbow(nrow(y_true)+1,alpha = 0.1)[nrow(y_true)+1],
        xlab = "x", ylab = "Reward", main = paste(
          "Steps = ",nrow(traj)
        ))
   for (i in 1:nrow(y_true)){
     lines(x_vals,y_true[i,],type = "l", col = colr[i])
   }
   
   for (i in 1:nrow(y_post)){
     lines(x_vals,y_post[i,],type = "l", lty = 2, col = colr[i] )
   }
   
   legend_names <- c()
   for (i in 1:nrow(y_true)){
     legend_names <- c(legend_names,paste("Lever",i))
   }
   legend("bottom",c(legend_names,"Reward Obs"),lty = c(rep(1,nrow(y_true)),NA),pch = c(rep(NA,nrow(y_true)),1), col = colr,
          bty = "n", xpd = TRUE, horiz = TRUE, inset = c(0,0))
  }
)


#' Generic C-MAB Simulation
#'
#' @param nsteps The number of steps in the Simulation.
#' @param J The number of basis functions to be used in the model.
#' @param means A list containing mean-reward functions for each lever.
#' @param sd A vector with the standard deviation of each lever.
#' @param bas_type Either "fourier" or "poly".
#' @param alpha Parameter to control the rate of decay in the prior variance.
#' @param b Hyper-parameter of the gamma distribution. Directly influences the rate
#' of exploration in the simulation.
#'
#' @export
#'
sim_bandit <- function(nsteps, J, means , sd, bas_type = "poly",
                       alpha = 1, b = 1){
  
  if ( !(bas_type %in% c("poly","fourier")) ) stop("Basis must be either 'poly' 
                                            or 'fourier'.")
  if (bas_type == "poly") basis <- legendre_basis(J)
  else basis <- fourier_basis(J)
  
  nlevers <- length(means)
  posteriors <- gen_priors(nlevers, J, bas_type, alpha, b)
  sim <- thompson(nsteps,means,sd,basis,posteriors,bas_type)
  out <- new_MAB(sim$posteriors,sim$traj,basis,means,bas_type)
  return(out)
}
dfcorbin/npbanditC documentation built on March 23, 2020, 5:25 a.m.