R/test_CP.R

Defines functions test_CP

Documented in test_CP

#' Perform 'cold start' prediction using BaTFLED algorthm for CP models
#' 
#' @export
#' @param d an input data object created with \code{input_data}
#' @param m a \code{CP_model} object created with \code{mk_model}
#' 
#' @return Response tensor generated by multiplying the input data through the trained model
#' @examples
#' data.params <- get_data_params(c('decomp=CP'))
#' toy <- mk_toy(data.params)
#' train.data <- input_data$new(mode1.X=toy$mode1.X[,-1],
#'                              mode2.X=toy$mode2.X[,-1],
#'                              mode3.X=toy$mode3.X[,-1],
#'                              resp=toy$resp)
#' model.params <- get_model_params(c('decomp=CP'))
#' toy.model <- mk_model(train.data, model.params)
#' toy.model$rand_init(model.params)
#'
#' train(d=train.data, m=toy.model, new.iter=1, params=model.params)
#' 
#' resp <- test_CP(train.data, toy.model)

test_CP <- function(d, m) {
  # First argument is the input data for the test set, second is a trained CP model object

  # Subset the input data to just predictors in mode1.A, mode2.A or mode3.A
  if(length(dimnames(m$mode1.A.mean)[[1]]) > 0) {
    d$mode1.X <- d$mode1.X[,dimnames(d$mode1.X)[[2]] %in% dimnames(m$mode1.A.mean)[[1]], drop=F]
    d$mode1.X <- d$mode1.X[,match(dimnames(m$mode1.A.mean)[[1]], dimnames(d$mode1.X)[[2]], nomatch=0), drop=F]
  }
  if(length(dimnames(m$mode2.A.mean)[[1]]) > 0) {
    d$mode2.X <- d$mode2.X[,dimnames(d$mode2.X)[[2]] %in% dimnames(m$mode2.A.mean)[[1]], drop=F]
    d$mode2.X <- d$mode2.X[,match(dimnames(m$mode2.A.mean)[[1]], dimnames(d$mode2.X)[[2]], nomatch=0), drop=F]
  }
  if(length(dimnames(m$mode3.A.mean)[[1]]) > 0) {
    d$mode3.X <- d$mode3.X[,dimnames(d$mode3.X)[[2]] %in% dimnames(m$mode3.A.mean)[[1]], drop=F]
    d$mode3.X <- d$mode3.X[,match(dimnames(m$mode3.A.mean)[[1]], dimnames(d$mode3.X)[[2]], nomatch=0), drop=F]
  }
  
  # Make predictions by multiplying input data through model
  if(ncol(d$mode1.X) == 0) { mode1.X.times.A <- m$mode1.H.mean
  } else mode1.X.times.A <- safe_prod(d$mode1.X, m$mode1.A.mean)
  
  if(ncol(d$mode2.X) == 0) { mode2.X.times.A <- m$mode2.H.mean
  } else mode2.X.times.A <- safe_prod(d$mode2.X, m$mode2.A.mean)
  
  if(ncol(d$mode3.X) == 0) { mode3.X.times.A <- m$mode3.H.mean
  } else mode3.X.times.A <- safe_prod(d$mode3.X, m$mode3.A.mean)

  R <- ncol(m$mode1.H.mean)
  core <- array(0, dim=c(R,R,R))
  for(r in 1:R) core[r,r,r] <- 1
  resp <- mult_3d(core, mode1.X.times.A, mode2.X.times.A, mode3.X.times.A)

  return(resp)
}
nathanlazar/BaTFLED3D documentation built on May 23, 2019, 12:19 p.m.