R/client.R

Defines functions overimpute combine imp_mean midas midas_transform midas_fit parse_table get_json post_json to_nested_list extract_model_id

Documented in combine extract_model_id get_json imp_mean midas midas_fit midas_transform overimpute parse_table post_json to_nested_list

# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

#' Extract model ID from a string or fitted model object
#'
#' Accepts either a bare character model ID or a list with a `$model_id`
#' element (as returned by [midas_fit()] or [midas()]).
#'
#' @param x A character string or a list with a `$model_id` element.
#' @return Character model ID.
#' @keywords internal
extract_model_id <- function(x) {
  if (is.character(x)) return(x)
  if (is.list(x) && !is.null(x$model_id)) return(x$model_id)
  rlang::abort(
    paste0(
      "`model_id` must be a character string or a list with a `$model_id` ",
      "element (as returned by midas_fit() or midas())."
    )
  )
}

#' Convert an R matrix / data.frame to a nested list suitable for JSON
#' @param x A matrix or data frame.
#' @return A nested list of rows.
#' @keywords internal
to_nested_list <- function(x) {
  x <- as.matrix(x)
  lapply(seq_len(nrow(x)), function(i) {
    row <- unname(as.list(x[i, ]))
    lapply(row, function(v) if (is.na(v)) NULL else v)
  })
}

#' POST JSON and return parsed body
#' @param path API path.
#' @param body List to send as JSON.
#' @param timeout Request timeout in seconds.
#' @return Parsed JSON response as a list.
#' @keywords internal
post_json <- function(path, body, timeout = 600) {
  resp <- base_req(path) |>
    httr2::req_body_json(body, auto_unbox = TRUE) |>
    httr2::req_timeout(timeout) |>
    httr2::req_perform()
  httr2::resp_body_json(resp, simplifyVector = TRUE)
}

#' GET and return parsed body
#' @param path API path.
#' @param timeout Request timeout in seconds.
#' @return Parsed JSON response as a list.
#' @keywords internal
get_json <- function(path, timeout = 60) {
  resp <- base_req(path) |>
    httr2::req_timeout(timeout) |>
    httr2::req_perform()
  httr2::resp_body_json(resp, simplifyVector = TRUE)
}

#' Parse a JSON table response into a data.frame
#' @param res List with `data` and `columns` elements.
#' @return A data frame.
#' @keywords internal
parse_table <- function(res) {
  mat <- if (is.matrix(res$data)) res$data else do.call(rbind, res$data)
  df <- as.data.frame(mat)
  colnames(df) <- res$columns
  df
}


# ---------------------------------------------------------------------------
# Fit
# ---------------------------------------------------------------------------

#' Fit a MIDAS model
#'
#' Sends data to the server and fits a MIDAS denoising autoencoder.
#'
#' @param data A data frame (may contain `NA` for missing values).
#' @param hidden_layers Integer vector of hidden layer sizes
#'   (default `c(256, 128, 64)`).
#' @param dropout_prob Numeric. Dropout probability (default 0.5).
#' @param epochs Integer. Number of training epochs (default 75).
#' @param batch_size Integer. Mini-batch size (default 64).
#' @param lr Numeric. Learning rate (default 0.001).
#' @param corrupt_rate Numeric. Corruption rate for denoising (default 0.8).
#' @param num_adj Numeric. Loss multiplier for numeric columns (default 1).
#' @param cat_adj Numeric. Loss multiplier for categorical columns (default 1).
#' @param bin_adj Numeric. Loss multiplier for binary columns (default 1).
#' @param pos_adj Numeric. Loss multiplier for positive columns (default 1).
#' @param omit_first Logical. Omit first column from encoder input
#'   (default `FALSE`).
#' @param seed Integer. Random seed (default 89).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list with `model_id`, `n_rows`, `n_cols`, `col_types`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200), X3 = rnorm(200))
#' df$X2[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' fit$model_id
#' }
#' @export
midas_fit <- function(data, hidden_layers = c(256L, 128L, 64L),
                      dropout_prob = 0.5, epochs = 75L, batch_size = 64L,
                      lr = 0.001, corrupt_rate = 0.8,
                      num_adj = 1, cat_adj = 1, bin_adj = 1, pos_adj = 1,
                      omit_first = FALSE, seed = 89L, ...) {
  ensure_server(...)

  # Auto-switch to parquet for large data
  if (nrow(data) > 100000L && rlang::is_installed("arrow")) {
    tmp <- tempfile(fileext = ".parquet")
    on.exit(unlink(tmp), add = TRUE)
    arrow::write_parquet(data, tmp)
    hl <- paste(hidden_layers, collapse = ",")
    rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
    resp <- base_req("/fit_parquet") |>
      httr2::req_body_multipart(
        file = curl::form_file(tmp, type = "application/octet-stream"),
        hidden_layers = hl,
        dropout_prob = as.character(dropout_prob),
        epochs = as.character(epochs),
        batch_size = as.character(batch_size),
        lr = as.character(lr),
        corrupt_rate = as.character(corrupt_rate),
        num_adj = as.character(num_adj),
        cat_adj = as.character(cat_adj),
        bin_adj = as.character(bin_adj),
        pos_adj = as.character(pos_adj),
        omit_first = as.character(omit_first),
        seed = if (!is.null(seed)) as.character(seed) else ""
      ) |>
      httr2::req_timeout(600) |>
      httr2::req_error(body = function(resp) {
        tryCatch({
          detail <- httr2::resp_body_json(resp)$detail
          if (is.character(detail)) detail else paste(detail, collapse = "; ")
        }, error = function(e) NULL)
      }) |>
      httr2::req_perform()
    rlang::inform("Training complete.")
    return(httr2::resp_body_json(resp, simplifyVector = TRUE))
  }

  body <- list(
    data = to_nested_list(data),
    columns = colnames(data),
    hidden_layers = as.list(hidden_layers),
    dropout_prob = dropout_prob,
    epochs = epochs,
    batch_size = batch_size,
    lr = lr,
    corrupt_rate = corrupt_rate,
    num_adj = num_adj,
    cat_adj = cat_adj,
    bin_adj = bin_adj,
    pos_adj = pos_adj,
    omit_first = omit_first,
    verbose = FALSE,
    seed = seed
  )
  rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
  res <- post_json("/fit", body, timeout = 600)
  rlang::inform("Training complete.")
  res
}


# ---------------------------------------------------------------------------
# Transform
# ---------------------------------------------------------------------------

#' Generate multiple imputations
#'
#' Generates `m` imputed datasets from a fitted MIDAS model.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#'   a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param m Integer. Number of imputations (default 5).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list of `m` data frames, each with imputed values.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' imps <- midas_transform(fit, m = 10)
#' head(imps[[1]])
#' }
#' @export
midas_transform <- function(model_id, m = 5L, ...) {
  model_id <- extract_model_id(model_id)
  ensure_server(...)

  # Trigger transform on server
  post_json(
    paste0("/model/", model_id, "/transform"),
    list(m = m),
    timeout = 600
  )

  # Fetch each imputation
  imps <- vector("list", m)
  for (i in seq_len(m)) {
    res <- get_json(
      paste0("/model/", model_id, "/imputations/", i - 1L),
      timeout = 60
    )
    imps[[i]] <- parse_table(res)
  }
  imps
}


# ---------------------------------------------------------------------------
# All-in-one
# ---------------------------------------------------------------------------

#' Multiple imputation (all-in-one)
#'
#' Convenience function that fits a MIDAS model and generates imputations
#' in a single call. Equivalent to calling [midas_fit()] followed by
#' [midas_transform()].
#'
#' @inheritParams midas_fit
#' @param m Integer. Number of imputations (default 5).
#'
#' @return A list with `model_id` and `imputations` (a list of data frames).
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' result <- midas(df, m = 5, epochs = 10)
#' head(result$imputations[[1]])
#' }
#' @export
midas <- function(data, m = 5L, hidden_layers = c(256L, 128L, 64L),
                  dropout_prob = 0.5, epochs = 75L, batch_size = 64L,
                  lr = 0.001, corrupt_rate = 0.8,
                  num_adj = 1, cat_adj = 1, bin_adj = 1, pos_adj = 1,
                  omit_first = FALSE, seed = 89L, ...) {
  ensure_server(...)

  # For large data use parquet path
  if (nrow(data) > 100000L && rlang::is_installed("arrow")) {
    tmp <- tempfile(fileext = ".parquet")
    on.exit(unlink(tmp), add = TRUE)
    arrow::write_parquet(data, tmp)
    hl <- paste(hidden_layers, collapse = ",")
    rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
    resp <- base_req("/complete_parquet") |>
      httr2::req_body_multipart(
        file = curl::form_file(tmp, type = "application/octet-stream"),
        m = as.character(m),
        hidden_layers = hl,
        dropout_prob = as.character(dropout_prob),
        epochs = as.character(epochs),
        batch_size = as.character(batch_size),
        lr = as.character(lr),
        corrupt_rate = as.character(corrupt_rate),
        num_adj = as.character(num_adj),
        cat_adj = as.character(cat_adj),
        bin_adj = as.character(bin_adj),
        pos_adj = as.character(pos_adj),
        omit_first = as.character(omit_first),
        seed = if (!is.null(seed)) as.character(seed) else ""
      ) |>
      httr2::req_timeout(600) |>
      httr2::req_error(body = function(resp) {
        tryCatch({
          detail <- httr2::resp_body_json(resp)$detail
          if (is.character(detail)) detail else paste(detail, collapse = "; ")
        }, error = function(e) NULL)
      }) |>
      httr2::req_perform()
    rlang::inform("Training complete.")
    res <- httr2::resp_body_json(resp, simplifyVector = TRUE)
  } else {
    body <- list(
      data = to_nested_list(data),
      columns = colnames(data),
      m = m,
      hidden_layers = as.list(hidden_layers),
      dropout_prob = dropout_prob,
      epochs = epochs,
      batch_size = batch_size,
      lr = lr,
      corrupt_rate = corrupt_rate,
      num_adj = num_adj,
      cat_adj = cat_adj,
      bin_adj = bin_adj,
      pos_adj = pos_adj,
      omit_first = omit_first,
      verbose = FALSE,
      seed = seed
    )
    rlang::inform(paste0("Training MIDAS model (", epochs, " epochs)..."))
    res <- post_json("/complete", body, timeout = 600)
    rlang::inform("Training complete.")
  }

  # Parse imputations from response
  # simplifyVector may produce a 3D array or a list of matrices
  raw <- res$imputations
  if (is.array(raw) && length(dim(raw)) == 3L) {
    imp_list <- lapply(seq_len(dim(raw)[1]), function(i) {
      df <- as.data.frame(raw[i, , ])
      colnames(df) <- res$columns
      df
    })
  } else {
    imp_list <- lapply(raw, function(imp_data) {
      mat <- if (is.matrix(imp_data)) imp_data else do.call(rbind, imp_data)
      df <- as.data.frame(mat)
      colnames(df) <- res$columns
      df
    })
  }

  list(model_id = res$model_id, imputations = imp_list)
}


# ---------------------------------------------------------------------------
# imp_mean
# ---------------------------------------------------------------------------

#' Compute mean imputation
#'
#' Calculates the element-wise mean across all stored imputations for a model.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#'   a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A data frame with the mean imputed values.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' midas_transform(fit, m = 10)
#' mean_df <- imp_mean(fit)
#' }
#' @export
imp_mean <- function(model_id, ...) {
  model_id <- extract_model_id(model_id)
  ensure_server(...)
  res <- post_json(paste0("/model/", model_id, "/imp_mean"), list())
  parse_table(res)
}


# ---------------------------------------------------------------------------
# Combine (Rubin's rules)
# ---------------------------------------------------------------------------

#' Combine results using Rubin's rules
#'
#' Runs a GLM across all stored imputations and combines the results
#' using Rubin's combination rules for multiple imputation inference.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#'   a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param y Character. Name of the outcome variable.
#' @param ind_vars Character vector of independent variable names, or `NULL`
#'   for all non-outcome columns.
#' @param dof_adjust Logical. Apply Barnard-Rubin degrees-of-freedom
#'   adjustment (default `TRUE`).
#' @param incl_constant Logical. Include an intercept (default `TRUE`).
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A data frame with columns `term`, `estimate`, `std.error`,
#'   `statistic`, `df`, and `p.value`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(Y = rnorm(200), X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' midas_transform(fit, m = 10)
#' results <- combine(fit, y = "Y")
#' results
#' }
#' @export
combine <- function(model_id, y, ind_vars = NULL, dof_adjust = TRUE,
                    incl_constant = TRUE, ...) {
  model_id <- extract_model_id(model_id)
  ensure_server(...)
  body <- list(y = y, dof_adjust = dof_adjust, incl_constant = incl_constant)
  if (!is.null(ind_vars)) body$ind_vars <- as.list(ind_vars)
  res <- post_json(paste0("/model/", model_id, "/combine"), body, timeout = 300)
  parse_table(res)
}


# ---------------------------------------------------------------------------
# Overimpute diagnostic
# ---------------------------------------------------------------------------

#' Overimputation diagnostic
#'
#' Masks a fraction of observed values, re-imputes them, and computes
#' RMSE to assess imputation quality.
#'
#' @param model_id A character model ID, or a fitted model object (list with
#'   a `$model_id` element) as returned by [midas_fit()] or [midas()].
#' @param mask_frac Numeric. Fraction of observed values to mask (default 0.1).
#' @param m Integer. Number of imputations for the diagnostic (default 5).
#' @param seed Integer or `NULL`. Random seed.
#' @param ... Arguments forwarded to [ensure_server()].
#'
#' @return A list with `rmse` (named numeric vector) and `mean_rmse`.
#'
#' @examples
#' \dontrun{
#' df <- data.frame(X1 = rnorm(200), X2 = rnorm(200))
#' df$X1[sample(200, 40)] <- NA
#' fit <- midas_fit(df, epochs = 10L)
#' diag <- overimpute(fit, mask_frac = 0.1)
#' diag$mean_rmse
#' }
#' @export
overimpute <- function(model_id, mask_frac = 0.1, m = 5L, seed = NULL, ...) {
  model_id <- extract_model_id(model_id)
  ensure_server(...)
  body <- list(mask_frac = mask_frac, m = m)
  if (!is.null(seed)) body$seed <- seed
  post_json(paste0("/model/", model_id, "/overimpute"), body, timeout = 300)
}

Try the rMIDAS2 package in your browser

Any scripts or data that you put into this service are public.

rMIDAS2 documentation built on March 12, 2026, 9:07 a.m.