R/split.R

Defines functions partition remap get_lims sapply_arg split_thompson merge_post split_bandit

Documented in split_bandit

partition <- function(k,lims){
  width <- (lims[2] - lims[1]) / k
  add <- 0:(k-1) * width
  start <- rep(lims[1],k) + add
  end <- start + width
  intervals <- cbind(start,end)
  return(intervals)
}

remap <- function(x,interval){
  2*(x - interval[1])/(interval[2]-interval[1]) - 1
}

get_lims <- function(x,intervals){
  for (i in 1:nrow(intervals)){
    if ( (intervals[i,1] <= x) && (x <= intervals[i,2]) ){
      return(list("intr" = intervals[i,],"group" = i))
    }
  }
}


sapply_arg <- function(x,means){
  out <- rep(0,length(means))
  for (i in 1:length(means)){
    out[i] <- means[[i]](x)
  }
  return(out)
}


split_thompson <- function(nsteps,means,sd,intervals,group_post,basis){

  # Start by defining outputs
  x <- rep(0,nsteps); a <- rep(0,nsteps); r <- rep(0,nsteps); t_regret <- rep(0,nsteps)
  regret_count <- 0
  for (t in 1:nsteps){
    x[t] <- stats::runif(1,-1,1)
    lims <- get_lims(x[t],intervals)
    x_new <- remap(x[t],lims$intr) # Remap context back on to [-1,1]
    phi <- expand(x_new,basis)
    post <- group_post[[lims$group]] # Choose posteriors for corresponding group
    param <- sim_params(post)
    r_exp <- param %*% phi
    a[t] <- which.max(r_exp)
    r[t]<- stats::rnorm(1, means[[ a[t] ]]( x[t] ), sd[ a[t] ]) # Observe the true reward
    post[[ a[t] ]] <- update_post(phi,r[t],post[[ a[t] ]])
    group_post[[lims$group]] <- post # Insert updated back into list of groups.
    r_exp_true <- sapply_arg(x[t],means)
    regret_count <- regret_count + (max(r_exp_true) - r_exp_true[ a[t] ])
    t_regret[t] <- regret_count
  }
  
  traj <- data.frame("x" = x, "action" = a, "reward" = r, "total_regret" = t_regret)
  out <- list("traj" = traj, "group_post" = group_post)
  return(out)
}


merge_post <- function(nlevers,intervals,group_post,basis){
  
  create_fun <- function(i){ # Takes the index of a lever as an arguemnt
    out <- function(x){
      lims <- get_lims(x,intervals)
      group <- lims$group
      x_new <- remap(x,lims$intr)
      post <- group_post[[group]]
      beta <- post[[i]]$beta
      phi <- expand(x_new,basis)
      return(t(beta) %*% phi)
    }
    return(out)
  }
  merged <- lapply(1:nlevers,create_fun)
  return(merged)
}



#' MAB simulation with partitioned context
#'
#' @param nsteps The number of steps in the simulation.
#' @param means A list of functions representing the true mean of the levers.
#' @param sd A vector of the standard deviations of each lever.
#' @param k The number of groups to partition the context vector into.
#' @param J The number of basis functions to use inside each group.
#' @param alpha Parameter to control the rate of decay of the prior variance.
#' @param b Hyper parameter for the inverse gamma distribution. Directly controls the
#' level of exploration.
#'
#' @return An object of S4 class MAB_split.
#'
split_bandit <- function(nsteps,means,sd,k,J,alpha=1,b=1){
  
  ## INITIAL SETUP
  intervals <- partition(k,c(-1,1))
  nlevers <- length(means)
  group_post <- list()
  basis <- legendre_basis(J)
  for(i in 1:k){ #Construct a seperate set of posteriors for each group.
    group_post[[i]] <- gen_priors(nlevers,J,"poly",alpha,b)
  }
  sim <- split_thompson(nsteps,means,sd,intervals,group_post,basis)
  merged <- merge_post(nlevers,intervals,sim$group_post,basis)
  out <- new_MAB_split(model = merged,trajectory = sim$traj,basis=basis,means = means,
                 intr = intervals)
  return(out)
}
dfcorbin/npbanditC documentation built on March 23, 2020, 5:25 a.m.