#############################################################################################################
# Authors:
# Amrit Singh, University of British Columbia, Vancouver.
# Florian Rohart, The University of Queensland, The University of Queensland Diamantina Institute, Translational Research Institute, Brisbane, QLD
# Kim-Anh Le Cao, The University of Queensland, The University of Queensland Diamantina Institute, Translational Research Institute, Brisbane, QLD
#
# created: 01-04-2015
# last modified: 27-05-2016
#
# Copyright (C) 2015
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#############################################################################################################
# ----------------------------------------------------------------------------------------------------------
# perf.assess.plsda - Function to evaluate the performance of fitted PLS-DA (cross-validation)
# ----------------------------------------------------------------------------------------------------------
## ------------------------------- (s)PLSDA ------------------------------- ##
#' @rdname perf.assess
#' @importFrom methods hasArg
#' @export
perf.assess.mixo_plsda <- function(object,
dist = c("all", "max.dist", "centroids.dist", "mahalanobis.dist"),
validation = c("Mfold", "loo"),
folds = 10,
nrepeat = 1,
auc = FALSE,
progressBar = FALSE,
signif.threshold = 0.01,
BPPARAM = SerialParam(),
seed = NULL,
...)
{
#-- initialising arguments --#
set.seed(seed)
BPPARAM$RNGseed <- seed
# these data are the centered and scaled X output or the unmapped(Y) scaled and centered
X = object$input.X
level.Y = object$names$colnames$Y #to make sure the levels are ordered
Y = object$Y
ncomp = object$ncomp
n = nrow(X)
logratio = object$logratio
if (is.null(logratio))
logratio = "none"
multilevel = object$multilevel # repeated measurement and Y
near.zero.var = !is.null(object$nzv) # if near.zero.var was used, we set it to TRUE. if not used, object$nzv is NULL
#-- tells which variables are selected in X and in Y --#
if (is(object, "mixo_splsda"))
{
keepX = object$keepX
} else {
keepX = rep(ncol(X), ncomp)
}
tol = object$tol
max.iter = object$max.iter
scale = object$scale
# check input arguments
if (hasArg(method.predict))
stop("'method.predict' argument has been replaced by 'dist' to match the 'tune' function")
method.predict = NULL # to pass R CMD check
dist = match.arg(dist, choices = c("all", "max.dist", "centroids.dist", "mahalanobis.dist"), several.ok = TRUE)
if (any(dist == "all"))
{
nmthdd = 3
dist = c("max.dist", "centroids.dist", "mahalanobis.dist")
} else {
nmthdd = length(dist)
}
if (length(validation) > 1 )
validation = validation [1]
if (!(validation %in% c("Mfold", "loo")))
stop("Choose 'validation' among the two following possibilities: 'Mfold' or 'loo'")
if (validation == "loo")
{
if (nrepeat != 1)
warning("Leave-One-Out validation does not need to be repeated: 'nrepeat' is set to '1'.")
nrepeat = 1
}
if (any(table(object$Y) <= 1)) {
stop(paste("Cannot evaluate performance when a class level ('",
names(table(object$Y))[which(table(object$Y) == 1)],
"') has only a single associated sample.", sep = ""))
}
if (!is.logical(progressBar))
stop("'progressBar' must be either TRUE or FALSE")
measure = c("overall","BER") # one of c("overall","BER")
if (!(logratio %in% c("none", "CLR")))
stop("Choose one of the two following logratio transformation: 'none' or 'CLR'")
#fold is checked in 'MCVfold'
#-- check significance threshold
signif.threshold <- .check_alpha(signif.threshold)
#---------------------------------------------------------------------------#
#-- logration + multilevel approach ----------------------------------------#
# we can do logratio and multilevel on the whole data as these transformation are done per sample
X = logratio.transfo(X = X, logratio = logratio)
if (!is.null(multilevel))
{
Xw = withinVariation(X, design = multilevel)
X = Xw
}
# -------------------------------------
# first check for near zero var on the whole data set
if (near.zero.var == TRUE)
{
nzv = nearZeroVar(X)
if (length(nzv$Position > 0))
{
warning("Zero- or near-zero variance predictors.\nReset predictors matrix to not near-zero variance predictors.\nSee $nzv for problematic predictors.")
X = X[, -nzv$Position, drop=TRUE]
if (ncol(X)==0)
stop("No more predictors after Near Zero Var has been applied!")
if (any(keepX > ncol(X)))
keepX = ncol(X)
}
}
# and then we start from the X data set with the nzv removed
#---------------------------------------------------------------------------#
#-- NA calculation ----------------------------------------------------#
misdata = c(X=anyNA(X), Y=FALSE) # Detection of missing data. we assume no missing values in the factor Y
if (any(misdata))
{
is.na.A = is.na(X)
#ind.NA = which(apply(is.na.A, 1, sum) > 0) # calculated only once
#ind.NA.col = which(apply(is.na.A, 2, sum) > 0) # calculated only once
} else {
is.na.A = NULL
#ind.NA = ind.NA.col = NULL
}
#-- Set up empty lists for outputs to go ------------------------------#
#---------------------------------------------------------------------------#
list.features = list()
mat.error.rate = mat.sd.error = mat.mean.error = error.per.class.keepX.opt = error.per.class.keepX.opt.mean = list()
error.per.class = list()
final=list()
for (measure_i in measure)
{
mat.sd.error[[measure_i]] = matrix(0, nrow = 1, ncol = length(dist),
dimnames = list(c(paste0('comp', ncomp)), dist))
mat.mean.error[[measure_i]] = matrix(0,nrow = 1, ncol = length(dist),
dimnames = list(c(paste0('comp', ncomp)), dist))
error.per.class.keepX.opt[[measure_i]] = list()
error.per.class.keepX.opt.mean[[measure_i]] = list()
mat.error.rate[[measure_i]]=list()
for(ijk in dist)
{
mat.error.rate[[measure_i]][[ijk]] = matrix(0, nrow = 1, ncol = nrepeat,
dimnames = list(c(paste0('comp', ncomp)), c(paste0('nrep', 1 : nrepeat))))
error.per.class.keepX.opt[[measure_i]][[ijk]] <- array(0, c(nlevels(Y), nrepeat, 1),
dimnames = list(levels(Y), paste0("nrep", 1:nrepeat), paste0('comp', ncomp)))
error.per.class.keepX.opt.mean[[measure_i]][[ijk]] = matrix(nrow = nlevels(Y), ncol = 1,
dimnames = list(c(levels(Y)), c(paste0('comp', ncomp))))
}
}
if(auc == TRUE)
{
auc.mean=list()
auc.all=list()
}
prediction.all = class.all = auc.mean = auc.all = list()
for(ijk in dist)
{
class.all[[ijk]] = array(0, c(nrow(X), nrepeat, 1),
dimnames = list(rownames(X),c(paste0('nrep', 1 : nrepeat)),c(paste0('comp', ncomp))))
}
class.object=class(object)
#-- Calculate error values - just on detected ncomp -------------------#
#---------------------------------------------------------------------------#
if (progressBar == TRUE)
cat("\ncomp", ncomp, "\n")
if(ncomp > 1)
{
choice.keepX = keepX[1 : (ncomp - 1)]
} else {
choice.keepX = NULL
}
test.keepX = keepX[ncomp]
names(test.keepX) = test.keepX
#test.keepX is a value
# estimate performance of the model
result = MCVfold.spls (X, Y, multilevel = multilevel, validation = validation, folds = folds, nrepeat = nrepeat, ncomp = ncomp,
choice.keepX = choice.keepX, test.keepX = test.keepX, test.keepY = nlevels(Y),
measure = measure, dist = dist, scale=scale,
near.zero.var = near.zero.var,
auc = auc, progressBar = progressBar, class.object = class.object,
BPPARAM = BPPARAM,
misdata = misdata, is.na.A = is.na.A)#, ind.NA = ind.NA, ind.NA.col = ind.NA.col)
# ---- extract stability of features ----- # NEW
if (is(object, "mixo_splsda"))
list.features[[1]] = result$features$stable
for (ijk in dist)
{
for (measure_i in measure)
{
mat.error.rate[[measure_i]][[ijk]] = result[[measure_i]]$mat.error.rate[[ijk]][1,]
mat.mean.error[[measure_i]][, ijk]=result[[measure_i]]$error.rate.mean[[ijk]]
if (!is.null(result[[measure_i]]$error.rate.sd))
{
mat.sd.error[[measure_i]][, ijk]=result[[measure_i]]$error.rate.sd[[ijk]]
} else {
mat.sd.error= NULL
}
# confusion matrix for keepX.opt, for each nrep
error.per.class.keepX.opt[[measure_i]][[ijk]] = result[[measure_i]]$confusion[[ijk]]
# confusion matrix for keepX.opt, averaged over all nrep
error.per.class.keepX.opt.mean[[measure_i]][[ijk]] = apply(result[[measure_i]]$confusion[[ijk]],1 , mean)
}
#prediction of each samples for each fold and each repeat, on each comp
class.all[[ijk]] = result$class.comp[[ijk]][,,1]
}
prediction.all[[1]] = array(unlist(result$prediction.comp),c(nrow(result$prediction.comp[[1]]), ncol(result$prediction.comp[[1]]), nrepeat),
dimnames = c(dimnames(result$prediction.comp[[1]])[1:2], list(paste0("nrep",1:nrepeat))))#[[1]][, , 1] #take only one component [[1]] and one of test.keepX [,,1]
if(auc == TRUE)
{
auc.all[[1]] = lapply(result$auc.all, function(x) x[,,1])
auc.mean[[1]] = result$auc[, , 1]
}
names(prediction.all) = paste0("comp", ncomp)
#-- Put together outputs -----------------------------------------------#
#---------------------------------------------------------------------------#
result = list(error.rate = mat.mean.error,
error.rate.sd = mat.sd.error,
error.rate.all = mat.error.rate,
error.rate.class = error.per.class.keepX.opt.mean[[1]],
error.rate.class.all = error.per.class.keepX.opt[[1]],
predict = prediction.all,
class = class.all)
if(auc)
{
names(auc.mean) = c(paste0("comp", ncomp))
result$auc = auc.mean
names(auc.all) = c(paste0("comp", ncomp))
result$auc.all =auc.all
}
if (is(object, "mixo_splsda"))
{
names(list.features) = paste0("comp", ncomp)
result$features$stable = list.features
}
if (progressBar == TRUE)
cat('\n')
# added
if (near.zero.var == TRUE)
result$nzvX = nzv$Position
if (is(object, "mixo_splsda"))
{
method = "splsda.mthd"
} else if (is(object, "mixo_plsda")) {
method = "plsda.mthd"
} else {
warning("Something that should not happen happened. Please contact us.")
}
# edit class to 'perf' and remove 'perf.plsda.mthd' class so that plotting functionality is stopped
class(result) = "perf"
result$call = match.call()
#updated outputs
return(invisible(result))
}
#' @rdname perf.assess
#' @export
perf.assess.mixo_splsda <- perf.assess.mixo_plsda
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.