Introduction

CISE is a package that fits the multiple-gragh factorization (M-GRAF) model ( Wang et al., 2017) for undirected graphs $A_i, i=1,\dots,n$, $$ A_{i[uv]}\mid\Pi_{i[uv]}\stackrel{\mbox{ind}}{\sim}\mbox{Bernoulli}(\Pi_{i[uv]}),\ u>v;u,v\in{1,2,\dots,V}, \ \mbox{logit}(\Pi_{i})=Z+D_{i},\ i=1,\dots,n. $$ Suppose the subject-specific deviation $D_i$ has low rank $K$. To accommodate different degrees of heterogeneity in the data, assume $D_i$ has 3 types of decomposition:

  1. $D_{i}=Q_{i}\Lambda_{i}Q_{i}^{\top}$

  2. $D_{i}=Q_{i}\Lambda Q_{i}^{\top}$

  3. $D_{i}=Q \Lambda_{i }Q^{\top}$

Model checking is needed to decide which assumption on $D_i$ is reasonable.

Example

Let's focus on studying the relationship between brain connectivity and one cognitive trait - visuospatial processing (VP). We first load HCP data on brain networks and VP ability of 212 subjects.

library(CISE)
data(A)
data(VSPLOT)

According to ( Wang et al., 2017), the variant 2 on $D_i$ is reasonable for this dataset while maintains a concise model. We use the MGRAF2 function in the CISE package to extract the low rank components {$Q_i$} and $\Lambda$. Classification of subjects into high and low VP group is proceeded via a distance-based procedure using these low rank components as described in ( Wang et al., 2017). The following procedure implements repeating 10-fold cross validation (CV) 30 times and displays the mean and and standard deviation of the CV accuracies under different choices of $K$.

n = length(VSPLOT)
y = numeric(n)
y[VSPLOT=="high"] = 1
zero_id = which(y==0)
one_id = which(y==1)

max_K = 7
acc_MGRAF2 = matrix(0,nr=max_K,nc=2) # store mean and sd of CV accuracies

for(K in 1:max_K){
  #print(paste("K=",K))
  res = MGRAF2(A = A, K=K, tol=0.01, maxit=5)
  Q_best = res$Q
  Lambda_best = res$Lambda

  ## do 10-fold cross validation and repeat 10 times
  nfd = 10 # number of folds
  rep = 30

  acc = numeric(nfd*rep) # store accuracy

  set.seed(22)

  for(t in 1:rep){
    foldid = sample(1:nfd, size=n, replace=T)

    for(fd in 1:nfd){
      test_id = which(foldid==fd)
      train_zero_id = setdiff(zero_id, test_id)
      train_one_id = setdiff(one_id, test_id)

      pred = sapply(test_id, function(i){

        ave_dist_zero = mean( sapply(train_zero_id, function(j){
          M = crossprod(Q_best[,,j], Q_best[,,i]) # KxK
          temp = 0
          for(k in 1:K){
            temp = temp + Lambda_best[k] * sum(M[k,] * Lambda_best * M[k,])
          }
          dist =  sqrt( 2 * sum(Lambda_best^2) - 2 * temp )
        }) )

        ave_dist_one = mean( sapply(train_one_id, function(j){
          M = crossprod(Q_best[,,j], Q_best[,,i]) # KxK
          temp = 0
          for(k in 1:K){
            temp = temp + Lambda_best[k] * sum(M[k,] * Lambda_best * M[k,])
          }
          dist = sqrt( 2 * sum(Lambda_best^2) - 2 * temp )
        }) )

        if (ave_dist_zero < ave_dist_one){
          return(0)
        }else{
          return(1)
        }
      })

      acc[nfd*(t-1)+fd] = sum(pred == y[test_id])/length(pred) 
      #print(fd)
    }
    #print(t)
  }
  acc_MGRAF2[K,1] = mean(acc)
  acc_MGRAF2[K,2] = sd(acc)
}
load("ExmData.RData")
library(knitr)
acc_data = data.frame(K=1:max_K, mean = acc_MGRAF2[ ,1], sd = acc_MGRAF2[ ,2])
kable(acc_data, digits = 3)


wangronglu/CISE documentation built on May 17, 2019, 8:46 p.m.