R/fcts.R

Defines functions LDDMM_fix_bound compute_WAIC plot_post_pars extract_post_draws extract_post_mean LDDMM_const_bound LDDMM_full LDDMM sample_smooth_var H_ball_unif H_ball B_basis g_HB P_smooth1 dhalfcauhy plot_RT plot_accuracy

Documented in B_basis compute_WAIC extract_post_draws extract_post_mean H_ball LDDMM plot_accuracy plot_post_pars plot_RT P_smooth1

require(Rcpp)
require(RcppArmadillo)
require(mvtnorm)
require(gtools)
require(ggplot2)
require(latex2exp)
require(RColorBrewer)
require(LaplacesDemon)
require(plyr)

`%notin%` <- Negate(`%in%`)


#' Descriptive plots 
#'
#' Plot the accuracy of the raw data.
#' 
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @return Individual and population level raw accuracies
plot_accuracy <- function(data){
  
  block = d = s = cens = subject = tally = n = freq = NULL
  
  data_aggr <- data %>%
    dplyr::filter(cens == 0) %>%
    dplyr::group_by(block, s, d) %>% 
    tally %>% 
    dplyr::mutate(freq = n / sum(n), 
                  s = factor(s), 
                  d = factor(d)) %>% 
    dplyr::ungroup() %>% 
    dplyr::select(block, s, d, freq)
  
  data %>%
    dplyr::filter(cens == 0) %>%
    dplyr::group_by(subject, block, s, d) %>% 
    tally %>% 
    dplyr::mutate(freq = n / sum(n), 
                  subject = factor(subject), 
                  s = factor(s), 
                  d = factor(d)) %>% 
    dplyr::ungroup() %>% 
    ggplot() +
    geom_point(aes(x = block, y = freq, group = subject, col = d), alpha = 0.2) +
    geom_line(aes(x = block, y = freq, group = interaction(subject, d), col = d), alpha = 0.2) +
    geom_line(aes(x = block, y = freq, col = d), size = 1.5, data = data_aggr) +
    facet_wrap( ~ s, nrow = 2, ncol = 2) +
    labs(x = 'block', y = 'probability') +
    scale_color_brewer(name = "", palette = 'Set1') +
    theme(plot.title = element_text(hjust = 0.5)) +
    scale_x_continuous(breaks=seq(1, max(data$block), by = 1)) +
    theme(legend.position = "top")
}


#' Descriptive plots 
#'
#' Plot the mean response times of the raw data.
#' 
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @return Population level raw response times
plot_RT <- function(data){
  
  mean_r_time = block = d = s = cens = r_time = NULL
  
  data %>%
    dplyr::filter(cens == 0) %>%
    dplyr::group_by(block, s, d) %>%
    dplyr::summarise(mean_r_time = mean(r_time)) %>%
    dplyr::ungroup() %>% 
    ggplot() +
    geom_line(aes(x = block, y = mean_r_time, col = factor(d)), size = 1.5) +
    facet_wrap( ~ s) +
    labs(x = 'block', y = 'response time') +
    scale_color_brewer(name = "", palette = "Set1") +
    theme(plot.title = element_text(hjust = 0.5)) +
    scale_x_continuous(breaks=seq(1, max(data$block), by = 1)) +
    theme(legend.position = "top")
}




# The Half Cauchy Distribution
# 
# Probability density function for the half Cauchy distribution
# 
# @param x vector of points where to evaluate the pdf
# @param sigma vector of scale parameters
# @param log should the result be returned on the log scale?
# @return Probability density function for the half Cauchy distribution with location 0 and scale sigma
dhalfcauhy <- function(x, sigma, log = T){
  
  out <- log(2) + log(sigma) - log(pi) - log(sigma^2 + x^2)
  if (!log){
    out <- exp(out)
  }
  
  return(out)
}


#' Spline Penalty Matrix
#' 
#' Construct the covariance matrix P of the smoothness inducing prior for the
#' spline coefficients
#' 
#' @param K Number of spline knots
#' @return Covariance of the smoothness inducing prior (penalizing first
#' differences in the spline coefficients)
P_smooth1 <- function(K){
  D <- diag(rep(1,K))
  D <- diff(D)
  P <- crossprod(D)
  
  return(P)
}


# Locally Informed Moves
# 
# Computes the g() function necessary for locally informed moves.
# 
# @param x input of the function
# @param log should the result be computed on a log scale?
# @return g(x)
g_HB <- function(x, log = T){
  # In the case we do not want to use the locally informed version, 
  # set out = 0
  out <- x/2
  
  if (log){
    return(out)
  }
  else{
    return(exp(out))
  }
}


#' Spline Basis Functions
#'
#' Construct the J basis functions for the splines evaluated on a grid.
#' 
#' @param xgrid grid where we want to evaluate the spline functions (vector of length n)
#' @param knots vector of knots for the splines (vector of length K)
#' @return n x (K+1) - matrix representing the value of each basis function evaluated on xgrid
B_basis <- function(xgrid, knots){
  n <- length(xgrid) # number of grid points where to evaluate the spline
  K <- length(knots) # number of knots
  delta <- knots[2] - knots[1]
  
  B <- array(0, dim = c(n, K + 1))
  for (j in 1:(K-1)){
    act_idx <- (1:n)[(xgrid >= knots[j]) & (xgrid <= knots[j+1])]
    act_x <- xgrid[act_idx]
    resc_x <- (act_x - knots[j]) / (knots[j+1] - knots[j])
    
    B[act_idx,j] <- (1/2) * (1 - resc_x)^2
    B[act_idx,j+1] <- -(resc_x^2) + resc_x + 1/2
    B[act_idx,j+2] <- (resc_x^2)/2
  }
  
  return(B)
}


#' Hamming Ball
#'
#' Computes the Hamming Ball centered at x with radius r.
#' 
#' @param x center of the Hamming Ball
#' @param S number of states 
#' @param r radius of the Hamming Ball
#' @return Hamming Ball
H_ball <- function(x, S, r){
  K <- length(x) # number of chains
  card_m <- (S - 1)^(0:r) * choose(K, 0:r) 
  # sum(card_m) is the cardinality of each HB with radius r
  
  # First, put the current vector in its HB
  HB_temp <- matrix(x, K, 1)
  
  for (m in 1:r){ # loop the possible HB radius
    HB_diff <- matrix(rep(x, card_m[m+1]), K, card_m[m+1])
    index <- utils::combn(K, m) # what elements of the vector to change?
    
    for (j in 1:ncol(index)){
      vec <- NULL
      for (i in 1:nrow(index)){ 
        vec <- c(vec, (1:S)[-x[index[i,j]]])
      }
      prop_col <- t(gtools::permutations(n = length(unique(vec)), r = m, 
                                         v = vec, repeats.allowed = TRUE))
      keep_col <- prop_col[,which(colSums(prop_col != x[index[,j]]) == m)]
      HB_diff[index[,j], ((j-1)*(card_m[m+1]/ncol(index))+1):(j*card_m[m+1]/ncol(index))] <- keep_col
    }
    HB_temp <- cbind(HB_temp,HB_diff) # save these in the Hamming Ball
  }
  
  return (HB_temp)
}


# Hamming Ball Sampler
#
# Samples a configuration uniformly within the HB.
# 
# @param x center of the Hamming Ball
# @param S number of states 
# @param r radius of the Hamming Ball
# @return sampled state
H_ball_unif <- function(x, S, r){
  HB_temp <- H_ball(x, S, r)
  HB_samp <- HB_temp[,sample(1:ncol(HB_temp), 1)]
  
  return (HB_samp)
}


# Random effects hyperparameters update
#
# Metropolis Hastings step to update the variance parameter and the smoothness
# parameter for the random effects.
#
# @param sigma2_ua_old variance parameter at the previous iteration
# @param sigma2_us_old smoothness parameter at the previous iteration
# @param beta_u_old random effects at the current iteration
# @param P_smooth matrix penalizing the first order differences of the coefficients
# @param n_ind number of participants (random effects)
# @return sampled variance and smoothness parameters
sample_smooth_var <- function(sigma2_ua_old, sigma2_us_old, 
                              beta_u_old, P_smooth, n_ind){
  
  J <- ncol(beta_u_old)
  # Update \sigma_{u,a}^{2}
  sigma2_ua_prop <- exp(rnorm(1, mean = log(sigma2_ua_old), sd = 0.2))
  lu <- log(runif(1))
  log_prop <- 0.5 * n_ind * log(det(diag(J)/sigma2_ua_prop + P_smooth/sigma2_us_old)) -
    0.5 * sum(diag(beta_u_old %*% tcrossprod(diag(J)/sigma2_ua_prop, beta_u_old))) -
    log(1 + sigma2_ua_prop^2)
  log_old <- 0.5 * n_ind * log(det(diag(J)/sigma2_ua_old + P_smooth/sigma2_us_old)) -
    0.5 * sum(diag(beta_u_old %*% tcrossprod(diag(J)/sigma2_ua_old, beta_u_old))) -
    log(1 + sigma2_ua_old^2)
  alpha <- min(c(0, log_prop + log(sigma2_ua_prop) -
                   log_old - log(sigma2_ua_old)))
  if (lu < alpha){
    sigma2_ua_old <- sigma2_ua_prop
  }
  
  # Update \sigma_{u,s}^{2}
  sigma2_us_prop <- exp(rnorm(1, mean = log(sigma2_us_old), sd = 0.1))
  while (sigma2_us_prop > 0.2){
    sigma2_us_prop <- exp(rnorm(1, mean = log(sigma2_us_old), sd = 0.1))
  }
  lu <- log(runif(1))
  log_prop <- 0.5 * n_ind * log(det(diag(J)/sigma2_ua_old + P_smooth/sigma2_us_prop)) -
    0.5 * sum(diag(beta_u_old %*% tcrossprod(P_smooth/sigma2_us_prop, beta_u_old))) -
    log(1 + sigma2_us_prop^2)
  log_old <- 0.5 * n_ind * log(det(diag(J)/sigma2_ua_old + P_smooth/sigma2_us_old)) -
    0.5 * sum(diag(beta_u_old %*% tcrossprod(P_smooth/sigma2_us_old, beta_u_old))) -
    log(1 + sigma2_us_old^2)
  alpha <- min(c(0, log_prop + log(sigma2_us_prop) -
                   log_old - log(sigma2_us_old)))
  if (lu < alpha){
    sigma2_us_old <- sigma2_us_prop
  }
  
  return(list('sigma2_us_old' = sigma2_us_old, 
              'sigma2_ua_old' = sigma2_ua_old))
  
}


#' Drift Diffusion Model Fit
#'
#' Main function for the Gibbs sampler for the drift-diffusion model. Note that 
#' priors are noninformative and calibrated so that, for the most stable 
#' performance, the response times (variable `r_time` in the `data` dataframe) 
#' should lie between 0 and 10.
#'
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times. To avoid numerical 
#'           issues, the unit of measurement should be such that the numerical 
#'           values of most response times should lie between 0 and 10
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @param hypers hyperparameters of the MCMC: list containing "s_sigma_mu" and "s_sigma_b", 
#'               which are the smoothness parameters for drifts and boundaries, respectively)
#' @param boundaries whether to fit the unrestricted model (flexible), assume constant 
#'                   boundaries over time (constant) or fix the boundaries to the same level 
#'                   across predictors (fixed)
#' @param Niter total number of iterations
#' @param burnin burnin of the chain
#' @param thin thinning factor
#' @return List with the following MCMC posterior samples: 
#' * post_mean_delta: posterior samples for the population offset parameters
#' * post_mean_mu: posterior samples for the population drift parameters
#' * post_mean_b: posterior samples for the population boundary parameters
#' * post_ind_delta: posterior samples for the individual offset parameters
#' * post_ind_mu: posterior samples for the individual drift parameters
#' * post_ind_b: posterior samples for the individual boundary parameters
#' * sigma2_mu_us: posterior samples for the random effects drift smoothness parameters
#' * sigma2_mu_ua: posterior samples for the random effects drift variance parameters
#' * sigma2_b_us: posterior samples for the random effects boundary smoothness parameters
#' * sigma2_b_ua: posterior samples for the random effects boundary variance parameters
#' * sigma2_1_mu: posterior samples for the drift smoothness parameters
#' * sigma2_1_b: posterior samples for the boundary smoothness parameters
#' * pred_ans: predicted population-level categories
#' * pred_time: predicted population-level response times
#' * pred_ans_ind: predicted individual-level categories
#' * pred_time_ind: predicted individual-level response times
LDDMM <- function(data, hypers, boundaries = 'flexible', 
                  Niter = 5000, burnin = 2000, thin = 5){
  
  # Check for data issues
  if ( (is.null(hypers[["s_sigma_mu"]])) | (is.null(hypers[["s_sigma_b"]]))){
    cat('The hyperparameters did not contain s_sigma_mu or s_sigma_b, running with their default values.\n\n')
    hypers$s_sigma_mu <-  hypers$s_sigma_b <- 0.1
  }
  if ( ("subject" %notin% colnames(data)) | 
       ("block" %notin% colnames(data)) | 
       ("s" %notin% colnames(data)) | 
       ("d" %notin% colnames(data)) | 
       ("r_time" %notin% colnames(data)) | 
       ("cens" %notin% colnames(data))){
    stop('The data is not in the correct format!')
  }
  
  # Call one of the main functions
  if (boundaries == 'flexible'){
    fit <- LDDMM_full(data, hypers, Niter, burnin, thin)
  }
  else if (boundaries == 'constant') {
    fit <- LDDMM_const_bound(data, hypers, Niter, burnin, thin)
  }
  else if (boundaries == 'fixed') {
    fit <- LDDMM_fix_bound(data, hypers, Niter, burnin, thin)
  }
  else {
    stop("The argument boundaries can be only one of the following: flexible, constant, or fixed")
  }
  
  return (fit)
}


LDDMM_full <- function(data, hypers, Niter = 5000, burnin = 2000, thin = 5){
  
  # Choose the number of knots (default is between the beginning and the end of 
  # the study, at every block)
  T_min <- min(data$block)
  T_max <- max(data$block)
  knots <- T_min:T_max
  K <- length(knots)
  P_smooth <- P_smooth1(K + 1)
  
  # Choose a fine grid in order to plot the curves resulting from the spline basis
  # expansion
  xgrid <- seq(T_min, T_max, by = 1)
  
  # Extract quantities of interest
  tau <- data$r_time
  ind <- data$subject
  time <- data$block
  cens <- data$cens
  D <- cbind(data$s, data$d)
  
  n_ind <- length(unique(ind))
  ind <- as.numeric(plyr::mapvalues(factor(ind),
                                    from = levels(factor(ind)),
                                    to = 1:n_ind))
  B <- B_basis(data$block, knots)
  Bgrid <- B_basis(xgrid, knots)
  
  samp_size <- (Niter - burnin)/thin # sample size
  p <- ncol(D) # number of covariates
  d_j <- rep(0, p) # Number of levels for each covariate
  for (j in 1:p){
    d_j[j] <- length(unique(D[!is.na(D[,j]),j]))
  }
  J <- ncol(B) # number of locations
  n <- nrow(B) # total number of observations
  n_ind <- length(unique(ind)) # number of individuals
  
  
  # Rescale the time steps in \{1, ..., T_max\}
  T_max <- max(time)
  T_min <- min(time)
  time <- time - T_min + 1
  T_max <- max(time)
  T_min <- min(time)
  idx_xy <- t(apply(expand.grid(y = 1:d_j[2], x = 1:d_j[1]), 1, rev))
  colnames(idx_xy) <- NULL
  
  
  # Set HB structures
  r_HB <- 1
  Z_max <- min(prod(d_j), 6)
  dim_HB <- sum((Z_max - 1)^(0:r_HB) * choose(d_j[2], 0:r_HB))
  
  # Set MCMC objects
  z <- array(NA, dim = c(prod(d_j), J, samp_size))
  post_mean_delta <- array(NA, dim = c(samp_size, d_j[1]))
  post_mean_mu <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_mean_b <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_ind_delta <- array(NA, dim = c(d_j[1], n_ind, samp_size))
  post_ind_mu <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  post_ind_b <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  sigma2_mu_us <- array(NA, dim = samp_size)
  sigma2_mu_ua <- array(NA, dim = samp_size)
  sigma2_b_us <- array(NA, dim = samp_size)
  sigma2_b_ua <- array(NA, dim = samp_size)
  sigma2_1_mu <- array(NA, dim = samp_size)
  sigma2_1_b <- array(NA, dim = samp_size)
  pred_time <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_ans <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_time_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  pred_ans_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  loglik_chain <- array(NA, dim = c(samp_size, n))
  
  # Set initial values
  delta_old <- array(NA, dim = c(d_j[1], n_ind))
  beta_mu_old <- array(NA, dim = c(J, d_j))
  beta_b_old <- array(NA, dim = c(J, d_j))
  delta_dat <- array(NA, dim = n)
  for (s_temp in 1:d_j[1]){
    for (i_temp in 1:n_ind){
      idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
      if (length(idx_temp) > 0){
        delta_old[s_temp,i_temp] <- min(tau[idx_temp])/2
        delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
      }
      else{
        delta_old[s_temp,i_temp] <- 1E-3
      }
    }
    for(j in 1:d_j[1]){
      beta_mu_old[,s_temp,j] <- rep(0.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp])), J)
      beta_b_old[,s_temp,j] <- rep(1.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp])), J)
    }
  }
  
  
  low_bound_mu <- min(beta_mu_old) - 1.5
  upp_bound_mu <- max(beta_mu_old) + 1
  low_bound_b <- min(beta_b_old) - 1.5
  upp_bound_b <- max(beta_b_old) + 1
  
  beta_mu_star_prop <- array(NA, dim = c(J, Z_max))
  beta_mu_u_old <- array(0, dim = c(n_ind, J))
  beta_b_star_prop <- array(NA, dim = c(J, Z_max))
  beta_b_u_old <- array(0, dim = c(n_ind, J))
  sigma2_1_mu_old <- 0.005
  sigma2_1_b_old <- 0.005
  sigma2_b_us_old <- 0.005
  sigma2_mu_us_old <- 0.005
  sigma2_b_ua_old <- 0.005
  sigma2_mu_ua_old <- 0.005
  
  
  # Message passing structures
  beta_mess <- array(NA, dim = c(J, dim_HB))
  z_old <- list()
  
  for (j in 1:d_j[1]){
    z_old[[j]] <- array(NA, dim = c(d_j[2], J))
    for (jj in 1:d_j[2]){
      if (j == jj){
        z_old[[j]][jj,] <- j
      }
      else{
        z_old[[j]][jj,] <- sample((d_j[1] + 1):Z_max, 1)
      }
    }
  }
  z_temp <- do.call(rbind, z_old)
  
  
  beta_mu_star_old <- array(NA, dim = c(J, Z_max))
  beta_b_star_old <- array(NA, dim = c(J, Z_max))
  for (i in 1:Z_max){
    beta_mu_star_old[,i] <- beta_mu_old[1,idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],1],
                                        idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],2]] 
    beta_b_star_old[,i] <- beta_b_old[1,idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],1],
                                      idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],2]] 
  }
  
  rand_mat <- array(rnorm(prod(d_j)), dim = d_j)
  idx_succ <- which(rand_mat == diag(rand_mat))
  
  v_old <- array(NA, dim = c(d_j[2], J))
  z_prop <- list()
  
  # Transition dynamics objects
  alpha_S_old <- alpha_F_old <- 1
  Q_S_old <- array(NA, dim = c(Z_max, Z_max))
  Q_F_old <- array(NA, dim = c(Z_max, Z_max))
  pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                         table_int(z_temp[idx_succ,1], Z_max))
  pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                         table_int(z_temp[-idx_succ,1], Z_max))
  tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
  tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
  for (h in 1:Z_max){
    Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + tr_count_S[h,])
    Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + tr_count_F[h,])
  }
  
  
  # Auxiliary variables
  prob_mat <- array(NA, dim = c(dim_HB, dim_HB))
  prob_vec <- array(NA, dim = dim_HB)
  
  # MH proposal parameters
  sd_MH_delta <- array(0.3, dim = c(d_j[1], n_ind))
  sd_MH_beta_mu <- 0.4
  sd_MH_beta_b <- array(0.25, dim = c(J, Z_max))
  sd_beta_mu_u <-  array(0.4, dim = c(n_ind, J))
  sd_beta_b_u <-  array(0.4, dim = c(n_ind, J))
  acc_delta <- array(0, dim = c(d_j[1], n_ind))
  acc_beta_mu <- array(0, dim = c(J, Z_max))
  acc_beta_b <- array(0, dim = c(J, Z_max))
  acc_beta_mu_u <- array(0, dim = c(n_ind, J))
  acc_beta_b_u <- array(0, dim = c(n_ind, J))
  n_batch <- 0
  
  # Auxiliary variables
  B_beta_mu_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_mu_u_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_b_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_b_u_dat <- array(0, dim = c(n, d_j[1]))
  mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
  b_dat <- exp(B_beta_b_dat + B_beta_b_u_dat)
  
  
  # Gibbs Sampler
  it <- 1
  pb <- txtProgressBar(style = 3)
  for (iter in 1:Niter){
    
    # (0) Adaptively tune the MH variance for the proposals of \delta_{s, i}, 
    # beta_u_mu, beta_u_b
    if (iter %% 20 == 0){
      n_batch <- n_batch + 1
      delta_n <- min(0.01, n_batch^(-0.5))
      
      for (i in 1:n_ind){
        for (x_temp in 1:d_j[1]){
          if (acc_delta[x_temp,i]/iter > 0.44){
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) + delta_n)
          }
          else{
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) - delta_n)
          }
        }
        for (k in 1:J){
          if (acc_beta_mu_u[i,k]/iter > 0.44){
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) + delta_n)
          }
          else{
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) - delta_n)
          }
          if (acc_beta_b_u[i,k]/iter > 0.44){
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) + delta_n)
          }
          else{
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) - delta_n)
          }
        }
      }
    }
    
    
    # (1) Update of the delta parameter: \delta_{s,i}: MH with log normal 
    #     proposal
    for (s_temp in 1:d_j[1]){
      for (i_temp in 1:n_ind){
        idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
        tau_temp <- tau[idx_temp]
        cens_temp <- cens[idx_temp]
        D_temp <- D[idx_temp,]
        
        # log-normal proposal distribution centered on the current value
        delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        while (delta_prop > min(tau[idx_temp])){
          delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        }
        loglik_prop <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                      b_dat[idx_temp,], 
                                      rep(delta_prop, length(idx_temp)), 
                                      cens_temp, D_temp, TRUE)
        loglik_old <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                     b_dat[idx_temp,], 
                                     rep(delta_old[s_temp,i_temp], length(idx_temp)), 
                                     cens_temp, D_temp, TRUE)
        
        alpha_acc <- min(0, loglik_prop + log(delta_prop) -
                           loglik_old - log(delta_old[s_temp,i_temp]))
        l_u <- log(runif(1))
        
        if (l_u < alpha_acc){
          delta_old[s_temp,i_temp] <- delta_prop
          delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
          acc_delta[s_temp,i_temp] <- acc_delta[s_temp,i_temp] + 1
        }
        
      }
    }
    
    
    # (2) Joint update of b, mu parameters: \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
    for (k in 1:J){ # loop over locations
      if (k == 1){ # only data at t = 1 influence the first coefficient
        idx_time <- T_min
      }
      else if (k == J){ # only data at t = T influence the last coefficient
        idx_time <- T_max
      }
      else { # data at t = {k-1,k} influence the kth coefficient
        idx_time <- (k-1):k
      }
      
      for (h in 1:Z_max){ # loop over possible latent values
        # tuples (combinations of covariates) that are clustered together via 
        # the latent h
        idx_cov <- matrix(idx_xy[which(z_temp[,k] == h),], length(which(z_temp[,k] == h)), p)
        
        X_1k <- unique(idx_cov[,1]) # all possible values of x in this cluster
        X_2k <- unique(idx_cov[,2]) # all possible values of y in this cluster
        
        if (length(X_1k) > 0){ # h \in \mathcal{Z}_{j,k}: posterior update
          # Pick data with covariate levels of x clustered in group h and 
          # at the correct locations
          idx_i <- which( (D[,1] %in% X_1k) & (time %in% idx_time) )
          tau_temp <- tau[idx_i]
          cens_temp <- cens[idx_i]
          D_temp <- D[which((D[,1] %in% X_1k) & (time %in% idx_time)),]
          
          
          # Normal proposal distribution centered on the current value
          if (k == 1){
            beta_b_star_prop[,h] <- beta_b_star_old[,h]
            beta_b_star_prop[k,h] <- rnorm(1, beta_b_star_old[k,h], sd_MH_beta_b[k,h])
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_mu_star_old[k,h], sd_MH_beta_mu)
          }
          # Normal proposal from the prior
          else if (k == J){
            beta_pre <- beta_b_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_b_star_prop[,h] <- beta_b_star_old[,h]
            beta_b_star_prop[k,h] <- rnorm(1, beta_pre, sqrt(sigma2_1_b_old))
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          # Normal proposal from the prior
          else {
            beta_pre <- beta_b_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_b_star_prop[,h] <- beta_b_star_old[,h]
            beta_b_star_prop[k,h] <- rnorm(1, beta_pre, sqrt(sigma2_1_b_old))
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          B_beta_b_prop_dat <- B_beta_b_dat[idx_i,]
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          
          # Modify the proposed values in the corresponding positions
          for (hh in 1:nrow(idx_cov)){
            B_beta_b_prop_dat[which(D_temp[,1] == idx_cov[hh,1]),idx_cov[hh,2]] <- 
              B[idx_i[which(D_temp[,1] == idx_cov[hh,1])],] %*% beta_b_star_prop[,h]
            B_beta_mu_prop_dat[which(D_temp[,1] == idx_cov[hh,1]),idx_cov[hh,2]] <- 
              B[idx_i[which(D_temp[,1] == idx_cov[hh,1])],] %*% beta_mu_star_prop[,h]
          }
          
          # This is the proposed value for \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
          b_prop_dat <- exp(B_beta_b_prop_dat + B_beta_b_u_dat[idx_i,])
          mu_prop_dat <- exp(B_beta_mu_prop_dat + B_beta_mu_u_dat[idx_i,])
          
          
          if (k == 1){
            beta_post <- beta_b_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_prop_dat, delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_b_old * sum((beta_b_star_prop[k,h] - beta_post)^2) - 
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_b_old * sum((beta_b_star_old[k,h] - beta_post)^2) - 
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          else if (k == J){
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_prop_dat, delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE)
          }
          else {
            beta_post <- beta_b_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_prop_dat, delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_b_old * sum((beta_b_star_prop[k,h] - beta_post)^2) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_b_old * sum((beta_b_star_old[k,h] - beta_post)^2) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          
          alpha_acc <- min(0, logpost_prop - logpost_old)
          l_u <- log(runif(1))
          
          if (l_u < alpha_acc){
            beta_b_star_old[k,h] <- beta_b_star_prop[k,h]
            B_beta_b_dat[idx_i,] <- B_beta_b_prop_dat
            b_dat[idx_i,] <- b_prop_dat
            acc_beta_b[k,h] <- acc_beta_b[k,h] + 1
            
            beta_mu_star_old[k,h] <- beta_mu_star_prop[k,h]
            B_beta_mu_dat[idx_i,] <- B_beta_mu_prop_dat
            mu_dat[idx_i,] <- mu_prop_dat
            acc_beta_mu[k,h] <- acc_beta_mu[k,h] + 1
          }
        }
        else { # h \notin \mathcal{Z}_{1,k}: prior sampling
          beta_b_star_old[k,h] <- runif(1, low_bound_b, upp_bound_b)
          beta_mu_star_old[k,h] <- runif(1, low_bound_mu, upp_bound_mu)
        }
      }
    }
    
    
    # (3) Update the cluster assignments
    for (x_temp in 1:d_j[1]){ # loop over possible latent values
      beta_mess <- array(-Inf, dim = c(J, dim_HB))
      beta_mess[J,] <- 1/dim_HB
      
      v_old[,J] <- H_ball_unif(z_old[[x_temp]][,J], S = Z_max, r = r_HB)
      z_prop[[J]] <- H_ball(v_old[,J], S = Z_max, r = r_HB)
      for (k in (J - 1):1){
        idx_i <- which( (time == k) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_i]
        cens_temp <- cens[idx_i]
        D_temp <- D[idx_i,]
        
        # (i) Sample the auxiliary variables
        v_temp <- H_ball(z_old[[x_temp]][,k], S = Z_max, r = r_HB)
        
        probs <- rep(-Inf, dim_HB)
        for (h in 1:dim_HB){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          B_beta_b_prop_dat <- B_beta_b_dat[idx_i,]
          
          B_beta_mu_prop_dat <- 0.5 * (beta_mu_star_old[k,v_temp[,h]] +
                                         beta_mu_star_old[k+1,z_old[[x_temp]][,k+1]])
          B_beta_b_prop_dat <- 0.5 * (beta_b_star_old[k,v_temp[,h]] +
                                        beta_b_star_old[k+1,z_old[[x_temp]][,k+1]])
          
          mu_dat_prop <- exp(t(B_beta_mu_prop_dat + t(B_beta_mu_u_dat[idx_i,])))
          b_dat_prop <- exp(t(B_beta_b_prop_dat + t(B_beta_b_u_dat[idx_i,])))
          
          probs[h] <- g_HB(log_likelihood(tau_temp, mu_dat_prop, b_dat_prop,
                                          delta_dat[idx_i], cens_temp, D_temp, TRUE))
        }
        probs <- as.numeric(normalise_log(probs))
        
        v_old[,k] <- v_temp[,sample(1:dim_HB, 1, prob = probs)]
        z_prop[[k]] <- H_ball(v_old[,k], S = Z_max, r = r_HB)
        
        
        # (ii) Pass messages backwards only in the restricted state space given
        #      by the slice
        z_kp1_temp <- which(beta_mess[k+1,] > 0)
        prob_mat <- array(-Inf, dim = c(dim_HB, dim_HB))
        
        for (h1 in z_kp1_temp){
          for (h2 in 1:dim_HB){
            
            B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
            B_beta_b_prop_dat <- B_beta_b_dat[idx_i,]
            B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k,z_prop[[k]][,h2]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            B_beta_b_prop_dat_1 <- 0.5 * (beta_b_star_old[k,z_prop[[k]][,h2]] +
                                            beta_b_star_old[k+1,z_prop[[k+1]][,h1]])
            B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k,v_old[,k]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            B_beta_b_prop_dat_2 <- 0.5 * (beta_b_star_old[k,v_old[,k]] +
                                            beta_b_star_old[k+1,z_prop[[k+1]][,h1]])
            
            mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_i,])))
            b_dat_prop_1 <- exp(t(B_beta_b_prop_dat_1 + t(B_beta_b_u_dat[idx_i,])))
            mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_i,])))
            b_dat_prop_2 <- exp(t(B_beta_b_prop_dat_2 + t(B_beta_b_u_dat[idx_i,])))
            
            prob_mat[h2,h1] <- log(beta_mess[k+1,h1]) -
              0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k,z_prop[[k]][,h2]] -
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])^2) -
              0.5 / sigma2_1_b_old * sum((beta_b_star_old[k,z_prop[[k]][,h2]] -
                                            beta_b_star_old[k+1,z_prop[[k+1]][,h1]])^2) +
              log_likelihood(tau_temp, mu_dat_prop_1, b_dat_prop_1,
                             delta_dat[idx_i], cens_temp, D_temp, TRUE) +
              g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat_prop_2,
                                  delta_dat[idx_i], cens_temp, D_temp, TRUE)) +
              sum(log(Q_F_old[cbind(z_prop[[k]][-x_temp,h2],z_prop[[k+1]][-x_temp,h1])])) +
              log(Q_S_old[z_prop[[k]][x_temp,h2],z_prop[[k+1]][x_temp,h1]])
          }
        }
        if ( sum(is.infinite(sum_rows_log(prob_mat))) == dim_HB){
          beta_mess[k,] <- 1/dim_HB
        }
        else{
          beta_mess[k,] <- as.numeric(sum_rows_log(prob_mat))
          beta_mess[k,] <- as.numeric(normalise_log(beta_mess[k,]))
        }
      }
      
      
      # (iii) Sample states forward (only on allowed states)
      idx_fail <- (1:d_j[2])[-x_temp]
      # Sample z_1
      prob_vec <- log(beta_mess[1,]) + log(pi_S_0[z_prop[[1]][x_temp,]]) +
        colSums(matrix(log(pi_F_0[z_prop[[1]][-x_temp,]]), d_j[2] - 1, dim_HB))
      prob_vec <- as.numeric(normalise_log(prob_vec))
      
      idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
      z_old[[x_temp]][,1] <- z_prop[[1]][,idx_samp]
      
      # Sample z_k
      for (k in 2:J){
        idx_km1 <- which( (time == k - 1) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_km1]
        cens_temp <- cens[idx_km1]
        D_temp <- D[idx_km1,]
        
        prob_vec <- log(beta_mess[k,]) + 
          log(Q_S_old[cbind(z_old[[x_temp]][x_temp,k-1], z_prop[[k]][x_temp,])])
        for (kkk in idx_fail){
          prob_vec <- prob_vec + log(Q_F_old[cbind(z_old[[x_temp]][kkk,k-1], z_prop[[k]][kkk,])])
        }
        
        for (z_k_temp in which(is.finite(prob_vec))){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_km1,]
          B_beta_b_prop_dat <- B_beta_b_dat[idx_km1,]
          
          B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          B_beta_b_prop_dat_1 <- 0.5 * (beta_b_star_old[k-1,z_old[[x_temp]][,k-1]] +
                                          beta_b_star_old[k,z_prop[[k]][,z_k_temp]])
          B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k-1,v_old[,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          B_beta_b_prop_dat_2 <- 0.5 * (beta_b_star_old[k-1,v_old[,k-1]] +
                                          beta_b_star_old[k,z_prop[[k]][,z_k_temp]])
          
          mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_km1,])))
          b_dat_prop_1 <- exp(t(B_beta_b_prop_dat_1 + t(B_beta_b_u_dat[idx_km1,])))
          mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_km1,])))
          b_dat_prop_2 <- exp(t(B_beta_b_prop_dat_2 + t(B_beta_b_u_dat[idx_km1,])))
          
          prob_vec[z_k_temp] <- prob_vec[z_k_temp] -
            0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] -
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])^2) -
            0.5 / sigma2_1_b_old * sum((beta_b_star_old[k-1,z_old[[x_temp]][,k-1]] -
                                          beta_b_star_old[k,z_prop[[k]][,z_k_temp]])^2) +
            log_likelihood(tau_temp, mu_dat_prop_1, b_dat_prop_1,
                           delta_dat[idx_km1], cens_temp, D_temp, TRUE) +
            g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat_prop_2,
                                delta_dat[idx_km1], cens_temp, D_temp, TRUE))
        }
        prob_vec <- as.numeric(normalise_log(prob_vec))
        
        idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
        z_old[[x_temp]][,k] <- z_prop[[k]][,idx_samp]
      }
      
      
      # (4) Assign the cluster specific curves f_{\mu}
      for (y_temp in 1:d_j[2]){
        beta_mu_old[,x_temp,y_temp] <- beta_mu_star_old[cbind(1:J, z_old[[x_temp]][y_temp,])]
        beta_b_old[,x_temp,y_temp] <- beta_b_star_old[cbind(1:J, z_old[[x_temp]][y_temp,])]
      }
      B_beta_mu_dat[(D[,1] == x_temp),] <- B[D[,1] == x_temp,] %*% beta_mu_old[,x_temp,]
      B_beta_b_dat[(D[,1] == x_temp),] <- B[D[,1] == x_temp,] %*% beta_b_old[,x_temp,]
    }
    z_temp <- do.call(rbind, z_old)
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    b_dat <- exp(B_beta_b_dat + B_beta_b_u_dat)
    
    
    
    # (5) Update the transition probabilities
    pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + table_int(z_temp[idx_succ,1], Z_max))
    pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + table_int(z_temp[-idx_succ,1], Z_max))
    tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
    tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
    for (h in 1:Z_max){
      Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                                  tr_count_S[h,])
      Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                                  tr_count_F[h,])
    }
    
    
    alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    while ((alpha_S_prop < 0.01) | (alpha_S_prop > 10)){
      alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_S_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_S_prop) -
      Z_max^2 * lgamma(alpha_S_prop/Z_max) +
      (alpha_S_prop/Z_max - 1) * sum(log(Q_S_old)) - (
        dgamma(alpha_S_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_S_old) -
          Z_max^2 * lgamma(alpha_S_old/Z_max) +
          (alpha_S_old/Z_max - 1) * sum(log(Q_S_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_S_old <- alpha_S_prop
      }
    }
    
    alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    while ((alpha_F_prop < 0.01) | (alpha_F_prop > 10)){
      alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_F_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_F_prop) -
      Z_max^2 * lgamma(alpha_F_prop/Z_max) +
      (alpha_F_prop/Z_max - 1) * sum(log(Q_F_old)) - (
        dgamma(alpha_F_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_F_old) -
          Z_max^2 * lgamma(alpha_F_old/Z_max) +
          (alpha_F_old/Z_max - 1) * sum(log(Q_F_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_F_old <- alpha_F_prop
      }
    }
    
    
    # (6) Correction term for random effects.
    corr_mu <- colMeans(beta_mu_u_old)
    corr_b <- colMeans(beta_b_u_old)
    beta_mu_u_old <- t(t(beta_mu_u_old) - corr_mu)
    beta_b_u_old <- t(t(beta_b_u_old) - corr_b)
    
    for (k in 1:J){
      if (k == 1){
        idx_time <- which(time == T_min)
      }
      else if (k == J){
        idx_time <- which(time == T_max)
      }
      else {
        idx_time <- which(time %in% (k-1):k)
      }
      beta_mu_old[k,,] <- beta_mu_old[k,,] + corr_mu[k]
      beta_b_old[k,,] <- beta_b_old[k,,] + corr_b[k]
      beta_mu_star_old[k,] <- beta_mu_star_old[k,] + corr_mu[k]
      beta_b_star_old[k,] <- beta_b_star_old[k,] + corr_b[k]
      B_beta_mu_dat[idx_time,] <- B_beta_mu_dat[idx_time,] + 
        as.numeric(B[idx_time,] %*% corr_mu)
      B_beta_mu_u_dat[idx_time,] <- B_beta_mu_u_dat[idx_time,] - 
        as.numeric(B[idx_time,] %*% corr_mu)
      B_beta_b_dat[idx_time,] <- B_beta_b_dat[idx_time,] + 
        as.numeric(B[idx_time,] %*% corr_b)
      B_beta_b_u_dat[idx_time,] <- B_beta_b_u_dat[idx_time,] -
        as.numeric(B[idx_time,] %*% corr_b)
    }
    
    
    # (7) Update the cluster specific smoothness parameters
    RSS_mu <- RSS_b <- 0
    for (h1 in 1:d_j[1]){
      for (h2 in 1:d_j[2]){
        RSS_mu <- RSS_mu + as.numeric(crossprod(beta_mu_old[,h1,h2], P_smooth) %*% 
                                        beta_mu_old[,h1,h2])
        RSS_b <- RSS_b + as.numeric(crossprod(beta_b_old[,h1,h2], P_smooth) %*% 
                                      beta_b_old[,h1,h2])
      }
    }
    
    sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    while(sigma2_1_mu_temp > 0.2){
      sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    }
    l_alpha <- min(c(0, log(sigma2_1_mu_temp) - 
                       0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_temp) - 
                       0.5/sigma2_1_mu_temp * RSS_mu + 
                       dhalfcauhy(sqrt(sigma2_1_mu_temp), hypers$s_sigma_mu, T) -
                       (log(sigma2_1_mu_old) - 
                          0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_old) -
                          0.5/sigma2_1_mu_old * RSS_mu +
                          dhalfcauhy(sqrt(sigma2_1_mu_old), hypers$s_sigma_mu, T))))
    l_u <- log(runif(1))
    if (l_u < l_alpha){
      sigma2_1_mu_old <- sigma2_1_mu_temp
    }
    
    sigma2_1_b_temp <- exp(rnorm(1, log(sigma2_1_b_old), 0.1))
    while(sigma2_1_b_temp > 0.2){
      sigma2_1_b_temp <- exp(rnorm(1, log(sigma2_1_b_old), 0.1))
    }
    l_alpha <- min(c(0, log(sigma2_1_b_temp) - 
                       0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_b_temp) - 
                       0.5/sigma2_1_b_temp * RSS_b + 
                       dhalfcauhy(sqrt(sigma2_1_b_temp), hypers$s_sigma_b, T) -
                       (log(sigma2_1_b_old) - 
                          0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_b_old) -
                          0.5/sigma2_1_b_old * RSS_b +
                          dhalfcauhy(sqrt(sigma2_1_b_old), hypers$s_sigma_b, T))))
    l_u <- log(runif(1))
    if (l_u < l_alpha){
      sigma2_1_b_old <- sigma2_1_b_temp
    }
    
    
    # (8a) Update random effects for b
    reff <- sample_reff_b(tau, D, cens, beta_b_u_old, delta_dat, B_beta_b_dat,
                          b_dat, mu_dat, B, P_smooth, ind, time,
                          sigma2_b_us_old, sigma2_b_ua_old, sd_beta_b_u, 
                          acc_beta_b_u)
    beta_b_u_old <- reff$beta_u_old
    B_beta_b_u_dat <- reff$B_beta_u_dat
    b_dat <- exp(B_beta_b_dat + B_beta_b_u_dat)
    acc_beta_b_u <- reff$acc_beta_u
    
    
    # (8b) Update random effects variances: MH with log normal proposal
    ls_var <- sample_smooth_var(sigma2_b_ua_old, sigma2_b_us_old,
                                beta_b_u_old, P_smooth, n_ind)
    sigma2_b_ua_old <- ls_var$sigma2_ua_old
    sigma2_b_us_old <- ls_var$sigma2_us_old
    
    
    # (9a) Update random effects for mu
    reff <- sample_reff_mu(tau, D, cens, beta_mu_u_old, delta_dat, b_dat,
                           B_beta_mu_dat, mu_dat, B, P_smooth, ind, time,
                           sigma2_mu_us_old, sigma2_mu_ua_old, sd_beta_mu_u,
                           acc_beta_mu_u)
    beta_mu_u_old <- reff$beta_u_old
    B_beta_mu_u_dat <- reff$B_beta_u_dat
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    acc_beta_mu_u <- reff$acc_beta_u
    
    
    # (9b) Update random effects variances: MH with log normal proposal
    ls_var <- sample_smooth_var(sigma2_mu_ua_old, sigma2_mu_us_old,
                                beta_mu_u_old, P_smooth, n_ind)
    sigma2_mu_ua_old <- ls_var$sigma2_ua_old
    sigma2_mu_us_old <- ls_var$sigma2_us_old
    
    
    
    # After burnin, save parameters in the chain
    if ( (iter > burnin) && (iter %% thin == 0) ){
      # This is the correction for the random effects integration: we need to 
      # compute the variance of the random effects as detailed in the Supplementary
      # Materials
      cov_reff_mu <- solve(diag(J) / sigma2_mu_ua_old + P_smooth / sigma2_mu_us_old)
      corr_term_gr_mu <- rep(0, nrow(Bgrid))
      corr_term_mu <- rep(0, T_max)
      cov_reff_b <- solve(diag(J) / sigma2_b_ua_old + P_smooth / sigma2_b_us_old)
      corr_term_gr_b <- rep(0, nrow(Bgrid))
      corr_term_b <- rep(0, T_max)
      for (k in 1:J){
        corr_term_gr_mu <- corr_term_gr_mu + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_mu[,k]))
        corr_term_mu <- corr_term_mu + rowSums(B_basis(1:T_max, knots)[,k] * 
                                                 t(t(B_basis(1:T_max, knots)) * cov_reff_mu[,k]))
        corr_term_gr_b <- corr_term_gr_b + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_b[,k]))
        corr_term_b <- corr_term_b + rowSums(B_basis(1:T_max, knots)[,k] * 
                                               t(t(B_basis(1:T_max, knots)) * cov_reff_b[,k]))
      }
      
      
      
      for (h1 in 1:d_j[1]){
        for (h2 in 1:d_j[2]){
          post_mean_mu[,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + 0.5 * corr_term_gr_mu)
          post_mean_b[,h1,h2,it] <- exp(Bgrid %*% beta_b_old[,h1,h2] + 0.5 * corr_term_gr_b)
          
          for (i in 1:n_ind){
            post_ind_mu[,i,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + Bgrid %*% beta_mu_u_old[i,])
            post_ind_b[,i,h1,h2,it] <- exp(Bgrid %*% beta_b_old[,h1,h2] + Bgrid %*% beta_b_u_old[i,])
          }
        }
      }
      post_ind_delta[,,it] <- delta_old
      
      
      # Sample from the predictive distributions of response category and 
      # response times (so we can use predictive checks to evaluate goodness of fit)
      for (t in 1:T_max){
        for (x_temp in 1:d_j[1]){
          mu_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_mu_old[,x_temp,] + 
                                      0.5 * corr_term_mu[t]))
          b_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_b_old[,x_temp,] + 
                                     0.5 * corr_term_b[t]))
          delta_temp <- mean(delta_old[x_temp,])
          pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                b_temp^2)
          pred_ans[t,x_temp,it] <- which.min(pred_temp)
          pred_time[t,x_temp,it] <- min(pred_temp)
          
          for (i in 1:n_ind){
            mu_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_mu_old[,x_temp,] + 
                                                               beta_mu_u_old[i,])))
            b_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_b_old[,x_temp,] + 
                                                              beta_b_u_old[i,])))
            delta_temp <- delta_old[x_temp,i]
            pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                  b_temp^2)
            pred_ans_ind[i,t,x_temp,it] <- which.min(pred_temp)
            pred_time_ind[i,t,x_temp,it] <- min(pred_temp)
          }
        }
      }
      
      
      # Save MCMC objects
      post_mean_delta[it,] <- rowMeans(delta_old)
      sigma2_mu_us[it] <- sigma2_mu_us_old
      sigma2_mu_ua[it] <- sigma2_mu_ua_old
      sigma2_b_us[it] <- sigma2_b_us_old
      sigma2_b_ua[it] <- sigma2_b_ua_old
      sigma2_1_mu[it] <- sigma2_1_mu_old
      sigma2_1_b[it] <- sigma2_1_b_old
      loglik_chain[it,] <- log_likelihood_ind(tau, mu_dat, b_dat, delta_dat, cens, D)
      # z[,,it] <- z_temp		
      
      it <- it + 1
    }
    setTxtProgressBar(pb, iter/Niter)
  }
  
  return(list(
    # 'Z' = z, 
    'post_mean_delta' = post_mean_delta, 
    'post_mean_mu' = post_mean_mu,
    'post_mean_b' = post_mean_b,
    'post_ind_delta' = post_ind_delta,
    'post_ind_mu' = post_ind_mu,
    'post_ind_b' = post_ind_b,
    'sigma2_mu_us' = sigma2_mu_us, 
    'sigma2_mu_ua' = sigma2_mu_ua,
    'sigma2_b_us' = sigma2_b_us, 
    'sigma2_b_ua' = sigma2_b_ua,
    'sigma2_1_mu' = sigma2_1_mu, 
    'sigma2_1_b' = sigma2_1_b, 
    'pred_ans' = pred_ans, 
    'pred_time' = pred_time,
    'pred_ans_ind' = pred_ans_ind, 
    'pred_time_ind' = pred_time_ind,
    'loglik' = loglik_chain
  ))
}


LDDMM_const_bound <- function(data, hypers, Niter = 5000, burnin = 2000, thin = 5){
  
  # Choose the number of knots (default is between the beginning and the end of 
  # the study, at every block)
  T_min <- min(data$block)
  T_max <- max(data$block)
  knots <- T_min:T_max
  K <- length(knots)
  P_smooth <- P_smooth1(K + 1)
  
  # Choose a fine grid in order to plot the curves resulting from the spline basis
  # expansion
  xgrid <- seq(T_min, T_max, by = 1)
  
  # Extract quantities of interest
  tau <- data$r_time
  ind <- data$subject
  time <- data$block
  cens <- data$cens
  D <- cbind(data$s, data$d)
  
  n_ind <- length(unique(ind))
  ind <- as.numeric(plyr::mapvalues(factor(ind),
                                    from = levels(factor(ind)),
                                    to = 1:n_ind))
  B <- B_basis(data$block, knots)
  Bgrid <- B_basis(xgrid, knots)
  
  samp_size <- (Niter - burnin)/thin # sample size
  p <- ncol(D) # number of covariates
  d_j <- rep(0, p) # Number of levels for each covariate
  for (j in 1:p){
    d_j[j] <- length(unique(D[!is.na(D[,j]),j]))
  }
  J <- ncol(B) # number of locations
  n <- nrow(B) # total number of observations
  n_ind <- length(unique(ind)) # number of individuals
  
  
  # Rescale the time steps in \{1, ..., T_max\}
  T_max <- max(time)
  T_min <- min(time)
  time <- time - T_min + 1
  T_max <- max(time)
  T_min <- min(time)
  idx_xy <- t(apply(expand.grid(y = 1:d_j[2], x = 1:d_j[1]), 1, rev))
  colnames(idx_xy) <- NULL
  
  
  # Set HB structures
  r_HB <- 1
  Z_max <- min(prod(d_j), 6)
  dim_HB <- sum((Z_max - 1)^(0:r_HB) * choose(d_j[2], 0:r_HB))
  
  # Set MCMC objects
  z <- array(NA, dim = c(prod(d_j), J, samp_size))
  post_mean_delta <- array(NA, dim = c(samp_size, d_j[1]))
  post_mean_mu <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_mean_b <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_ind_delta <- array(NA, dim = c(d_j[1], n_ind, samp_size))
  post_ind_mu <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  post_ind_b <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  sigma2_mu_us <- array(NA, dim = samp_size)
  sigma2_mu_ua <- array(NA, dim = samp_size)
  sigma2_b_us <- array(NA, dim = samp_size)
  sigma2_b_ua <- array(NA, dim = samp_size)
  sigma2_1_mu <- array(NA, dim = samp_size)
  sigma2_1_b <- array(NA, dim = samp_size)
  pred_time <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_ans <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_time_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  pred_ans_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  loglik_chain <- array(NA, dim = c(samp_size, n))
  
  # Set initial values
  delta_old <- array(NA, dim = c(d_j[1], n_ind))
  beta_mu_old <- array(NA, dim = c(J, d_j))
  delta_dat <- array(NA, dim = n)
  b_old <- array(NA, dim = d_j)
  B_beta_b_dat <- array(0, dim = c(n, d_j[1]))
  for (s_temp in 1:d_j[1]){
    for (i_temp in 1:n_ind){
      idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
      if (length(idx_temp) > 0){
        delta_old[s_temp,i_temp] <- min(tau[idx_temp])/2
        delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
      }
      else{
        delta_old[s_temp,i_temp] <- 1E-3
      }
    }
    for(j in 1:d_j[2]){
      beta_mu_old[,s_temp,j] <- rep(0.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp])), J)
      b_old[s_temp,j] <- 0.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp]))
      B_beta_b_dat[which((D[,1] == s_temp)),j] <- b_old[s_temp,j]
    }
  }
  # b_old <- 1.5*log(mean(tau)) - log(sd(tau))
  # B_beta_b_dat <- array(b_old, dim = c(n, d_j[1]))
  
  
  low_bound_mu <- min(beta_mu_old) - 1.5
  upp_bound_mu <- max(beta_mu_old) + 1
  
  beta_mu_star_prop <- array(NA, dim = c(J, Z_max))
  beta_mu_u_old <- array(0, dim = c(n_ind, J))
  sigma2_1_mu_old <- 0.005
  sigma2_mu_us_old <- 0.005
  sigma2_mu_ua_old <- 0.005
  
  
  # Message passing structures
  beta_mess <- array(NA, dim = c(J, dim_HB))
  z_old <- list()
  
  for (j in 1:d_j[1]){
    z_old[[j]] <- array(NA, dim = c(d_j[2], J))
    for (jj in 1:d_j[2]){
      if (j == jj){
        z_old[[j]][jj,] <- j
      }
      else{
        z_old[[j]][jj,] <- sample((d_j[1] + 1):Z_max, 1)
      }
    }
  }
  z_temp <- do.call(rbind, z_old)
  
  
  beta_mu_star_old <- array(NA, dim = c(J, Z_max))
  for (i in 1:Z_max){
    beta_mu_star_old[,i] <- beta_mu_old[1,idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],1],
                                        idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],2]]
  }
  
  rand_mat <- array(rnorm(prod(d_j)), dim = d_j)
  idx_succ <- which(rand_mat == diag(rand_mat))
  
  v_old <- array(NA, dim = c(d_j[2], J))
  z_prop <- list()
  
  # Transition dynamics objects
  alpha_S_old <- alpha_F_old <- 1
  Q_S_old <- array(NA, dim = c(Z_max, Z_max))
  Q_F_old <- array(NA, dim = c(Z_max, Z_max))
  pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                         table_int(z_temp[idx_succ,1], Z_max))
  pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                         table_int(z_temp[-idx_succ,1], Z_max))
  tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
  tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
  for (h in 1:Z_max){
    Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + tr_count_S[h,])
    Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + tr_count_F[h,])
  }
  
  
  # Auxiliary variables
  prob_mat <- array(NA, dim = c(dim_HB, dim_HB))
  prob_vec <- array(NA, dim = dim_HB)
  
  # MH proposal parameters
  sd_MH_delta <- array(0.3, dim = c(d_j[1], n_ind))
  sd_MH_beta_mu <- 0.4
  sd_MH_beta_b <- array(0.05, dim = d_j)
  acc_b <- array(0, dim = d_j)
  sd_beta_mu_u <-  array(0.4, dim = c(n_ind, J))
  sd_beta_b_u <-  array(0.4, dim = c(n_ind, J))
  acc_delta <- array(0, dim = c(d_j[1], n_ind))
  acc_beta_mu_u <- array(0, dim = c(n_ind, J))
  acc_beta_b_u <- array(0, dim = c(n_ind, J))
  n_batch <- 0
  
  # Auxiliary variables
  B_beta_mu_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_mu_u_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_b_prop_dat <- array(0, dim = c(n, d_j[1]))
  mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
  b_dat <- matrix(exp(b_old), n, d_j[1])
  
  
  # Gibbs Sampler
  it <- 1
  pb <- txtProgressBar(style = 3)
  for (iter in 1:Niter){
    
    # (0) Adaptively tune the MH variance for the proposals of \delta_{s, i}, 
    # beta_u_mu, beta_u_b
    if (iter %% 20 == 0){
      n_batch <- n_batch + 1
      delta_n <- min(0.01, n_batch^(-0.5))
      
      for (i in 1:n_ind){
        for (x_temp in 1:d_j[1]){
          if (acc_delta[x_temp,i]/iter > 0.44){
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) + delta_n)
          }
          else{
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) - delta_n)
          }
        }
        for (k in 1:J){
          if (acc_beta_mu_u[i,k]/iter > 0.44){
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) + delta_n)
          }
          else{
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) - delta_n)
          }
          if (acc_beta_b_u[i,k]/iter > 0.44){
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) + delta_n)
          }
          else{
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) - delta_n)
          }
        }
      }
      for (d1 in 1:d_j[1]){
        for (d2 in 1:d_j[2]){
          if (acc_b[d1,d2]/iter > 0.44){
            sd_MH_beta_b[d1,d2] <- exp(log(sd_MH_beta_b[d1,d2]) + delta_n)
          }
          else{
            sd_MH_beta_b[d1,d2] <- exp(log(sd_MH_beta_b[d1,d2]) - delta_n)
          }
          
        }
      }
    }
    
    
    # (1) Update of the delta parameter: \delta_{s,i}: MH with log normal 
    #     proposal
    for (s_temp in 1:d_j[1]){
      for (i_temp in 1:n_ind){
        idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
        tau_temp <- tau[idx_temp]
        cens_temp <- cens[idx_temp]
        D_temp <- D[idx_temp,]
        
        # log-normal proposal distribution centered on the current value
        delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        while (delta_prop > min(tau[idx_temp])){
          delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        }
        loglik_prop <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                      b_dat[idx_temp,], 
                                      rep(delta_prop, length(idx_temp)), 
                                      cens_temp, D_temp, TRUE)
        loglik_old <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                     b_dat[idx_temp,], 
                                     rep(delta_old[s_temp,i_temp], length(idx_temp)), 
                                     cens_temp, D_temp, TRUE)
        
        alpha_acc <- min(0, loglik_prop + log(delta_prop) -
                           loglik_old - log(delta_old[s_temp,i_temp]))
        l_u <- log(runif(1))
        
        if (l_u < alpha_acc){
          delta_old[s_temp,i_temp] <- delta_prop
          delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
          acc_delta[s_temp,i_temp] <- acc_delta[s_temp,i_temp] + 1
        }
        
      }
    }
    
    
    # (2) Update of mu parameters: \mu_{x,y}^{(i)}(t)
    for (k in 1:J){ # loop over locations
      if (k == 1){ # only data at t = 1 influence the first coefficient
        idx_time <- T_min
      }
      else if (k == J){ # only data at t = T influence the last coefficient
        idx_time <- T_max
      }
      else { # data at t = {k-1,k} influence the kth coefficient
        idx_time <- (k-1):k
      }
      
      for (h in 1:Z_max){ # loop over possible latent values
        # tuples (combinations of covariates) that are clustered together via 
        # the latent h
        idx_cov <- matrix(idx_xy[which(z_temp[,k] == h),], length(which(z_temp[,k] == h)), p)
        
        X_1k <- unique(idx_cov[,1]) # all possible values of x in this cluster
        X_2k <- unique(idx_cov[,2]) # all possible values of y in this cluster
        
        if (length(X_1k) > 0){ # h \in \mathcal{Z}_{j,k}: posterior update
          # Pick data with covariate levels of x clustered in group h and 
          # at the correct locations
          idx_i <- which( (D[,1] %in% X_1k) & (time %in% idx_time) )
          tau_temp <- tau[idx_i]
          cens_temp <- cens[idx_i]
          D_temp <- D[which((D[,1] %in% X_1k) & (time %in% idx_time)),]
          
          
          # Normal proposal distribution centered on the current value
          if (k == 1){
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_mu_star_old[k,h], sd_MH_beta_mu)
          }
          # Normal proposal from the prior
          else if (k == J){
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          # Normal proposal from the prior
          else {
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          
          # Modify the proposed values in the corresponding positions
          for (hh in 1:nrow(idx_cov)){
            B_beta_mu_prop_dat[which(D_temp[,1] == idx_cov[hh,1]),idx_cov[hh,2]] <- 
              B[idx_i[which(D_temp[,1] == idx_cov[hh,1])],] %*% beta_mu_star_prop[,h]
          }
          
          # This is the proposed value for \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
          mu_prop_dat <- exp(B_beta_mu_prop_dat + B_beta_mu_u_dat[idx_i,])
          
          
          if (k == 1){
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          else if (k == J){
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE)
          }
          else {
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          
          alpha_acc <- min(0, logpost_prop - logpost_old)
          l_u <- log(runif(1))
          
          if (l_u < alpha_acc){
            
            beta_mu_star_old[k,h] <- beta_mu_star_prop[k,h]
            B_beta_mu_dat[idx_i,] <- B_beta_mu_prop_dat
            mu_dat[idx_i,] <- mu_prop_dat
          }
        }
        else { # h \notin \mathcal{Z}_{1,k}: prior sampling
          beta_mu_star_old[k,h] <- runif(1, low_bound_mu, upp_bound_mu)
        }
      }
    }
    
    # 2(a) Update the constant boundary parameters
    for (d1 in 1:d_j[1]){
      idx_i <- which( (D[,1] == d1) )
      tau_temp <- tau[idx_i]
      cens_temp <- cens[idx_i]
      D_temp <- D[which( (D[,1] == d1) ),]
      
      for (d2 in 1:d_j[2]){
        b_prop <- rnorm(1, b_old[d1,d2], sd_MH_beta_b[d1,d2])
        
        # Modify the proposed values in the corresponding positions
        B_beta_b_prop_dat <- B_beta_b_dat[idx_i,]
        B_beta_b_prop_dat[,d2] <- b_prop
        
        # This is the proposed value for b
        b_prop_dat <- exp(B_beta_b_prop_dat)
        
        logpost_prop <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                       b_prop_dat, delta_dat[idx_i], 
                                       cens_temp, D_temp, TRUE)
        logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                      b_dat[idx_i,], delta_dat[idx_i], 
                                      cens_temp, D_temp, TRUE)
        
        alpha_acc <- min(0, logpost_prop - logpost_old)
        l_u <- log(runif(1))
        
        if (l_u < alpha_acc){
          b_old[d1,d2] <- b_prop
          # b_dat <- b_prop_dat
          B_beta_b_dat[idx_i,] <- B_beta_b_prop_dat
          b_dat[idx_i,] <- b_prop_dat
          acc_b[d1,d2] = acc_b[d1,d2] + 1
        }
      }
    }
    
    
    # (3) Update the cluster assignments
    for (x_temp in 1:d_j[1]){ # loop over possible latent values
      beta_mess <- array(-Inf, dim = c(J, dim_HB))
      beta_mess[J,] <- 1/dim_HB
      
      v_old[,J] <- H_ball_unif(z_old[[x_temp]][,J], S = Z_max, r = r_HB)
      z_prop[[J]] <- H_ball(v_old[,J], S = Z_max, r = r_HB)
      for (k in (J - 1):1){
        idx_i <- which( (time == k) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_i]
        cens_temp <- cens[idx_i]
        D_temp <- D[idx_i,]
        
        # (i) Sample the auxiliary variables
        v_temp <- H_ball(z_old[[x_temp]][,k], S = Z_max, r = r_HB)
        
        probs <- rep(-Inf, dim_HB)
        for (h in 1:dim_HB){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          
          B_beta_mu_prop_dat <- 0.5 * (beta_mu_star_old[k,v_temp[,h]] +
                                         beta_mu_star_old[k+1,z_old[[x_temp]][,k+1]])
          
          mu_dat_prop <- exp(t(B_beta_mu_prop_dat + t(B_beta_mu_u_dat[idx_i,])))
          
          probs[h] <- g_HB(log_likelihood(tau_temp, mu_dat_prop, b_dat[idx_i,],
                                          delta_dat[idx_i], cens_temp, D_temp, TRUE))
        }
        probs <- as.numeric(normalise_log(probs))
        
        v_old[,k] <- v_temp[,sample(1:dim_HB, 1, prob = probs)]
        z_prop[[k]] <- H_ball(v_old[,k], S = Z_max, r = r_HB)
        
        
        # (ii) Pass messages backwards only in the restricted state space given
        #      by the slice
        z_kp1_temp <- which(beta_mess[k+1,] > 0)
        prob_mat <- array(-Inf, dim = c(dim_HB, dim_HB))
        
        for (h1 in z_kp1_temp){
          for (h2 in 1:dim_HB){
            
            B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
            B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k,z_prop[[k]][,h2]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k,v_old[,k]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            
            mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_i,])))
            mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_i,])))
            
            prob_mat[h2,h1] <- log(beta_mess[k+1,h1]) -
              0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k,z_prop[[k]][,h2]] -
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])^2) +
              log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_i,],
                             delta_dat[idx_i], cens_temp, D_temp, TRUE) +
              g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_i,],
                                  delta_dat[idx_i], cens_temp, D_temp, TRUE)) +
              sum(log(Q_F_old[cbind(z_prop[[k]][-x_temp,h2],z_prop[[k+1]][-x_temp,h1])])) +
              log(Q_S_old[z_prop[[k]][x_temp,h2],z_prop[[k+1]][x_temp,h1]])
          }
        }
        if ( sum(is.infinite(sum_rows_log(prob_mat))) == dim_HB){
          beta_mess[k,] <- 1/dim_HB
        }
        else{
          beta_mess[k,] <- as.numeric(sum_rows_log(prob_mat))
          beta_mess[k,] <- as.numeric(normalise_log(beta_mess[k,]))
        }
      }
      
      
      # (iii) Sample states forward (only on allowed states)
      idx_fail <- (1:d_j[2])[-x_temp]
      # Sample z_1
      prob_vec <- log(beta_mess[1,]) + log(pi_S_0[z_prop[[1]][x_temp,]]) +
        colSums(matrix(log(pi_F_0[z_prop[[1]][-x_temp,]]), d_j[2] - 1, dim_HB))
      prob_vec <- as.numeric(normalise_log(prob_vec))
      
      idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
      z_old[[x_temp]][,1] <- z_prop[[1]][,idx_samp]
      
      # Sample z_k
      for (k in 2:J){
        idx_km1 <- which( (time == k - 1) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_km1]
        cens_temp <- cens[idx_km1]
        D_temp <- D[idx_km1,]
        
        prob_vec <- log(beta_mess[k,]) + 
          log(Q_S_old[cbind(z_old[[x_temp]][x_temp,k-1], z_prop[[k]][x_temp,])])
        for (kkk in idx_fail){
          prob_vec <- prob_vec + log(Q_F_old[cbind(z_old[[x_temp]][kkk,k-1], z_prop[[k]][kkk,])])
        }
        
        for (z_k_temp in which(is.finite(prob_vec))){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_km1,]
          
          B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k-1,v_old[,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          
          mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_km1,])))
          mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_km1,])))
          
          prob_vec[z_k_temp] <- prob_vec[z_k_temp] -
            0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] -
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])^2) +
            log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_km1,],
                           delta_dat[idx_km1], cens_temp, D_temp, TRUE) +
            g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_km1,],
                                delta_dat[idx_km1], cens_temp, D_temp, TRUE))
        }
        prob_vec <- as.numeric(normalise_log(prob_vec))
        
        idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
        z_old[[x_temp]][,k] <- z_prop[[k]][,idx_samp]
      }
      
      
      # (4) Assign the cluster specific curves f_{\mu}
      for (y_temp in 1:d_j[2]){
        beta_mu_old[,x_temp,y_temp] <- beta_mu_star_old[cbind(1:J, z_old[[x_temp]][y_temp,])]
      }
      B_beta_mu_dat[(D[,1] == x_temp),] <- B[D[,1] == x_temp,] %*% beta_mu_old[,x_temp,]
    }
    z_temp <- do.call(rbind, z_old)
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    
    
    # (5) Update the transition probabilities
    pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + table_int(z_temp[idx_succ,1], Z_max))
    pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + table_int(z_temp[-idx_succ,1], Z_max))
    tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
    tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
    for (h in 1:Z_max){
      Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                                  tr_count_S[h,])
      Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                                  tr_count_F[h,])
    }
    
    
    alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    while ((alpha_S_prop < 0.01) | (alpha_S_prop > 10)){
      alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_S_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_S_prop) -
      Z_max^2 * lgamma(alpha_S_prop/Z_max) +
      (alpha_S_prop/Z_max - 1) * sum(log(Q_S_old)) - (
        dgamma(alpha_S_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_S_old) -
          Z_max^2 * lgamma(alpha_S_old/Z_max) +
          (alpha_S_old/Z_max - 1) * sum(log(Q_S_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_S_old <- alpha_S_prop
      }
    }
    
    alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    while ((alpha_F_prop < 0.01) | (alpha_F_prop > 10)){
      alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_F_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_F_prop) -
      Z_max^2 * lgamma(alpha_F_prop/Z_max) +
      (alpha_F_prop/Z_max - 1) * sum(log(Q_F_old)) - (
        dgamma(alpha_F_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_F_old) -
          Z_max^2 * lgamma(alpha_F_old/Z_max) +
          (alpha_F_old/Z_max - 1) * sum(log(Q_F_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_F_old <- alpha_F_prop
      }
    }
    
    
    # (6) Correction term for random effects.
    corr_mu <- colMeans(beta_mu_u_old)
    beta_mu_u_old <- t(t(beta_mu_u_old) - corr_mu)
    
    for (k in 1:J){
      if (k == 1){
        idx_time <- which(time == T_min)
      }
      else if (k == J){
        idx_time <- which(time == T_max)
      }
      else {
        idx_time <- which(time %in% (k-1):k)
      }
      beta_mu_old[k,,] <- beta_mu_old[k,,] + corr_mu[k]
      beta_mu_star_old[k,] <- beta_mu_star_old[k,] + corr_mu[k]
      B_beta_mu_dat[idx_time,] <- B_beta_mu_dat[idx_time,] + 
        as.numeric(B[idx_time,] %*% corr_mu)
      B_beta_mu_u_dat[idx_time,] <- B_beta_mu_u_dat[idx_time,] - 
        as.numeric(B[idx_time,] %*% corr_mu)
    }
    
    
    # (7) Update the cluster specific smoothness parameters
    RSS_mu <- RSS_b <- 0
    for (h1 in 1:d_j[1]){
      for (h2 in 1:d_j[2]){
        RSS_mu <- RSS_mu + as.numeric(crossprod(beta_mu_old[,h1,h2], P_smooth) %*% 
                                        beta_mu_old[,h1,h2])
      }
    }
    
    sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    while(sigma2_1_mu_temp > 0.2){
      sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    }
    l_alpha <- min(c(0, log(sigma2_1_mu_temp) - 
                       0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_temp) - 
                       0.5/sigma2_1_mu_temp * RSS_mu + 
                       dhalfcauhy(sqrt(sigma2_1_mu_temp), hypers$s_sigma_mu, T) -
                       (log(sigma2_1_mu_old) - 
                          0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_old) -
                          0.5/sigma2_1_mu_old * RSS_mu +
                          dhalfcauhy(sqrt(sigma2_1_mu_old), hypers$s_sigma_mu, T))))
    l_u <- log(runif(1))
    if (l_u < l_alpha){
      sigma2_1_mu_old <- sigma2_1_mu_temp
    }
    
    
    # (9a) Update random effects for mu
    reff <- sample_reff_mu(tau, D, cens, beta_mu_u_old, delta_dat, b_dat,
                           B_beta_mu_dat, mu_dat, B, P_smooth, ind, time,
                           sigma2_mu_us_old, sigma2_mu_ua_old, sd_beta_mu_u,
                           acc_beta_mu_u)
    beta_mu_u_old <- reff$beta_u_old
    B_beta_mu_u_dat <- reff$B_beta_u_dat
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    acc_beta_mu_u <- reff$acc_beta_u
    
    
    # (9b) Update random effects variances: MH with log normal proposal
    ls_var <- sample_smooth_var(sigma2_mu_ua_old, sigma2_mu_us_old,
                                beta_mu_u_old, P_smooth, n_ind)
    sigma2_mu_ua_old <- ls_var$sigma2_ua_old
    sigma2_mu_us_old <- ls_var$sigma2_us_old
    
    
    
    # After burnin, save parameters in the chain
    if ( (iter > burnin) && (iter %% thin == 0) ){
      # This is the correction for the random effects integration: we need to 
      # compute the variance of the random effects as detailed in the Supplementary
      # Materials
      cov_reff_mu <- solve(diag(J) / sigma2_mu_ua_old + P_smooth / sigma2_mu_us_old)
      corr_term_gr_mu <- rep(0, nrow(Bgrid))
      corr_term_mu <- rep(0, T_max)
      for (k in 1:J){
        corr_term_gr_mu <- corr_term_gr_mu + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_mu[,k]))
        corr_term_mu <- corr_term_mu + rowSums(B_basis(1:T_max, knots)[,k] * 
                                                 t(t(B_basis(1:T_max, knots)) * cov_reff_mu[,k]))
      }
      
      
      
      for (h1 in 1:d_j[1]){
        for (h2 in 1:d_j[2]){
          post_mean_mu[,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + 0.5 * corr_term_gr_mu)
          post_mean_b[,h1,h2,it] <- exp(b_old[h1,h2])
          
          for (i in 1:n_ind){
            post_ind_mu[,i,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + Bgrid %*% beta_mu_u_old[i,])
            post_ind_b[,i,h1,h2,it] <- exp(b_old[h1,h2])
          }
        }
      }
      post_ind_delta[,,it] <- delta_old
      
      
      # Sample from the predictive distributions of response category and 
      # response times (so we can use predictive checks to evaluate goodness of fit)
      for (t in 1:T_max){
        for (x_temp in 1:d_j[1]){
          mu_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_mu_old[,x_temp,] + 
                                      0.5 * corr_term_mu[t]))
          b_temp <- exp(b_old[x_temp,])
          delta_temp <- mean(delta_old[x_temp,])
          pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                b_temp^2)
          pred_ans[t,x_temp,it] <- which.min(pred_temp)
          pred_time[t,x_temp,it] <- min(pred_temp)
          
          for (i in 1:n_ind){
            mu_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_mu_old[,x_temp,] + 
                                                               beta_mu_u_old[i,])))
            b_temp <- exp(b_old[x_temp,])
            delta_temp <- delta_old[x_temp,i]
            pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                  b_temp^2)
            pred_ans_ind[i,t,x_temp,it] <- which.min(pred_temp)
            pred_time_ind[i,t,x_temp,it] <- min(pred_temp)
          }
        }
      }
      
      
      # Save MCMC objects
      post_mean_delta[it,] <- rowMeans(delta_old)
      sigma2_mu_us[it] <- sigma2_mu_us_old
      sigma2_mu_ua[it] <- sigma2_mu_ua_old
      sigma2_1_mu[it] <- sigma2_1_mu_old
      # z[,,it] <- z_temp		
      loglik_chain[it,] <- log_likelihood_ind(tau, mu_dat, b_dat, delta_dat, cens, D)
      
      it <- it + 1
    }
    setTxtProgressBar(pb, iter/Niter)
  }
  
  return(list(
    # 'Z' = z, 
    'post_mean_delta' = post_mean_delta, 
    'post_mean_mu' = post_mean_mu,
    'post_mean_b' = post_mean_b,
    'post_ind_delta' = post_ind_delta,
    'post_ind_mu' = post_ind_mu,
    'post_ind_b' = post_ind_b,
    'sigma2_mu_us' = sigma2_mu_us, 
    'sigma2_mu_ua' = sigma2_mu_ua,
    'sigma2_1_mu' = sigma2_1_mu, 
    'pred_ans' = pred_ans, 
    'pred_time' = pred_time,
    'pred_ans_ind' = pred_ans_ind, 
    'pred_time_ind' = pred_time_ind, 
    'loglik' = loglik_chain
  ))
}


# LDDMM_fix_bound <- function(data, hypers, Niter = 5000, burnin = 2000, thin = 5){
#   
#   # Choose the number of knots (default is between the beginning and the end of 
#   # the study, at every block)
#   T_min <- min(data$block)
#   T_max <- max(data$block)
#   knots <- T_min:T_max
#   K <- length(knots)
#   P_smooth <- P_smooth1(K + 1)
#   
#   # Choose a fine grid in order to plot the curves resulting from the spline basis
#   # expansion
#   xgrid <- seq(T_min, T_max, by = 1)
#   
#   # Extract quantities of interest
#   tau <- data$r_time
#   ind <- data$subject
#   time <- data$block
#   cens <- data$cens
#   D <- cbind(data$s, data$d)
#   
#   n_ind <- length(unique(ind))
#   ind <- as.numeric(plyr::mapvalues(factor(ind),
#                                     from = levels(factor(ind)),
#                                     to = 1:n_ind))
#   B <- B_basis(data$block, knots)
#   Bgrid <- B_basis(xgrid, knots)
#   
#   samp_size <- (Niter - burnin)/thin # sample size
#   p <- ncol(D) # number of covariates
#   d_j <- rep(0, p) # Number of levels for each covariate
#   for (j in 1:p){
#     d_j[j] <- length(unique(D[!is.na(D[,j]),j]))
#   }
#   J <- ncol(B) # number of locations
#   n <- nrow(B) # total number of observations
#   n_ind <- length(unique(ind)) # number of individuals
#   
#   
#   # Rescale the time steps in \{1, ..., T_max\}
#   T_max <- max(time)
#   T_min <- min(time)
#   time <- time - T_min + 1
#   T_max <- max(time)
#   T_min <- min(time)
#   idx_xy <- t(apply(expand.grid(y = 1:d_j[2], x = 1:d_j[1]), 1, rev))
#   colnames(idx_xy) <- NULL
#   
#   
#   # Set HB structures
#   r_HB <- 1
#   Z_max <- min(prod(d_j), 6)
#   dim_HB <- sum((Z_max - 1)^(0:r_HB) * choose(d_j[2], 0:r_HB))
#   
#   # Set MCMC objects
#   z <- array(NA, dim = c(prod(d_j), J, samp_size))
#   post_mean_delta <- array(NA, dim = c(samp_size, d_j[1]))
#   post_mean_mu <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
#   post_mean_b <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
#   post_ind_delta <- array(NA, dim = c(d_j[1], n_ind, samp_size))
#   post_ind_mu <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
#   post_ind_b <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
#   sigma2_mu_us <- array(NA, dim = samp_size)
#   sigma2_mu_ua <- array(NA, dim = samp_size)
#   sigma2_b_us <- array(NA, dim = samp_size)
#   sigma2_b_ua <- array(NA, dim = samp_size)
#   sigma2_1_mu <- array(NA, dim = samp_size)
#   sigma2_1_b <- array(NA, dim = samp_size)
#   pred_time <- array(NA, dim = c(T_max, d_j[1], samp_size))
#   pred_ans <- array(NA, dim = c(T_max, d_j[1], samp_size))
#   pred_time_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
#   pred_ans_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
#   loglik_chain <- array(NA, dim = c(samp_size, n))
#   
#   # Set initial values
#   delta_old <- array(NA, dim = c(d_j[1], n_ind))
#   beta_mu_old <- array(NA, dim = c(J, d_j))
#   delta_dat <- array(NA, dim = n)
#   b_old <- array(NA, dim = d_j[2])
#   B_beta_b_dat <- array(0, dim = c(n, d_j[1]))
#   for (s_temp in 1:d_j[1]){
#     for (i_temp in 1:n_ind){
#       idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
#       if (length(idx_temp) > 0){
#         delta_old[s_temp,i_temp] <- min(tau[idx_temp])/2
#         delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
#       }
#       else{
#         delta_old[s_temp,i_temp] <- 1E-3
#       }
#     }
#     for(j in 1:d_j[2]){
#       beta_mu_old[,s_temp,j] <- rep(0.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp])), J)
#       b_old[j] <- 0.5*log(mean(tau)) - log(sd(tau))
#       B_beta_b_dat[,j] <- b_old[j]
#     }
#   }
#   # b_old <- 1.5*log(mean(tau)) - log(sd(tau))
#   # B_beta_b_dat <- array(b_old, dim = c(n, d_j[1]))
#   
#   
#   low_bound_mu <- min(beta_mu_old) - 1.5
#   upp_bound_mu <- max(beta_mu_old) + 1
#   
#   beta_mu_star_prop <- array(NA, dim = c(J, Z_max))
#   beta_mu_u_old <- array(0, dim = c(n_ind, J))
#   sigma2_1_mu_old <- 0.005
#   sigma2_mu_us_old <- 0.005
#   sigma2_mu_ua_old <- 0.005
#   
#   
#   # Message passing structures
#   beta_mess <- array(NA, dim = c(J, dim_HB))
#   z_old <- list()
#   
#   for (j in 1:d_j[1]){
#     z_old[[j]] <- array(NA, dim = c(d_j[2], J))
#     for (jj in 1:d_j[2]){
#       if (j == jj){
#         z_old[[j]][jj,] <- j
#       }
#       else{
#         z_old[[j]][jj,] <- sample((d_j[1] + 1):Z_max, 1)
#       }
#     }
#   }
#   z_temp <- do.call(rbind, z_old)
#   
#   
#   beta_mu_star_old <- array(NA, dim = c(J, Z_max))
#   for (i in 1:Z_max){
#     beta_mu_star_old[,i] <- beta_mu_old[1,idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],1],
#                                         idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],2]]
#   }
#   
#   rand_mat <- array(rnorm(prod(d_j)), dim = d_j)
#   idx_succ <- which(rand_mat == diag(rand_mat))
#   
#   v_old <- array(NA, dim = c(d_j[2], J))
#   z_prop <- list()
#   
#   # Transition dynamics objects
#   alpha_S_old <- alpha_F_old <- 1
#   Q_S_old <- array(NA, dim = c(Z_max, Z_max))
#   Q_F_old <- array(NA, dim = c(Z_max, Z_max))
#   pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
#                          table_int(z_temp[idx_succ,1], Z_max))
#   pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
#                          table_int(z_temp[-idx_succ,1], Z_max))
#   tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
#   tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
#   for (h in 1:Z_max){
#     Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + tr_count_S[h,])
#     Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + tr_count_F[h,])
#   }
#   
#   
#   # Auxiliary variables
#   prob_mat <- array(NA, dim = c(dim_HB, dim_HB))
#   prob_vec <- array(NA, dim = dim_HB)
#   
#   # MH proposal parameters
#   sd_MH_delta <- array(0.3, dim = c(d_j[1], n_ind))
#   sd_MH_beta_mu <- 0.4
#   sd_MH_beta_b <- array(0.05, dim = d_j[2])
#   acc_b <- array(0, dim = d_j[2])
#   sd_beta_mu_u <-  array(0.4, dim = c(n_ind, J))
#   sd_beta_b_u <-  array(0.4, dim = c(n_ind, J))
#   acc_delta <- array(0, dim = c(d_j[1], n_ind))
#   acc_beta_mu_u <- array(0, dim = c(n_ind, J))
#   acc_beta_b_u <- array(0, dim = c(n_ind, J))
#   n_batch <- 0
#   
#   # Auxiliary variables
#   B_beta_mu_dat <- array(0, dim = c(n, d_j[1]))
#   B_beta_mu_u_dat <- array(0, dim = c(n, d_j[1]))
#   B_beta_b_prop_dat <- array(0, dim = c(n, d_j[1]))
#   mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
#   b_dat <- matrix(rep(exp(b_old), each = n), n, d_j[1])
#   
#   
#   # Gibbs Sampler
#   it <- 1
#   pb <- txtProgressBar(style = 3)
#   for (iter in 1:Niter){
#     
#     # (0) Adaptively tune the MH variance for the proposals of \delta_{s, i}, 
#     # beta_u_mu, beta_u_b
#     if (iter %% 20 == 0){
#       n_batch <- n_batch + 1
#       delta_n <- min(0.01, n_batch^(-0.5))
#       
#       for (i in 1:n_ind){
#         for (x_temp in 1:d_j[1]){
#           if (acc_delta[x_temp,i]/iter > 0.44){
#             sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) + delta_n)
#           }
#           else{
#             sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) - delta_n)
#           }
#         }
#         for (k in 1:J){
#           if (acc_beta_mu_u[i,k]/iter > 0.44){
#             sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) + delta_n)
#           }
#           else{
#             sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) - delta_n)
#           }
#           if (acc_beta_b_u[i,k]/iter > 0.44){
#             sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) + delta_n)
#           }
#           else{
#             sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) - delta_n)
#           }
#         }
#       }
#       for (d2 in 1:d_j[2]){
#         if (acc_b[d2]/iter > 0.44){
#           sd_MH_beta_b[d2] <- exp(log(sd_MH_beta_b[d2]) + delta_n)
#         }
#         else{
#           sd_MH_beta_b[d2] <- exp(log(sd_MH_beta_b[d2]) - delta_n)
#         }
#         
#       }
#       
#     }
#     
#     
#     # (1) Update of the delta parameter: \delta_{s,i}: MH with log normal 
#     #     proposal
#     for (s_temp in 1:d_j[1]){
#       for (i_temp in 1:n_ind){
#         idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
#         tau_temp <- tau[idx_temp]
#         cens_temp <- cens[idx_temp]
#         D_temp <- D[idx_temp,]
#         
#         # log-normal proposal distribution centered on the current value
#         delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
#         while (delta_prop > min(tau[idx_temp])){
#           delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
#         }
#         loglik_prop <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
#                                       b_dat[idx_temp,], 
#                                       rep(delta_prop, length(idx_temp)), 
#                                       cens_temp, D_temp, TRUE)
#         loglik_old <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
#                                      b_dat[idx_temp,], 
#                                      rep(delta_old[s_temp,i_temp], length(idx_temp)), 
#                                      cens_temp, D_temp, TRUE)
#         
#         alpha_acc <- min(0, loglik_prop + log(delta_prop) -
#                            loglik_old - log(delta_old[s_temp,i_temp]))
#         l_u <- log(runif(1))
#         
#         if (l_u < alpha_acc){
#           delta_old[s_temp,i_temp] <- delta_prop
#           delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
#           acc_delta[s_temp,i_temp] <- acc_delta[s_temp,i_temp] + 1
#         }
#         
#       }
#     }
#     
#     
#     # (2) Update of mu parameters: \mu_{x,y}^{(i)}(t)
#     for (k in 1:J){ # loop over locations
#       if (k == 1){ # only data at t = 1 influence the first coefficient
#         idx_time <- T_min
#       }
#       else if (k == J){ # only data at t = T influence the last coefficient
#         idx_time <- T_max
#       }
#       else { # data at t = {k-1,k} influence the kth coefficient
#         idx_time <- (k-1):k
#       }
#       
#       for (h in 1:Z_max){ # loop over possible latent values
#         # tuples (combinations of covariates) that are clustered together via 
#         # the latent h
#         idx_cov <- matrix(idx_xy[which(z_temp[,k] == h),], length(which(z_temp[,k] == h)), p)
#         
#         X_1k <- unique(idx_cov[,1]) # all possible values of x in this cluster
#         X_2k <- unique(idx_cov[,2]) # all possible values of y in this cluster
#         
#         if (length(X_1k) > 0){ # h \in \mathcal{Z}_{j,k}: posterior update
#           # Pick data with covariate levels of x clustered in group h and 
#           # at the correct locations
#           idx_i <- which( (D[,1] %in% X_1k) & (time %in% idx_time) )
#           tau_temp <- tau[idx_i]
#           cens_temp <- cens[idx_i]
#           D_temp <- D[which((D[,1] %in% X_1k) & (time %in% idx_time)),]
#           
#           
#           # Normal proposal distribution centered on the current value
#           if (k == 1){
#             beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
#             beta_mu_star_prop[k,h] <- rnorm(1, beta_mu_star_old[k,h], sd_MH_beta_mu)
#           }
#           # Normal proposal from the prior
#           else if (k == J){
#             beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
#             
#             beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
#             beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
#           }
#           # Normal proposal from the prior
#           else {
#             beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
#             
#             beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
#             beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
#           }
#           B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
#           
#           # Modify the proposed values in the corresponding positions
#           for (hh in 1:nrow(idx_cov)){
#             B_beta_mu_prop_dat[which(D_temp[,1] == idx_cov[hh,1]),idx_cov[hh,2]] <- 
#               B[idx_i[which(D_temp[,1] == idx_cov[hh,1])],] %*% beta_mu_star_prop[,h]
#           }
#           
#           # This is the proposed value for \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
#           mu_prop_dat <- exp(B_beta_mu_prop_dat + B_beta_mu_u_dat[idx_i,])
#           
#           
#           if (k == 1){
#             beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
#             
#             logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
#                                            b_dat[idx_i,], delta_dat[idx_i], 
#                                            cens_temp, D_temp, TRUE) -
#               0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
#             logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
#                                           b_dat[idx_i,], delta_dat[idx_i], 
#                                           cens_temp, D_temp, TRUE) -
#               0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
#           }
#           else if (k == J){
#             logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
#                                            b_dat[idx_i,], delta_dat[idx_i], 
#                                            cens_temp, D_temp, TRUE)
#             logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
#                                           b_dat[idx_i,], delta_dat[idx_i], 
#                                           cens_temp, D_temp, TRUE)
#           }
#           else {
#             beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
#             
#             logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
#                                            b_dat[idx_i,], delta_dat[idx_i], 
#                                            cens_temp, D_temp, TRUE) -
#               0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
#             logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
#                                           b_dat[idx_i,], delta_dat[idx_i], 
#                                           cens_temp, D_temp, TRUE) -
#               0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
#           }
#           
#           alpha_acc <- min(0, logpost_prop - logpost_old)
#           l_u <- log(runif(1))
#           
#           if (l_u < alpha_acc){
#             
#             beta_mu_star_old[k,h] <- beta_mu_star_prop[k,h]
#             B_beta_mu_dat[idx_i,] <- B_beta_mu_prop_dat
#             mu_dat[idx_i,] <- mu_prop_dat
#           }
#         }
#         else { # h \notin \mathcal{Z}_{1,k}: prior sampling
#           beta_mu_star_old[k,h] <- runif(1, low_bound_mu, upp_bound_mu)
#         }
#       }
#     }
#     
#     # 2(a) Update the constant boundary parameters
#     # for (d1 in 1:d_j[1]){
#     tau_temp <- tau
#     cens_temp <- cens
#     D_temp <- D
#     
#     for (d2 in 1:d_j[2]){
#       b_prop <- rnorm(1, b_old[d2], sd_MH_beta_b[d2])
#       
#       # Modify the proposed values in the corresponding positions
#       B_beta_b_prop_dat <- B_beta_b_dat
#       B_beta_b_prop_dat[,d2] <- b_prop
#       
#       # This is the proposed value for b
#       b_prop_dat <- exp(B_beta_b_prop_dat)
#       
#       logpost_prop <- log_likelihood(tau_temp, mu_dat, 
#                                      b_prop_dat, delta_dat, 
#                                      cens_temp, D_temp, TRUE)
#       logpost_old <- log_likelihood(tau_temp, mu_dat, 
#                                     b_dat, delta_dat, 
#                                     cens_temp, D_temp, TRUE)
#       
#       alpha_acc <- min(0, logpost_prop - logpost_old)
#       l_u <- log(runif(1))
#       
#       if (l_u < alpha_acc){
#         b_old[d2] <- b_prop
#         B_beta_b_dat <- B_beta_b_prop_dat
#         b_dat <- b_prop_dat
#         acc_b[d2] = acc_b[d2] + 1
#       }
#     }
#     
#     
#     
#     # (3) Update the cluster assignments
#     for (x_temp in 1:d_j[1]){ # loop over possible latent values
#       beta_mess <- array(-Inf, dim = c(J, dim_HB))
#       beta_mess[J,] <- 1/dim_HB
#       
#       v_old[,J] <- H_ball_unif(z_old[[x_temp]][,J], S = Z_max, r = r_HB)
#       z_prop[[J]] <- H_ball(v_old[,J], S = Z_max, r = r_HB)
#       for (k in (J - 1):1){
#         idx_i <- which( (time == k) & (D[,1] == x_temp) )
#         tau_temp <- tau[idx_i]
#         cens_temp <- cens[idx_i]
#         D_temp <- D[idx_i,]
#         
#         # (i) Sample the auxiliary variables
#         v_temp <- H_ball(z_old[[x_temp]][,k], S = Z_max, r = r_HB)
#         
#         probs <- rep(-Inf, dim_HB)
#         for (h in 1:dim_HB){
#           B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
#           
#           B_beta_mu_prop_dat <- 0.5 * (beta_mu_star_old[k,v_temp[,h]] +
#                                          beta_mu_star_old[k+1,z_old[[x_temp]][,k+1]])
#           
#           mu_dat_prop <- exp(t(B_beta_mu_prop_dat + t(B_beta_mu_u_dat[idx_i,])))
#           
#           probs[h] <- g_HB(log_likelihood(tau_temp, mu_dat_prop, b_dat[idx_i,],
#                                           delta_dat[idx_i], cens_temp, D_temp, TRUE))
#         }
#         probs <- as.numeric(normalise_log(probs))
#         
#         v_old[,k] <- v_temp[,sample(1:dim_HB, 1, prob = probs)]
#         z_prop[[k]] <- H_ball(v_old[,k], S = Z_max, r = r_HB)
#         
#         
#         # (ii) Pass messages backwards only in the restricted state space given
#         #      by the slice
#         z_kp1_temp <- which(beta_mess[k+1,] > 0)
#         prob_mat <- array(-Inf, dim = c(dim_HB, dim_HB))
#         
#         for (h1 in z_kp1_temp){
#           for (h2 in 1:dim_HB){
#             
#             B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
#             B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k,z_prop[[k]][,h2]] +
#                                              beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
#             B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k,v_old[,k]] +
#                                              beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
#             
#             mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_i,])))
#             mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_i,])))
#             
#             prob_mat[h2,h1] <- log(beta_mess[k+1,h1]) -
#               0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k,z_prop[[k]][,h2]] -
#                                              beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])^2) +
#               log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_i,],
#                              delta_dat[idx_i], cens_temp, D_temp, TRUE) +
#               g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_i,],
#                                   delta_dat[idx_i], cens_temp, D_temp, TRUE)) +
#               sum(log(Q_F_old[cbind(z_prop[[k]][-x_temp,h2],z_prop[[k+1]][-x_temp,h1])])) +
#               log(Q_S_old[z_prop[[k]][x_temp,h2],z_prop[[k+1]][x_temp,h1]])
#           }
#         }
#         if ( sum(is.infinite(sum_rows_log(prob_mat))) == dim_HB){
#           beta_mess[k,] <- 1/dim_HB
#         }
#         else{
#           beta_mess[k,] <- as.numeric(sum_rows_log(prob_mat))
#           beta_mess[k,] <- as.numeric(normalise_log(beta_mess[k,]))
#         }
#       }
#       
#       
#       # (iii) Sample states forward (only on allowed states)
#       idx_fail <- (1:d_j[2])[-x_temp]
#       # Sample z_1
#       prob_vec <- log(beta_mess[1,]) + log(pi_S_0[z_prop[[1]][x_temp,]]) +
#         colSums(matrix(log(pi_F_0[z_prop[[1]][-x_temp,]]), d_j[2] - 1, dim_HB))
#       prob_vec <- as.numeric(normalise_log(prob_vec))
#       
#       idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
#       z_old[[x_temp]][,1] <- z_prop[[1]][,idx_samp]
#       
#       # Sample z_k
#       for (k in 2:J){
#         idx_km1 <- which( (time == k - 1) & (D[,1] == x_temp) )
#         tau_temp <- tau[idx_km1]
#         cens_temp <- cens[idx_km1]
#         D_temp <- D[idx_km1,]
#         
#         prob_vec <- log(beta_mess[k,]) + 
#           log(Q_S_old[cbind(z_old[[x_temp]][x_temp,k-1], z_prop[[k]][x_temp,])])
#         for (kkk in idx_fail){
#           prob_vec <- prob_vec + log(Q_F_old[cbind(z_old[[x_temp]][kkk,k-1], z_prop[[k]][kkk,])])
#         }
#         
#         for (z_k_temp in which(is.finite(prob_vec))){
#           B_beta_mu_prop_dat <- B_beta_mu_dat[idx_km1,]
#           
#           B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] +
#                                            beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
#           B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k-1,v_old[,k-1]] +
#                                            beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
#           
#           mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_km1,])))
#           mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_km1,])))
#           
#           prob_vec[z_k_temp] <- prob_vec[z_k_temp] -
#             0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] -
#                                            beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])^2) +
#             log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_km1,],
#                            delta_dat[idx_km1], cens_temp, D_temp, TRUE) +
#             g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_km1,],
#                                 delta_dat[idx_km1], cens_temp, D_temp, TRUE))
#         }
#         prob_vec <- as.numeric(normalise_log(prob_vec))
#         
#         idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
#         z_old[[x_temp]][,k] <- z_prop[[k]][,idx_samp]
#       }
#       
#       
#       # (4) Assign the cluster specific curves f_{\mu}
#       for (y_temp in 1:d_j[2]){
#         beta_mu_old[,x_temp,y_temp] <- beta_mu_star_old[cbind(1:J, z_old[[x_temp]][y_temp,])]
#       }
#       B_beta_mu_dat[(D[,1] == x_temp),] <- B[D[,1] == x_temp,] %*% beta_mu_old[,x_temp,]
#     }
#     z_temp <- do.call(rbind, z_old)
#     mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
#     
#     
#     # (5) Update the transition probabilities
#     pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + table_int(z_temp[idx_succ,1], Z_max))
#     pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + table_int(z_temp[-idx_succ,1], Z_max))
#     tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
#     tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
#     for (h in 1:Z_max){
#       Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
#                                   tr_count_S[h,])
#       Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
#                                   tr_count_F[h,])
#     }
#     
#     
#     alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
#     while ((alpha_S_prop < 0.01) | (alpha_S_prop > 10)){
#       alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
#     }
#     alpha_acc <- dgamma(alpha_S_prop, 1, 1, log = T) +
#       Z_max * lgamma(alpha_S_prop) -
#       Z_max^2 * lgamma(alpha_S_prop/Z_max) +
#       (alpha_S_prop/Z_max - 1) * sum(log(Q_S_old)) - (
#         dgamma(alpha_S_old, 1, 1, log = T) +
#           Z_max * lgamma(alpha_S_old) -
#           Z_max^2 * lgamma(alpha_S_old/Z_max) +
#           (alpha_S_old/Z_max - 1) * sum(log(Q_S_old)))
#     
#     l_u <- log(runif(1))
#     if (!is.na(alpha_acc)){
#       if (l_u < alpha_acc){
#         alpha_S_old <- alpha_S_prop
#       }
#     }
#     
#     alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
#     while ((alpha_F_prop < 0.01) | (alpha_F_prop > 10)){
#       alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
#     }
#     alpha_acc <- dgamma(alpha_F_prop, 1, 1, log = T) +
#       Z_max * lgamma(alpha_F_prop) -
#       Z_max^2 * lgamma(alpha_F_prop/Z_max) +
#       (alpha_F_prop/Z_max - 1) * sum(log(Q_F_old)) - (
#         dgamma(alpha_F_old, 1, 1, log = T) +
#           Z_max * lgamma(alpha_F_old) -
#           Z_max^2 * lgamma(alpha_F_old/Z_max) +
#           (alpha_F_old/Z_max - 1) * sum(log(Q_F_old)))
#     
#     l_u <- log(runif(1))
#     if (!is.na(alpha_acc)){
#       if (l_u < alpha_acc){
#         alpha_F_old <- alpha_F_prop
#       }
#     }
#     
#     
#     # (6) Correction term for random effects.
#     corr_mu <- colMeans(beta_mu_u_old)
#     beta_mu_u_old <- t(t(beta_mu_u_old) - corr_mu)
#     
#     for (k in 1:J){
#       if (k == 1){
#         idx_time <- which(time == T_min)
#       }
#       else if (k == J){
#         idx_time <- which(time == T_max)
#       }
#       else {
#         idx_time <- which(time %in% (k-1):k)
#       }
#       beta_mu_old[k,,] <- beta_mu_old[k,,] + corr_mu[k]
#       beta_mu_star_old[k,] <- beta_mu_star_old[k,] + corr_mu[k]
#       B_beta_mu_dat[idx_time,] <- B_beta_mu_dat[idx_time,] + 
#         as.numeric(B[idx_time,] %*% corr_mu)
#       B_beta_mu_u_dat[idx_time,] <- B_beta_mu_u_dat[idx_time,] - 
#         as.numeric(B[idx_time,] %*% corr_mu)
#     }
#     
#     
#     # (7) Update the cluster specific smoothness parameters
#     RSS_mu <- RSS_b <- 0
#     for (h1 in 1:d_j[1]){
#       for (h2 in 1:d_j[2]){
#         RSS_mu <- RSS_mu + as.numeric(crossprod(beta_mu_old[,h1,h2], P_smooth) %*% 
#                                         beta_mu_old[,h1,h2])
#       }
#     }
#     
#     sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
#     while(sigma2_1_mu_temp > 0.2){
#       sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
#     }
#     l_alpha <- min(c(0, log(sigma2_1_mu_temp) - 
#                        0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_temp) - 
#                        0.5/sigma2_1_mu_temp * RSS_mu + 
#                        dhalfcauhy(sqrt(sigma2_1_mu_temp), hypers$s_sigma_mu, T) -
#                        (log(sigma2_1_mu_old) - 
#                           0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_old) -
#                           0.5/sigma2_1_mu_old * RSS_mu +
#                           dhalfcauhy(sqrt(sigma2_1_mu_old), hypers$s_sigma_mu, T))))
#     l_u <- log(runif(1))
#     if (l_u < l_alpha){
#       sigma2_1_mu_old <- sigma2_1_mu_temp
#     }
#     
#     
#     # (9a) Update random effects for mu
#     reff <- sample_reff_mu(tau, D, cens, beta_mu_u_old, delta_dat, b_dat,
#                            B_beta_mu_dat, mu_dat, B, P_smooth, ind, time,
#                            sigma2_mu_us_old, sigma2_mu_ua_old, sd_beta_mu_u,
#                            acc_beta_mu_u)
#     beta_mu_u_old <- reff$beta_u_old
#     B_beta_mu_u_dat <- reff$B_beta_u_dat
#     mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
#     acc_beta_mu_u <- reff$acc_beta_u
#     
#     
#     # (9b) Update random effects variances: MH with log normal proposal
#     ls_var <- sample_smooth_var(sigma2_mu_ua_old, sigma2_mu_us_old,
#                                 beta_mu_u_old, P_smooth, n_ind)
#     sigma2_mu_ua_old <- ls_var$sigma2_ua_old
#     sigma2_mu_us_old <- ls_var$sigma2_us_old
#     
#     
#     
#     # After burnin, save parameters in the chain
#     if ( (iter > burnin) && (iter %% thin == 0) ){
#       # This is the correction for the random effects integration: we need to 
#       # compute the variance of the random effects as detailed in the Supplementary
#       # Materials
#       cov_reff_mu <- solve(diag(J) / sigma2_mu_ua_old + P_smooth / sigma2_mu_us_old)
#       corr_term_gr_mu <- rep(0, nrow(Bgrid))
#       corr_term_mu <- rep(0, T_max)
#       for (k in 1:J){
#         corr_term_gr_mu <- corr_term_gr_mu + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_mu[,k]))
#         corr_term_mu <- corr_term_mu + rowSums(B_basis(1:T_max, knots)[,k] * 
#                                                  t(t(B_basis(1:T_max, knots)) * cov_reff_mu[,k]))
#       }
#       
#       
#       
#       for (h1 in 1:d_j[1]){
#         for (h2 in 1:d_j[2]){
#           post_mean_mu[,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + 0.5 * corr_term_gr_mu)
#           post_mean_b[,h1,h2,it] <- exp(b_old[h2])
#           
#           for (i in 1:n_ind){
#             post_ind_mu[,i,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + Bgrid %*% beta_mu_u_old[i,])
#             post_ind_b[,i,h1,h2,it] <- exp(b_old[h2])
#           }
#         }
#       }
#       post_ind_delta[,,it] <- delta_old
#       
#       
#       # Sample from the predictive distributions of response category and 
#       # response times (so we can use predictive checks to evaluate goodness of fit)
#       for (t in 1:T_max){
#         for (x_temp in 1:d_j[1]){
#           mu_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_mu_old[,x_temp,] + 
#                                       0.5 * corr_term_mu[t]))
#           b_temp <- exp(b_old)
#           delta_temp <- mean(delta_old[x_temp,])
#           pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
#                                                                 b_temp^2)
#           pred_ans[t,x_temp,it] <- which.min(pred_temp)
#           pred_time[t,x_temp,it] <- min(pred_temp)
#           
#           for (i in 1:n_ind){
#             mu_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_mu_old[,x_temp,] + 
#                                                                beta_mu_u_old[i,])))
#             b_temp <- exp(b_old)
#             delta_temp <- delta_old[x_temp,i]
#             pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
#                                                                   b_temp^2)
#             pred_ans_ind[i,t,x_temp,it] <- which.min(pred_temp)
#             pred_time_ind[i,t,x_temp,it] <- min(pred_temp)
#           }
#         }
#       }
#       
#       
#       # Save MCMC objects
#       post_mean_delta[it,] <- rowMeans(delta_old)
#       sigma2_mu_us[it] <- sigma2_mu_us_old
#       sigma2_mu_ua[it] <- sigma2_mu_ua_old
#       sigma2_1_mu[it] <- sigma2_1_mu_old
#       # z[,,it] <- z_temp		
#       loglik_chain[it,] <- log_likelihood_ind(tau, mu_dat, b_dat, delta_dat, cens, D)
#       
#       it <- it + 1
#     }
#     setTxtProgressBar(pb, iter/Niter)
#   }
#   
#   return(list(
#     # 'Z' = z, 
#     'post_mean_delta' = post_mean_delta, 
#     'post_mean_mu' = post_mean_mu,
#     'post_mean_b' = post_mean_b,
#     'post_ind_delta' = post_ind_delta,
#     'post_ind_mu' = post_ind_mu,
#     'post_ind_b' = post_ind_b,
#     'sigma2_mu_us' = sigma2_mu_us, 
#     'sigma2_mu_ua' = sigma2_mu_ua,
#     'sigma2_1_mu' = sigma2_1_mu, 
#     'pred_ans' = pred_ans, 
#     'pred_time' = pred_time,
#     'pred_ans_ind' = pred_ans_ind, 
#     'pred_time_ind' = pred_time_ind, 
#     'loglik' = loglik_chain
#   ))
# }


#' Parameter posterior means
#'
#' Function to extract the posterior means of the parameters of interest from a lddmm fit object.
#'
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @param fit fit from the lddmm function
#' @param par parameter to output ('drift', or 'boundary')
#' @return Matrix with the following columns: 
#' * subject: participant labels
#' * block: training blocks
#' * par_s_d, ...: posterior means for the requested parameters
extract_post_mean = function(data, fit, par = c('drift', 'boundary')){
  
  block = d = s = subject = value = NULL
  
  if (par == 'drift'){
    post_mean <- apply(fit$post_ind_mu, 1:4, mean)
  }
  else if (par == 'boundary'){
    post_mean <- apply(fit$post_ind_b, 1:4, mean)
  }
  
  T_min <- min(data$block)
  T_max <- max(data$block)
  
  estimated_pars <- dplyr::tibble(reshape2::melt(post_mean, varnames = c('block', 'subject', 's', 'd'))) %>% 
    dplyr::filter(s == d) %>% 
    dplyr::mutate(block = (T_min:T_max)[block], 
                  subject = plyr::mapvalues(factor(subject), from = levels(factor(subject)), to = levels(factor(data$subject)))) %>%  
    pivot_wider(names_from = c(s, d), 
                values_from = value, 
                names_prefix = paste0(par, '_'))
  
  return (estimated_pars)
}



#' Parameter posterior draws
#'
#' Function to extract the posterior draws of the parameters of interest from a lddmm fit object.
#'
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @param fit fit from the lddmm function
#' @param par parameter to output ('drift', or 'boundary')
#' @return Matrix with the following columns: 
#' * subject: participant labels
#' * block: training blocks
#' * draw: iteration of the MCMC estimates
#' * par_s_d, ...: posterior draws for the requested parameters
extract_post_draws = function(data, fit, par = c('drift', 'boundary')){
  
  block = d = s = subject = value = NULL
  
  if (par == 'drift'){
    post_draw <- fit$post_ind_mu
  }
  else if (par == 'boundary'){
    post_draw <- fit$post_ind_b
  }
  d1 = dim(post_draw)[3]
  samp_size = dim(post_draw)[5]
  
  T_min <- min(data$block)
  T_max <- max(data$block)
  
  draw_pars <- dplyr::tibble(reshape2::melt(post_draw, varnames = c('block', 'subject', 's', 'd', 'draw'))) %>% 
    dplyr::filter(s == d) %>% 
    dplyr::mutate(block = (T_min:T_max)[block], 
                  subject = plyr::mapvalues(factor(subject), from = levels(factor(subject)), to = levels(factor(data$subject)))) %>%  
    pivot_wider(names_from = c(s, d), 
                values_from = value, 
                names_prefix = paste0(par, '_'))
  
  return (draw_pars)
}


#' Plot posterior estimates
#'
#' Function to plot the posterior mean and credible intervals of the parameters of interest from a lddmm fit object.
#'
#' @param data dataframe with the following columns:
#' * subject: vector of size n containing the participant labels
#' * block: vector of size n containing the training blocks (longitudinal units)
#' * s: vector of size n containing the stimuli
#' * d: vector of size n containing the decisions
#' * r_time: vector of size n containing the response times
#' * cens: vector of size n containing the censoring indicators (1 censored, 0 non censored)
#' @param fit fit from the lddmm function
#' @param par parameter to output ('drift', or 'boundary')
#' @return Posterior mean and 95% CI
plot_post_pars = function(data, fit, par = c('drift', 'boundary')){
  
  Var1 = Var3 = low = mean_r_time = upp = value = NULL 
  
  if (par == 'drift'){
    post_draw <- fit$post_mean_mu
    tex_string <- "$\\mu_{d,s}(t)$"
  }
  else if (par == 'boundary'){
    post_draw <- fit$post_mean_b
    tex_string <- "$\\b_{d,s}(t)$"
  }
  
  K <- max(data$block) - min(data$block) + 1
  xgrid <- seq(min(data$block), max(data$block), by = 1)
  
  post_mean <- reshape2::melt(apply(post_draw, 1:3, mean))
  post_quant <- reshape2::melt(apply(post_draw, 1:3, quantile, probs = 0.05))
  post_mean$Var1 <- rep(xgrid, nrow(post_mean)/length(xgrid))
  post_mean$low <- post_quant$value
  post_quant <- reshape2::melt(apply(post_draw, 1:3, quantile, probs = 0.95))
  post_mean$upp <- post_quant$value
  
  
  ggplot(post_mean) +
    geom_line(aes(x = Var1, y = value, col = factor(Var3)), size = 1.5) +
    geom_ribbon(aes(x = Var1, ymin = low, ymax = upp, fill = factor(Var3)), alpha = 0.4) +
    facet_wrap( ~ Var2, nrow = 2, ncol = 2) +
    labs(x = 'block', y = TeX(tex_string)) +
    scale_color_brewer(name = "", palette = "Set1") +
    scale_fill_brewer(name = "", palette = "Set1") +
    theme(plot.title = element_text(hjust = 0.5)) +
    scale_x_continuous(breaks=seq(1, K, by = 1)) +
    theme(legend.position = "none")
}


#' Calculate WAIC
#'
#' Function to compute the Watanabe-Akaike information criterion (Gelman, Hwang, 
#' Vehtari, 2014), which estimates the expected out-of-sample-prediction error 
#' using a bias-corrected adjustment of within-sample error.
#'
#' @param model_fit results of a model fit from the lddmm function
#' @return A scalar indicating the WAIC (smaller WAIC denotes better fit)
compute_WAIC = function(model_fit) {
  
  # log pointwise predictive density 
  LPPD = sum(log(apply(exp(model_fit$loglik), 2, mean))) 
  
  # Penalty term
  p_WAIC = sum(apply(model_fit$loglik, 2, var))   
  
  WAIC = -2 * (LPPD - p_WAIC)
  
  return (WAIC) 
}



LDDMM_fix_bound <- function(data, hypers, Niter = 5000, burnin = 2000, thin = 5){
  
  # Choose the number of knots (default is between the beginning and the end of 
  # the study, at every block)
  T_min <- min(data$block)
  T_max <- max(data$block)
  knots <- T_min:T_max
  K <- length(knots)
  P_smooth <- P_smooth1(K + 1)
  
  # Choose a fine grid in order to plot the curves resulting from the spline basis
  # expansion
  xgrid <- seq(T_min, T_max, by = 1)
  
  # Extract quantities of interest
  tau <- data$r_time
  ind <- data$subject
  time <- data$block
  cens <- data$cens
  D <- cbind(data$s, data$d)
  
  n_ind <- length(unique(ind))
  ind <- as.numeric(plyr::mapvalues(factor(ind),
                                    from = levels(factor(ind)),
                                    to = 1:n_ind))
  B <- B_basis(data$block, knots)
  Bgrid <- B_basis(xgrid, knots)
  
  samp_size <- (Niter - burnin)/thin # sample size
  p <- ncol(D) # number of covariates
  d_j <- rep(0, p) # Number of levels for each covariate
  for (j in 1:p){
    d_j[j] <- length(unique(D[!is.na(D[,j]),j]))
  }
  J <- ncol(B) # number of locations
  n <- nrow(B) # total number of observations
  n_ind <- length(unique(ind)) # number of individuals
  
  
  # Rescale the time steps in \{1, ..., T_max\}
  T_max <- max(time)
  T_min <- min(time)
  time <- time - T_min + 1
  T_max <- max(time)
  T_min <- min(time)
  idx_xy <- t(apply(expand.grid(y = 1:d_j[2], x = 1:d_j[1]), 1, rev))
  colnames(idx_xy) <- NULL
  
  
  # Set HB structures
  r_HB <- 1
  Z_max <- min(prod(d_j), 6)
  dim_HB <- sum((Z_max - 1)^(0:r_HB) * choose(d_j[2], 0:r_HB))
  
  # Set MCMC objects
  z <- array(NA, dim = c(prod(d_j), J, samp_size))
  post_mean_delta <- array(NA, dim = c(samp_size, d_j[1]))
  post_mean_mu <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_mean_b <- array(NA, dim = c(nrow(Bgrid), d_j, samp_size))
  post_ind_delta <- array(NA, dim = c(d_j[1], n_ind, samp_size))
  post_ind_mu <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  post_ind_b <- array(NA, dim = c(nrow(Bgrid), n_ind, d_j, samp_size))
  sigma2_mu_us <- array(NA, dim = samp_size)
  sigma2_mu_ua <- array(NA, dim = samp_size)
  sigma2_b_us <- array(NA, dim = samp_size)
  sigma2_b_ua <- array(NA, dim = samp_size)
  sigma2_1_mu <- array(NA, dim = samp_size)
  sigma2_1_b <- array(NA, dim = samp_size)
  pred_time <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_ans <- array(NA, dim = c(T_max, d_j[1], samp_size))
  pred_time_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  pred_ans_ind <- array(NA, dim = c(n_ind, T_max, d_j[1], samp_size))
  loglik_chain <- array(NA, dim = c(samp_size, n))
  
  # Set initial values
  delta_old <- array(NA, dim = c(d_j[1], n_ind))
  beta_mu_old <- array(NA, dim = c(J, d_j))
  beta_b_old <- beta_b_prop <- array(NA, dim = c(J, d_j[2]))
  delta_dat <- array(NA, dim = n)
  for (s_temp in 1:d_j[1]){
    for (i_temp in 1:n_ind){
      idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
      if (length(idx_temp) > 0){
        delta_old[s_temp,i_temp] <- min(tau[idx_temp])/2
        delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
      }
      else{
        delta_old[s_temp,i_temp] <- 1E-3
      }
    }
    for(j in 1:d_j[2]){
      beta_mu_old[,s_temp,j] <- rep(0.5*log(mean(tau[D[,1] == s_temp])) - log(sd(tau[D[,1] == s_temp])), J)
      beta_b_old[,j] <- rep(1.5*log(mean(tau)) - log(sd(tau)), J)
    }
  }
  
  
  low_bound_mu <- min(beta_mu_old) - 1.5
  upp_bound_mu <- max(beta_mu_old) + 1
  low_bound_b <- min(beta_b_old) - 1.5
  upp_bound_b <- max(beta_b_old) + 1
  
  beta_mu_star_prop <- array(NA, dim = c(J, Z_max))
  beta_mu_u_old <- array(0, dim = c(n_ind, J))
  beta_b_u_old <- array(0, dim = c(n_ind, J))
  sigma2_1_mu_old <- 0.005
  sigma2_1_b_old <- 0.005
  sigma2_b_us_old <- 0.005
  sigma2_mu_us_old <- 0.005
  sigma2_b_ua_old <- 0.005
  sigma2_mu_ua_old <- 0.005
  
  
  # Message passing structures
  beta_mess <- array(NA, dim = c(J, dim_HB))
  z_old <- list()
  
  for (j in 1:d_j[1]){
    z_old[[j]] <- array(NA, dim = c(d_j[2], J))
    for (jj in 1:d_j[2]){
      if (j == jj){
        z_old[[j]][jj,] <- j
      }
      else{
        z_old[[j]][jj,] <- sample((d_j[1] + 1):Z_max, 1)
      }
    }
  }
  z_temp <- do.call(rbind, z_old)
  
  
  beta_mu_star_old <- array(NA, dim = c(J, Z_max))
  for (i in 1:Z_max){
    beta_mu_star_old[,i] <- beta_mu_old[1,idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],1],
                                        idx_xy[which(apply(z_temp == i, 1, prod) == 1)[1],2]] 
  }
  
  rand_mat <- array(rnorm(prod(d_j)), dim = d_j)
  idx_succ <- which(rand_mat == diag(rand_mat))
  
  v_old <- array(NA, dim = c(d_j[2], J))
  z_prop <- list()
  
  # Transition dynamics objects
  alpha_S_old <- alpha_F_old <- 1
  Q_S_old <- array(NA, dim = c(Z_max, Z_max))
  Q_F_old <- array(NA, dim = c(Z_max, Z_max))
  pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                         table_int(z_temp[idx_succ,1], Z_max))
  pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                         table_int(z_temp[-idx_succ,1], Z_max))
  tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
  tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
  for (h in 1:Z_max){
    Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + tr_count_S[h,])
    Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + tr_count_F[h,])
  }
  
  
  # Auxiliary variables
  prob_mat <- array(NA, dim = c(dim_HB, dim_HB))
  prob_vec <- array(NA, dim = dim_HB)
  
  # MH proposal parameters
  sd_MH_delta <- array(0.3, dim = c(d_j[1], n_ind))
  sd_MH_beta_mu <- 0.4
  sd_MH_beta_b <- 0.4
  sd_beta_mu_u <-  array(0.4, dim = c(n_ind, J))
  sd_beta_b_u <-  array(0.4, dim = c(n_ind, J))
  acc_delta <- array(0, dim = c(d_j[1], n_ind))
  acc_beta_mu <- array(0, dim = c(J, Z_max))
  acc_beta_mu_u <- array(0, dim = c(n_ind, J))
  acc_beta_b <- array(0, dim = c(J, d_j[2]))
  acc_beta_b_u <- array(0, dim = c(n_ind, J))
  n_batch <- 0
  
  # Auxiliary variables
  B_beta_mu_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_mu_u_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_b_dat <- array(0, dim = c(n, d_j[1]))
  B_beta_b_u_dat <- array(0, dim = c(n, d_j[1]))
  mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
  b_dat <- exp(B_beta_b_dat + B_beta_b_u_dat)
  
  
  # Gibbs Sampler
  it <- 1
  pb <- txtProgressBar(style = 3)
  for (iter in 1:Niter){
    
    # (0) Adaptively tune the MH variance for the proposals of \delta_{s, i}, 
    # beta_u_mu, beta_u_b
    if (iter %% 20 == 0){
      n_batch <- n_batch + 1
      delta_n <- min(0.01, n_batch^(-0.5))
      
      for (i in 1:n_ind){
        for (x_temp in 1:d_j[1]){
          if (acc_delta[x_temp,i]/iter > 0.44){
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) + delta_n)
          }
          else{
            sd_MH_delta[x_temp,i] <- exp(log(sd_MH_delta[x_temp,i]) - delta_n)
          }
        }
        for (k in 1:J){
          if (acc_beta_mu_u[i,k]/iter > 0.44){
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) + delta_n)
          }
          else{
            sd_beta_mu_u[i,k] <- exp(log(sd_beta_mu_u[i,k]) - delta_n)
          }
          if (acc_beta_b_u[i,k]/iter > 0.44){
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) + delta_n)
          }
          else{
            sd_beta_b_u[i,k] <- exp(log(sd_beta_b_u[i,k]) - delta_n)
          }
        }
      }
    }
    
    
    # (1) Update of the delta parameter: \delta_{s,i}: MH with log normal 
    #     proposal
    for (s_temp in 1:d_j[1]){
      for (i_temp in 1:n_ind){
        idx_temp <- which((D[,1] == s_temp) & (ind == i_temp))
        tau_temp <- tau[idx_temp]
        cens_temp <- cens[idx_temp]
        D_temp <- D[idx_temp,]
        
        # log-normal proposal distribution centered on the current value
        delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        while (delta_prop > min(tau[idx_temp])){
          delta_prop <- exp(rnorm(1, log(delta_old[s_temp,i_temp]), sd_MH_delta[s_temp,i_temp]))
        }
        loglik_prop <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                      b_dat[idx_temp,], 
                                      rep(delta_prop, length(idx_temp)), 
                                      cens_temp, D_temp, TRUE)
        loglik_old <- log_likelihood(tau_temp, mu_dat[idx_temp,], 
                                     b_dat[idx_temp,], 
                                     rep(delta_old[s_temp,i_temp], length(idx_temp)), 
                                     cens_temp, D_temp, TRUE)
        
        alpha_acc <- min(0, loglik_prop + log(delta_prop) -
                           loglik_old - log(delta_old[s_temp,i_temp]))
        l_u <- log(runif(1))
        
        if (l_u < alpha_acc){
          delta_old[s_temp,i_temp] <- delta_prop
          delta_dat[idx_temp] <- delta_old[s_temp,i_temp]
          acc_delta[s_temp,i_temp] <- acc_delta[s_temp,i_temp] + 1
        }
        
      }
    }
    
    
    # (2) Update of mu parameters: \mu_{x,y}^{(i)}(t)
    for (k in 1:J){ # loop over locations
      if (k == 1){ # only data at t = 1 influence the first coefficient
        idx_time <- T_min
      }
      else if (k == J){ # only data at t = T influence the last coefficient
        idx_time <- T_max
      }
      else { # data at t = {k-1,k} influence the kth coefficient
        idx_time <- (k-1):k
      }
      
      for (h in 1:Z_max){ # loop over possible latent values
        # tuples (combinations of covariates) that are clustered together via 
        # the latent h
        idx_cov <- matrix(idx_xy[which(z_temp[,k] == h),], length(which(z_temp[,k] == h)), p)
        
        X_1k <- unique(idx_cov[,1]) # all possible values of x in this cluster
        X_2k <- unique(idx_cov[,2]) # all possible values of y in this cluster
        
        if (length(X_1k) > 0){ # h \in \mathcal{Z}_{j,k}: posterior update
          # Pick data with covariate levels of x clustered in group h and 
          # at the correct locations
          idx_i <- which( (D[,1] %in% X_1k) & (time %in% idx_time) )
          tau_temp <- tau[idx_i]
          cens_temp <- cens[idx_i]
          D_temp <- D[which((D[,1] %in% X_1k) & (time %in% idx_time)),]
          
          
          # Normal proposal distribution centered on the current value
          if (k == 1){
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_mu_star_old[k,h], sd_MH_beta_mu)
          }
          # Normal proposal from the prior
          else if (k == J){
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          # Normal proposal from the prior
          else {
            beta_pre1 <- beta_mu_star_old[k-1, unique(z_temp[which(z_temp[,k] == h),k-1])]
            
            beta_mu_star_prop[,h] <- beta_mu_star_old[,h]
            beta_mu_star_prop[k,h] <- rnorm(1, beta_pre1, sqrt(sigma2_1_mu_old))
          }
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          
          # Modify the proposed values in the corresponding positions
          for (hh in 1:nrow(idx_cov)){
            B_beta_mu_prop_dat[which(D_temp[,1] == idx_cov[hh,1]),idx_cov[hh,2]] <- 
              B[idx_i[which(D_temp[,1] == idx_cov[hh,1])],] %*% beta_mu_star_prop[,h]
          }
          
          # This is the proposed value for \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
          mu_prop_dat <- exp(B_beta_mu_prop_dat + B_beta_mu_u_dat[idx_i,])
          
          
          if (k == 1){
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          else if (k == J){
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE)
          }
          else {
            beta_post1 <- beta_mu_star_old[k+1, unique(z_temp[which(z_temp[,k] == h),k+1])]
            
            logpost_prop <- log_likelihood(tau_temp, mu_prop_dat, 
                                           b_dat[idx_i,], delta_dat[idx_i], 
                                           cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_prop[k,h] - beta_post1)^2)
            logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                          b_dat[idx_i,], delta_dat[idx_i], 
                                          cens_temp, D_temp, TRUE) -
              0.5/sigma2_1_mu_old * sum((beta_mu_star_old[k,h] - beta_post1)^2)
          }
          
          alpha_acc <- min(0, logpost_prop - logpost_old)
          l_u <- log(runif(1))
          
          if (l_u < alpha_acc){
            beta_mu_star_old[k,h] <- beta_mu_star_prop[k,h]
            B_beta_mu_dat[idx_i,] <- B_beta_mu_prop_dat
            mu_dat[idx_i,] <- mu_prop_dat
            acc_beta_mu[k,h] <- acc_beta_mu[k,h] + 1
          }
        }
        else { # h \notin \mathcal{Z}_{1,k}: prior sampling
          beta_mu_star_old[k,h] <- runif(1, low_bound_mu, upp_bound_mu)
        }
      }
    }
    
    
    # (2) Update of b parameters: \b_{y}(t)
    for (k in 1:J){ # loop over locations
      if (k == 1){ # only data at t = 1 influence the first coefficient
        idx_time <- T_min
      }
      else if (k == J){ # only data at t = T influence the last coefficient
        idx_time <- T_max
      }
      else { # data at t = {k-1,k} influence the kth coefficient
        idx_time <- (k-1):k
      }
      
      # Pick data with covariate levels of x clustered in group h and 
      # at the correct locations
      
      for (d in 1:d_j[2]){
        idx_i <- which( (D[,2] == d) & (time %in% idx_time) )
        tau_temp <- tau[idx_i]
        cens_temp <- cens[idx_i]
        D_temp <- D[which( (D[,2] == d) & (time %in% idx_time) ),]
        
        # Normal proposal distribution centered on the current value
        if (k == 1){
          beta_b_prop[k,] <- beta_b_old[k,]
          beta_b_prop[k,d] <- rnorm(1, beta_b_old[k,d], sd_MH_beta_b)
        }
        # Normal proposal from the prior
        else if (k == J){
          beta_b_prop[k,] <- beta_b_old[k,]
          beta_b_prop[k,d] <- rnorm(1, beta_b_old[k-1,d], sd_MH_beta_b)
        }
        # Normal proposal from the prior
        else {
          beta_b_prop[k,] <- beta_b_old[k,]
          beta_b_prop[k,d] <- rnorm(1, beta_b_old[k-1,d], sd_MH_beta_b)
        }
        B_beta_b_prop_dat <- B_beta_b_dat[idx_i,]
        
        # Modify the proposed values in the corresponding positions
        B_beta_b_prop_dat[,d] <- B[idx_i,] %*% beta_b_prop[,d]
        
        
        # This is the proposed value for \b_{x,y}^{(i)}(t), \mu_{x,y}^{(i)}(t)
        b_prop_dat <- exp(B_beta_b_prop_dat + B_beta_b_u_dat[idx_i,])
        
        if (k == 1){
          beta_post1 <- beta_b_old[k+1,d]
          
          logpost_prop <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                         b_prop_dat, delta_dat[idx_i], 
                                         cens_temp, D_temp, TRUE) -
            0.5/sigma2_1_b_old * sum((beta_b_prop[k,d] - beta_post1)^2)
          logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                        b_dat[idx_i,], delta_dat[idx_i], 
                                        cens_temp, D_temp, TRUE) -
            0.5/sigma2_1_b_old * sum((beta_b_old[k,d] - beta_post1)^2)
        }
        else if (k == J){
          logpost_prop <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                         b_prop_dat, delta_dat[idx_i], 
                                         cens_temp, D_temp, TRUE)
          logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                        b_dat[idx_i,], delta_dat[idx_i], 
                                        cens_temp, D_temp, TRUE)
        }
        else {
          beta_post1 <- beta_b_old[k+1,d]
          
          logpost_prop <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                         b_prop_dat, delta_dat[idx_i], 
                                         cens_temp, D_temp, TRUE) -
            0.5/sigma2_1_b_old * sum((beta_b_prop[k,d] - beta_post1)^2)
          logpost_old <- log_likelihood(tau_temp, mu_dat[idx_i,], 
                                        b_dat[idx_i,], delta_dat[idx_i], 
                                        cens_temp, D_temp, TRUE) -
            0.5/sigma2_1_b_old * sum((beta_b_old[k,d] - beta_post1)^2)
        }
        
        alpha_acc <- min(0, logpost_prop - logpost_old)
        l_u <- log(runif(1))
        
        if (l_u < alpha_acc){
          beta_b_old[k,d] <- beta_b_prop[k,d]
          B_beta_b_dat[idx_i,] <- B_beta_b_prop_dat[,d]
          b_dat[idx_i,] <- b_prop_dat
          acc_beta_b[k,d] <- acc_beta_b[k,d] + 1
        }
      }
    }
    
    
    # (3) Update the cluster assignments
    for (x_temp in 1:d_j[1]){ # loop over possible latent values
      beta_mess <- array(-Inf, dim = c(J, dim_HB))
      beta_mess[J,] <- 1/dim_HB
      
      v_old[,J] <- H_ball_unif(z_old[[x_temp]][,J], S = Z_max, r = r_HB)
      z_prop[[J]] <- H_ball(v_old[,J], S = Z_max, r = r_HB)
      for (k in (J - 1):1){
        idx_i <- which( (time == k) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_i]
        cens_temp <- cens[idx_i]
        D_temp <- D[idx_i,]
        
        # (i) Sample the auxiliary variables
        v_temp <- H_ball(z_old[[x_temp]][,k], S = Z_max, r = r_HB)
        
        probs <- rep(-Inf, dim_HB)
        for (h in 1:dim_HB){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
          
          B_beta_mu_prop_dat <- 0.5 * (beta_mu_star_old[k,v_temp[,h]] +
                                         beta_mu_star_old[k+1,z_old[[x_temp]][,k+1]])
          
          mu_dat_prop <- exp(t(B_beta_mu_prop_dat + t(B_beta_mu_u_dat[idx_i,])))
          
          probs[h] <- g_HB(log_likelihood(tau_temp, mu_dat_prop, b_dat[idx_i,],
                                          delta_dat[idx_i], cens_temp, D_temp, TRUE))
        }
        probs <- as.numeric(normalise_log(probs))
        
        v_old[,k] <- v_temp[,sample(1:dim_HB, 1, prob = probs)]
        z_prop[[k]] <- H_ball(v_old[,k], S = Z_max, r = r_HB)
        
        
        # (ii) Pass messages backwards only in the restricted state space given
        #      by the slice
        z_kp1_temp <- which(beta_mess[k+1,] > 0)
        prob_mat <- array(-Inf, dim = c(dim_HB, dim_HB))
        
        for (h1 in z_kp1_temp){
          for (h2 in 1:dim_HB){
            
            B_beta_mu_prop_dat <- B_beta_mu_dat[idx_i,]
            B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k,z_prop[[k]][,h2]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k,v_old[,k]] +
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])
            
            mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_i,])))
            mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_i,])))
            
            prob_mat[h2,h1] <- log(beta_mess[k+1,h1]) -
              0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k,z_prop[[k]][,h2]] -
                                             beta_mu_star_old[k+1,z_prop[[k+1]][,h1]])^2) +
              log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_i,],
                             delta_dat[idx_i], cens_temp, D_temp, TRUE) +
              g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_i,],
                                  delta_dat[idx_i], cens_temp, D_temp, TRUE)) +
              sum(log(Q_F_old[cbind(z_prop[[k]][-x_temp,h2],z_prop[[k+1]][-x_temp,h1])])) +
              log(Q_S_old[z_prop[[k]][x_temp,h2],z_prop[[k+1]][x_temp,h1]])
          }
        }
        if ( sum(is.infinite(sum_rows_log(prob_mat))) == dim_HB){
          beta_mess[k,] <- 1/dim_HB
        }
        else{
          beta_mess[k,] <- as.numeric(sum_rows_log(prob_mat))
          beta_mess[k,] <- as.numeric(normalise_log(beta_mess[k,]))
        }
      }
      
      
      # (iii) Sample states forward (only on allowed states)
      idx_fail <- (1:d_j[2])[-x_temp]
      # Sample z_1
      prob_vec <- log(beta_mess[1,]) + log(pi_S_0[z_prop[[1]][x_temp,]]) +
        colSums(matrix(log(pi_F_0[z_prop[[1]][-x_temp,]]), d_j[2] - 1, dim_HB))
      prob_vec <- as.numeric(normalise_log(prob_vec))
      
      idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
      z_old[[x_temp]][,1] <- z_prop[[1]][,idx_samp]
      
      # Sample z_k
      for (k in 2:J){
        idx_km1 <- which( (time == k - 1) & (D[,1] == x_temp) )
        tau_temp <- tau[idx_km1]
        cens_temp <- cens[idx_km1]
        D_temp <- D[idx_km1,]
        
        prob_vec <- log(beta_mess[k,]) + 
          log(Q_S_old[cbind(z_old[[x_temp]][x_temp,k-1], z_prop[[k]][x_temp,])])
        for (kkk in idx_fail){
          prob_vec <- prob_vec + log(Q_F_old[cbind(z_old[[x_temp]][kkk,k-1], z_prop[[k]][kkk,])])
        }
        
        for (z_k_temp in which(is.finite(prob_vec))){
          B_beta_mu_prop_dat <- B_beta_mu_dat[idx_km1,]
          
          B_beta_mu_prop_dat_1 <- 0.5 * (beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          B_beta_mu_prop_dat_2 <- 0.5 * (beta_mu_star_old[k-1,v_old[,k-1]] +
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])
          
          mu_dat_prop_1 <- exp(t(B_beta_mu_prop_dat_1 + t(B_beta_mu_u_dat[idx_km1,])))
          mu_dat_prop_2 <- exp(t(B_beta_mu_prop_dat_2 + t(B_beta_mu_u_dat[idx_km1,])))
          
          prob_vec[z_k_temp] <- prob_vec[z_k_temp] -
            0.5 / sigma2_1_mu_old * sum((beta_mu_star_old[k-1,z_old[[x_temp]][,k-1]] -
                                           beta_mu_star_old[k,z_prop[[k]][,z_k_temp]])^2) +
            log_likelihood(tau_temp, mu_dat_prop_1, b_dat[idx_km1,],
                           delta_dat[idx_km1], cens_temp, D_temp, TRUE) +
            g_HB(log_likelihood(tau_temp, mu_dat_prop_2, b_dat[idx_km1,],
                                delta_dat[idx_km1], cens_temp, D_temp, TRUE))
        }
        prob_vec <- as.numeric(normalise_log(prob_vec))
        
        idx_samp <- sample(1:dim_HB, 1, FALSE, prob_vec)
        z_old[[x_temp]][,k] <- z_prop[[k]][,idx_samp]
      }
      
      
      # (4) Assign the cluster specific curves f_{\mu}
      for (y_temp in 1:d_j[2]){
        beta_mu_old[,x_temp,y_temp] <- beta_mu_star_old[cbind(1:J, z_old[[x_temp]][y_temp,])]
      }
      B_beta_mu_dat[(D[,1] == x_temp),] <- B[D[,1] == x_temp,] %*% beta_mu_old[,x_temp,]
    }
    z_temp <- do.call(rbind, z_old)
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    
    
    
    # (5) Update the transition probabilities
    pi_S_0 <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + table_int(z_temp[idx_succ,1], Z_max))
    pi_F_0 <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + table_int(z_temp[-idx_succ,1], Z_max))
    tr_count_S <- count_assign(z_temp[idx_succ,], Z_max)
    tr_count_F <- count_assign(z_temp[-idx_succ,], Z_max)
    for (h in 1:Z_max){
      Q_S_old[h,] <- rdirichlet(rep(alpha_S_old / Z_max, Z_max) + 
                                  tr_count_S[h,])
      Q_F_old[h,] <- rdirichlet(rep(alpha_F_old / Z_max, Z_max) + 
                                  tr_count_F[h,])
    }
    
    
    alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    while ((alpha_S_prop < 0.01) | (alpha_S_prop > 10)){
      alpha_S_prop <- exp(rnorm(1, log(alpha_S_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_S_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_S_prop) -
      Z_max^2 * lgamma(alpha_S_prop/Z_max) +
      (alpha_S_prop/Z_max - 1) * sum(log(Q_S_old)) - (
        dgamma(alpha_S_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_S_old) -
          Z_max^2 * lgamma(alpha_S_old/Z_max) +
          (alpha_S_old/Z_max - 1) * sum(log(Q_S_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_S_old <- alpha_S_prop
      }
    }
    
    alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    while ((alpha_F_prop < 0.01) | (alpha_F_prop > 10)){
      alpha_F_prop <- exp(rnorm(1, log(alpha_F_old), 0.1))
    }
    alpha_acc <- dgamma(alpha_F_prop, 1, 1, log = T) +
      Z_max * lgamma(alpha_F_prop) -
      Z_max^2 * lgamma(alpha_F_prop/Z_max) +
      (alpha_F_prop/Z_max - 1) * sum(log(Q_F_old)) - (
        dgamma(alpha_F_old, 1, 1, log = T) +
          Z_max * lgamma(alpha_F_old) -
          Z_max^2 * lgamma(alpha_F_old/Z_max) +
          (alpha_F_old/Z_max - 1) * sum(log(Q_F_old)))
    
    l_u <- log(runif(1))
    if (!is.na(alpha_acc)){
      if (l_u < alpha_acc){
        alpha_F_old <- alpha_F_prop
      }
    }
    
    
    # (6) Correction term for random effects.
    corr_mu <- colMeans(beta_mu_u_old)
    corr_b <- colMeans(beta_b_u_old)
    beta_mu_u_old <- t(t(beta_mu_u_old) - corr_mu)
    beta_b_u_old <- t(t(beta_b_u_old) - corr_b)
    
    for (k in 1:J){
      if (k == 1){
        idx_time <- which(time == T_min)
      }
      else if (k == J){
        idx_time <- which(time == T_max)
      }
      else {
        idx_time <- which(time %in% (k-1):k)
      }
      beta_mu_old[k,,] <- beta_mu_old[k,,] + corr_mu[k]
      beta_b_old[k,] <- beta_b_old[k,] + corr_b[k]
      beta_mu_star_old[k,] <- beta_mu_star_old[k,] + corr_mu[k]
      B_beta_mu_dat[idx_time,] <- B_beta_mu_dat[idx_time,] + 
        as.numeric(B[idx_time,] %*% corr_mu)
      B_beta_mu_u_dat[idx_time,] <- B_beta_mu_u_dat[idx_time,] - 
        as.numeric(B[idx_time,] %*% corr_mu)
      B_beta_b_dat[idx_time,] <- B_beta_b_dat[idx_time,] + 
        as.numeric(B[idx_time,] %*% corr_b)
      B_beta_b_u_dat[idx_time,] <- B_beta_b_u_dat[idx_time,] -
        as.numeric(B[idx_time,] %*% corr_b)
    }
    
    
    # (7) Update the cluster specific smoothness parameters
    RSS_mu <- RSS_b <- 0
    for (h2 in 1:d_j[2]){
      RSS_b <- RSS_b + as.numeric(crossprod(beta_b_old[,h2], P_smooth) %*% 
                                    beta_b_old[,h2])
      for (h1 in 1:d_j[1]){
        RSS_mu <- RSS_mu + as.numeric(crossprod(beta_mu_old[,h1,h2], P_smooth) %*% 
                                        beta_mu_old[,h1,h2])
      }
    }
    
    sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    while(sigma2_1_mu_temp > 0.2){
      sigma2_1_mu_temp <- exp(rnorm(1, log(sigma2_1_mu_old), 0.1))
    }
    l_alpha <- min(c(0, log(sigma2_1_mu_temp) - 
                       0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_temp) - 
                       0.5/sigma2_1_mu_temp * RSS_mu + 
                       dhalfcauhy(sqrt(sigma2_1_mu_temp), hypers$s_sigma_mu, T) -
                       (log(sigma2_1_mu_old) - 
                          0.5 * (prod(d_j)*(J-1)) * log(sigma2_1_mu_old) -
                          0.5/sigma2_1_mu_old * RSS_mu +
                          dhalfcauhy(sqrt(sigma2_1_mu_old), hypers$s_sigma_mu, T))))
    l_u <- log(runif(1))
    if (l_u < l_alpha){
      sigma2_1_mu_old <- sigma2_1_mu_temp
    }
    
    sigma2_1_b_temp <- exp(rnorm(1, log(sigma2_1_b_old), 0.1))
    while(sigma2_1_b_temp > 0.2){
      sigma2_1_b_temp <- exp(rnorm(1, log(sigma2_1_b_old), 0.1))
    }
    l_alpha <- min(c(0, log(sigma2_1_b_temp) - 
                       0.5 * (d_j[2]*(J-1)) * log(sigma2_1_b_temp) - 
                       0.5/sigma2_1_b_temp * RSS_b + 
                       dhalfcauhy(sqrt(sigma2_1_b_temp), hypers$s_sigma_b, T) -
                       (log(sigma2_1_b_old) - 
                          0.5 * (d_j[2]*(J-1)) * log(sigma2_1_b_old) -
                          0.5/sigma2_1_b_old * RSS_b +
                          dhalfcauhy(sqrt(sigma2_1_b_old), hypers$s_sigma_b, T))))
    l_u <- log(runif(1))
    if (l_u < l_alpha){
      sigma2_1_b_old <- sigma2_1_b_temp
    }
    
    
    # (8a) Update random effects for b
    reff <- sample_reff_b(tau, D, cens, beta_b_u_old, delta_dat, B_beta_b_dat,
                          b_dat, mu_dat, B, P_smooth, ind, time,
                          sigma2_b_us_old, sigma2_b_ua_old, sd_beta_b_u, 
                          acc_beta_b_u)
    beta_b_u_old <- reff$beta_u_old
    B_beta_b_u_dat <- reff$B_beta_u_dat
    b_dat <- exp(B_beta_b_dat + B_beta_b_u_dat)
    acc_beta_b_u <- reff$acc_beta_u
    
    
    # (8b) Update random effects variances: MH with log normal proposal
    ls_var <- sample_smooth_var(sigma2_b_ua_old, sigma2_b_us_old,
                                beta_b_u_old, P_smooth, n_ind)
    sigma2_b_ua_old <- ls_var$sigma2_ua_old
    sigma2_b_us_old <- ls_var$sigma2_us_old
    
    
    # (9a) Update random effects for mu
    reff <- sample_reff_mu(tau, D, cens, beta_mu_u_old, delta_dat, b_dat,
                           B_beta_mu_dat, mu_dat, B, P_smooth, ind, time,
                           sigma2_mu_us_old, sigma2_mu_ua_old, sd_beta_mu_u,
                           acc_beta_mu_u)
    beta_mu_u_old <- reff$beta_u_old
    B_beta_mu_u_dat <- reff$B_beta_u_dat
    mu_dat <- exp(B_beta_mu_dat + B_beta_mu_u_dat)
    acc_beta_mu_u <- reff$acc_beta_u
    
    
    # (9b) Update random effects variances: MH with log normal proposal
    ls_var <- sample_smooth_var(sigma2_mu_ua_old, sigma2_mu_us_old,
                                beta_mu_u_old, P_smooth, n_ind)
    sigma2_mu_ua_old <- ls_var$sigma2_ua_old
    sigma2_mu_us_old <- ls_var$sigma2_us_old
    
    
    
    # After burnin, save parameters in the chain
    if ( (iter > burnin) && (iter %% thin == 0) ){
      # This is the correction for the random effects integration: we need to 
      # compute the variance of the random effects as detailed in the Supplementary
      # Materials
      cov_reff_mu <- solve(diag(J) / sigma2_mu_ua_old + P_smooth / sigma2_mu_us_old)
      corr_term_gr_mu <- rep(0, nrow(Bgrid))
      corr_term_mu <- rep(0, T_max)
      cov_reff_b <- solve(diag(J) / sigma2_b_ua_old + P_smooth / sigma2_b_us_old)
      corr_term_gr_b <- rep(0, nrow(Bgrid))
      corr_term_b <- rep(0, T_max)
      for (k in 1:J){
        corr_term_gr_mu <- corr_term_gr_mu + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_mu[,k]))
        corr_term_mu <- corr_term_mu + rowSums(B_basis(1:T_max, knots)[,k] * 
                                                 t(t(B_basis(1:T_max, knots)) * cov_reff_mu[,k]))
        corr_term_gr_b <- corr_term_gr_b + rowSums(Bgrid[,k] * t(t(Bgrid) * cov_reff_b[,k]))
        corr_term_b <- corr_term_b + rowSums(B_basis(1:T_max, knots)[,k] * 
                                               t(t(B_basis(1:T_max, knots)) * cov_reff_b[,k]))
      }
      
      
      
      for (h1 in 1:d_j[1]){
        for (h2 in 1:d_j[2]){
          post_mean_mu[,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + 0.5 * corr_term_gr_mu)
          post_mean_b[,h1,h2,it] <- exp(Bgrid %*% beta_b_old[,h2] + 0.5 * corr_term_gr_b)
          
          for (i in 1:n_ind){
            post_ind_mu[,i,h1,h2,it] <- exp(Bgrid %*% beta_mu_old[,h1,h2] + Bgrid %*% beta_mu_u_old[i,])
            post_ind_b[,i,h1,h2,it] <- exp(Bgrid %*% beta_b_old[,h2] + Bgrid %*% beta_b_u_old[i,])
          }
        }
      }
      post_ind_delta[,,it] <- delta_old
      
      
      # Sample from the predictive distributions of response category and 
      # response times (so we can use predictive checks to evaluate goodness of fit)
      for (t in 1:T_max){
        for (x_temp in 1:d_j[1]){
          mu_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_mu_old[,x_temp,] + 
                                      0.5 * corr_term_mu[t]))
          b_temp <- as.numeric(exp(B_basis(t, knots) %*% beta_b_old + 
                                     0.5 * corr_term_b[t]))
          delta_temp <- mean(delta_old[x_temp,])
          pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                b_temp^2)
          pred_ans[t,x_temp,it] <- which.min(pred_temp)
          pred_time[t,x_temp,it] <- min(pred_temp)
          
          for (i in 1:n_ind){
            mu_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_mu_old[,x_temp,] + 
                                                               beta_mu_u_old[i,])))
            b_temp <- as.numeric(exp(B_basis(t, knots) %*% (beta_b_old + 
                                                              beta_b_u_old[i,])))
            delta_temp <- delta_old[x_temp,i]
            pred_temp <- delta_temp + LaplacesDemon::rinvgaussian(d_j[2], b_temp/mu_temp, 
                                                                  b_temp^2)
            pred_ans_ind[i,t,x_temp,it] <- which.min(pred_temp)
            pred_time_ind[i,t,x_temp,it] <- min(pred_temp)
          }
        }
      }
      
      
      # Save MCMC objects
      post_mean_delta[it,] <- rowMeans(delta_old)
      sigma2_mu_us[it] <- sigma2_mu_us_old
      sigma2_mu_ua[it] <- sigma2_mu_ua_old
      sigma2_b_us[it] <- sigma2_b_us_old
      sigma2_b_ua[it] <- sigma2_b_ua_old
      sigma2_1_mu[it] <- sigma2_1_mu_old
      sigma2_1_b[it] <- sigma2_1_b_old
      loglik_chain[it,] <- log_likelihood_ind(tau, mu_dat, b_dat, delta_dat, cens, D)
      # z[,,it] <- z_temp		
      
      it <- it + 1
    }
    setTxtProgressBar(pb, iter/Niter)
  }
  
  return(list(
    # 'Z' = z, 
    'post_mean_delta' = post_mean_delta, 
    'post_mean_mu' = post_mean_mu,
    'post_mean_b' = post_mean_b,
    'post_ind_delta' = post_ind_delta,
    'post_ind_mu' = post_ind_mu,
    'post_ind_b' = post_ind_b,
    'sigma2_mu_us' = sigma2_mu_us, 
    'sigma2_mu_ua' = sigma2_mu_ua,
    'sigma2_b_us' = sigma2_b_us, 
    'sigma2_b_ua' = sigma2_b_ua,
    'sigma2_1_mu' = sigma2_1_mu, 
    'sigma2_1_b' = sigma2_1_b, 
    'pred_ans' = pred_ans, 
    'pred_time' = pred_time,
    'pred_ans_ind' = pred_ans_ind, 
    'pred_time_ind' = pred_time_ind,
    'loglik' = loglik_chain
  ))
}

Try the lddmm package in your browser

Any scripts or data that you put into this service are public.

lddmm documentation built on June 7, 2023, 5:28 p.m.