R/EM_halfsib.R

Defines functions EM_fit.halfsibdata EM_fit

#' @export
EM_fit <- function(data, ...) {
  UseMethod("EM_fit", data)
}

#' @export
EM_fit.halfsibdata <- function(data,
                               prior_covs = list(
                                 ind  = diag(1, nrow = data$dims$q),
                                 dam  = diag(1, nrow = data$dims$q),
                                 sire = diag(1, nrow = data$dims$q)
                               ),
                               method     = c("ML", "REML", "ML_nofix"),
                               flat_sire  = (method == "ML_nofix"),
                               max_iter   = 1000,
                               err.tol    = 1e-6,
                               verbose    = FALSE) {

  method <- match.arg(method)  

  if(!is.dam_balanced(data)) {
    data <- include_unobs_dams(data)
  }
  
  I <- data$dims$I
  J <- data$dims$J
  K <- data$dims$K
  
  n_missing <- K - data$n.observed$inds

  if(method == "ML") {
    # Might have 0 observed individuals per dam, in which case the mean should be 0.
    mu <- colMeans(data$dam_sums / pmax(data$n.observed$inds, 1))
  } else {
    mu <- rep(0, data$dims$q)
  }
  
  for(iter in 1:max_iter) {

    ccov_raw  <- cond_cov_counts(prior_covs, data)
    ccov      <- cond_cov(ccov_raw, data)
    
    if(method == "REML") {
      ccov_reml <- cond_cov_reml(prior_covs, ccov_raw, ccov, data)
      cmean     <- cond_mean_reml(prior_covs, ccov_reml, data)
      mu        <- cmean$grand
    } else {
      cmean <- cond_mean(prior_covs, ccov, data, prior_mean = mu)
    }

    balanced_data <- balance(data, cmean, globmean = mu)

    if(method == "ML") {
      mu <- colMeans(balanced_data$dam_sums) / data$dims$K
    }
    
    # Naive sum-of-squares M-matrices
    ss_base <- ss_mats.halfsibdata(balanced_data)

    M_E <- ss_base$m_ind + sum(n_missing) * (1 - 1/K) * prior_covs$ind * 1/(I * J * (K-1))
    M_B <- ss_base$m_dam + sum(n_missing) * (1 - 1/J) * prior_covs$ind * 1/(I * (J-1) * K)
    M_A <- ss_base$m_sire + sum(n_missing) * (1 - 1/I) * prior_covs$ind * 1/((I-1) * J * K)

    sire_names <- names(data$sires)
    dam_names  <- split(names(data$sires), data$sires)

    if(method == "REML") {
      ccomp_reml <- function(sire1, sire2, dam1, dam2) {
        ccov_reml("group", "group", NA, NA) +
          ccov_reml("group", sire2, NA, "group") + ccov_reml(sire1, "group", "group", NA) +
          ccov_reml("group", sire2, NA, dam2) + ccov_reml(sire1, "group", dam1, NA) +
          ccov_reml(sire1, sire2, "group", "group") +
          ccov_reml(sire1, sire2, "group", dam2) + ccov_reml(sire1, sire2, dam1, "group") +
          ccov_reml(sire1, sire2, dam1, dam2)
      }

      ccomp <- function(sire, dam1, dam2) {
        ccov_reml(sire, sire, dam1, dam2)
      }
    } else {
      ccomp <- function(sire, dam1, dam2) {
        ccov(sire, "group", "group") +
          ccov(sire, "group", dam2) + ccov(sire, dam1, "group") +
          ccov(sire, dam1, dam2)
      }
    }
    
    for(sire in sire_names) {
      for(dam in dam_names[[sire]]) {
        M_E <- M_E +
          n_missing[[dam]] * (1 - n_missing[[dam]]/K) * ccomp(sire, dam, dam) *
          1/(I * J * (K-1))
        
        M_B <- M_B +
          n_missing[[dam]]^2 * ccomp(sire, dam, dam) *
          1/K * 1 /(I * (J-1)) 

        for(dam2 in dam_names[[sire]]) {
          M_B <- M_B -
            n_missing[[dam]] * n_missing[[dam2]] * ccomp(sire, dam, dam2) *
            1/(J * K) * 1/(I * (J-1))

          M_A <- M_A +
            n_missing[[dam]] * n_missing[[dam2]] * ccomp(sire, dam, dam2) *
            (1 - (method != "REML")/I) * 1/(J * K) * 1/(I-1)
        }
      }
    }

    if(method == "REML") {
      for(sire in sire_names) {
        for(sire2 in sire_names) {
          for(dam in dam_names[[sire]]) {
            for(dam2 in dam_names[[sire2]]) {
              M_A <- M_A -
                n_missing[[dam]] * n_missing[[dam2]] * ccomp_reml(sire, sire2, dam, dam2) *
                1/(I * J * K) * 1/(I-1)
            }
          }
        }
      }
    }

    if (I == 1) {
      M_A <- matrix(0, nrow = data$dims$q, ncol = data$dims$q)
    }
    
    # E step
    if (method %in% c("ML", "ML_nofix")) {
      E_method <- "ML"
    } else {
      E_method <- method
    }
    curr_primal <- stepreml_2way_mat(M_E, K, M_B, J, M_A, I,
                                     log_crit = "never",
                                     method = E_method)

    prior_covs <- list(
      ind = curr_primal$S1,
      dam = curr_primal$S2,
      sire = curr_primal$S3
    )
    
    if(iter > 1) {
      err <- mat_err(prev_primal, curr_primal, list(I * J * (K-1), I * (J-1), I - (E_method == "ML")))

      if(isTRUE(verbose)) {
        print("------------------")
        print(paste("Iter:", iter))
        print(paste("Err: ", err))
      }
      
      if(err < err.tol) {break}
    } else {
      err <- NA
    }
    
    prev_primal <- curr_primal
  }

  out_covs <- rlang::list2(
    !!data$level_names$ind_name  := prior_covs$ind,
    !!data$level_names$dam_name  := prior_covs$dam,
    !!data$level_names$sire_name := prior_covs$sire
  )

  if(isTRUE(flat_sire)) {
    out_covs[[3]] <- NULL
  }
  
  lapply(
    out_covs,
    \(A) {dimnames(A) <- dimnames(M_E); A}
  )
}
damian-t-p/halfsibdesign documentation built on March 14, 2023, 4:55 a.m.