R/predict.missoNet.R

Defines functions predict.cv.missoNet predict.missoNet

Documented in predict.cv.missoNet predict.missoNet

#' Predict method for \code{missoNet} models
#'
#' Generate predicted responses for new observations from a fitted
#' \code{\link{missoNet}} (or cross-validated) model. The prediction at a given
#' regularization choice \eqn{(\lambda_B,\lambda_\Theta)} uses the fitted
#' intercept(s) \eqn{\hat{\mu}} and coefficient matrix \eqn{\hat{B}}:
#' \deqn{\hat{Y} = \mathbf{1}_n \hat{\mu}^{\mathsf{T}} + X_\mathrm{new}\,\hat{B}.}
#'
#' @section Which solution is used:
#' The \code{s} argument selects the stored solution:
#' \itemize{
#'   \item \code{"lambda.min"} (default): the minimum CV error or selected
#'         GoF solution, stored in \code{object$est.min}.
#'   \item \code{"lambda.1se.beta"}: the 1-SE solution favoring larger
#'         \eqn{\lambda_B}, stored in \code{object$est.1se.beta}.
#'   \item \code{"lambda.1se.theta"}: the 1-SE solution favoring larger
#'         \eqn{\lambda_\Theta}, stored in \code{object$est.1se.theta}.
#' }
#' 1-SE solutions are available only if the model was fit with
#' \code{compute.1se = TRUE} during training or cross-validation.
#'
#' @param object A fitted \code{missoNet} (or cross-validated \code{missoNet})
#'   object that contains the components \code{$est.min} (and optionally
#'   \code{$est.1se.beta}, \code{$est.1se.theta}), each with numeric fields
#'   \code{$mu} (length \eqn{q}) and \code{$Beta} (p x q).
#' @param newx Numeric matrix of predictors with \eqn{p} columns (no intercept
#'   column of 1s). Missing or non-finite values are not allowed. Columns must be in
#'   the same order/scale used to fit \code{object}.
#' @param s Character string selecting the stored solution; one of
#'   \code{c("lambda.min","lambda.1se.beta","lambda.1se.theta")}.
#' @param ... Ignored; included for S3 compatibility.
#'
#' @return A numeric matrix of predicted responses of dimension
#'   \eqn{n_\mathrm{new} x q}. Row names are taken from \code{newx} (if any),
#'   and column names are inherited from the fitted coefficient matrix (if any).
#'
#' @details
#' This method does not modify or standardize \code{newx}. If the model was
#' trained with standardization, ensure that \code{newx} has been prepared in
#' the same way as the training data (same centering/scaling and column order).
#'
#' @seealso \code{\link{missoNet}}, \code{\link{cv.missoNet}}, \code{\link{plot.missoNet}}
#'
#' @examples
#' sim <- generateData(n = 200, p = 8, q = 6, rho = 0.1,
#'                     missing.type = "MCAR", seed = 123)
#' tr  <- 1:150
#' tst <- 151:200
#'
#' \donttest{
#' ## Cross-validated fit, keeping 1-SE solutions
#' cvfit <- cv.missoNet(X = sim$X[tr, ], Y = sim$Z[tr, ], kfold = 5,
#'                      compute.1se = TRUE, verbose = 0)
#'
#' ## Predict on held-out set
#' yhat_min  <- predict(cvfit, newx = sim$X[tst, ], s = "lambda.min")
#' yhat_b1se <- predict(cvfit, newx = sim$X[tst, ], s = "lambda.1se.beta")
#' yhat_t1se <- predict(cvfit, newx = sim$X[tst, ], s = "lambda.1se.theta")
#' dim(yhat_min)  # 50 x q
#' }
#'
#' @keywords models regression
#' @method predict missoNet
#' @export

predict.missoNet <- function(object,
                             newx,
                             s = c("lambda.min", "lambda.1se.beta", "lambda.1se.theta"),
                             ...) {
  ## ---- Input checks -------------------------------------------------------
  if (missing(newx) || is.null(newx)) {
    stop("Argument 'newx' must be provided as a numeric matrix (n x p).", call. = FALSE)
  }
  if (is.data.frame(newx)) newx <- as.matrix(newx)
  if (!is.matrix(newx) || !is.numeric(newx)) {
    stop("'newx' must be a numeric matrix.", call. = FALSE)
  }
  if (any(!is.finite(newx))) {
    stop("'newx' contains NA/NaN/Inf; predictions require finite values.", call. = FALSE)
  }
  
  s <- match.arg(s)
  
  est <- switch(
    s,
    "lambda.min"       = object$est.min,
    "lambda.1se.beta"  = object$est.1se.beta,
    "lambda.1se.theta" = object$est.1se.theta
  )
  if (is.null(est)) {
    msg <- switch(
      s,
      "lambda.1se.beta"  = "Requested 'lambda.1se.beta' but 'object$est.1se.beta' is NULL.\n",
      "lambda.1se.theta" = "Requested 'lambda.1se.theta' but 'object$est.1se.theta' is NULL.\n",
      "lambda.min"       = "Requested 'lambda.min' but 'object$est.min' is NULL.\n"
    )
    hint <- "Ensure the model was fit with the corresponding solution stored (e.g., compute.1se = TRUE for 1-SE choices)."
    stop(paste0(msg, hint), call. = FALSE)
  }
  
  Beta <- est$Beta
  mu   <- est$mu
  if (is.null(Beta) || is.null(mu)) {
    stop("Selected estimate does not contain both '$Beta' and '$mu'.", call. = FALSE)
  }
  if (!is.matrix(Beta) || !is.numeric(Beta)) {
    stop("'$Beta' must be a numeric matrix.", call. = FALSE)
  }
  if (!is.numeric(mu)) {
    stop("'$mu' must be a numeric vector.", call. = FALSE)
  }
  
  ## ---- Dimension checks & column alignment --------------------------------
  p_fit <- nrow(Beta)
  q_fit <- ncol(Beta)
  
  if (ncol(newx) != p_fit) {
    stop(sprintf("Column mismatch: ncol(newx) = %d, but the fitted model expects p = %d.",
                 ncol(newx), p_fit),
         call. = FALSE)
  }
  
  ## If both have column names and they are a permutation, reorder newx
  beta_rownms <- rownames(Beta)
  newx_colnms <- colnames(newx)
  if (!is.null(beta_rownms) && !is.null(newx_colnms)) {
    if (setequal(beta_rownms, newx_colnms) && !identical(beta_rownms, newx_colnms)) {
      newx <- newx[, beta_rownms, drop = FALSE]
    } else if (!identical(beta_rownms, newx_colnms)) {
      warning("Column names of 'newx' do not match training features; proceeding by position.", call. = FALSE)
    }
  }
  
  ## ---- Prediction ----------------------------------------------------------
  n_new <- nrow(newx)
  yhat  <- matrix(1, n_new, 1) %*% t(mu) + newx %*% Beta
  
  ## Add dimnames when available
  if (!is.null(rownames(newx))) rownames(yhat) <- rownames(newx)
  if (!is.null(colnames(Beta))) colnames(yhat) <- colnames(Beta)
  
  return(yhat)
}

#' @rdname predict.missoNet
#' @method predict cv.missoNet
#' @export

predict.cv.missoNet <- function(object,
                                newx,
                                s = c("lambda.min", "lambda.1se.beta", "lambda.1se.theta"),
                                ...) {
  predict.missoNet(object = object, newx = newx, s = s, ...)
}

Try the missoNet package in your browser

Any scripts or data that you put into this service are public.

missoNet documentation built on Sept. 9, 2025, 5:55 p.m.