demo/replication_li_2010/demo_yahoo_classes/yahoo_policy_ucb1_alpha.R

#' @export
YahooUCB1AlphaPolicy <- R6::R6Class(
  portable = FALSE,
  class = FALSE,
  inherit = Policy,
  public = list(
    alpha = NULL,
    class_name = "YahooUCB1AlphaPolicy",
    initialize = function(alpha) {
      super$initialize()
      self$alpha                  <- alpha
    },
    set_parameters = function(context_params) {
      self$theta_to_arms <- list('n' = 0, 'mean' = 0)
    },
    get_action = function(t, context) {
      local_arms       <- context$arms
      n_zero_arms      <- which(self$theta$n[local_arms] == 0)
      if (length(n_zero_arms) > 0) {
        action$choice <- local_arms[sample_one_of(n_zero_arms)]
        return(action)
      }
      expected_rewards <- rep(0.0, length(context$arms))
      for (arm in seq_along(local_arms)) {
        # usb1 variance as in Li 2010 paper
        variance <- self$alpha / sqrt( self$theta$n[[local_arms[arm]]] )
        expected_rewards[arm] <- self$theta$mean[[local_arms[arm]]] + variance
      }
      action$choice <- local_arms[which_max_tied(expected_rewards)]
      action
    },
    set_reward = function(t, context, action, reward) {

      arm                         <- action$choice
      reward                      <- reward$reward
      self$theta$n[[arm]]         <- self$theta$n[[arm]] + 1
      self$theta$mean[[arm]]      <- self$theta$mean[[arm]] + (reward - self$theta$mean[[arm]]) / self$theta$n[[arm]]

      self$theta
    }
  )
)

Try the contextual package in your browser

Any scripts or data that you put into this service are public.

contextual documentation built on July 26, 2020, 1:06 a.m.