R/reinflateFac.R

Defines functions reinflateFac

Documented in reinflateFac

#' Reinflate all datablocks from a model Fac object.
#'
#' Basically a wrapper function for [reinflateTensor()] and [reinflateMatrix()].
#'
#' @param Fac Fac object output from CMTF and ACMTF
#' @param Z Z object as generated by [setupCMTFdata()].
#' @param returnAsTensor Boolean to return data blocks as rTensor tensor objects (default FALSE)
#'
#' @return List of data blocks
#' @export
#'
#' @examples
#' set.seed(123)
#' A = array(rnorm(108*2), c(108, 2))
#' B = array(rnorm(100*2), c(100, 2))
#' C = array(rnorm(10*2), c(10, 2))
#' D = array(rnorm(100*2), c(100,2))
#' E = array(rnorm(10*2), c(10,2))
#'
#' df1 = reinflateTensor(A, B, C)
#' df2 = reinflateTensor(A, D, E)
#' datasets = list(df1, df2)
#' modes = list(c(1,2,3), c(1,4,5))
#' Z = setupCMTFdata(datasets, modes, normalize=FALSE)
#'
#' result = cmtf_opt(Z, 1, max_iter=2)
#' Xhats = reinflateFac(result$Fac, Z)
reinflateFac = function(Fac, Z, returnAsTensor=FALSE){
  Fac = lapply(Fac, as.matrix) # Cast to matrix for correct indexation in the one-component case.
  numDatasets = length(Z$object)
  numModes = max(unlist(Z$modes))
  numComponents = ncol(Fac[[1]])
  reinflatedFac = list()

  # Check for ACMTF case
  ACMTFcase = FALSE
  if(length(Fac) > numModes){
    ACMTFcase = TRUE
  }

  for(p in 1:numDatasets){
    modes = Z$modes[[p]]
    reinflatedBlock = array(0L, dim(Z$object[[p]]))

    for(i in 1:numComponents){
      lambda = ifelse(ACMTFcase, Fac[[numModes+1]][p,i], 1) # Check for ACMTF model lambdas, otherwise lambda=1

      if(length(modes) == 3){
        reinflatedBlock = reinflatedBlock + lambda * reinflateTensor(Fac[[modes[1]]][,i], Fac[[modes[2]]][,i], Fac[[modes[3]]][,i])
      } else if(length(modes) == 2){
        reinflatedBlock = reinflatedBlock + lambda * reinflateMatrix(Fac[[modes[1]]][,i], Fac[[modes[2]]][,i])
      } else{
        stop("Reinflation of blocks of higher modes than 3 is not yet implemented.")
      }
    }

    if(returnAsTensor == TRUE){
      reinflatedFac[[p]] = rTensor::as.tensor(reinflatedBlock)
    }
    else{
      reinflatedFac[[p]] = reinflatedBlock
    }
  }

  return(reinflatedFac)
}

Try the CMTFtoolbox package in your browser

Any scripts or data that you put into this service are public.

CMTFtoolbox documentation built on Aug. 23, 2025, 1:11 a.m.