R/model.R

Defines functions tail_sampling_logs create_data compile_model sample_model cmdstan_model

#' @export
cmdstan_model <- function(stan_file, compile = TRUE, version = "2.27.0", local = FALSE, ...) {
    if(local) {
        message("* Compiling model using local machine executor.")
        # We try and use the user's cmdstan install on his machine. For now
        # we ignore the version.
        cmdstanr::cmdstan_model(stan_file = stan_file, compile = compile, ...)
    } else {
        message("* Compiling model using remote Cloudstan executor.")
        # We try and use remote Cloudstan compiler.
        path <- compile_model(stan_file = stan_file, version = version)
        mod <- cmdstanr::cmdstan_model(stan_file = stan_file, compile = FALSE)

        # Store the ID of the remote model as "exe_file" because this
        # is the only way we can store it into the model.
        mod$exe_file(path)

        # Override the mod$sample function so that we can inject our own
        # chain_ids. This allows us to track which samplings belong to
        # the same sampling event.
        unlockBinding("sample", mod)
        mod$sample <- sample_model(mod)
        lockBinding("sample", mod)

        mod
    }
}

sample_model <- function(mod) {
    func <- function(data,
                     seed = NULL,
                     refresh = NULL,
                     init = NULL,
                     chains = 4,
                     iter_warmup = NULL,
                     iter_sampling = NULL,
                     logs = TRUE,
                     wait = TRUE,

                     # Not supported when sampling remotely.
                     parallel_chains = 1,
                     save_latent_dynamics = FALSE,
                     output_dir = NULL,
                     output_basename = NULL,
                     sig_figs = NULL,
                     chain_ids = seq_len(chains),
                     threads_per_chain = NULL,
                     opencl_ids = NULL,
                     save_warmup = FALSE,
                     thin = NULL,
                     max_treedepth = NULL,
                     adapt_engaged = TRUE,
                     adapt_delta = NULL,
                     step_size = NULL,
                     metric = NULL,
                     metric_file = NULL,
                     inv_metric = NULL,
                     init_buffer = NULL,
                     term_buffer = NULL,
                     window = NULL,
                     fixed_param = FALSE,
                     validate_csv = TRUE,
                     show_messages = TRUE,
                     ...) {

        data <- create_data(data)

        apn <- function(args, key, val = NULL) {
            if(!is.null(val)) {
                c(args, key, val)
            } else {
                args
            }
        }

        args <- c(
            "samplings", "create",
            "--data-id", data,
            "--model-id", mod$exe_file()
        )
        args <- apn(args, "--seed", seed)
        args <- apn(args, "--chains", chains)
        args <- apn(args, "--iterations", iter_sampling)
        args <- apn(args, "--warmup", iter_warmup)
        args <- apn(args, "--refresh", refresh)

        # We run the sampling process and wait for it to
        # complete.
        message("* Model sampling using Cloudstan platform started!")

        rsp <- exec_bin(args = args)
        if (rsp[1] == "ERROR") {
            stop(paste(rsp$code, ": ", rsp$message))
        }

        if(logs) {
            tail_sampling_logs(rsp$id)
        }

        if(wait) {
            message("* Waiting for sampling to complete")
            exec_bin_raw(c("samplings", "wait", rsp$id))
            message("* Sampling complete!")
            load_sampling_results(rsp$id)
        }
    }
    func
}

compile_model <- function(stan_file, version) {
    stan_file = normalizePath(stan_file)

    checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")

    model_name <- sub(" ", "_", strip_ext(basename(stan_file)))

    message("* Model compilation started.")
    args <- c(
        "models", "create",
        "--name", model_name,  # Human readable name.
        "--code", stan_file,   # Absolute path pointing to Stan source file.
        "--compiler", version, # Which Stan compiler version to use.
        "--wait"               # Wait for compilation to complete.
    )
    rsp <- exec_bin(args)

    if (rsp[1] == "ERROR") {
        if(rsp$code != "STAN_MODEL_EXISTS") {
            stop(paste(rsp$code, ": ", rsp$message))
        } else {
            message("* Compiled model already exists. Skipping compilation.")
            rsp$details$id
        }
    } else {
        message("* Model successfully compiled.")
        rsp$id
    }
}

create_data <- function(data) {
    temp_file <- tempfile()
    write_stan_json(data, file = temp_file)

    message("* Uploading sampling data.")
    args <-c(
        "data", "create",
        "--data", temp_file
    )
    rsp <- exec_bin(args = args)

    if(rsp[1] == "ERROR") {
        stop(paste(rsp$code, ": ", rsp$message))
    }

    message("* Data successfully uploaded.")
    rsp$id
}

tail_sampling_logs <- function(id) {
    args <- c("samplings", "logs", id)

    cb <- function(line, proc) {
        message(line)
    }
    exec_bin_raw(args, cb)
}
uroshercog/cloudstan-r documentation built on Dec. 23, 2021, 2:03 p.m.