tests/testthat/obsolete_test_tracecalcs.R

library(bmdupdate)
source("helper.R")

context("Trace calculation")


m <- 400
nY <- 2
rho <- 0.5
ndata <- 300
Beta <- 0.2
s2 <- 1

mvrnormR <- function(n, mu, sigma) {
  ncols <- ncol(sigma)
  mu <- rep(mu, each = n) ## not obliged to use a matrix (recycling)
  mu + matrix(rnorm(n * ncols), ncol = ncols) %*% chol(sigma)
}

generate_data <- function(m, nY, rho, ndata, beta, s2) {
  Sig <- diag(1 - rho, m) + matrix(rep(rho, m^2), ncol = m)
  X <- mvrnormR(ndata, rep(0, m), Sig)
  Y <- Beta * rowSums(X) + matrix(rnorm(ndata * nY, sd = sqrt(s2)), ncol = nY)

  list(X = X, Y = Y)
}

precomputations <- function(X, Y) {
  X <- scale(X)
  Y <- scale(Y)
  ndata <- nrow(X)
  X2 <- X^2
  X3 <- X^3
  X4 <- X^4
  Y2 <- Y^2
  Y3 <- Y^3
  Y4 <- Y^4

  tX <- t(X)
  tY <- t(Y)
  tX2 <- t(X2)
  tX3 <- t(X3)
  XXt <- tcrossprod(X)
  XXt2 <- XXt^2

  Y4ColSums <- colSums(Y4)
  X4ColSums <- colSums(X4)
  X1RowSums <- rowSums(X)
  X2RowSums <- rowSums(X2)

  M <- ndata - 1

  allr <- crossprod(Y, X) / M
  allrSums <- rowSums(allr)
  allr22 <- crossprod(Y2, X2) / M
  allr31 <- crossprod(Y3, X) / M

  return(environment())
}

traceXC <- function(i, env){
  traceXCpp(
         y=env$Y[,i],
         y2=env$Y2[,i],
         y3=env$Y3[,i],
         y4Sum = env$Y4ColSums[i],
         rSqSum = sum(env$allr[i,]^2),
         r = env$allr[i,],
         r2 = env$allr22[i, ],
         r3 = env$allr31[i, ],
         X = env$X,
         XXt = env$XXt,
         XXt2 = env$XXt2,
         X2 = env$X2,
         X3 = env$X3,
         tX = env$tX,
         X2RowSums = env$X2RowSums,
         X4ColSums = env$X4ColSums)
}


traceMC <- function(i, env){
    traceMCpp(y = env$Y[,i],
          y4Sum = env$Y4ColSums[i],
              X = env$X,
            tX2 = env$tX2,
              r = env$allr[i,],
            r2  = env$allr22[i,],
             r3 = env$allr31[i,])
}


set.seed(1234567)

sim <- generate_data(m, nY, rho, ndata, beta, s2)
c <- precomputations(sim$X, sim$Y)

test_that("traceXC gives the same result as traceXR",{
  expect_equal(traceXR(1, c), traceXC(1, c))
  expect_equal(traceXR(2, c), traceXC(2, c))
})

test_that("traceMC gives the same result as traceMR",{
  expect_equal(traceMR(1, c), traceMC(1, c))
  expect_equal(traceMR(2, c), traceMC(2, c))
})
miheerdew/bmdupdate documentation built on May 17, 2019, 1:35 p.m.