R/HZINB_ind_two_gamma.R

Defines functions HZINB_ind_two_gamma grid_HZINB_two_gamma

Documented in grid_HZINB_two_gamma HZINB_ind_two_gamma

#' HGZIPS - HZINB with two gamma components (assuming independence)
#'
#' This \code{HZINB_ind_two_gamma} function finds hyperparameter estimates by implementing the Expectation-Maximization (EM) algorithm and hierarchical zero-inflated negative binomial model with two gamma components.
#'
#' @name HZINB_ind_two_gamma
#' @import pscl
#' @import stats
#' @import emdbook
#'
#' @param grid_a1 alpha1 value grid
#' @param grid_b1 beta1 value grid
#' @param grid_a2 alpha2 value grid
#' @param grid_b2 beta2 value grid
#' @param grid_pi pi value grid
#' @param grid_omega omega value grid
#' @param init_pi_k1 initial probability of each alpha1 value for implementing the EM algorithm
#' @param init_pi_l1 initial probability of each beta1 value for implementing the EM algorithm
#' @param init_pi_k2 initial probability of each alpha2 value for implementing the EM algorithm
#' @param init_pi_l2 initial probability of each beta2 value for implementing the EM algorithm
#' @param init_pi_m initial probability of each pi value for implementing the EM algorithm
#' @param init_pi_h initial probability of each omega value for implementing the EM algprithm
#' @param dataset a list of squashed datasets that include N_ij, E_ij and weights for each drug (j). This dataset list can be generated by the rawProcessing function in this package.
#' @param iteration number of EM algorithm iterations to run
#' @param Loglik whether to return the loglikelihood of each iteration or not (TRUE or FALSE)
#' @param zeroes A logical scalar specifying if zero counts should be included.
#' @param N_star the minimum Nij count size to be used for hyperparameter estimation. If zeroes are included in Nij vector, please set N_star = NULL
#'


# +-x +-x +-x +-x +-x +-x +-x +-x
#  assuming independence
# +-x +-x +-x +-x +-x +-x +-x +-x





#' @rdname HZINB_ind_two_gamma
#' @return \code{grid_HZINB} build a suitable grid of a_j, b_j, and omega_j for implementing HZINB
#' grid_HZINB
#' @export
#'
grid_HZINB_two_gamma = function(a_j, b_j, omega_j, K1, L1, K2, L2, M, H){

  grid = as.data.frame(matrix(NA, max(K1, L1, K2, L2, M, H), 6))
  colnames(grid) = c("a1_j", "b1_j", "a2_j", "b2_j", "pi_j", "omega_j")

  for (i in c(1:H)){
    grid[i,6] = exp(log(omega_j[which.min(omega_j)]) + (i)/(H + 1)*(log(omega_j[which.max(omega_j)])  - log(omega_j[which.min(omega_j)])))
  }

  for (i in c(1:K1)){
    grid[i,1] = exp(log(a_j[which.min(a_j)]) + (i - 1)/(K1 + 1)*(log(quantile(a_j, 0.99, names = FALSE, na.rm = TRUE)) - log(a_j[which.min(a_j)])))
  }

  for (i in c(1:K2)){
    grid[i,3] = exp(log(a_j[which.min(a_j)]) + (i - 1)/(K2 + 1)*(log(quantile(a_j, 0.99, names = FALSE, na.rm = TRUE)) - log(a_j[which.min(a_j)])))
  }

  for (i in c(1:L1)){
    grid[i,2] = exp(log(b_j[which.min(b_j)]) + (i - 1)/(L1 + 1)*(log(quantile(b_j, 0.99, names = FALSE, na.rm = TRUE)) - log(b_j[which.min(b_j)])))
  }

  for (i in c(1:L2)){
    grid[i,4] = exp(log(b_j[which.min(b_j)]) + (i - 1)/(L2 + 1)*(log(quantile(b_j, 0.99, names = FALSE, na.rm = TRUE)) - log(b_j[which.min(b_j)])))
  }

  grid[, 5] = seq(0.001, 0.99, 1/M)

  return(grid)

}


#' HZINB_independence
#' @rdname HZINB_ind_two_gamma
#' @return \code{HZINB_ind_two_gamma} a list of estimated probability of each alpha1, beta1, alpha2, beta2, pi, omega combination and their corresponding loglikelihood (optional)
#' \itemize{
#' \item \code{theta_EM} Estimate of hyperparameters for each EM iteration
#' \item \code{llh} logliklihood for each EM iteration (optional)
#' }
#' @export
#'
HZINB_ind_two_gamma = function(grid_a1, grid_a2, grid_b1, grid_b2, grid_pi, grid_omega, init_pi_k1, init_pi_l1, init_pi_k2, init_pi_l2, init_pi_m, init_pi_h, dataset, iteration, Loglik = FALSE, zeroes = FALSE, N_star = 1){
  ## EM algorithm

  for (k in 1:length(dataset)){
    if (!is.null(N_star)){
      dataset[[k]] = subset(dataset[[k]], N >= N_star)
    }
  }

  LSE_R <- function(vec){
    n.vec <- length(vec)
    vec <- sort(vec, decreasing = TRUE)
    Lk <- vec[1]
    for (k in 1:(n.vec-1)) {
      Lk <- max(vec[k+1], Lk) + log1p(exp(-abs(vec[k+1] - Lk)))
    }
    return(Lk)
  }

  if (zeroes == FALSE){
    K1 = length(grid_a1)
    L1 = length(grid_b1)
    K2 = length(grid_a2)
    L2 = length(grid_b2)
    M = length(grid_pi)
    H = length(grid_omega)
    #grid_omega = grid_omega

    #  if (!require('countreg')) install.packages('countreg'); library('countreg')

    all_combinations = expand.grid(grid_a1, grid_b1, grid_a2, grid_b2, grid_pi, grid_omega)
    colnames(all_combinations) = c("a1_j", "b1_j", "a2_j", "b2_j", "m_j", "h_j")


    ## EM algorithm

    # initialization
    N.EM <- iteration  # number of E-M iterations
    #iteration_50 = pi_klh[50,]

    pi_klh_K1 = matrix(NA, N.EM + 2, K1)
    pi_klh_L1 = matrix(NA, N.EM + 2, L1)
    pi_klh_K2 = matrix(NA, N.EM + 2, K2)
    pi_klh_L2 = matrix(NA, N.EM + 2, L2)
    pi_klh_m = matrix(NA, N.EM + 2, M)
    pi_klh_h = matrix(NA, N.EM + 2, H)

    pi_klh_K1[1,] = init_pi_k1
    pi_klh_L1[1,] = init_pi_l1
    pi_klh_K2[1,] = init_pi_k2
    pi_klh_L2[1,] = init_pi_l2
    pi_klh_m[1,] = init_pi_m
    pi_klh_h[1,] = init_pi_h

    pi_klh_all_combinations = as.data.frame(matrix(NA, K1*L1*K2*L2*M*H, 6))
    colnames(pi_klh_all_combinations) = c("K1", "L1", "K2", "L2", "M", "H")

    denominator = rep(NA, length(dataset))
    numerator = rep(NA, length(dataset))
      #ratio = rep(NA, ncol(N_ij))
    joint_probs = as.data.frame(matrix(NA, nrow(all_combinations), length(dataset)))

    for (j in 1:length(dataset)){
      for (m in 1:nrow(all_combinations)){
        nb1 = dnbinom(dataset[[j]]$N, size = all_combinations$a1_j[m], prob = all_combinations$b1_j[m]/(dataset[[j]]$E + all_combinations$b1_j[m]), log = TRUE)
        nb2 = dnbinom(dataset[[j]]$N, size = all_combinations$a2_j[m], prob = all_combinations$b2_j[m]/(dataset[[j]]$E + all_combinations$b2_j[m]), log = TRUE)
        ad_nb1 = log1p(-pnbinom(N_star - 1, size = all_combinations$a1_j[m], prob = all_combinations$b1_j[m]/(dataset[[j]]$E + all_combinations$b1_j[m]), log = TRUE))
        ad_nb2 = log1p(-pnbinom(N_star - 1, size = all_combinations$a2_j[m], prob = all_combinations$b2_j[m]/(dataset[[j]]$E + all_combinations$b2_j[m]), log = TRUE))
        joint_probs[m,j] = sum(dataset[[j]]$weight * (hgzips::log_sum_exp(log(all_combinations$m_j[m]) + (nb1 - ad_nb1), log(1 - all_combinations$m_j[m]) + (nb2 - ad_nb2))))
        }
    }



    ratio = as.data.frame(matrix(NA, K1*L1*K2*L2*M, length(dataset)))

    for (i in 1:(N.EM + 1)) {

      pi_klh_all_combinations = expand.grid(pi_klh_K1[i,], pi_klh_L2[i,], pi_klh_K2[i,], pi_klh_L2[i,], pi_klh_m[i,], pi_klh_h[i,])
      pi_klh_all_combinations$prod = apply(pi_klh_all_combinations[, 1:6], 1, prod)

      for (m in 1:nrow(all_combinations)){
        for (j in 1:length(dataset)){
          denominator[j] = hgzips::LSE_R(log(pi_klh_all_combinations$prod) + joint_probs[,j])
          numerator[j] = log(pi_klh_K1[i, which(grid_a1 == all_combinations$a1_j[m])]*pi_klh_L1[i, which(grid_b1 == all_combinations$b1_j[m])]*pi_klh_K2[i, which(grid_a2 == all_combinations$a2_j[m])]*pi_klh_L2[i, which(grid_b2 == all_combinations$b2_j[m])]*pi_klh_m[i, which(grid_pi == all_combinations$m_j[m])]*pi_klh_h[i, which(grid_omega == all_combinations$h_j[m])]) + joint_probs[m, j]
          ratio[m,j] = numerator[j] - denominator[j]
        }
      }

      # if (Loglik == TRUE){
      #    RATIO[[i]] = ratio
      #  } else {
      #    RATIO = NULL
      #  }

      all = cbind(all_combinations, ratio)

      for (iv in 1:nrow(ratio)){
        all$Sum[iv] = hgzips::LSE_R(ratio[iv,])
      }

      all$Sum = unlist(all$Sum)
      temp = subset(all, !is.na(Sum))
      overallSum = hgzips::LSE_R(temp$Sum)
      sum_a1_j = aggregate(temp$Sum, by = list(Category = temp$a1_j), FUN=hgzips::LSE_R)
      sum_b1_j = aggregate(temp$Sum, by = list(Category = temp$b1_j), FUN=hgzips::LSE_R)
      sum_a2_j = aggregate(temp$Sum, by = list(Category = temp$a2_j), FUN=hgzips::LSE_R)
      sum_b2_j = aggregate(temp$Sum, by = list(Category = temp$b2_j), FUN=hgzips::LSE_R)
      sum_m_j = aggregate(temp$Sum, by = list(Category = temp$m_j), FUN=hgzips::LSE_R)
      sum_h_j = aggregate(temp$Sum, by = list(Category = temp$h_j), FUN=hgzips::LSE_R)

      a1_id = NULL
      b1_id = NULL
      a2_id = NULL
      b2_id = NULL
      pi_id = NULL
      omega_id = NULL

      for (kk in 1:length(grid_a1)){
        a1_id = append(a1_id, ifelse(sum(grid_a1[kk] == sum_a1_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_b1)){
        b1_id = append(b1_id, ifelse(sum(grid_b1[kk] == sum_b1_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_a2)){
        a2_id = append(a2_id, ifelse(sum(grid_a2[kk] == sum_a2_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_b2)){
        b2_id = append(b2_id, ifelse(sum(grid_b2[kk] == sum_b2_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_pi)){
        pi_id = append(pi_id, ifelse(sum(grid_pi[kk] == sum_m_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_omega)){
        omega_id = append(omega_id, ifelse(sum(grid_omega[kk] == sum_h_j$Category) == 0, kk, next))
      }


      if (length(a1_id) == 0){
        pi_klh_K1[i + 1, ] = exp(sum_a1_j$x - overallSum)
      } else {
        pi_klh_K1[i + 1, ][-a1_id] = exp(sum_a1_j$x - overallSum)
        pi_klh_K1[i + 1, ][a1_id] = 0
      }

      if (length(b1_id) == 0){
        pi_klh_L1[i + 1, ] = exp(sum_b1_j$x - overallSum)
      } else {
        pi_klh_L1[i + 1, ][-b1_id] = exp(sum_b1_j$x - overallSum)
        pi_klh_L1[i + 1, ][b1_id] = 0
      }

      if (length(a2_id) == 0){
        pi_klh_K2[i + 1, ] = exp(sum_a2_j$x - overallSum)
      } else {
        pi_klh_K2[i + 1, ][-a2_id] = exp(sum_a2_j$x - overallSum)
        pi_klh_K2[i + 1, ][a2_id] = 0
      }

      if (length(b2_id) == 0){
        pi_klh_L2[i + 1, ] = exp(sum_b2_j$x - overallSum)
      } else {
        pi_klh_L2[i + 1, ][-b2_id] = exp(sum_b2_j$x - overallSum)
        pi_klh_L2[i + 1, ][b2_id] = 0
      }

      if (length(pi_id) == 0){
        pi_klh_m[i + 1, ] = exp(sum_m_j$x - overallSum)
      } else {
        pi_klh_m[i + 1, ][-pi_id] = exp(sum_m_j$x - overallSum)
        pi_klh_m[i + 1, ][pi_id] = 0
      }

      if (length(omega_id) == 0){
        pi_klh_h[i + 1, ] = exp(sum_h_j$x - overallSum)
      } else {
        pi_klh_h[i + 1, ][-omega_id] = exp(sum_h_j$x - overallSum)
        pi_klh_h[i + 1, ][omega_id] = 0
      }

    }

    result = list("pi_K1" = pi_klh_K1[-(N.EM + 2), ], "pi_L1" = pi_klh_L1[-(N.EM + 2), ], "pi_K2" = pi_klh_K2[-(N.EM + 2), ], "pi_L2" = pi_klh_L2[-(N.EM + 2), ], "pi_pi" = pi_klh_m[-(N.EM + 2), ], "pi_omega" = pi_klh_h[-(N.EM + 2), ])

  } else {

    K1 = length(grid_a1)
    L1 = length(grid_b1)
    K2 = length(grid_a2)
    L2 = length(grid_b2)
    M = length(grid_pi)
    H = length(grid_omega)

    #install.packages("countreg", repos="http://R-Forge.R-project.org")
    #library(countreg)

    all_combinations = expand.grid(grid_a1, grid_b1, grid_a2, grid_b2, grid_pi, grid_omega)
    colnames(all_combinations) = c("a1_j", "b1_j", "a2_j", "b2_j", "m_j", "h_j")

    # initialization
    N.EM <- iteration  # number of E-M iterations

    pi_klh_K1 = matrix(NA, N.EM + 2, K1)
    pi_klh_L1 = matrix(NA, N.EM + 2, L1)
    pi_klh_K2 = matrix(NA, N.EM + 2, K2)
    pi_klh_L2 = matrix(NA, N.EM + 2, L2)
    pi_klh_m = matrix(NA, N.EM + 2, M)
    pi_klh_H = matrix(NA, N.EM + 2, H)

    pi_klh_K1[1,] = init_pi_k1
    pi_klh_L1[1,] = init_pi_l1
    pi_klh_K2[1,] = init_pi_k2
    pi_klh_L2[1,] = init_pi_l2
    pi_klh_m[1,] = init_pi_m
    pi_klh_H[1,] = init_pi_h

    pi_klh_all_combinations = as.data.frame(matrix(NA, K1*L1*K2*L2*M*H, 6))
    colnames(pi_klh_all_combinations) = c("K1", "L1", "K2", "L2", "M", "H")

    denominator = rep(NA, length(dataset))
    numerator = rep(NA, length(dataset))
    #ratio = rep(NA, ncol(N_ij))
    joint_probs = as.data.frame(matrix(NA, nrow(all_combinations), length(dataset)))
    ratio = as.data.frame(matrix(NA, K1*L1*K2*L2*M*H, length(dataset)))


    for (j in 1:length(dataset)){
      print(j)
      for (m in 1:nrow(all_combinations)){

        zero_index = which(dataset[[j]]$N == 0)
        non_z_index = which(dataset[[j]]$N != 0)
        nb10 = dnbinom(0, size = all_combinations$a1_j[m], prob = all_combinations$b1_j[m]/(dataset[[j]]$E[zero_index] + all_combinations$b1_j[m]), log = TRUE)
        nb20 = dnbinom(0, size = all_combinations$a2_j[m], prob = all_combinations$b2_j[m]/(dataset[[j]]$E[zero_index] + all_combinations$b2_j[m]), log = TRUE)
        nb1 = dnbinom(dataset[[j]]$N[non_z_index], size = all_combinations$a1_j[m], prob = all_combinations$b1_j[m]/(dataset[[j]]$E[non_z_index] + all_combinations$b1_j[m]), log = TRUE)
        nb2 = dnbinom(dataset[[j]]$N[non_z_index], size = all_combinations$a2_j[m], prob = all_combinations$b2_j[m]/(dataset[[j]]$E[non_z_index] + all_combinations$b2_j[m]), log = TRUE)

        joint_probs[m,j] = sum(dataset[[j]]$weight[zero_index] * (hgzips::log_sum_exp(all_combinations$h_j[m],  log(1 - all_combinations$h_j[m]) + (hgzips::log_sum_exp(log(all_combinations$m_j[m]) + nb10, log(1 - all_combinations$m_j[m]) + nb20))))) +
                               sum(dataset[[j]]$weight[non_z_index] * (log(1 - all_combinations$h_j[m]) + hgzips::log_sum_exp(log(all_combinations$m_j[m]) + nb1, log(1 - all_combinations$h_j[m]) + nb2)))
      }
    }

    llh_j = rep(NA, length(dataset))
    llh = rep(NA, N.EM + 1)

    for (i in 1:(N.EM + 1)) {

      pi_klh_all_combinations = expand.grid(pi_klh_K1[i,], pi_klh_L2[i,], pi_klh_K2[i,], pi_klh_L2[i,], pi_klh_m[i,], pi_klh_H[i,])
      pi_klh_all_combinations$prod = apply(pi_klh_all_combinations[, 1:6], 1, prod)

      for (m in 1:nrow(all_combinations)){
        for (j in 1:length(dataset)){
          denominator[j] = hgzips::LSE_R(log(pi_klh_all_combinations$prod) + joint_probs[,j])
          numerator[j] = log(pi_klh_K1[i, which(grid_a1 == all_combinations$a1_j[m])]*pi_klh_L1[i, which(grid_b1 == all_combinations$b1_j[m])]*pi_klh_K2[i, which(grid_a2 == all_combinations$a2_j[m])]*pi_klh_L2[i, which(grid_b2 == all_combinations$b2_j[m])]*pi_klh_m[i, which(grid_pi == all_combinations$m_j[m])] * pi_klh_H[i, which(grid_omega == all_combinations$h_j[m])]) + joint_probs[m, j]
          ratio[m,j] = numerator[j] - denominator[j]
        }
      }

      # if (Loglik == TRUE){
      #    RATIO[[i]] = ratio
      #  } else {
      #    RATIO = NULL
      #  }

      all = cbind(all_combinations, ratio)

      for (iv in 1:nrow(ratio)){
        all$Sum[iv] = hgzips::LSE_R(ratio[iv,])
      }

      all$Sum = unlist(all$Sum)
      temp = subset(all, !is.na(Sum))
      overallSum = LSE_R(temp$Sum)

      sum_a1_j = aggregate(temp$Sum, by = list(Category = temp$a1_j), FUN=hgzips::LSE_R)
      sum_b1_j = aggregate(temp$Sum, by = list(Category = temp$b1_j), FUN=hgzips::LSE_R)
      sum_a2_j = aggregate(temp$Sum, by = list(Category = temp$a2_j), FUN=hgzips::LSE_R)
      sum_b2_j = aggregate(temp$Sum, by = list(Category = temp$b2_j), FUN=hgzips::LSE_R)
      sum_m_j = aggregate(temp$Sum, by = list(Category = temp$m_j), FUN=hgzips::LSE_R)
      sum_omega_j = aggregate(temp$Sum, by = list(Category = temp$h_j), FUN=hgzips::LSE_R)

      a1_id = NULL
      b1_id = NULL
      a2_id = NULL
      b2_id = NULL
      pi_id = NULL
      omega_id = NULL

      for (kk in 1:length(grid_a1)){
        a1_id = append(a1_id, ifelse(sum(grid_a1[kk] == sum_a1_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_b1)){
        b1_id = append(b1_id, ifelse(sum(grid_b1[kk] == sum_b1_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_a2)){
        a2_id = append(a2_id, ifelse(sum(grid_a2[kk] == sum_a2_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_b2)){
        b2_id = append(b2_id, ifelse(sum(grid_b2[kk] == sum_b2_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_pi)){
        pi_id = append(pi_id, ifelse(sum(grid_pi[kk] == sum_m_j$Category) == 0, kk, next))
      }

      for (kk in 1:length(grid_omega)){
        omega_id = append(omega_id, ifelse(sum(grid_omega[kk] == sum_omega_j$Category) == 0, kk, next))
      }



      if (length(a1_id) == 0){
        pi_klh_K1[i + 1, ] = exp(sum_a1_j$x - overallSum)
      } else {
        pi_klh_K1[i + 1, ][-a1_id] = exp(sum_a1_j$x - overallSum)
        pi_klh_K1[i + 1, ][a1_id] = 0
      }

      if (length(b1_id) == 0){
        pi_klh_L1[i + 1, ] = exp(sum_b1_j$x - overallSum)
      } else {
        pi_klh_L1[i + 1, ][-b1_id] = exp(sum_b1_j$x - overallSum)
        pi_klh_L1[i + 1, ][b1_id] = 0
      }

      if (length(a2_id) == 0){
        pi_klh_K2[i + 1, ] = exp(sum_a2_j$x - overallSum)
      } else {
        pi_klh_K2[i + 1, ][-a2_id] = exp(sum_a2_j$x - overallSum)
        pi_klh_K2[i + 1, ][a2_id] = 0
      }

      if (length(b2_id) == 0){
        pi_klh_L2[i + 1, ] = exp(sum_b2_j$x - overallSum)
      } else {
        pi_klh_L2[i + 1, ][-b2_id] = exp(sum_b2_j$x - overallSum)
        pi_klh_L2[i + 1, ][b2_id] = 0
      }

      if (length(pi_id) == 0){
        pi_klh_m[i + 1, ] = exp(sum_m_j$x - overallSum)
      } else {
        pi_klh_m[i + 1, ][-pi_id] = exp(sum_m_j$x - overallSum)
        pi_klh_m[i + 1, ][pi_id] = 0
      }

      if (length(omega_id) == 0){
        pi_klh_H[i + 1, ] = exp(sum_omega_j$x - overallSum)
      } else {
        pi_klh_H[i + 1, ][-omega_id] = exp(sum_omega_j$x - overallSum)
        pi_klh_H[i + 1, ][omega_id] = 0
      }




      if (Loglik == TRUE){

        for (j in 1:length(dataset)){

          pi_klh_all_combinations$logSum = rowSums(log(pi_klh_all_combinations[,1:7])) + joint_probs[,j]
          llh_j[j] = hgzips::LSE_R(pi_klh_all_combinations$logSum)

        }

        llh[i] = sum(llh_j)
        print(i)

      } else {
        llh = NULL
      }
    }


  result = list("pi_K1" = pi_klh_K1[-(N.EM + 2), ], "pi_L1" = pi_klh_L1[-(N.EM + 2), ], "pi_K2" = pi_klh_K2[-(N.EM + 2), ], "pi_L2" = pi_klh_L2[-(N.EM + 2), ], "pi_pi" = pi_klh_m[-(N.EM + 2), ], "Loglik" = llh)
}
  return(result)
}
sidiwang/hgzips documentation built on Jan. 19, 2021, 4:09 p.m.