Nothing
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
#' Extract model ID from a string or fitted model object
#'
#' Accepts either a bare character model ID or a list with a `$model_id`
#' element (as returned by [midas_fit()] or [midas()]).
#'
#' @param x A character string or a list with a `$model_id` element.
#' @return Character model ID.
#' @keywords internal
extract_model_id <- function(x) {
if (is.character(x)) return(x)
if (is.list(x) && !is.null(x$model_id)) return(x$model_id)
rlang::abort(
paste0(
"`model_id` must be a character string or a list with a `$model_id` ",
"element (as returned by midas_fit() or midas())."
)
)
}
#' Convert an R matrix / data.frame to a nested list suitable for JSON
#' @param x A matrix or data frame.
#' @return A nested list of rows.
#' @keywords internal
to_nested_list <- function(x) {
x <- as.matrix(x)
lapply(seq_len(nrow(x)), function(i) {
row <- unname(as.list(x[i, ]))
lapply(row, function(v) if (is.na(v)) NULL else v)
})
}
#' POST JSON and return parsed body
#' @param path API path.
#' @param body List to send as JSON.
#' @param timeout Request timeout in seconds.
#' @return Parsed JSON response as a list.
#' @keywords internal
post_json <- function(path, body, timeout = 600) {
resp <- base_req(path) |>
httr2::req_body_json(body, auto_unbox = TRUE) |>
httr2::req_timeout(timeout) |>
httr2::req_perform()
httr2::resp_body_json(resp, simplifyVector = TRUE)
}
#' GET and return parsed body
#' @param path API path.
#' @param timeout Request timeout in seconds.
#' @return Parsed JSON response as a list.
#' @keywords internal
get_json <- function(path, timeout = 60) {
resp <- base_req(path) |>
httr2::req_timeout(timeout) |>
httr2::req_perform()
httr2::resp_body_json(resp, simplifyVector = TRUE)
}
#' Parse a JSON table response into a data.frame
#' @param res List with `data` and `columns` elements.
#' @return A data frame.
#' @keywords internal
parse_table <- function(res) {
mat <- if (is.matrix(res$data)) res$data else do.call(rbind, res$data)
df <- as.data.frame(mat)
colnames(df) <- res$columns
df
}
# ---------------------------------------------------------------------------
# Fit
# ---------------------------------------------------------------------------
#' Fit a MIDAS model
#'
#' Sends data to the server and fits a MIDAS denoising autoencoder.
#'
#' @param data A data frame (may contain `NA` for missing values).
#' @param hidden_layers Integer vector of hidden layer sizes
#' (default `c(256, 128, 64)`).
#' @param dropout_prob Numeric. Dropout probability (default 0.5).
#' @param epochs Integer. Number of training epochs (default 75).
#' @param batch_size Integer. Mini-batch size (default 64).
#' @param lr Numeric. Learning rate (default 0.001).
#' @param corrupt_rate Numeric. Corruption rate for denoising (default 0.8).
#' @param num_adj Numeric. Loss multiplier for numeric columns (default 1).
#' @param cat_adj Numeric. Loss multiplier for categorical columns (default 1).
#' @param bin_adj Numeric. Loss multiplier for binary columns (default 1).
#' @param pos_adj Numeric. Loss multiplier for positive columns (default 1).
#' @param omit_first Logical. Omit first column from encoder input
#' (default `FALSE`).
#' @param seed Integer. Random seed (default 89).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list with `model_id`, `n_rows`, `n_cols`, `col_types`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200), X3 = rnorm(200))
#' df$X2[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' fit$model_id
#' }
#' @export
midas_fit <- function(data, hidden_layers = c(256L, 128L, 64L),
dropout_prob = 0.5, epochs = 75L, batch_size = 64L,
lr = 0.001, corrupt_rate = 0.8,
num_adj = 1, cat_adj = 1, bin_adj = 1, pos_adj = 1,
omit_first = FALSE, seed = 89L, ...) {
ensure_server(...)
# Auto-switch to parquet for large data
if (nrow(data) > 100000L && rlang::is_installed("arrow")) {
tmp <- tempfile(fileext = ".parquet")
on.exit(unlink(tmp), add = TRUE)
arrow::write_parquet(data, tmp)
hl <- paste(hidden_layers, collapse = ",")
rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
resp <- base_req("/fit_parquet") |>
httr2::req_body_multipart(
file = curl::form_file(tmp, type = "application/octet-stream"),
hidden_layers = hl,
dropout_prob = as.character(dropout_prob),
epochs = as.character(epochs),
batch_size = as.character(batch_size),
lr = as.character(lr),
corrupt_rate = as.character(corrupt_rate),
num_adj = as.character(num_adj),
cat_adj = as.character(cat_adj),
bin_adj = as.character(bin_adj),
pos_adj = as.character(pos_adj),
omit_first = as.character(omit_first),
seed = if (!is.null(seed)) as.character(seed) else ""
) |>
httr2::req_timeout(600) |>
httr2::req_error(body = function(resp) {
tryCatch({
detail <- httr2::resp_body_json(resp)$detail
if (is.character(detail)) detail else paste(detail, collapse = "; ")
}, error = function(e) NULL)
}) |>
httr2::req_perform()
rlang::inform("Training complete.")
return(httr2::resp_body_json(resp, simplifyVector = TRUE))
}
body <- list(
data = to_nested_list(data),
columns = colnames(data),
hidden_layers = as.list(hidden_layers),
dropout_prob = dropout_prob,
epochs = epochs,
batch_size = batch_size,
lr = lr,
corrupt_rate = corrupt_rate,
num_adj = num_adj,
cat_adj = cat_adj,
bin_adj = bin_adj,
pos_adj = pos_adj,
omit_first = omit_first,
verbose = FALSE,
seed = seed
)
rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
res <- post_json("/fit", body, timeout = 600)
rlang::inform("Training complete.")
res
}
# ---------------------------------------------------------------------------
# Transform
# ---------------------------------------------------------------------------
#' Generate multiple imputations
#'
#' Generates `m` imputed datasets from a fitted MIDAS model.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#' a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param m Integer. Number of imputations (default 5).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list of `m` data frames, each with imputed values.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' imps <- midas_transform(fit, m = 10)
#' head(imps[[1]])
#' }
#' @export
midas_transform <- function(model_id, m = 5L, ...) {
model_id <- extract_model_id(model_id)
ensure_server(...)
# Trigger transform on server
post_json(
paste0("/model/", model_id, "/transform"),
list(m = m),
timeout = 600
)
# Fetch each imputation
imps <- vector("list", m)
for (i in seq_len(m)) {
res <- get_json(
paste0("/model/", model_id, "/imputations/", i - 1L),
timeout = 60
)
imps[[i]] <- parse_table(res)
}
imps
}
# ---------------------------------------------------------------------------
# All-in-one
# ---------------------------------------------------------------------------
#' Multiple imputation (all-in-one)
#'
#' Convenience function that fits a MIDAS model and generates imputations
#' in a single call. Equivalent to calling [midas_fit()] followed by
#' [midas_transform()].
#'
#' @inheritParams midas_fit
#' @param m Integer. Number of imputations (default 5).
#'
#' @return A list with `model_id` and `imputations` (a list of data frames).
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' result <- midas(df, m = 5, epochs = 10)
#' head(result$imputations[[1]])
#' }
#' @export
midas <- function(data, m = 5L, hidden_layers = c(256L, 128L, 64L),
dropout_prob = 0.5, epochs = 75L, batch_size = 64L,
lr = 0.001, corrupt_rate = 0.8,
num_adj = 1, cat_adj = 1, bin_adj = 1, pos_adj = 1,
omit_first = FALSE, seed = 89L, ...) {
ensure_server(...)
# For large data use parquet path
if (nrow(data) > 100000L && rlang::is_installed("arrow")) {
tmp <- tempfile(fileext = ".parquet")
on.exit(unlink(tmp), add = TRUE)
arrow::write_parquet(data, tmp)
hl <- paste(hidden_layers, collapse = ",")
rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
resp <- base_req("/complete_parquet") |>
httr2::req_body_multipart(
file = curl::form_file(tmp, type = "application/octet-stream"),
m = as.character(m),
hidden_layers = hl,
dropout_prob = as.character(dropout_prob),
epochs = as.character(epochs),
batch_size = as.character(batch_size),
lr = as.character(lr),
corrupt_rate = as.character(corrupt_rate),
num_adj = as.character(num_adj),
cat_adj = as.character(cat_adj),
bin_adj = as.character(bin_adj),
pos_adj = as.character(pos_adj),
omit_first = as.character(omit_first),
seed = if (!is.null(seed)) as.character(seed) else ""
) |>
httr2::req_timeout(600) |>
httr2::req_error(body = function(resp) {
tryCatch({
detail <- httr2::resp_body_json(resp)$detail
if (is.character(detail)) detail else paste(detail, collapse = "; ")
}, error = function(e) NULL)
}) |>
httr2::req_perform()
rlang::inform("Training complete.")
res <- httr2::resp_body_json(resp, simplifyVector = TRUE)
} else {
body <- list(
data = to_nested_list(data),
columns = colnames(data),
m = m,
hidden_layers = as.list(hidden_layers),
dropout_prob = dropout_prob,
epochs = epochs,
batch_size = batch_size,
lr = lr,
corrupt_rate = corrupt_rate,
num_adj = num_adj,
cat_adj = cat_adj,
bin_adj = bin_adj,
pos_adj = pos_adj,
omit_first = omit_first,
verbose = FALSE,
seed = seed
)
rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
res <- post_json("/complete", body, timeout = 600)
rlang::inform("Training complete.")
}
# Parse imputations from response
# simplifyVector may produce a 3D array or a list of matrices
raw <- res$imputations
if (is.array(raw) && length(dim(raw)) == 3L) {
imp_list <- lapply(seq_len(dim(raw)[1]), function(i) {
df <- as.data.frame(raw[i, , ])
colnames(df) <- res$columns
df
})
} else {
imp_list <- lapply(raw, function(imp_data) {
mat <- if (is.matrix(imp_data)) imp_data else do.call(rbind, imp_data)
df <- as.data.frame(mat)
colnames(df) <- res$columns
df
})
}
list(model_id = res$model_id, imputations = imp_list)
}
# ---------------------------------------------------------------------------
# imp_mean
# ---------------------------------------------------------------------------
#' Compute mean imputation
#'
#' Calculates the element-wise mean across all stored imputations for a model.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#' a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A data frame with the mean imputed values.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' midas_transform(fit, m = 10)
#' mean_df <- imp_mean(fit)
#' }
#' @export
imp_mean <- function(model_id, ...) {
model_id <- extract_model_id(model_id)
ensure_server(...)
res <- post_json(paste0("/model/", model_id, "/imp_mean"), list())
parse_table(res)
}
# ---------------------------------------------------------------------------
# Combine (Rubin's rules)
# ---------------------------------------------------------------------------
#' Combine results using Rubin's rules
#'
#' Runs a GLM across all stored imputations and combines the results
#' using Rubin's combination rules for multiple imputation inference.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#' a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param y Character. Name of the outcome variable.
#' @param ind_vars Character vector of independent variable names, or `NULL`
#' for all non-outcome columns.
#' @param dof_adjust Logical. Apply Barnard-Rubin degrees-of-freedom
#' adjustment (default `TRUE`).
#' @param incl_constant Logical. Include an intercept (default `TRUE`).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A data frame with columns `term`, `estimate`, `std.error`,
#' `statistic`, `df`, and `p.value`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(Y = rnorm(200), X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' midas_transform(fit, m = 10)
#' results <- combine(fit, y = "Y")
#' results
#' }
#' @export
combine <- function(model_id, y, ind_vars = NULL, dof_adjust = TRUE,
incl_constant = TRUE, ...) {
model_id <- extract_model_id(model_id)
ensure_server(...)
body <- list(y = y, dof_adjust = dof_adjust, incl_constant = incl_constant)
if (!is.null(ind_vars)) body$ind_vars <- as.list(ind_vars)
res <- post_json(paste0("/model/", model_id, "/combine"), body, timeout = 300)
parse_table(res)
}
# ---------------------------------------------------------------------------
# Overimpute diagnostic
# ---------------------------------------------------------------------------
#' Overimputation diagnostic
#'
#' Masks a fraction of observed values, re-imputes them, and computes
#' RMSE to assess imputation quality.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#' a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param mask_frac Numeric. Fraction of observed values to mask (default 0.1).
#' @param m Integer. Number of imputations for the diagnostic (default 5).
#' @param seed Integer or `NULL`. Random seed.
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list with `rmse` (named numeric vector) and `mean_rmse`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' diag <- overimpute(fit, mask_frac = 0.1)
#' diag$mean_rmse
#' }
#' @export
overimpute <- function(model_id, mask_frac = 0.1, m = 5L, seed = NULL, ...) {
model_id <- extract_model_id(model_id)
ensure_server(...)
body <- list(mask_frac = mask_frac, m = m)
if (!is.null(seed)) body$seed <- seed
post_json(paste0("/model/", model_id, "/overimpute"), body, timeout = 300)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.