tests/testthat/test-sbm-bernoulli-with-covariates.R

set.seed(1234)
library(sbm)
library(aricode)

N_cov <- 80
Q <- 2
M <- 1
source("utils_test.R", local = TRUE)

test_that("SimpleSBM_fit 'Bernoulli' model, undirected, one covariate", {

  sampler_undirected_cov$rNetwork(store = TRUE)

  ## blockmodels
  mySBM_sbm <- sbm::SimpleSBM_fit$new(sampler_undirected_cov$networkData, 'bernoulli', FALSE, covarList = covarList_undirected)
  mySBM_sbm$optimize(estimOptions=list(verbosity = 0, plot = FALSE))
  mySBM_sbm$setModel(Q)

  ## missSBM
  net <- missSBM:::partlyObservedNetwork$new(sampler_undirected_cov$networkData, covariates = covarList_undirected)
  cls <- net$clustering(1:(2*Q))
  cl <- cls[[Q]]
  mySBM_missSBM <- missSBM:::SimpleSBM_fit_withCov$new(net, clusterInit = cl, covarList = covarList_undirected)
  mySBM_missSBM$doVEM()

  ## correctness

  ## distance with blockmodels/sbm estiamtor
  expect_lt(rmse(mySBM_missSBM$connectParam$mean, mySBM_sbm$connectParam$mean), 0.05)
  expect_gt(ARI(mySBM_missSBM$memberships, mySBM_sbm$memberships), 0.8)
  expect_lt(rmse(mySBM_missSBM$loglik, mySBM_sbm$loglik), 0.01)

  ## distance to true values
  expect_lt(rmse(mySBM_missSBM$connectParam$mean, sampler_undirected_cov$connectParam$mean), 0.1)
  expect_lt(rmse(mySBM_missSBM$covarParam, sampler_undirected_cov$covarParam), 0.1)
  expect_gt(ARI(mySBM_missSBM$memberships, sampler_undirected_cov$memberships), 0.85)

})

test_that("SimpleSBM_fit 'Bernoulli' model, directed, one covariate", {

  sampler_directed_cov$rNetwork(store = TRUE)

  ## Construction----------------------------------------------------------------
  mySBM_sbm <- sbm::SimpleSBM_fit$new(sampler_directed_cov$networkData, 'bernoulli', TRUE, covarList = covarList_directed)
  mySBM_sbm$optimize(estimOptions=list(verbosity = 0, plot = FALSE))
  mySBM_sbm$setModel(Q)

  ## missSBM
  net <- missSBM:::partlyObservedNetwork$new(sampler_directed_cov$networkData, covariates = covarList_directed)
  cls <- net$clustering(1:(2*Q))
  cl <- cls[[Q]]
  mySBM_missSBM <- missSBM:::SimpleSBM_fit_withCov$new(net, clusterInit = cl, covarList = covarList_directed)
  mySBM_missSBM$doVEM(trace = TRUE)

  ## correctness
  ## distance with blockmodels/sbm estiamtor
  expect_lt(rmse(mySBM_missSBM$connectParam$mean, mySBM_sbm$connectParam$mean), 0.1)
  expect_gt(ARI(mySBM_missSBM$memberships, mySBM_sbm$memberships), 0.8)
  expect_lt(rmse(mySBM_missSBM$loglik, mySBM_sbm$loglik), 0.05)

  ## distance to true values
  expect_lt(rmse(mySBM_missSBM$connectParam$mean, sampler_directed_cov$connectParam$mean), 0.1)
  expect_gt(ARI(mySBM_missSBM$memberships, sampler_directed_cov$memberships), 0.85)

})
jchiquet/missSBM documentation built on Oct. 25, 2023, 5:30 a.m.