Nothing
#' Perform 'cold start' prediction for Tucker models
#'
#' @export
#' @param d an input data object created with \code{input_data}
#' @param m a \code{Tucker_model} object created with \code{mk_model}
#' @return Response tensor generated by multiplying the input data through the trained model
#' @examples
#' data.params <- get_data_params(c('decomp=Tucker'))
#' toy <- mk_toy(data.params)
#' train.data <- input_data$new(mode1.X=toy$mode1.X[,-1],
#' mode2.X=toy$mode2.X[,-1],
#' mode3.X=toy$mode3.X[,-1],
#' resp=toy$resp)
#' model.params <- get_model_params(c('decomp=Tucker'))
#' toy.model <- mk_model(train.data, model.params)
#' toy.model$rand_init(model.params)
#'
#' resp <- test_Tucker(train.data, toy.model)
test_Tucker <- function(d, m) {
# First argument is the input data for the test set, second is a trained
# Tucker_model object.
# Note: predictor data should be normalized in the same way as the training data already
# Subset the input data to just predictors in mode1.A, m$mode2.A.mean or m$mode3.A.mean
if(length(dimnames(m$mode1.A.mean)[[1]]) > 0) {
mode1.X <- d$mode1.X[,dimnames(d$mode1.X)[[2]] %in% dimnames(m$mode1.A.mean)[[1]], drop=F]
mode1.X <- d$mode1.X[,match(dimnames(m$mode1.A.mean)[[1]],
dimnames(d$mode1.X)[[2]], nomatch=0), drop=F]
}
if(length(dimnames(m$mode2.A.mean)[[1]]) > 0) {
mode2.X <- d$mode2.X[,dimnames(d$mode2.X)[[2]] %in% dimnames(m$mode2.A.mean)[[1]], drop=F]
mode2.X <- d$mode2.X[,match(dimnames(m$mode2.A.mean)[[1]],
dimnames(d$mode2.X)[[2]], nomatch=0), drop=F]
}
if(length(dimnames(m$mode3.A.mean)[[1]]) > 0) {
mode3.X <- d$mode3.X[,dimnames(d$mode3.X)[[2]] %in% dimnames(m$mode3.A.mean)[[1]], drop=F]
mode3.X <- d$mode3.X[,match(dimnames(m$mode3.A.mean)[[1]],
dimnames(d$mode3.X)[[2]], nomatch=0), drop=F]
}
# Make predictions by multiplying input data through model
if(ncol(d$mode1.X) == 0) {
mode1.X.times.A <- m$mode1.H.mean
} else if(ncol(m$mode1.A.mean) == ncol(m$mode1.H.mean)) {
mode1.X.times.A <- safe_prod(mode1.X, m$mode1.A.mean)
} else {
mode1.X.times.A <- cbind(1, safe_prod(mode1.X, m$mode1.A.mean))
}
if(ncol(d$mode2.X) == 0) {
mode2.X.times.A <- m$mode2.H.mean
} else if(ncol(m$mode2.A.mean) == ncol(m$mode2.H.mean)) {
mode2.X.times.A <- safe_prod(mode2.X, m$mode2.A.mean)
} else {
mode2.X.times.A <- cbind(1, safe_prod(mode2.X, m$mode2.A.mean))
}
if(ncol(d$mode3.X) == 0) {
mode3.X.times.A <- m$mode3.H.mean
} else if(ncol(m$mode3.A.mean) == ncol(m$mode3.H.mean)) {
mode3.X.times.A <- safe_prod(mode3.X, m$mode3.A.mean)
} else {
mode3.X.times.A <- cbind(1, safe_prod(mode3.X, m$mode3.A.mean))
}
resp <- mult_3d(m$core.mean, mode1.X.times.A, mode2.X.times.A, mode3.X.times.A)
# if(m$multiplier != 0) resp <- resp / m$multiplier
return(resp)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.