To Do

We are testing the Polya-gamma linear model pg_lm()

library(pgR)
# library(MCMCpack)
library(splines)
library(tidyverse)
library(patchwork)
library(BayesLogit)

simulate some data

set.seed(404)
N <- 5000
J <- 2
X <- runif(N)
df <- 4
Xbs <- bs(X, df)
beta <- matrix(rnorm((J-1) * df), df, (J-1))
## make the intercepts smaller to reduce stochastic ordering effect
beta[1, ] <- beta[1, ] - seq(from = 4, to = 0, length.out = J - 1)
eta <- Xbs %*% beta 
pi <- eta_to_pi(eta)

Y <- matrix(0, N, J)

for (i in 1:N) {
    Y[i, ] <- rmultinom(1, 1, pi[i, ])
}
Y_prop     <- counts_to_proportions(Y)

Plot the simulated data

dat <- data.frame(
  y        = c(Y_prop),

  x        = X,
  species  = factor(rep(1:J, each = N)))

dat_truth <- data.frame(
  y        = c(pi),
  x        = X,
  species  = factor(rep(1:J, each = N)))

dat %>%
  group_by(species) %>%
  sample_n(pmin(nrow(Y), 500)) %>%
  ggplot(aes(y = y, x = x, group = species, color = species)) +
  geom_point(alpha = 0.2) +
  ylab("Proportion of count") +
  geom_line(data = dat_truth, aes(y = y, x = x, group = species), color = "black", lwd = 0.5) +
  facet_wrap(~ species, ncol = 4) +
  ggtitle("Simulated data")
## fit the model to verify parameters
params <- default_params() 
params$n_adapt <- 1000
params$n_mcmc <- 1000
params$n_message <- 500
params$n_thin <- 1
priors <- default_priors_pg_lm(Y, Xbs)
inits <- default_inits_pg_lm(Y, Xbs, priors)
if (file.exists(here::here("results", "pg_logit.RData"))) {
    load(here::here("results", "pg_logit.RData"))
} else {
    start <- Sys.time()
    out <- pg_lm(Y, as.matrix(Xbs), params, priors, n_cores = 1L, sample_rmvn = FALSE)
    stop <- Sys.time()
    runtime <- stop - start

    save(out, runtime, file = here::here("results", "pg_logit.RData"))
}
betas <- out$beta
dimnames(betas) <- list(
  iteration = 1:dim(betas)[1],
  parameter = 1:dim(betas)[2],
  species = 1:dim(betas)[3]
)
as.data.frame.table(betas, responseName = "value") %>%
  mutate(iteration = as.numeric(iteration)) %>%
  ggplot(aes(x = iteration, y = value, group = parameter, color = parameter)) +
  geom_line() +
  facet_wrap(~ species)

apply(out$beta, 2, mean)
beta
# layout(matrix(1:9, 3, 3))
# for (i in 1:9) {
#   matplot(out$beta[, , i], type = 'l', main = paste("species", i))
#   abline(h = beta[, i], col = 1:nrow(beta))
# }

etas <- out$eta
## plot beta estimates
p_beta <- data.frame(
  beta_mean = c(apply(out$beta, c(2, 3), mean)), 
  beta_lower = c(apply(out$beta, c(2, 3), quantile, prob = 0.025)),  
  beta_upper = c(apply(out$beta, c(2, 3), quantile, prob = 0.975)), 
  beta_truth = c(beta),
  species = factor(rep(1:(J-1), each = ncol(Xbs))),
  knots = factor(1:ncol(Xbs))
) %>%
  ggplot(aes(x = beta_truth, y = beta_mean, color = knots)) +
  scale_color_viridis_d(begin = 0, end = 0.8) +
  geom_point() +
  geom_errorbar(aes(ymin = beta_lower, ymax = beta_upper)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~ species, nrow = 3) +
  geom_abline(intercept = 0, slope = 1, col = "red")

## plot eta estimates
p_eta <- data.frame(
  eta_mean = c(apply(out$eta, c(2, 3), mean)), 
  eta_lower = c(apply(out$eta, c(2, 3), quantile, prob = 0.025)),  
  eta_upper = c(apply(out$eta, c(2, 3), quantile, prob = 0.975)), 
  eta_truth = c(eta),
  species = factor(rep(1:(J-1), each = ncol(Xbs))),
  knots = factor(1:ncol(Xbs))
) %>%
  ggplot(aes(x = eta_truth, y = eta_mean, color = knots)) +
  scale_color_viridis_d(begin = 0, end = 0.8) +
  geom_point() +
  geom_errorbar(aes(ymin = eta_lower, ymax = eta_upper)) +
  geom_point(alpha = 0.5) +
  facet_wrap(knots ~ species, nrow = 3) +
  geom_abline(intercept = 0, slope = 1, col = "red")

p_beta / p_eta
pi_post <- array(0, dim = c(dim(out$eta)[1], dim(out$eta)[2], dim(out$eta)[3] + 1))
for (i in 1:dim(out$eta)[1]) {
  pi_post[i, , ] <- eta_to_pi(out$eta[i, , ])
}

 dat_pi <- data.frame(
  pi_mean = c(apply(pi_post, c(2, 3), mean)), 
  pi_lower = c(apply(pi_post, c(2, 3), quantile, prob = 0.025)),  
  pi_upper = c(apply(pi_post, c(2, 3), quantile, prob = 0.975)), 
  pi_truth = c(pi),
  species = factor(rep(1:J, each = N)),
  observation = factor(1:N)) 

p_pi <-  dat_pi %>% 
  ggplot(aes(x = pi_truth, y = pi_mean)) +
  scale_color_viridis_d(begin = 0, end = 0.8) +
  geom_point() +
  geom_errorbar(aes(ymin = pi_lower, ymax = pi_upper)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~ species, nrow = 3) +
  geom_abline(intercept = 0, slope = 1, col = "red")

p_pi
# plot the fitted response curves
dat <- data.frame(
  y        = c(Y_prop),

  x        = X,
  species  = factor(rep(1:J, each = N)))

dat_fit <- data.frame(
  pi_mean  = c(apply(pi_post, c(2, 3), mean)),
  pi_lower = c(apply(pi_post, c(2, 3), quantile, prob = 0.025)),  
  pi_upper = c(apply(pi_post, c(2, 3), quantile, prob = 0.975)), 
  x        = X,
  species  = factor(rep(1:J, each = N)))

dat %>%
  group_by(species) %>%
  sample_n(pmin(nrow(Y), 500)) %>%
  ggplot(aes(y = y, x = x, group = species, color = species)) +
  geom_point(alpha = 0.2) +
  ylab("Proportion of count") +
  geom_line(data = dat_fit, aes(y = pi_mean, x = x, group = species), color = "black", lwd = 0.5) +
  geom_line(data = dat_truth, aes(y = y, x = x, group = species), color = "blue", lwd = 0.5) +
  facet_wrap(~ species, ncol = 4) +
  ggtitle("Simulated data")


jtipton25/pgR documentation built on July 8, 2022, 12:44 a.m.