R/classifyV.R

Defines functions classifyV

Documented in classifyV

# Classification for the test set based on V
# Xtrain is n1 by p
# Xtest is n2 by p
# Ytrain is n1 by 1
# V is p by g-1
classifyV <- function(Xtrain, Ytrain, Xtest, V, prior = TRUE, tol1 = 1e-10){
  if (any(is.na(Xtest))|any(is.na(Ytrain))|any(is.na(Xtrain))){
    stop("Missing values are not allowed!")
  }
  p <- ncol(Xtrain)
  if (ncol(Xtest)!=p){
    stop("Dimensions of Xtrain and Xtest don't match!")
  }
  
  G <- max(Ytrain)
  if (length(V)/(G-1) != p){
    stop("Dimensions of Xtrain and V don't match!")
  }  
  ntrain <- nrow(Xtrain)
  if (length(Ytrain)!=ntrain){
    stop("Dimensions of Xtrain and Ytrain don't match!")
  }
  
  ntest <- nrow(Xtest)
  Ytest <- rep(0,ntest) 
  V <- as.matrix(V)

  trainproj <- Xtrain%*%V
  testproj <- Xtest%*%V
  
  if (G==2){
    if (prior){
      outlda <- lda(trainproj, grouping=Ytrain, tol=1e-16)
      ypredlda <- predict(outlda, testproj)
      return(ypredlda$class)
    }else{
      means <- matrix(0,2,1)
      for (i in 1:2){
        means[i,] <- mean(trainproj[Ytrain==i,])
      }  
      Dis <- matrix(testproj^2,ntest,2)-2*tcrossprod(testproj,means)+matrix(t(means^2),ntest,2,byrow=T)
      Ytest <- apply(Dis,1,which.min)   
      return(Ytest)
    }
  }else{
    ######### G>2 ########################   
    myg <- as.factor(Ytrain)
    group.means <- tapply(trainproj,list(rep(myg,ncol(V)),col(trainproj)),mean)
    A1 <- var(trainproj-group.means[myg,]) 
    tmp <- eigen(A1,symmetric=T)
    if (min(tmp$values)>tol1){
        V <- V%*%tmp$vectors%*%diag(1/sqrt(tmp$values))
    }else { # V is low rank
        if (sum(tmp$values>tol1)>1){
            V <- V%*%tmp$vectors[,tmp$values>tol1]%*%diag(1/sqrt(tmp$values[tmp$values>tol1]))
        }else {
            V <- V%*%tmp$vectors[,tmp$values==max(tmp$values)]/sqrt(tmp$values[tmp$values==max(tmp$values)])
        }
    }
        
    trainproj <- Xtrain%*%V
    testproj <- Xtest%*%V

    if (prior == T){
        outlda <- lda(trainproj,grouping=Ytrain,tol=1e-16)
    }else{
        outlda <- lda(trainproj,grouping=Ytrain,prior=rep(1/max(Ytrain),max(Ytrain)),tol=1e-16)
    }
    ypredlda <- predict(outlda,testproj)
    return(ypredlda$class)
  }
}

Try the MGSDA package in your browser

Any scripts or data that you put into this service are public.

MGSDA documentation built on Sept. 4, 2023, 1:06 a.m.