library(pgR)
library(tidyverse)
library(microbenchmark)
## simulate some data
set.seed(11)
N <- 5000
d <- 40
df <- 500
X <- cbind(1, matrix(runif(N*df), N, df))
beta <- matrix(rnorm(ncol(X) *(d-1)), ncol(X), d-1)
mu_beta <- rep(0, ncol(X))
Sigma_beta <- 5 * diag(ncol(X))
Sigma_beta_inv <- solve(Sigma_beta)

eta <- X %*% beta
pi <- eta_to_pi(eta)

Y <- matrix(0, N, d)
for (i in 1:N) {
    Y[i, ] <- rmultinom(1, 50, pi[i, ])
}

## Calculate Mi
Mi <- matrix(0, N, d-1)
for(i in 1: N){
    Mi[i,] <- sum(Y[i, ]) - c(0, cumsum(Y[i,][1:(d-2)]))
}

nonzero_idx <- Mi != 0

## initialize kappa
kappa <- matrix(0, N, d-1)
for (i in 1: N) {
    kappa[i,] <- Y[i, 1:(d - 1)]- Mi[i, ] / 2
}

omega <- matrix(0, N, d-1)
omega[nonzero_idx] <- pgdraw(Mi[nonzero_idx], eta[nonzero_idx], cores = 1L)
Omega <- vector(mode = "list", length = d-1)
for (j in 1:(d - 1)) {
    Omega[[j]] <- diag(omega[, j])
}
all.equal(
    t(X) %*% (Omega[[j]] %*% X),
    t(X) %*% (omega[, j] * X)
)

if (file.exists(here::here("results", "timings-diagonal-multiplication.RData"))) {
    load(here::here("results", "timings-diagonal-multiplication.RData"))
} else {
    bm <- microbenchmark::microbenchmark(
        t(X) %*% (Omega[[j]] %*% X),
        t(X) %*% (omega[, j] * X)
    )
    save(bm, file = here::here("results", "timings-diagonal-multiplication.RData"))
}
print(bm)
autoplot(bm)

example sample functions

sample_beta <- function() {
    for (j in 1:(d-1)) {
        Sigma_tilde <- chol2inv(chol(Sigma_beta_inv + t(X) %*% (Omega[[j]] %*% X))) 
        mu_tilde    <- c(Sigma_tilde %*% (Sigma_beta_inv %*% mu_beta + t(X) %*% kappa[, j]))
        beta[, j]   <- mvnfast::rmvn(1, mu_tilde, Sigma_tilde)
    }
}

sample_beta_fast <- function() {
    for (j in 1:(d-1)) {
        Sigma_tilde <- chol2inv(chol(Sigma_beta_inv + t(X) %*% (omega[, j] * X))) 
        mu_tilde    <- c(Sigma_tilde %*% (Sigma_beta_inv %*% mu_beta + t(X) %*% kappa[, j]))
        beta[, j]   <- mvnfast::rmvn(1, mu_tilde, Sigma_tilde)
    }
}

sample_beta_AB <- function() {
    for (j in 1:(d-1)) {
        A <- Sigma_beta_inv + t(X) %*% (Omega[[j]] %*% X)
        b <- Sigma_beta_inv %*% mu_beta + t(X) %*% kappa[, j]
        beta[, j]   <- rmvn_arma(A, b)
    }
}

sample_beta_AB_fast <- function() {
    for (j in 1:(d-1)) {
        A <- Sigma_beta_inv + t(X) %*% (omega[, j] * X)
        b <- Sigma_beta_inv %*% mu_beta + t(X) %*% kappa[, j]
        beta[, j]   <- rmvn_arma(A, b)
    }
}



if (file.exists(here::here("results", "timings-beta-full-conditionals.RData"))) {
    load(here::here("results", "timings-beta-full-conditionals.RData"))
} else {
    bm <- microbenchmark::microbenchmark(
        sample_beta(),
        sample_beta_fast(),
        sample_beta_AB(),
        sample_beta_AB_fast()
    )
    save(bm, file = here::here("results", "timings-beta-full-conditionals.RData"))
}

# knitr::kable(print(bm))
print(bm)
autoplot(bm) 


jtipton25/pgR documentation built on July 8, 2022, 12:44 a.m.