Nothing
#' Allocation Rule Class
#'
#' @description
#' This class represents an allocation rule that generates a next allocation.
#'
#' @field policy The RLlib policy that is a Python object.
#' @field dir Directory path of the allocation rule (policy).
#' @field dirpath Full path to the directory of the allocation rule.
#' @field created_at Created time of this object.
#' @field info Information when learning the allocation rule.
#' @field input Inputs for learning the allocation rule.
#' @field log The log of scores during the learning of the allocation rule.
#' @field checkpoints The integer vector of iteration counts for checkpoints.
#' @field checkpoints_paths The paths to the directories where each checkpoint is stored.
#'
#' @export
AllocationRule <- R6Class(
"AllocationRule",
public = list(
policy = NULL,
dir = NULL,
dirpath = NULL,
created_at = NULL,
info = NULL,
input = NULL,
log = NULL,
checkpoints = NULL,
checkpoints_paths = NULL,
#' @description
#' Create a new AllocationRule object.
#'
#' @param dir A character value. A directory name or path where an
#' allocation rule is outputted. By default, the latest allocation
#' rule is searched in 'base_dir'.
#' @param base_dir A character value. A directory path that is used as the
#' parent directory if the 'dir' argument is a directory name and is
#' not used otherwise.
initialize = function(dir = "latest", base_dir = "allocation_rules") {
# Check arguments
stopifnot(length(dir) == 1L)
stopifnot(length(base_dir) == 1L)
# Identify the specified directory path
if (dir == "latest") {
# Search the latest directory in 'base_dir'
df <- file.info(dir(base_dir, full.names = TRUE))
df <- df[df$isdir, , drop = FALSE]
df <- df[df$mtime == max(df$mtime), , drop = FALSE]
stopifnot("Cannot identify the latest allocation rule" = nrow(df) == 1L)
dir <- rownames(df)
} else if (!grepl("/|\\\\", dir)) {
# If 'dir' does not contain '/' or '\\', it is a directory name.
dir <- file.path(base_dir, dir)
}
# If dir is a checkpoint directory, change it to its policy directory
if ("policies" %in% list.dirs(dir, full.names = FALSE, recursive = FALSE)) {
dir <- file.path(dir, "policies", "default_policy")
}
# Restore policy object
policy_lib <- reticulate::import("ray.rllib.policy.policy")
policy <- policy_lib$Policy$from_checkpoint(dir)
# Compress the policy directory and keep it in binary format in this object
compressed_policy_file <- tempfile(fileext = ".zip")
zip(zipfile = compressed_policy_file, files = list.files(dir, full.names = TRUE))
private$policy_binary <- readBin(compressed_policy_file, what = "raw",
n = file.info(compressed_policy_file)$size)
self$policy <- policy
self$dir <- dir
self$dirpath <- normalizePath(dir)
self$created_at <- format(Sys.time(), format = "%Y-%m-%d %H:%M:%S")
},
#' @description
#' Compute optimal allocation probabilities using the obtained allocation rule for dose and response data.
#'
#' @param data_doses A numeric vector. The doses actually administered to each
#' subject in your clinical trial. It must include all previous
#' doses.
#' @param data_resps A numeric vector. The values of responses corresponding to
#' each subject for the 'data_doses' argument.
#'
#' @return A vector of the probabilities of the doses.
#'
#' @importFrom glue glue
opt_allocation_probs = function(data_doses, data_resps) {
# If policy has been reset, reload it from the directory where it is stored
if (is.null(self$policy$config)) {
# If the directory where policy is stored cannot be found, restore it from the binary data
if (!dir.exists(self$dir)) {
compressed_policy_file <- tempfile(fileext = ".zip")
writeBin(private$policy_binary, compressed_policy_file)
unzip(compressed_policy_file)
message(glue("Created allocation rule directory '{self$dir}'."))
}
policy_lib <- reticulate::import("ray.rllib.policy.policy")
self$policy <- policy_lib$Policy$from_checkpoint(self$dir)
}
# Extract the clinical trial settings from the allocation rule (policy)
policy <- self$policy
env_config <- policy$config$env_config
K <- policy$action_space$n
N_total <- env_config$N_total
# Check arguments
stopifnot(length(data_doses) == length(data_resps))
data_doses <- as.numeric(data_doses)
data_resps <- as.numeric(data_resps)
# Convert data_doses to data_actions
actions <- seq_len(K) - 1L
doses <- env_config$doses
names(actions) <- as.character(doses)
excluded_doses <- setdiff(unique(data_doses), doses)
if (length(excluded_doses) > 0L) {
excluded_doses <- paste(excluded_doses, collapse = ", ")
stop(glue("{excluded_doses} is not included in doses on the learning."))
}
data_actions <- actions[as.character(data_doses)]
# Obtain the probabilities of next actions
state <- compute_state(data_actions, data_resps, N_total)
info <- policy$compute_single_action(state, full_fetch = TRUE)[[3L]]
action_probs <- info$action_dist_inputs # array
action_probs <- as.vector(action_probs) # cast to numeric vector
action_probs <- softmax(action_probs)
names(action_probs) <- as.character(doses)
action_probs
},
#' @description
#' Resume learning the allocation rule. This function updates the original
#' AllocationRule object.
#'
#' @param iter A number of additional iterations.
#'
#' @return An updated \link{AllocationRule} object.
resume_learning = function(iter) {
checkpoint_path <- tail(self$checkpoints_paths, 1L)
algorithm <- reticulate::import("ray.rllib.algorithms.algorithm")
algo <- algorithm$Algorithm$from_checkpoint(checkpoint_path)
output_path <- self$dirpath
output_checkpoint_path <- sub("^(.*)_\\d+$", "\\1", checkpoint_path)
save_start_iter <- self$input$rl_config$save_start_iter
save_every_iter <- self$input$rl_config$save_every_iter
N_update <- self$info$iterations + iter
n_start <- self$info$iterations + 1L
result <- train_algo(algo, n_start, N_update,
output_path, output_checkpoint_path,
save_start_iter, save_every_iter)
checkpoints <- c(self$checkpoints_paths, result$checkpoints)
episode_data <- rbind(self$info$log, result$episode_data)
self$info$iterations <- N_update
self$log <- episode_data
private$set_checkpoints(checkpoints)
invisible(self)
},
#' @description
#' Set information when learning the allocation rule.
#'
#' @param info Information when learning the allocation rule.
#' @param input Inputs for learning the allocation rule.
#' @param log The log of scores during the learning of the allocation rule.
#' @param checkpoints The paths to the directories where each checkpoint is stored.
set_info = function(info, input, log, checkpoints) {
self$info <- info
self$input <- input
self$log <- log
private$set_checkpoints(checkpoints)
},
#' @description
#' Print function for AllocationRule object
#'
#' @importFrom glue glue
print = function() {
print(glue("<AllocationRule>"))
print(glue("dir: {self$dir}"))
print(glue("created at: {self$created_at}"))
if (!is.null(self$info)) {
print(glue("call:"))
print(glue("{deparse(self$info$call)}"))
print(glue("iterations: {self$info$iterations}"))
checkpoints <- paste0(self$checkpoints, collapse = ", ")
print(glue("checkpoints: {checkpoints}"))
}
}
),
private = list(
policy_binary = NULL,
set_checkpoints = function(checkpoints_paths) {
self$checkpoints_paths = checkpoints_paths
self$checkpoints <- as.integer(sub(".*_(\\d+)$", "\\1", checkpoints_paths))
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.