R/allocation_rule.R

#' Allocation Rule Class
#'
#' @description
#' This class represents an allocation rule that generates a next allocation.
#'
#' @field policy The RLlib policy that is a Python object.
#' @field dir Directory path of the allocation rule (policy).
#' @field dirpath Full path to the directory of the allocation rule.
#' @field created_at Created time of this object.
#' @field info Information when learning the allocation rule.
#' @field input Inputs for learning the allocation rule.
#' @field log The log of scores during the learning of the allocation rule.
#' @field checkpoints The integer vector of iteration counts for checkpoints.
#' @field checkpoints_paths The paths to the directories where each checkpoint is stored.
#'
#' @export
AllocationRule <- R6Class(
  "AllocationRule",

  public = list(
    policy = NULL,
    dir = NULL,
    dirpath = NULL,
    created_at = NULL,

    info = NULL,
    input = NULL,
    log = NULL,
    checkpoints = NULL,
    checkpoints_paths = NULL,

    #' @description
    #' Create a new AllocationRule object.
    #'
    #' @param dir A character value. A directory name or path where an
    #'        allocation rule is outputted. By default, the latest allocation
    #'        rule is searched in 'base_dir'.
    #' @param base_dir A character value. A directory path that is used as the
    #'        parent directory if the 'dir' argument is a directory name and is
    #'        not used otherwise.
    initialize = function(dir = "latest", base_dir = "allocation_rules") {
      # Check arguments
      stopifnot(length(dir) == 1L)
      stopifnot(length(base_dir) == 1L)

      # Identify the specified directory path
      if (dir == "latest") {
        # Search the latest directory in 'base_dir'
        df <- file.info(dir(base_dir, full.names = TRUE))
        df <- df[df$isdir, , drop = FALSE]
        df <- df[df$mtime == max(df$mtime), , drop = FALSE]
        stopifnot("Cannot identify the latest allocation rule" = nrow(df) == 1L)
        dir <- rownames(df)
      } else if (!grepl("/|\\\\", dir)) {
        # If 'dir' does not contain '/' or '\\', it is a directory name.
        dir <- file.path(base_dir, dir)
      }

      # If dir is a checkpoint directory, change it to its policy directory
      if ("policies" %in% list.dirs(dir, full.names = FALSE, recursive = FALSE)) {
        dir <- file.path(dir, "policies", "default_policy")
      }

      # Restore policy object
      policy_lib <- reticulate::import("ray.rllib.policy.policy")
      policy <- policy_lib$Policy$from_checkpoint(dir)

      # Compress the policy directory and keep it in binary format in this object
      compressed_policy_file <- tempfile(fileext = ".zip")
      zip(zipfile = compressed_policy_file, files = list.files(dir, full.names = TRUE))
      private$policy_binary <- readBin(compressed_policy_file, what = "raw",
                                       n = file.info(compressed_policy_file)$size)

      self$policy <- policy
      self$dir <- dir
      self$dirpath <- normalizePath(dir)
      self$created_at <- format(Sys.time(), format = "%Y-%m-%d %H:%M:%S")
    },

    #' @description
    #' Compute optimal allocation probabilities using the obtained allocation rule for dose and response data.
    #'
    #' @param data_doses A numeric vector. The doses actually administered to each
    #'        subject in your clinical trial. It must include all previous
    #'        doses.
    #' @param data_resps A numeric vector. The values of responses corresponding to
    #'        each subject for the 'data_doses' argument.
    #'
    #' @return A vector of the probabilities of the doses.
    #'
    #' @importFrom glue glue
    opt_allocation_probs = function(data_doses, data_resps) {
      # If policy has been reset, reload it from the directory where it is stored
      if (is.null(self$policy$config)) {
        # If the directory where policy is stored cannot be found, restore it from the binary data
        if (!dir.exists(self$dir)) {
          compressed_policy_file <- tempfile(fileext = ".zip")
          writeBin(private$policy_binary, compressed_policy_file)
          unzip(compressed_policy_file)
          message(glue("Created allocation rule directory '{self$dir}'."))
        }
        policy_lib <- reticulate::import("ray.rllib.policy.policy")
        self$policy <- policy_lib$Policy$from_checkpoint(self$dir)
      }

      # Extract the clinical trial settings from the allocation rule (policy)
      policy <- self$policy
      env_config <- policy$config$env_config
      K <- policy$action_space$n
      N_total <- env_config$N_total

      # Check arguments
      stopifnot(length(data_doses) == length(data_resps))
      
      data_doses <- as.numeric(data_doses)
      data_resps <- as.numeric(data_resps)

      # Convert data_doses to data_actions
      actions <- seq_len(K) - 1L
      doses <- env_config$doses
      names(actions) <- as.character(doses)
      excluded_doses <- setdiff(unique(data_doses), doses)
      if (length(excluded_doses) > 0L) {
        excluded_doses <- paste(excluded_doses, collapse = ", ")
        stop(glue("{excluded_doses} is not included in doses on the learning."))
      }
      data_actions <- actions[as.character(data_doses)]

      # Obtain the probabilities of next actions
      state <- compute_state(data_actions, data_resps, N_total)
      info <- policy$compute_single_action(state, full_fetch = TRUE)[[3L]]
      action_probs <- info$action_dist_inputs  # array
      action_probs <- as.vector(action_probs)  # cast to numeric vector
      action_probs <- softmax(action_probs)
      names(action_probs) <- as.character(doses)
      action_probs
    },

    #' @description
    #' Resume learning the allocation rule. This function updates the original
    #' AllocationRule object.
    #'
    #' @param iter A number of additional iterations.
    #'
    #' @return An updated \link{AllocationRule} object.
    resume_learning = function(iter) {
      checkpoint_path <- tail(self$checkpoints_paths, 1L)
      algorithm <- reticulate::import("ray.rllib.algorithms.algorithm")
      algo <- algorithm$Algorithm$from_checkpoint(checkpoint_path)

      output_path <- self$dirpath
      output_checkpoint_path <- sub("^(.*)_\\d+$", "\\1", checkpoint_path)
      save_start_iter <- self$input$rl_config$save_start_iter
      save_every_iter <- self$input$rl_config$save_every_iter
      N_update <- self$info$iterations + iter

      n_start <- self$info$iterations + 1L

      result <- train_algo(algo, n_start, N_update,
                           output_path, output_checkpoint_path,
                           save_start_iter, save_every_iter)

      checkpoints <- c(self$checkpoints_paths, result$checkpoints)
      episode_data <- rbind(self$info$log, result$episode_data)

      self$info$iterations <- N_update
      self$log <- episode_data
      private$set_checkpoints(checkpoints)

      invisible(self)
    },

    #' @description
    #' Set information when learning the allocation rule.
    #'
    #' @param info Information when learning the allocation rule.
    #' @param input Inputs for learning the allocation rule.
    #' @param log The log of scores during the learning of the allocation rule.
    #' @param checkpoints The paths to the directories where each checkpoint is stored.
    set_info = function(info, input, log, checkpoints) {
      self$info <- info
      self$input <- input
      self$log <- log
      private$set_checkpoints(checkpoints)
    },

    #' @description
    #' Print function for AllocationRule object
    #'
    #' @importFrom glue glue
    print = function() {
      print(glue("<AllocationRule>"))
      print(glue("dir: {self$dir}"))
      print(glue("created at: {self$created_at}"))
      if (!is.null(self$info)) {
        print(glue("call:"))
        print(glue("{deparse(self$info$call)}"))
        print(glue("iterations: {self$info$iterations}"))
        checkpoints <- paste0(self$checkpoints, collapse = ", ")
        print(glue("checkpoints: {checkpoints}"))
      }
    }
  ),

  private = list(
    policy_binary = NULL,

    set_checkpoints = function(checkpoints_paths) {
      self$checkpoints_paths = checkpoints_paths
      self$checkpoints <- as.integer(sub(".*_(\\d+)$", "\\1", checkpoints_paths))
    }
  )
)

Try the RLoptimal package in your browser

Any scripts or data that you put into this service are public.

RLoptimal documentation built on April 4, 2025, 4:43 a.m.