R/BFDA_VB.R

ldet <- function(x) {
  if(!is.matrix(x)) return(log(x))
  determinant(x,logarithm = TRUE)$modulus
}


#' @export
BayesFda_VB <- function(X,  H = 5, time = NULL,
                           alpha  = 1, sigma = 0,
                           lambda = 1, a_sigma=1, b_sigma=1, inner_knots = min(round(NCOL(X)/4),40), degree=3,  
                           maxiter=1000, tol=1e-7, clust_init = 3, verbose = FALSE) {
  
  
  if(clust_init >= H) stop("The clust_init value is greater than the upper bound H")
  
  # Fixed quantities
  n <- NROW(X)
  TT <- NCOL(X)
  
  if(is.null(time)) {
    time <- 1:TT
    warning("The time dimension was not supplied and equally-spaced intervals are assumed.")
  }
  
  if(length(time) != TT) {
    stop("The length of time and the dimension of X must coincide.")
  }
  
  colnames(X)  <- time
  X            <- as.matrix(X) # If not already done, convert it onto a matrix
  index_not_NA <- !is.na(X)
  nTT          <- sum(index_not_NA) # Number of observed values
  n_time       <- colSums(!is.na(X)) # How many not-missing values per column?
  n_subject    <- rowSums(!is.na(X))
  
  # Smoothing splines settings
  xl    <- min(time); xr <- max(time); dx <- (xr - xl) / (inner_knots-1)
  knots <- seq(xl - degree * dx, xr + degree * dx, by = dx) # Knots placement
  
  B     <- spline.des(knots, time, degree + 1, 0 * time, outer.ok=TRUE)$design
  
  # VECTORIZATION
  Y  <- c(t(X))[c(t(index_not_NA))] # Columns are ordered by subject
  BB <- B; for(i in 2:n) {BB <- rbind(BB,B)}
  BB <- BB[c(t(index_not_NA)),] # This select only the relevant values
  
  # Penalty term
  D   <- diag(NCOL(B))
  DtD <- crossprod(D)
  
  
  traceBSigmaB <-  rowSums(B %*% diag(1/lambda,NCOL(B))  * B)
  # Verbose settings and output
  verbose_step <- 1
  
  rho            <- matrix(0, n, H)
  predH          <- matrix(0, H, TT)
  E_residuals    <- array(0, c(H, n, TT)) 
  mu_tilde       <- matrix(0, H, NCOL(B))
  Sigma_tilde    <- array(0, c(H, NCOL(B), NCOL(B)))
  v_alpha        <- numeric(H-1)
  v_beta         <- numeric(H-1)
  E_logv <- E_log1v <- numeric(H-1)
  E_logp         <- numeric(H)
  
  # Initialization of different pieces of the lowerbound
  lower1 <- lower4 <- lower7 <- lower8 <- 0
  lower3 <- lower5 <- numeric(H)
  lower2 <- lower6 <- numeric(H - 1)
  lowerbound <- -Inf
  
  
  # The variance is shared. Notice that it is updated below.
  a_sigma2_tilde <- b_sigma2_tilde <- 1
  
  # Initialization settings
  if(verbose){ cat("Pre-allocating observations into groups...\n")}
  pre_clust       <- FSplines_clust(X=X,  H=clust_init, time = time, lambda = lambda, a_sigma=a_sigma, b_sigma=b_sigma, 
                                    inner_knots=inner_knots, degree=degree, dif = 0, verbose=FALSE, prediction = TRUE)
  G               <- as.factor(pre_clust$cluster)
  G   <- factor(as.numeric(factor(G,levels(G)[order(table(G),decreasing = TRUE)])),levels=1:clust_init)
  
  # Prediction and probabilities are obtained as a pre-processing
  rho <- cbind(model.matrix(rep(1,n) ~ G - 1),matrix(0,n,H - clust_init))
  rho <- rho + 1
  rho <- rho/rowSums(rho)
  sums_rho <- colSums(rho)
  
  # This is a "raw" estimate for the parameters of the variance. The weights are not used at all
  a_sigma2_tilde <- a_sigma + nTT/2
  b_sigma2_tilde <- b_sigma + sum((X - pre_clust$prediction)^2,na.rm=TRUE)/2
  
  E_tau  <- a_sigma2_tilde / b_sigma2_tilde # Expected value of tau
  E_ltau <- digamma(a_sigma2_tilde) - log(b_sigma2_tilde)
  
  # Starting the Variational Algorithm
  if(verbose){ cat("Starting the Variational Bayes algorithm...\n")}
  for (r in 1:maxiter) {
    
    # Step 2 - 3: performed within the cluster.
    for (h in 1:H) {
      
      if(sum(rho[,h]) < 1e-12){
       
        Sigma_tilde[h,,] <- diag(1/lambda,NCOL(B)) 
        mu_tilde[h,]     <- numeric(NCOL(B)) 
        predH[h,]        <- numeric(TT)
        
        # Allocating the residuals 
        for(i in 1:n) {E_residuals[h, i,] <- X[i,]^2 + traceBSigmaB }
      } else {
        weights   <- rep(rho[,h],n_subject) # This associate to each subject / observation a weight
      
        Sigma_tilde[h, , ] <- solve(E_tau * crossprod(BB * sqrt(weights)) + lambda*DtD)
        mu_tilde[h, ]      <- Sigma_tilde[h, , ] %*% (crossprod(BB * weights * E_tau, Y))
      
        # Predictions (the same for elements in the same cluster)
        predH[h,]        <- as.numeric(B%*%mu_tilde[h,])
        
        # How this step is coded is awful to run and to read
        for(i in 1:n) {E_residuals[h, i,] <- X[i,]^2 - 2 * X[i,] * predH[h,] + predH[h, ]^2 + rowSums(B %*% Sigma_tilde[h, , ] * B)}
        
      }
            
      # Updating the stick-breaking weights - take the expected value of their logarithms
      if(h == 1){
        v_alpha[1]       <- 1 - sigma + sums_rho[1]
        v_beta[1]        <- alpha + sigma*h + sum(sums_rho[2:H])
        E_logv[1]        <- digamma(v_alpha[1]) - digamma(v_alpha[1] + v_beta[1])
        E_log1v[1]       <- digamma(v_beta[1]) - digamma(v_alpha[1] + v_beta[1])
        E_logp[1]        <- E_logv[1]
      } else if(h < H && h > 1){
        v_alpha[h]       <- 1 - sigma + sums_rho[h]
        v_beta[h]        <- alpha + sigma*h + sum(sums_rho[(h+1):H])
        E_logv[h]        <- digamma(v_alpha[h]) - digamma(v_alpha[h] + v_beta[h])
        E_log1v[h]       <- digamma(v_beta[h])  - digamma(v_alpha[h] + v_beta[h])
        E_logp[h]        <- E_logv[h] + sum(E_log1v[1:(h-1)])
      } else {E_logp[H] <- sum(E_log1v[1:(H-1)])}
    }
    
    # Variance
    a_sigma2_tilde <- a_sigma + 0.5*nTT
    b_sigma2_tilde <- b_sigma + 0.5*sum(rho*t(apply(E_residuals,c(1,2),function(x) sum(x,na.rm=TRUE))))
    E_tau  <- a_sigma2_tilde / b_sigma2_tilde               # Expected value of tau
    E_ltau <- digamma(a_sigma2_tilde) - log(b_sigma2_tilde) # Expected value of its logarithm
    
    # Step  - Cluster allocation
    rho      <- rho_update(X, E_logp, E_residuals, E_tau)
    rho      <- rho + 1e-16 # This step is needed for numerical reasons!
    rho      <- rho / rowSums(rho)
    sums_rho <- colSums(rho)
    
    # Lower-bound for the likelihood
    lower1 <- sum(sums_rho*E_logp) + 0.5*nTT*E_ltau - 0.5*E_tau*sum(rho*t(apply(E_residuals,c(1,2),function(x) sum(x,na.rm=TRUE))))
    
    # Lowerbound for the prior
    lower4 <- - b_sigma * E_tau + (a_sigma - 1) * E_ltau
    
    # Lower-bound for the variational approximation
    lower7 <-  a_sigma2_tilde * log(b_sigma2_tilde) - lgamma(a_sigma2_tilde) - b_sigma2_tilde * E_tau +  (a_sigma2_tilde - 1) * E_ltau
    
    # Lower-bound for the discrete distribution
    lower8 <- sum(rho*log(rho))
    
    for (h in 1:H) {
      # Lower bound for the priors 
      lower3[h]    <-  - 0.5*lambda*(crossprod(mu_tilde[h,]) + sum(diag(Sigma_tilde[h,,])))
      
      # Lower bound variational distribution
      lower5[h]    <-  - 0.5*ldet(Sigma_tilde[h,,])
  
      if(h < H){
        # Lower bound for the prior distribution
        lower2[h] <- - sigma*E_logv[h] + (alpha + sigma*h - 1)*E_log1v[h] - lbeta(1 - sigma, alpha + sigma*h)
        
        # Lower-bound for the variational distribution
        lower6[h] <- (v_alpha[h] - 1)*E_logv[h] + (v_beta[h] - 1)*E_log1v[h] - lbeta(v_alpha[h], v_beta[h])
      }
  }
  # Convergence checks
  lowerbound_new <- lower1 + sum(lower2) + sum(lower3) + lower4 - sum(lower5) - sum(lower6) -  lower7 - lower8
    
    # # Break the loop at convergence
  if (lowerbound_new - lowerbound < tol) {
      if (verbose) 
        cat(paste("Convergence reached after", r, "iterations."))
      break
    }
    
    # Otherwise continue
    lowerbound <- lowerbound_new

    # Display status
    if (verbose) {
      if (r%%verbose_step == 0)
       cat(paste("Lower-bound: ", round(lowerbound, 7), ", iteration: ", r, "\n", sep = ""))
    }
  }
  
  if (r == maxiter) 
    warning(paste("Convergence has not been reached after", r, "iterations."))
  
  # Output
  cluster <- as.numeric(factor(apply(rho, 1, which.max)))
  # Individual curves
  pred <- matrix(0,n,TT)
  for(i in 1:n){pred[i,] <- colSums(predH*rho[i,])}
    
  # table(cluster)
  out <- list(mu_tilde = mu_tilde, Sigma_tilde = Sigma_tilde, 
                    a_tilde = a_sigma2_tilde, b_tilde = b_sigma2_tilde, pred=pred,
              predH=predH, cluster = cluster, rho = rho, lowerbound = lowerbound)
  attr(out,"class") <- "BFDA_VB"
  return(out)
}

#' @export
BayesBFda_VB <- function(X,  H = 5, B= NULL, time = NULL,
                        alpha  = 1, sigma = 0,
                        lambda = 1, a_sigma=1, b_sigma=1, 
                        maxiter=1000, tol=1e-7, clust_init = 3,  verbose = FALSE) {
  
  
  if(clust_init >= H) stop("The clust_init value is greater than the upper bound H")
  
  # Fixed quantities
  n <- NROW(X)
  TT <- NCOL(X)
  
  if(is.null(time)) {
    time <- 1:TT
    warning("The time dimension was not supplied and equally-spaced intervals are assumed.")
  }
  
  if(length(time) != TT) {
    stop("The length of time and the dimension of X must coincide.")
  }
  
  if(length(time) != nrow(B)) {
    stop("The length of time and the dimension of B must coincide.")
  }
  
  colnames(X)  <- time
  X            <- as.matrix(X) # If not already done, convert it onto a matrix
  index_not_NA <- !is.na(X)
  nTT          <- sum(index_not_NA) # Number of observed values
  n_time       <- colSums(!is.na(X)) # How many not-missing values per column?
  n_subject    <- rowSums(!is.na(X))
  
  # VECTORIZATION
  Y  <- c(t(X))[c(t(index_not_NA))] # Columns are ordered by subject
  BB <- B; for(i in 2:n) {BB <- rbind(BB,B)}
  BB <- BB[c(t(index_not_NA)),] # This select only the relevant values
  
  # Penalty term
  D   <- diag(NCOL(B))
  DtD <- crossprod(D)
  
  traceBSigmaB <-  rowSums(B %*% diag(1/lambda,NCOL(B))  * B)
  # Verbose settings and output
  verbose_step <- 1
  
  rho            <- matrix(0, n, H)
  predH          <- matrix(0, H, TT)
  E_residuals    <- array(0, c(H, n, TT)) 
  mu_tilde       <- matrix(0, H, NCOL(B))
  Sigma_tilde    <- array(0, c(H, NCOL(B), NCOL(B)))
  v_alpha        <- numeric(H-1)
  v_beta         <- numeric(H-1)
  E_logv <- E_log1v <- numeric(H-1)
  E_logp         <- numeric(H)
  
  # Initialization of different pieces of the lowerbound
  lower1 <- lower4 <- lower7 <- lower8 <- 0
  lower3 <- lower5 <- numeric(H)
  lower2 <- lower6 <- numeric(H - 1)
  lowerbound <- -Inf
  
  
  # The variance is shared. Notice that it is updated below.
  a_sigma2_tilde <- b_sigma2_tilde <- 1
  
  # Initialization settings
  if(verbose){ cat("Pre-allocating observations into groups...\n")}
  pre_clust       <- FB_clust(X=X,  H=clust_init, B = B, time = time, verbose=FALSE, prediction = TRUE)

  G   <- as.factor(pre_clust$cluster)
  G   <- factor(as.numeric(factor(G,levels(G)[order(table(G),decreasing = TRUE)])),levels=1:clust_init)
  
  # Prediction and probabilities are obtained as a pre-processing
  rho <- cbind(model.matrix(rep(1,n) ~ G - 1), matrix(0,n,H - clust_init))
  rho <- rho + 1/H
  rho <- rho/rowSums(rho)
  sums_rho <- colSums(rho)
  
  # This is a "raw" estimate for the parameters of the variance. The weights are not used at all
  a_sigma2_tilde <- a_sigma + nTT/2
  b_sigma2_tilde <- b_sigma + sum((X - pre_clust$prediction)^2,na.rm=TRUE)/2
  
  E_tau  <- a_sigma2_tilde / b_sigma2_tilde # Expected value of tau
  E_ltau <- digamma(a_sigma2_tilde) - log(b_sigma2_tilde)
  
  # Starting the Variational Algorithm
  if(verbose){ cat("Starting the Variational Bayes algorithm...\n")}
  for (r in 1:maxiter) {
    
    # Step 2 - 3: performed within the cluster.
    for (h in 1:H) {
      
      if(sum(rho[,h]) < 1e-12){
        
        Sigma_tilde[h,,] <- diag(1/lambda,NCOL(B)) 
        mu_tilde[h,]     <- numeric(NCOL(B)) 
        predH[h,]        <- numeric(TT)
        
        # Allocating the residuals 
        for(i in 1:n) {E_residuals[h, i,] <- X[i,]^2 + traceBSigmaB }
      } else {
        weights   <- rep(rho[,h],n_subject) # This associate to each subject / observation a weight
        
        Sigma_tilde[h, , ] <- solve(E_tau * crossprod(BB * sqrt(weights)) + lambda*DtD)
        mu_tilde[h, ]      <- Sigma_tilde[h, , ] %*% (crossprod(BB * weights * E_tau, Y))
        
        # Predictions (the same for elements in the same cluster)
        predH[h,]        <- as.numeric(B%*%mu_tilde[h,])
        
        # How this step is coded is awful to run and to read
        for(i in 1:n) {E_residuals[h, i,] <- X[i,]^2 - 2 * X[i,] * predH[h,] + predH[h, ]^2 + rowSums(B %*% Sigma_tilde[h, , ] * B)}
        
      }
      
      # Updating the stick-breaking weights - take the expected value of their logarithms
      if(h == 1){
        v_alpha[1]       <- 1 - sigma + sums_rho[1]
        v_beta[1]        <- alpha + sigma*h + sum(sums_rho[2:H])
        E_logv[1]        <- digamma(v_alpha[1]) - digamma(v_alpha[1] + v_beta[1])
        E_log1v[1]       <- digamma(v_beta[1]) - digamma(v_alpha[1] + v_beta[1])
        E_logp[1]        <- E_logv[1]
      } else if(h < H && h > 1){
        v_alpha[h]       <- 1 - sigma + sums_rho[h]
        v_beta[h]        <- alpha + sigma*h + sum(sums_rho[(h+1):H])
        E_logv[h]        <- digamma(v_alpha[h]) - digamma(v_alpha[h] + v_beta[h])
        E_log1v[h]       <- digamma(v_beta[h])  - digamma(v_alpha[h] + v_beta[h])
        E_logp[h]        <- E_logv[h] + sum(E_log1v[1:(h-1)])
      } else {E_logp[H] <- sum(E_log1v[1:(H-1)])}
    }
    
    # Variance
    a_sigma2_tilde <- a_sigma + 0.5*nTT
    b_sigma2_tilde <- b_sigma + 0.5*sum(rho*t(apply(E_residuals,c(1,2),function(x) sum(x,na.rm=TRUE))))
    E_tau  <- a_sigma2_tilde / b_sigma2_tilde               # Expected value of tau
    E_ltau <- digamma(a_sigma2_tilde) - log(b_sigma2_tilde) # Expected value of its logarithm
    
    # Step  - Cluster allocation
    rho      <- rho_update(X, E_logp, E_residuals, E_tau)
    rho      <- rho + 1e-16 # This step is needed for numerical reasons!
    rho      <- rho / rowSums(rho)
    sums_rho <- colSums(rho)
    
    # Lower-bound for the likelihood
    lower1 <- sum(sums_rho*E_logp) + 0.5*nTT*E_ltau - 0.5*E_tau*sum(rho*t(apply(E_residuals,c(1,2),function(x) sum(x,na.rm=TRUE))))
    
    # Lowerbound for the prior
    lower4 <- - b_sigma * E_tau + (a_sigma - 1) * E_ltau
    
    # Lower-bound for the variational approximation
    lower7 <-  a_sigma2_tilde * log(b_sigma2_tilde) - lgamma(a_sigma2_tilde) - b_sigma2_tilde * E_tau +  (a_sigma2_tilde - 1) * E_ltau
    
    # Lower-bound for the discrete distribution
    lower8 <- sum(rho*log(rho))
    
    for (h in 1:H) {
      # Lower bound for the priors 
      lower3[h]    <-  - 0.5*lambda*(crossprod(mu_tilde[h,]) + sum(diag(Sigma_tilde[h,,])))
      
      # Lower bound variational distribution
      lower5[h]    <-  - 0.5*ldet(Sigma_tilde[h,,])
      
      if(h < H){
        # Lower bound for the prior distribution
        lower2[h] <- - sigma*E_logv[h] + (alpha + sigma*h - 1)*E_log1v[h] - lbeta(1 - sigma, alpha + sigma*h)
        
        # Lower-bound for the variational distributions
        lower6[h] <- (v_alpha[h] - 1)*E_logv[h] + (v_beta[h] - 1)*E_log1v[h] - lbeta(v_alpha[h], v_beta[h])
      }
    }
    # Convergence checks
    lowerbound_new <- lower1 + sum(lower2) + sum(lower3) + lower4 - sum(lower5) - sum(lower6) -  lower7 - lower8
    
    # # Break the loop at convergence
    if (lowerbound_new - lowerbound < tol) {
      if (verbose) 
        cat(paste("Convergence reached after", r, "iterations."))
      break
    }
    
    # Otherwise continue
    lowerbound <- lowerbound_new
    
    # Display status
    if (verbose) {
      if (r%%verbose_step == 0)
        cat(paste("Lower-bound: ", round(lowerbound, 7), ", iteration: ", r, "\n", sep = ""))
    }
  }
  
  if (r == maxiter) 
    warning(paste("Convergence has not been reached after", r, "iterations."))
  
  # Output
  cluster <- as.numeric(factor(apply(rho, 1, which.max)))
  # Individual curves
  pred <- matrix(0,n,TT)
  for(i in 1:n){pred[i,] <- colSums(predH*rho[i,])}
  
  # table(cluster)
  out <- list(mu_tilde = mu_tilde, Sigma_tilde = Sigma_tilde, 
              a_tilde = a_sigma2_tilde, b_tilde = b_sigma2_tilde, pred=pred,
              predH=predH, cluster = cluster, rho = rho, lowerbound = lowerbound)
  attr(out,"class") <- "BFDA_VB"
  return(out)
}
tommasorigon/BayesFda documentation built on May 8, 2019, 3:14 a.m.