R/mse.R

Defines functions mse

Documented in mse

#' Mean Squared Error (MSE) of a factorization
#'
#' @description Mean squared error of factor models "W" and "H" given "A"
#' 
#' @param A dgCMatrix of samples (columns) by features (rows)
#' @param W matrix of class "matrix" with factors (columns) by features (rows)
#' @param H matrix of class "matrix" with samples (columns) by factors (rows)
#' @param detail Whether to calculate residual and imputed error for both entire and zero-masked inputs
#' @return 
#' If detail = FALSE, returns only mean squared error. If detail = TRUE, returns a list of:
#' \itemize{
#'   \item mse: mean squared error, mean((A - W * H)^2)  
#'   \item err.imp: imputed signal in W and H not in A, given by (A - W * H)_{>0}
#'   \item err.res: residual signal in A not in W and H, given by (A - W * H)_{>0}   
#'   \item zm.err.mse: mean squared error for all non-zero values in A    
#'   \item zm.err.imp: imputed signal for all non-zero values in A, regardless of whether they are non-zero in W * H     
#'   \item zm.err.res: residual signal for all non-zero values in A, regardless of whether they are non-zero in W * H       
#' }
#' @examples
#' \dontrun{
#' data(moca7k)
#' model <- lsmf(moca7k, 5, rel.tol = 1e-2)
#' err.summary <- mse(moca7k, model$W, model$H)
#' err.summary$mse == model$mse
#' # [1] TRUE
#' }
mse <- function(A, W, H, detail = TRUE){
  if(class(A) != "dgCMatrix") stop("A must be a dgCMatrix")
  if(class(W)[1] != "matrix" || class(H)[1] != "matrix") stop("W and H must be of class 'matrix'")
  if(A@Dim[1] != nrow(W)) stop("nrow(A) must == nrow(W)")
  if(A@Dim[2] != ncol(H)) stop("ncol(A) must == ncol(H)")
  if(ncol(W) != nrow(H)) stop("ncol(W) must == nrow(H)")
  if(class(W) != "matrix" || class(H) != "matrix") stop("W and H must be dense matrices of class 'matrix'")
  if(detail) {
    return(mse_detail(A, W, H))
  } else {
    return(mse_simple(A, W, H))
  }
}
zdebruine/LSMF documentation built on Jan. 1, 2021, 1:50 p.m.