R/armed_bandit_var.R

Defines functions cram_bandit_var

Documented in cram_bandit_var

#' Cram Bandit Variance of the Policy Value Estimate
#'
#' This function implements the crammed variance estimate of the policy value
#' estimate for the contextual armed bandit on-policy evaluation setting.
#'
#' @param pi An array of shape (T × B, T, K) or (T × B, T),
#' where T is the number of learning steps (or policy updates),
#' B is the batch size, K is the number of arms,
#' T x B is the total number of contexts.
#' If 3D, pi[j, t, a] gives the probability that
#' the policy pi_t assigns arm a to context X_j.
#' If 2D, pi[j, t] gives the probability that the policy pi_t
#' assigns arm A_j (arm actually chosen under X_j in the history)
#' to context X_j. Please see vignette for more details.
#' @param reward A vector of observed rewards of length T x B
#' @param arm A vector of length T x B indicating which arm was selected in each context
#' @param batch (Optional) A vector or integer. If a vector, gives the
#' batch assignment for each context. If an integer, interpreted as the batch
#' size and contexts are assigned to a batch in the order of the dataset.
#' Default is 1.
#' @return The crammed variance estimate of the policy value estimate.
#' @export

cram_bandit_var <- function(pi, reward, arm, batch=1) {
  dims_result <- dim(pi)

  if (is.numeric(batch) && length(batch) == 1) {
    n <- dims_result[1]
    # `batch` is an integer, interpret it as `batch_size`
    batch_size <- batch  # Guaranteed to be an integer since n is divisible by B
    nb_batch <- n / batch_size

    # Assign batch indices in order without shuffling
    indices <- 1:n
    group_labels <- rep(1:nb_batch, each = batch_size)  # Assign first B to batch 1, etc.

    # Split indices into batches
    batches <- split(indices, group_labels)

  } else {

    batch_assinement <- unlist(batch)
    batches <- split(1:n, batch_assinement)
    nb_batch <- length(batches)
    batch_size <- length(batches[[1]])

  }

  if (batch_size == 1) {

    if (length(dims_result) == 3) {

      # Extract relevant dimensions
      nb_arms <- dims_result[3]
      nb_timesteps <- dims_result[2]


      ## POLICY SLICED: remove the arm dimension as Xj is associated to Aj

      # pi:
      # for each row j, column t, depth a, gives pi_t(Xj, a)

      # We do not need the last column and the first row
      # We only need, for each row j, pi_t(Xj, Aj), where Aj is the arm chosen from context j
      # Aj is the jth element of the vector arm, and corresponds to a depth index

      # drop = False to maintain 3D structure

      pi <- pi[-1, -ncol(pi), , drop = FALSE]

      depth_indices <- arm[2:nb_timesteps]

      pi <- extract_2d_from_3d(pi, depth_indices)

    } else {

      nb_timesteps <- dims_result[2]

      pi <- pi[-1, -ncol(pi), drop = FALSE]

    }

    # pi is now a (T-1) x (T-1) matrix

    ## POLICY DIFFERENCE

    # Before doing policy differences with lag 1, we add a column of 0
    # such that the first policy difference corresponds to pi_1
    pi <- cbind(0, pi)

    pi_diff <- pi[, -1] - pi[, -ncol(pi)]

    # Vector of terms for each column
    policy_diff_weights <- 1 / (nb_timesteps - (1:(nb_timesteps - 1)))

    # Multiply each column of pi_diff by corresponding policy_diff_weight
    pi_diff <- sweep(pi_diff, 2, policy_diff_weights, FUN = "*")

    # Cumulative sum
    pi_diff <- t(apply(pi_diff, 1, cumsum))

    # pi_diff is a (T-1) x (T-1) matrix


    ## MULTIPLY by Rk / pi_k-1

    # Get diagonal elements from pi
    pi_diag <- pi[cbind(1:(nrow(pi)), 2:(ncol(pi)))]

    # Create multipliers using vectorized operations
    multipliers <- (1 / pi_diag) * reward[2:length(reward)]

    # Apply row-wise multiplication using efficient matrix operation
    # mult_pi_diff <- pi_diff * multipliers  # Works via R's recycling rules (most efficient)
    mult_pi_diff <- sweep(pi_diff, 1, multipliers, FUN = "*")


    ## VARIANCE PER COLUMN

    # Create the mask for batch indices
    mask <- matrix(1, nrow = nrow(mult_pi_diff), ncol = ncol(mult_pi_diff))

    # Assign NaN to upper triangle elements (excluding diagonal)
    mask[upper.tri(mask, diag = FALSE)] <- NaN

    # Apply the mask to policy_diff
    mult_pi_diff <- mult_pi_diff * mask

    # Calculate variance for each column
    # column_variances <- apply(mult_pi_diff, 2, function(x) var(x, na.rm = TRUE))
    # column_variances <- apply(mult_pi_diff, 2, function(x) {
    #   n <- sum(!is.na(x))  # Count of non-NA values
    #   sample_var <- var(x, na.rm = TRUE)  # Compute sample variance (divides by n-1)
    #   sample_var * (n - 1) / n  # Convert to population variance (divides by n)
    # })

    column_variances <- apply(mult_pi_diff, 2, function(x) {
      n <- sum(!is.na(x))  # Count of non-NA values
      if (n == 1){
        return(0)
      } else {
        sample_var <- var(x, na.rm = TRUE)  # Compute sample variance (divides by n-1)
        sample_var <- sample_var * (n - 1) / n  # Convert to population variance (divides by n)
        return(sample_var)
      }
    })

    # print(column_variances)
    total_variance <- sum(column_variances)

  } else { # batch size higher than 1

    if (length(dims_result) == 3) {

      # Extract relevant dimensions
      nb_arms <- dims_result[3]
      nb_timesteps <- dims_result[2]

      sample_size <- nb_timesteps * batch_size


      ## POLICY SLICED: remove the arm dimension as Xj is associated to Aj

      # pi:
      # for each row j, column t, depth a, gives pi_t(Xj, a)

      # We do not need the last column and the first row
      # We only need, for each row j, pi_t(Xj, Aj), where Aj is the arm chosen from context j
      # Aj is the jth element of the vector arm, and corresponds to a depth index

      # drop = False to maintain 3D structure

      # pi <- pi[-1, -ncol(pi), , drop = FALSE]
      pi <- pi[-(1:batch_size), -ncol(pi), , drop = FALSE]


      # depth_indices <- arm[2:nb_timesteps]
      depth_indices <- arm[(batch_size+1):sample_size]

      pi <- extract_2d_from_3d(pi, depth_indices)

    } else {

      pi <- pi[, colSums(is.na(pi)) == 0, drop = FALSE]

      dims_result <- dim(pi)

      nb_timesteps <- dims_result[2]

      sample_size <- nb_timesteps * batch_size

      # pi <- pi[-1, -ncol(pi), drop = FALSE]
      pi <- pi[-(1:batch_size), -ncol(pi), drop = FALSE]

    }

    # pi is now a (T-1)B x (T-1) matrix

    ## POLICY DIFFERENCE

    # Before doing policy differences with lag 1, we add a column of 0
    # such that the first policy difference corresponds to pi_1
    pi <- cbind(0, pi)

    pi_diff <- pi[, -1] - pi[, -ncol(pi)]

    # pi_diff is a (T-1)B x (T-1) matrix

    # Vector of terms for each column
    policy_diff_weights <- 1 / (nb_timesteps - (1:(nb_timesteps - 1)))

    # Multiply each column of pi_diff by corresponding policy_diff_weight
    pi_diff <- sweep(pi_diff, 2, policy_diff_weights, FUN = "*")

    # Cumulative sum
    pi_diff <- t(apply(pi_diff, 1, cumsum))

    # pi_diff is a (T-1) x (T-1) matrix


    ## MULTIPLY by Rk / pi_k-1

    row_indices <- 1:nrow(pi)  # All row indices
    col_indices <- rep(2:ncol(pi), each = batch_size)  # Repeats column indices B times

    pi_diag <- pi[cbind(row_indices, col_indices)]

    # Create multipliers using vectorized operations
    # multipliers <- (1 / pi_diag) * reward[3:length(reward)]
    multipliers <- (1 / pi_diag) * reward[(batch_size+1):length(reward)]


    # # Get diagonal elements from pi
    # pi_diag <- pi[cbind(1:(nrow(pi)), 2:(ncol(pi)))]
    #
    # # Create multipliers using vectorized operations
    # multipliers <- (1 / pi_diag) * reward[2:length(reward)]

    # Apply row-wise multiplication using efficient matrix operation
    # mult_pi_diff <- pi_diff * multipliers  # Works via R's recycling rules (most efficient)
    mult_pi_diff <- sweep(pi_diff, 1, multipliers, FUN = "*")


    ## VARIANCE PER COLUMN


    # # Create the mask for batch indices
    # mask <- matrix(1, nrow = nrow(mult_pi_diff), ncol = ncol(mult_pi_diff))
    #
    # # Assign NaN to upper triangle elements (excluding diagonal)
    # mask[upper.tri(mask, diag = FALSE)] <- NaN

    num_rows <- nrow(mult_pi_diff)
    num_cols <- ncol(mult_pi_diff)

    # Create a matrix with row indices
    row_indices <- matrix(seq_len(num_rows), nrow = num_rows, ncol = num_cols, byrow = FALSE)

    # Compute the NaN threshold for each column
    nan_thresholds <- (seq_len(num_cols) - 1) * batch_size

    # Construct a mask where values should be 1 or NaN
    mask <- row_indices > matrix(nan_thresholds, nrow = num_rows, ncol = num_cols, byrow = TRUE)
    mask[!mask] <- NaN  # Assign NaN where condition is FALSE

    # Apply the mask to policy_diff
    mult_pi_diff <- mult_pi_diff * mask

    # Calculate variance for each column
    # column_variances <- apply(mult_pi_diff, 2, function(x) var(x, na.rm = TRUE))
    # column_variances <- apply(mult_pi_diff, 2, function(x) {
    #   n <- sum(!is.na(x))  # Count of non-NA values
    #   sample_var <- var(x, na.rm = TRUE)  # Compute sample variance (divides by n-1)
    #   sample_var * (n - 1) / n  # Convert to population variance (divides by n)
    # })

    column_variances <- apply(mult_pi_diff, 2, function(x) {
      n <- sum(!is.na(x))  # Count of non-NA values
      if (n == 1){
        return(0)
      } else {
        sample_var <- var(x, na.rm = TRUE)  # Compute sample variance (divides by n-1)
        sample_var <- sample_var * (n - 1) / n  # Convert to population variance (divides by n)
        return(sample_var)
      }
    })

    # print(column_variances)
    total_variance <- sum(column_variances) / batch_size

  }


  return(total_variance)
}

Try the cramR package in your browser

Any scripts or data that you put into this service are public.

cramR documentation built on Aug. 25, 2025, 1:12 a.m.