R/PPclassify.R

Defines functions PP.classify

#' Predict class for the test set and calculate prediction error
#' 
#' After finding tree structure, predict class for the test set and calculate prediction error.
#' @usage PP.classify(test.data, true.class, Tree.result, Rule, ...)  
#' @param test.data  the test dataset
#' @param true.class true class of test dataset if available
#' @param Tree.result the result of PP.Tree
#' @param Rule split rule 1:mean of two group means, 2:weighted mean, 3: mean of max(left group) and min(right group), 4: weighted mean of max(left group) and min(right group)
#' @return predict.class predicted class
#' @return predict.error prediction error
#' @references Lee, YD, Cook, D., Park JW, and Lee, EK(2013) 
#' PPtree: Projection pursuit classification tree, 
#' Electronic Journal of Statistics, 7:1369-1386.
#' @export
#' @keywords tree
#' @examples
#' data(iris)
#' n <- nrow(iris)
#' tot <- c(1:n)
#' n.train <- round(n*0.9)
#' train <- sample(tot,n.train)
#' test <- tot[-train]
#' Tree.result <- PP.Tree("LDA",iris[train,5],iris[train,1:4])
#' PP.classify(iris[test,1:4],iris[test,5],Tree.result,1)

PP.classify <- function(test.data, true.class=NULL, Tree.result, Rule, ...) {
    test.data<-as.matrix(test.data)
    if(!is.null(true.class))
    {  true.class<-as.matrix(true.class); 
       if(nrow(true.class)==1) true.class<-t(true.class)
       if (!is.numeric(true.class)) {
          class.name<-names(table(true.class))
          temp<-rep(0,nrow(true.class))
          for(i in 1:length(class.name))
             temp<-temp+(true.class==class.name[i])*i
          true.class<-temp
       }
    }   

    PP.Classification <- function(Tree.Struct, test.class.index, IOindex,
                                  test.class, id, rep) {
        if (Tree.Struct[id,4] == 0) {
            i.class <- test.class
            i.class[i.class > 0] <- 1
            i.class <- 1 - i.class
            test.class <- test.class + IOindex * i.class * Tree.Struct[id, 3]
            return(list(test.class=test.class, rep=rep))
        } else {  
            IOindexL <- IOindex * test.class.index[rep,]
            IOindexR <- IOindex * (1 - test.class.index[rep,])
            rep <- rep + 1
            a <- PP.Classification(Tree.Struct, test.class.index, IOindexL,
                                   test.class, Tree.Struct[id,2], rep)
            test.class <- a$test.class
            rep <- a$rep;
            a <- PP.Classification(Tree.Struct, test.class.index, IOindexR,
                                   test.class, Tree.Struct[id,3], rep)
            test.class <- a$test.class
            rep <- a$rep
        }
        list(test.class=test.class, rep=rep)
    }
    
    PP.Class.index <- function(class.temp, test.class.index, test.data,
                               Tree.Struct, Alpha.Keep, C.Keep, id,Rule) {
        class.temp <- as.integer(class.temp)
        if (Tree.Struct[id,2] == 0) {
            return(list(test.class.index=test.class.index,
                        class.temp=class.temp))
        } else {
            t.class <- class.temp 
            t.n <- length(t.class[t.class == 0])
            t.index <- sort.list(t.class)
            if (t.n) t.index <- sort(t.index[-(1:t.n)])
            t.data <- test.data[t.index,]
            id.proj <- Tree.Struct[id,4]
            
            proj.test <- as.matrix(test.data) %*%
                as.matrix(Alpha.Keep[id.proj,])
            ##  proj.test<-(proj.test-mean(proj.test))
            proj.test <- as.double(proj.test)
            class.temp <- t(proj.test < C.Keep[id.proj,Rule]) 
            test.class.index <- rbind(test.class.index, class.temp)
            a <- PP.Class.index(class.temp, test.class.index, test.data,
                                Tree.Struct, Alpha.Keep, C.Keep,
                                Tree.Struct[id,2], Rule)
            test.class.index <- a$test.class.index
            a<-PP.Class.index(1 - class.temp, test.class.index, test.data,
                              Tree.Struct, Alpha.Keep, C.Keep,
                              Tree.Struct[id,3], Rule)
            test.class.index <- a$test.class.index;
        }
        list(test.class.index=test.class.index, class.temp=class.temp)
    }
    
    n <- nrow(test.data)
    class.temp <- rep(1, n)
    test.class.index <- NULL
    temp <- PP.Class.index(class.temp, test.class.index, test.data,
                           Tree.result$Tree.Struct, Tree.result$Alpha.Keep,
                           Tree.result$C.Keep, 1, Rule)
    test.class <- rep(0, n)
    IOindex <- rep(1, n)
    rep <- 1
    temp <- PP.Classification(Tree.result$Tree.Struct, temp$test.class.index,
                              IOindex, test.class, 1, 1)
    if(!is.null(true.class)){
       predict.error <- sum(true.class != temp$test.class)
    } else {
       predict.error <- NA
    }  
    predict.class <- temp$test.class
    list(predict.error=predict.error, predict.class=predict.class)
}
EK-Lee/PPtreeCR documentation built on May 6, 2019, 3:08 p.m.