R/perturbation-clustering-helpers.R

Defines functions classify classifierProb classifierSmall classifierBig CalcPerturbedDiscrepancy CalcAUC GetPerturbedSimilarity GetOriginalSimilarity ClusterToConnectivity BuildConnectivityMatrix GetPerturbationAlgorithm GetClusteringAlgorithm

GetClusteringAlgorithm <- function(clusteringMethod = "kmeans", clusteringFunction = NULL, clusteringOptions = NULL, ...){
    name = clusteringMethod
    
    if (!is.function(clusteringFunction)) {
        switch(
            clusteringMethod,
            kmeans = {
                clusteringFunction <- kmeansWrapper
            },
            pam = {
                clusteringFunction <- pamWrapper
            },
            hclust = {
                clusteringFunction <- hclustWrapper
            },
            {
                stop("clusteringMethod not found. Please pass a clusteringFunction instead")
            }
        )
    }
    else {
        name = "Unknow"
    }
    
    list(
        fun = function(data, k) do.call(clusteringFunction, c(list(data = data, k = k), clusteringOptions)),
        name = name
    )
}

GetPerturbationAlgorithm <- function(data = data, perturbMethod = "noise", perturbFunction = NULL, perturbOptions = NULL, ...){
    name = perturbMethod
    
    if (!is.function(perturbFunction)) {
        switch(
            perturbMethod,
            noise = {
                if (is.null(perturbOptions))
                    perturbOptions <- list()
                
                noise = perturbOptions$noise
                if (is.null(noise))
                    noise <- GetNoise(data, noisePercent = perturbOptions$noisePercent)
                
                # get perturbed similarity
                perturbFunction <- function(data, ...) {
                    AddNoisePerturb(data = data, noise = noise)
                }
            },
            subsampling = {
                perturbFunction <- SubSampling
            },
            {
                stop("perturbMethod not found. Please pass a perturbFunction instead")
            }
        )
    }
    else {
        name = "Unknow"
    }
    
    list(
        fun = function(data) do.call(perturbFunction, c(list(data = data), perturbOptions)),
        name = name
    )
}

BuildConnectivityMatrix <- function(data, clus, clusteringAlgorithm) {
    rowNum <- nrow(data)
    S <- matrix(0L, rowNum, rowNum)
    cluster <- clusteringAlgorithm(data, clus)
    
    for (j in 1:clus) {
        idx <- cluster == j
        S[idx, idx] <- 1L
    }
    
    rownames(S) <- rownames(data)
    colnames(S) <- rownames(data)
    
    list(matrix = S, groups = cluster)
}

ClusterToConnectivity <- function(cluster){
    S <- matrix(0L, length(cluster), length(cluster))
    for (j in 1:max(cluster)) {
        idx <- cluster == j
        S[idx, idx] <- 1L
    }
    S
}

GetOriginalSimilarity <- function(data, clusRange, clusteringAlgorithm, showProgress = F, ncore) {
    groupings <- list()
    origS <- list()
    
    if (showProgress) {
        pb <- txtProgressBar(min = 0, max = length(clusRange), style = 3)
    }
    
    seeds = abs(round(rnorm(max(clusRange))*10^6))
    
    for (clus in clusRange){
        set.seed(seeds[clus])
        groupings[[clus]] <- clusteringAlgorithm(data, clus)
        origS[[clus]] <- ClusterToConnectivity(groupings[[clus]])
        rownames(origS[[clus]]) <- colnames(origS[[clus]]) <- rownames(data)
        if (showProgress) setTxtProgressBar(pb, clus)
    }
    
    list(origS = origS, groupings = groupings)
}

GetPerturbedSimilarity <- function(data, clusRange, iterMax, iterMin, origS, clusteringAlgorithm, perturbedFunction, stoppingCriteriaHandler, showProgress = F, ncore) {
    pertS <- list()
    currentIter <- rep(0,max(clusRange))

    jobs <- rep(clusRange, iterMax)
    maxJob = length(jobs)

    seeds = list()

    for (clus in clusRange) {
        pertS[[clus]] <- list()
        seeds[[clus]] <- round(rnorm(max(iterMax,1000))*10^6)
    }

    if (showProgress) pb <- txtProgressBar(min = 0, max = maxJob, style = 3)

    kProgress = rep(0, max(clusRange))
    step = iterMin
    if(.Platform$OS.type == "unix") doParallel::registerDoParallel(min(ncore, step))
    while (length(jobs) > 0) {
        jobLength = step*length(clusRange)
        step = 10

        if (jobLength > length(jobs)){
            jobLength =  length(jobs)
        }

        currentJobs = list()
        count = 1
        for (clus in jobs[1:jobLength]){
            kProgress[clus] = kProgress[clus] + 1
            currentJobs[[count]] <- list(iter = kProgress[clus], clus = clus)
            count = count + 1
        }

        if (jobLength < length(jobs)){
            jobs <- jobs[(jobLength+1):length(jobs)]
        }
        else jobs <- c()

        job <- NULL
        
        if(.Platform$OS.type == "unix") 
        {
            rets <- foreach(job = currentJobs) %dopar% {
                clus = job$clus
                
                set.seed(seeds[[clus]][job$iter])
                perturbedRet <- perturbedFunction(data = data)
                
                cMatrix <- BuildConnectivityMatrix(data = perturbedRet$data, clus, clusteringAlgorithm)
                connectivityMatrix = perturbedRet$ConnectivityMatrixHandler(connectivityMatrix = cMatrix$matrix, iter = job$iter, k = clus)
                
                list(
                    connectivityMatrix = connectivityMatrix,
                    clus = clus,
                    auc = CalcAUC(origS[[clus]], connectivityMatrix)$area
                )
            }
        } else {
            rets <- foreach(job = currentJobs) %do% {
                clus = job$clus
                
                set.seed(seeds[[clus]][job$iter])
                perturbedRet <- perturbedFunction(data = data)
                
                cMatrix <- BuildConnectivityMatrix(data = perturbedRet$data, clus, clusteringAlgorithm)
                connectivityMatrix = perturbedRet$ConnectivityMatrixHandler(connectivityMatrix = cMatrix$matrix, iter = job$iter, k = clus)
                
                list(
                    connectivityMatrix = connectivityMatrix,
                    clus = clus,
                    auc = CalcAUC(origS[[clus]], connectivityMatrix)$area
                )
            }
        }
        
        
        allStop = F
        for(ret in rets){
            if (kProgress[clus] != -1){
                clus = ret$clus
                currentIter[clus] <- currentIter[clus] + 1
                pertS[[clus]][[currentIter[clus]]] <- ret$connectivityMatrix
                
                stop <- stoppingCriteriaHandler(iter = currentIter[clus], k = clus, auc = ret$auc)
                if (stop == 2) {
                    allStop <- T
                }
                else if (stop == 1) {
                    if (showProgress) setTxtProgressBar(pb, getTxtProgressBar(pb) + length(jobs[jobs == clus]))
                    jobs <- jobs[jobs != clus]
                    kProgress[clus] <- -1
                }
            }
            if (showProgress) setTxtProgressBar(pb, getTxtProgressBar(pb) + 1)
        }
        if (allStop) break()

        countDone <- -1
        while(countDone != length(which(kProgress[clusRange] == -1))){
            countDone <- length(which(kProgress[clusRange] == -1))

            for (clus in clusRange){
                if (kProgress[clus] != -1){
                    stop <- stoppingCriteriaHandler(k = clus, type = 1)
                    if (stop == 1) {
                        if (showProgress) setTxtProgressBar(pb, getTxtProgressBar(pb) + length(jobs[jobs == clus]))
                        jobs <- jobs[jobs != clus]
                        kProgress[clus] <- -1
                    }
                }
            }
        }
    }
    if(.Platform$OS.type == "unix") doParallel::stopImplicitCluster()

    if (showProgress) {
        setTxtProgressBar(pb, maxJob)
        cat("\n")
    }

   

    pertS
}

CalcAUC <- function(orig, pert) {
    N <- nrow(orig)
    S <- abs(orig - pert)
    diag(S) <- 0
    # added -10^(-5) for visual purposes
    # A <- c(-10^(-5), sort(unique(as.numeric(S))))
    A <- c(-10^(-5), 0, sort(unique(as.numeric(S[S!=0]))))
    if (max(A) < 1)
        A <- c(A, 1)
    B <- NULL
    for (i in 1:length(A)) {
        B[i] <- sum(S <= A[i])/(N * N)
    }
    
    area <- 0
    for (i in (2:length(A))) {
        area <- area + B[i - 1] * (A[i] - A[i - 1])
    }
    
    list(area = area, entry = A, cdf = B)
}

CalcPerturbedDiscrepancy <- function(origS, pertS, clusRange) {
    diff <- NULL
    for (clus in clusRange) {
        diff[clus] <- sum(abs(origS[[clus]] - pertS[[clus]]))
    }
    
    AUC <- NULL
    entries <- list()
    cdfs <- list()
    
    for (clus in clusRange) {
        ret <- CalcAUC(origS[[clus]], pertS[[clus]]);
        
        entries[[clus]] <- ret$entry
        cdfs[[clus]] <- ret$cdf
        AUC[clus] <- ret$area
    }
    
    list(Diff = round(diff, digits = 10), Entry = entries, CDF = cdfs, AUC = round(AUC, digits = 10))
}


classifierBig <- function(train, label, test, knn.k){
    train <- as.data.frame(train)
    
    fold <- sample(rep(1:5, ceiling(nrow(train)/5))[1:nrow(train)])
    
    if (is.null(knn.k)){
        minK <- 5
        maxK <- min(round(nrow(train)/5), 50)
        
        errorRate <- colMeans(
            do.call(
                what = rbind,
                args = lapply(1:5, function(foldIter){
                    nn_index <- FNN::knnx.index(train[fold != foldIter, ], train[fold == foldIter, ], k = maxK)
                    sapply(minK:maxK, function(k){
                        predicted <- apply(nn_index[, 1:k], 1, function(index){
                            tbl <- table(label[fold != foldIter][index])
                            names(tbl)[which.max(tbl)]
                        })
                        
                        sum(predicted != label[fold == foldIter])/length(predicted)
                    })
                })
            )
        )
        
        minR <- which(errorRate == min(errorRate))
        bestK <- (minK:maxK)[minR[length(minR)]]
    } else {
        bestK <- knn.k
    }
    
    nn_index <- FNN::knnx.index(train, test, k = bestK)
    
    res <- t(apply(nn_index, 1, function(indices) {
        tbl <- table(label[indices])[as.character(unique(label))]/bestK
        tbl/sum(tbl, na.rm = T)
    }))
    
    res[is.na(res)] <- 0
    
    colnames(res) <- unique(label)
    rownames(res) <- rownames(test)
    res
}

classifierSmall <- function(train, label, test){
    
    colnames(train) <- colnames(test) <- NULL
    dat <- as.matrix(rbind(train, test))
    
    centers <- as.matrix(
        do.call(
            what = rbind,
            args = lapply(unique(label), function(l){
                if (sum(label == l) == 1) return(train[label == l,])
                colMeans(train[label == l,])
            })
        )
    )
    
    cluster <- kmeans(dat, centers = centers, iter.max = 1000)$cluster
    
    res <- do.call(
        what = rbind,
        args = lapply(1:nrow(test), function(i){
            r <- table(label[cluster[1:nrow(train)] == cluster[nrow(train) + i]])[as.character(unique(label))]/table(label)[as.character(unique(label))]
            # r/sum(r, na.rm = T)
        })
    )
    
    res[is.na(res)] <- 0
    
    colnames(res) <- unique(label)
    rownames(res) <- rownames(test)
    res
}

classifierProb = function(train, label, test, knn.k){
    if (nrow(train) > 100){
        classifierBig(train, label, test, knn.k)
    } else {
        classifierSmall(train, label, test)
    }
}

classify <- function(train, label, test, knn.k){
    if (nrow(train) > 100){
        prob <- classifierBig(train, label, test, knn.k)
    } else {
        prob <- classifierSmall(train, label, test)
    }
    
    unlist(apply(prob, 1, function(p) colnames(prob)[which.max(p)]))
}

Try the PINSPlus package in your browser

Any scripts or data that you put into this service are public.

PINSPlus documentation built on May 29, 2024, 6:12 a.m.