R/predictionMap.R

Defines functions format.PredictionMap format trainSVMe1071 predictionMap

Documented in predictionMap

#' Construction of a Prediction Map Object
#'
#' Makes a PredictionMap object for the given data.
#'
#' @importFrom TunePareto trainTuneParetoClassifier
#' @importFrom e1071 svm
#' @importFrom stats predict
#' @importFrom foreach getDoParWorkers
#' @importFrom utils sessionInfo capture.output
#'
#' @param data
#' The data either as matrix or data.frame with samples in rows and features in columns.
#' @param labels
#' A vector of labels of data, consecutively labelled starting with 0.
#' @param foldList
#' A set of partitions for a cross-validation experiment.
#' This list comprises as many elements as cross-validation runs.
#' Each run is a list of as many vectors as folds. The entries are the indices of the samples that are left out in the folds.
#' This fold list can be generated by using \code{\link[TunePareto:generateCVRuns]{TunePareto::generateCVRuns()}}.
#' If the foldList is set to \code{NULL} (default) reclassification is performed.
#' @param parallel
#' Either TRUE or FALSE (default). If TRUE the pairwise training is performed parallelized.
#' @param classifier
#' A TunePareto classifier object. \cr 
#' For detailed information refer to \code{\link[TunePareto:tuneParetoClassifier]{TunePareto::tuneParetoClassifier()}}.
#' @param ...
#' Further parameters of the classifier object.
#'
#' @details
#' Using a reclassification or a cross-validation set up this function performs a pairwise training for all class combinations and evaluates all samples using the trained classifiers.
#' This means that for each run, fold and binary classifier the predicted class for each sample is calculated.
#'
#'
#' @return
#' A PredictionMap object.
#' It is made up of a list of two matrices, which are called meta and pred. Both matrices provide information for individual samples column-wise.
#' The meta information in meta connects the values in the pred-matrix to a specific fold, run, sample and contains the original label.
#' The rownames of the pred-matrix (e.g. [0vs1]) show the classes of the binary base classifier. The elements are the prediction result of a specific training.
#' The rows that correspond to base classifiers that would separate the same class consists of -1. Those rows are not used within the analysis.
#'
#' @seealso \code{\link{summary.PredictionMap}}, \code{\link{print.PredictionMap}}, \code{\link{plot.PredictionMap}}
#'
#' @examples
#' library(TunePareto)
#' data(esl)
#' data = esl$data
#' labels = esl$labels
#' foldList = generateCVRuns(labels  = labels,
#'                           ntimes      = 2,
#'                           nfold       = 2,
#'                           leaveOneOut = FALSE,
#'                           stratified  = TRUE)
#'
#' # svm with linear kernel
#' predMap = predictionMap(data, labels, foldList = foldList,
#'                        classifier = tunePareto.svm(), kernel='linear')
#' \donttest{
#' # knn with k = 3
#' predMap = predictionMap(data, labels, foldList = foldList,
#'                        classifier = tunePareto.knn(), k = 3)
#' # randomForest
#' predMap = predictionMap(data, labels, foldList = foldList,
#'                        classifier = tunePareto.randomForest())
#'                            }
#' @export
predictionMap <- function(data=NULL,
                          labels=NULL,
                          foldList = NULL,
                          parallel = FALSE,
                          classifier = NULL, ...){

  if(parallel == TRUE){
    if (!getDoParWorkers()>1)
      stop(errorStrings('parallel'))
    `%par%` <- foreach::`%dopar%`
  }else{
    `%par%` <- foreach::`%do%`
  }


    #################################################
    ##
    ## Check parameter 'data'

    if(is.null(data))
        stop(errorStrings('dataMissing'))

    if(!(is.data.frame(data)|is.matrix(data)))
        stop(errorStrings('data'))

    #################################################
    ##
    ## Check parameter 'labels'

    if(is.null(labels))
        stop(errorStrings('labelsMissing'))

    if( !is.numeric(labels) | (length(labels)!=nrow(data)) )
        stop(errorStrings('labels'))

    classes <- sort(unique(labels))
    sample.id <- 1:length(labels)

    if(!all(classes == 0:(length(classes)-1) ))
        stop(errorStrings('labels'))

    #################################################
    ##
    ## Check parameter 'foldList'

    if(!is.null(foldList)){
      .validateFoldList(foldList, nrow(data))
    }
    is_reclassification <- is.null(foldList)

    #################################################
    ##
    ## Check parameter 'classifier'

    if(is.null(classifier))
      stop(errorStrings('classifierMissing'))
    
    if(!inherits(classifier, "TuneParetoClassifier"))
      stop(errorStrings("classifier"))


    #################################################

    numSam <- length(labels)

    class.combs <- as.matrix(expand.grid(c2 = classes, c1 = classes))
    class.combs <- class.combs[,c(2,1)]

    packages = classifier$requiredPackages

    print.output.classifier <- c()

    if(is_reclassification){
        # Reclassification experiment (RE)
        #################################################
        ##
        ## Generating pred (RE)

        pred <- do.call("rbind",foreach::foreach(combCount=1:nrow(class.combs),.packages = packages) %par% {
            comb <- class.combs[combCount,]

            if(comb[1]==comb[2]){
               return(rep(-1, numSam))
            }else{
                train.id <-c( which(labels[sample.id] == comb[1]),
                              which(labels[sample.id] == comb[2]))

                model    <- TunePareto::trainTuneParetoClassifier(
                                classifier  = classifier,
                                trainData   = data[train.id,,drop=FALSE],
                                trainLabels = labels[train.id],
                                ...)

                print.output.classifier <- invisible(capture.output(model))

                pred.cl  <- as.numeric(as.character(predict(object=model, newdata = data)))

                return(pred.cl)
            }
        })

        rownames(pred) <- apply(class.combs,1,function(x){paste('[',x[1],'vs',x[2],']',sep = '')})

        ########################################################
        ###
        ### Generating meta (RE)

        meta <- rbind(label = as.vector(labels),run = rep(1, numSam), fold = rep(1, numSam), sample = sample.id)

    }else{  #Crossvalidation experiment (CV)

        #################################################
        ##
        ## Generating pred (CV)

        numRun <- length(foldList)
        numFold<- length(foldList[[1]])

        task.combs <- as.matrix(expand.grid(run = 1:numRun, fold = 1:numFold))

        pred <- c()

        for (combCount in 1:nrow(class.combs)){
            comb <- class.combs[combCount,]

            cl1 = comb[1]
            cl2 = comb[2]

            if(cl1 == cl2){
                result.comb <- rep(-1, numSam*numRun)
                pred <- rbind(pred, result.comb)
            }else{

                result.comb <- foreach::foreach(rC=1:nrow(task.combs),.packages = packages) %par% {

                    fold <- foldList[[task.combs[rC,'run']]][[task.combs[rC,'fold']]]

                    train.id <-sample.id[-fold]
                    train.id <-train.id[ c( which(labels[train.id] == comb[1]),
                                            which(labels[train.id] == comb[2])) ]

                    model <- TunePareto::trainTuneParetoClassifier(
                      classifier = classifier,
                      trainData = data[train.id,,drop=FALSE],
                      trainLabels = labels[train.id], ...)

                    print.output.classifier <- invisible(capture.output(model))

                    pred.cl <- as.numeric(as.character(predict(model, newdata = data[fold,,drop=FALSE])))

                    return(pred.cl)
                }

                mat <- matrix(-1, nrow = numSam, ncol = numRun)

                for(rC in 1:nrow(task.combs)){
                    my.run  <- task.combs[rC,'run']
                    my.fold <- task.combs[rC,'fold']
                    fold <- foldList[[my.run]][[my.fold]]

                    mat[fold, my.run] <- result.comb[[rC]]
                }

                result.comb <- as.vector(mat)
                pred <- rbind(pred, result.comb)
            }
        }

        rownames(pred) <- apply(class.combs,1,function(x){paste('[',x[1],'vs',x[2],']',sep = '')})

        ########################################################
        ###
        ### Generating meta (CV)

        run.mat <- matrix(-1, nrow = numSam, ncol = numRun)
        fold.mat <- matrix(-1, nrow = numSam, ncol = numRun)

        for(rC in 1:nrow(task.combs)){
            my.run  <- task.combs[rC,'run']
            my.fold <- task.combs[rC,'fold']
            fold <- foldList[[my.run]][[my.fold]]

            run.mat[fold, my.run] <- my.run
            fold.mat[fold, my.run] <- my.fold
        }

        real.labels <- rep(labels, numRun)


        meta <- rbind(label = as.vector(real.labels),run = as.vector(run.mat), fold = as.vector(fold.mat), sample = rep(sample.id, numRun))
    }
    ########################################################
    ###
    ### Generating result object

    index <- order(apply(meta,2,function(x){sum(x*c(10^15, 10^10,10^5,1))}))

    structure(list(pred = pred[,index],
                   meta = meta[,index],
                   printParams = list(data=paste(nrow(data),"x",ncol(data)," matrix"),
                                      labels=paste("Vector of length ",length(labels)),
                                      foldList=ifelse(!is.null(foldList),
                                                      paste(length(foldList)," Runs, ",
                                                            length(foldList[[1]])," Folds",
                                                            sep=""),
                                                      "Reclassification"
                                                      ),
                                      parallel=parallel),
                   printClassifierParams = print.output.classifier),
              class = "PredictionMap")
}

# specific trainers
# e1071
trainSVMe1071 <- function(x,y,control,...){
  model <- e1071::svm(x=x, y=y,)
}

#generic function for formatting outputs of a PredictionMap object
format <- function(x, ...) UseMethod("format")

#implementation of the generic function \code{\link{format}} to give an formatted output of a PredictionMap output
#' @export
format.PredictionMap <- function(x, showMeta=TRUE, showPred=TRUE, ...) {
  if(showMeta) {
    cat("Meta data:\n")
    print.default(x$meta, ...)
  }
  if(showPred) {
    which.negatives <- which(apply(x$pred,1,mean) == -1)
    cat("\nPredictions:\n")
    print.default(x$pred[-which.negatives,], ...)
  }
}

Try the ORION package in your browser

Any scripts or data that you put into this service are public.

ORION documentation built on Feb. 12, 2026, 5:07 p.m.