R/split_methods.R

Defines functions new_MAB_split

#' MAB_split class
#'
#' @slot model A list of functions that model the mean of each lever.
#' @slot trajectory A data frame containing the history of the MAB simulation.
#' @slot basis A list of the basis functions used in the model.
#' @slot means A list of the true mean-reward functions.
#' @slot intr A matrix with rows that give the intervals of the partition.
#'
#' @export
methods::setClass("MAB_split",
                  slots = c(
                    model = "list",
                    trajectory = "data.frame",
                    basis = "list",
                    means = "list",
                    intr = "matrix"
                  )
)


new_MAB_split <- function(model,trajectory,basis,means,intr){
  obj <- methods::new("MAB_split",
                      model = model,
                      trajectory = trajectory,
                      basis = basis,
                      means = means,
                      intr = intr
  )
  return(obj)
}



methods::setMethod("plot","MAB_split",
  function(x){
   traj <- x@trajectory
   model <- x@model
   basis <- x@basis
   means <- x@means
   intr <- x@intr
   
   x_vals <- seq(-1,0.99,length=200)
   
   phi <- c()
   for (i in 1:length(x_vals)){
     phi <- rbind(phi,
                  t(expand(x_vals[i],basis))
     )
   }
   
   y_post <- c()
   for (i in 1:length(model)){
     y_post <- rbind(y_post,
                     sapply(x_vals,model[[i]])
                     )
   }
   
   y_true <- c()
   for (i in 1:length(means)){ # Compute the true y values from the mean-reward functions.
     y_true <- rbind(y_true,
                     means[[i]](x_vals))
   }
   
   ylims <- c(min(c(y_true,y_post)),max(c(y_true,y_post)))
   colr <- grDevices::rainbow(nrow(y_true)+1)
   plot(traj$x,traj$reward,ylim = ylims, col = grDevices::rainbow(nrow(y_true)+1,alpha = 0.1)[nrow(y_true)+1],
        xlab = "x", ylab = "Reward", main = paste(
          "Steps = ",nrow(traj)
        ))
   for (i in 1:nrow(y_true)){
     lines(x_vals,y_true[i,],type = "l", col = colr[i])
   }
   
   for (i in 1:nrow(y_post)){
     lines(x_vals,y_post[i,],type = "l", lty = 2, col = colr[i] )
   }
   
   legend_names <- c()
   for (i in 1:nrow(y_true)){
     legend_names <- c(legend_names,paste("Lever",i))
   }
   legend("bottom",c(legend_names,"Reward Obs"),lty = c(rep(1,nrow(y_true)),NA),pch = c(rep(NA,nrow(y_true)),1), col = colr,
          bty = "n", xpd = TRUE, horiz = TRUE, inset = c(0,0))
   abline(v = intr[1,1],lty = 2)
   for (i in intr[,2]){
     abline(v = i, lty = 2)
   }
  }
)
dfcorbin/npbanditC documentation built on March 23, 2020, 5:25 a.m.