R/test_Tucker.R

Defines functions test_Tucker

Documented in test_Tucker

#' Perform 'cold start' prediction for Tucker models
#' 
#' @export
#' @param d an input data object created with \code{input_data}
#' @param m a \code{Tucker_model} object created with \code{mk_model} 
#' @return Response tensor generated by multiplying the input data through the trained model
#' @examples
#' data.params <- get_data_params(c('decomp=Tucker'))
#' toy <- mk_toy(data.params)
#' train.data <- input_data$new(mode1.X=toy$mode1.X[,-1],
#'                              mode2.X=toy$mode2.X[,-1],
#'                              mode3.X=toy$mode3.X[,-1],
#'                              resp=toy$resp)
#' model.params <- get_model_params(c('decomp=Tucker'))
#' toy.model <- mk_model(train.data, model.params)
#' toy.model$rand_init(model.params)
#'
#' resp <- test_Tucker(train.data, toy.model)

test_Tucker <- function(d, m) {
  # First argument is the input data for the test set, second is a trained 
  # Tucker_model object. 
  
  # Note: predictor data should be normalized in the same way as the training data already
  
  # Subset the input data to just predictors in mode1.A, m$mode2.A.mean or m$mode3.A.mean
  if(length(dimnames(m$mode1.A.mean)[[1]]) > 0) {
    mode1.X <- d$mode1.X[,dimnames(d$mode1.X)[[2]] %in% dimnames(m$mode1.A.mean)[[1]], drop=F]
    mode1.X <- d$mode1.X[,match(dimnames(m$mode1.A.mean)[[1]], 
                                dimnames(d$mode1.X)[[2]], nomatch=0), drop=F]
  }
  if(length(dimnames(m$mode2.A.mean)[[1]]) > 0) {
    mode2.X <- d$mode2.X[,dimnames(d$mode2.X)[[2]] %in% dimnames(m$mode2.A.mean)[[1]], drop=F]
    mode2.X <- d$mode2.X[,match(dimnames(m$mode2.A.mean)[[1]], 
                                dimnames(d$mode2.X)[[2]], nomatch=0), drop=F]
  }
  if(length(dimnames(m$mode3.A.mean)[[1]]) > 0) {
    mode3.X <- d$mode3.X[,dimnames(d$mode3.X)[[2]] %in% dimnames(m$mode3.A.mean)[[1]], drop=F]
    mode3.X <- d$mode3.X[,match(dimnames(m$mode3.A.mean)[[1]], 
                                dimnames(d$mode3.X)[[2]], nomatch=0), drop=F]
  }
  
  # Make predictions by multiplying input data through model
  if(ncol(d$mode1.X) == 0) {
    mode1.X.times.A <- m$mode1.H.mean
  } else if(ncol(m$mode1.A.mean) == ncol(m$mode1.H.mean)) {
    mode1.X.times.A <- safe_prod(mode1.X, m$mode1.A.mean)
  } else {
    mode1.X.times.A <- cbind(1, safe_prod(mode1.X, m$mode1.A.mean))
  }
  
  if(ncol(d$mode2.X) == 0) {
    mode2.X.times.A <- m$mode2.H.mean
  } else if(ncol(m$mode2.A.mean) == ncol(m$mode2.H.mean)) {
    mode2.X.times.A <- safe_prod(mode2.X, m$mode2.A.mean)
  } else {
    mode2.X.times.A <- cbind(1, safe_prod(mode2.X, m$mode2.A.mean))
  }
  
  if(ncol(d$mode3.X) == 0) {
    mode3.X.times.A <- m$mode3.H.mean
  } else if(ncol(m$mode3.A.mean) == ncol(m$mode3.H.mean)) {
    mode3.X.times.A <- safe_prod(mode3.X, m$mode3.A.mean)
  } else {
    mode3.X.times.A <- cbind(1, safe_prod(mode3.X, m$mode3.A.mean))
  }
  
  resp <- mult_3d(m$core.mean, mode1.X.times.A, mode2.X.times.A, mode3.X.times.A)

  # if(m$multiplier != 0) resp <- resp / m$multiplier

  return(resp)
}

Try the BaTFLED3D package in your browser

Any scripts or data that you put into this service are public.

BaTFLED3D documentation built on May 2, 2019, 2:38 p.m.