#' @export
cmdstan_model <- function(stan_file, compile = TRUE, version = "2.27.0", local = FALSE, ...) {
if(local) {
message("* Compiling model using local machine executor.")
# We try and use the user's cmdstan install on his machine. For now
# we ignore the version.
cmdstanr::cmdstan_model(stan_file = stan_file, compile = compile, ...)
} else {
message("* Compiling model using remote Cloudstan executor.")
# We try and use remote Cloudstan compiler.
path <- compile_model(stan_file = stan_file, version = version)
mod <- cmdstanr::cmdstan_model(stan_file = stan_file, compile = FALSE)
# Store the ID of the remote model as "exe_file" because this
# is the only way we can store it into the model.
mod$exe_file(path)
# Override the mod$sample function so that we can inject our own
# chain_ids. This allows us to track which samplings belong to
# the same sampling event.
unlockBinding("sample", mod)
mod$sample <- sample_model(mod)
lockBinding("sample", mod)
mod
}
}
sample_model <- function(mod) {
func <- function(data,
seed = NULL,
refresh = NULL,
init = NULL,
chains = 4,
iter_warmup = NULL,
iter_sampling = NULL,
logs = TRUE,
wait = TRUE,
# Not supported when sampling remotely.
parallel_chains = 1,
save_latent_dynamics = FALSE,
output_dir = NULL,
output_basename = NULL,
sig_figs = NULL,
chain_ids = seq_len(chains),
threads_per_chain = NULL,
opencl_ids = NULL,
save_warmup = FALSE,
thin = NULL,
max_treedepth = NULL,
adapt_engaged = TRUE,
adapt_delta = NULL,
step_size = NULL,
metric = NULL,
metric_file = NULL,
inv_metric = NULL,
init_buffer = NULL,
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
validate_csv = TRUE,
show_messages = TRUE,
...) {
data <- create_data(data)
apn <- function(args, key, val = NULL) {
if(!is.null(val)) {
c(args, key, val)
} else {
args
}
}
args <- c(
"samplings", "create",
"--data-id", data,
"--model-id", mod$exe_file()
)
args <- apn(args, "--seed", seed)
args <- apn(args, "--chains", chains)
args <- apn(args, "--iterations", iter_sampling)
args <- apn(args, "--warmup", iter_warmup)
args <- apn(args, "--refresh", refresh)
# We run the sampling process and wait for it to
# complete.
message("* Model sampling using Cloudstan platform started!")
rsp <- exec_bin(args = args)
if (rsp[1] == "ERROR") {
stop(paste(rsp$code, ": ", rsp$message))
}
if(logs) {
tail_sampling_logs(rsp$id)
}
if(wait) {
message("* Waiting for sampling to complete")
exec_bin_raw(c("samplings", "wait", rsp$id))
message("* Sampling complete!")
load_sampling_results(rsp$id)
}
}
func
}
compile_model <- function(stan_file, version) {
stan_file = normalizePath(stan_file)
checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")
model_name <- sub(" ", "_", strip_ext(basename(stan_file)))
message("* Model compilation started.")
args <- c(
"models", "create",
"--name", model_name, # Human readable name.
"--code", stan_file, # Absolute path pointing to Stan source file.
"--compiler", version, # Which Stan compiler version to use.
"--wait" # Wait for compilation to complete.
)
rsp <- exec_bin(args)
if (rsp[1] == "ERROR") {
if(rsp$code != "STAN_MODEL_EXISTS") {
stop(paste(rsp$code, ": ", rsp$message))
} else {
message("* Compiled model already exists. Skipping compilation.")
rsp$details$id
}
} else {
message("* Model successfully compiled.")
rsp$id
}
}
create_data <- function(data) {
temp_file <- tempfile()
write_stan_json(data, file = temp_file)
message("* Uploading sampling data.")
args <-c(
"data", "create",
"--data", temp_file
)
rsp <- exec_bin(args = args)
if(rsp[1] == "ERROR") {
stop(paste(rsp$code, ": ", rsp$message))
}
message("* Data successfully uploaded.")
rsp$id
}
tail_sampling_logs <- function(id) {
args <- c("samplings", "logs", id)
cb <- function(line, proc) {
message(line)
}
exec_bin_raw(args, cb)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.