R/ten_reshape.R

Defines functions ten_reshape

Documented in ten_reshape

#' Tensor reshape
#'
#' @description Performing tensor reshape on a given tensor.
#' @param ten An array representing a tensor.
#' @param AA A vector representing mode indices to reshape along.
#' @param time.mode Logical. TRUE if mode-1 of the input tensor is the time mode and hence not involved in reshape; otherwise reshape is on the entire input tensor. Default is TRUE.
#'
#' @return An array representing a reshaped tensor.
#' @export
#' @import tensorMiss
#'
#' @examples
#' ten_reshape(array(1:24, dim=c(2,3,4)), c(2,3), FALSE);
#'
#'
#'
ten_reshape <- function(ten, AA, time.mode = TRUE){
  if (requireNamespace(c('tensorMiss'), quietly = TRUE)){

    AA <- unique(AA)
    ten_K <- length(dim(ten))
    if (time.mode){
      if (ten_K<2){
        message("Parameter `ten` should be at least a vector time series, i.e. the length of dimension of 'ten' should be at least 2.")
        return()
      }
      if ( (min(AA)<1)|(max(AA)> (ten_K-1)) ){
        message("The elements in `AA` can only start from 1 to the length of dimension of 'ten' deducted by 1.")
        return()
      }
    } else {
      if (ten_K<1){
        message("Parameter `ten` should be at least a vector, i.e. the length of dimension of 'ten' should be at least 1.")
        return()
      }
      if ( (min(AA)<1)|(max(AA)>ten_K) ){
        message("The elements in `AA` can only start from 1 to the length of dimension of 'ten'.")
        return()
      }
    }

    if (time.mode){ AA <- AA+1 }
    AA <- sort(AA)

    if (length(AA)==1){
      return( tensorMiss::refold(tensorMiss::unfold(ten, AA), ten_K, c(dim(ten)[-AA] , dim(ten)[AA]) ) )
    } else if (length(AA)==2) {
      mat_2 <- tensorMiss::unfold(ten, AA[2])
      mat_1 <- tensorMiss::unfold(array(mat_2[1,], dim = dim(ten)[-AA[2]]), AA[1])
      if (nrow(mat_2)==1){
        return( tensorMiss::refold(mat_1, ten_K-1, c(dim(ten)[-AA] , prod(dim(ten)[AA])) ) )
      }
      for (i in 2:nrow(mat_2)){
        mat_1 <- rbind(mat_1, tensorMiss::unfold(array(mat_2[i,], dim = dim(ten)[-AA[2]]), AA[1]))
      }
      return( tensorMiss::refold(mat_1, ten_K-1, c(dim(ten)[-AA] , prod(dim(ten)[AA])) ) )
    } else if (length(AA)>=3){
      if (time.mode){
        # return(ten_reshape(ten, AA[c(length(AA)-1,length(AA))]-1, time.mode))
        return( ten_reshape(ten_reshape(ten, AA[c(length(AA)-1,length(AA))]-1, time.mode), c(AA[-c(length(AA)-1,length(AA))], ten_K-1)-1, time.mode) )
      } else {
        # return(ten_reshape(ten, AA[c(length(AA)-1,length(AA))], time.mode))
        return( ten_reshape(ten_reshape(ten, AA[c(length(AA)-1,length(AA))], time.mode), c(AA[-c(length(AA)-1,length(AA))], ten_K-1), time.mode) )
      }
    }

  }
}

Try the KOFM package in your browser

Any scripts or data that you put into this service are public.

KOFM documentation built on April 3, 2025, 11:05 p.m.