R/estimate.R

#' estimate
#'
#' \code{estimate} would estimate the parameters and print the details of estimation.
#'
#' @export
#' @param X     list of vectors of observed chain x.
#' @param RNA   list of 0-1 vectors. 1 if next 3-base is stop codon.
#' @param E     vector of normalizing constant for observed chain x.
#' @param init  list of initial parameters. 3 variables: v, m, trans.
#' @param tol   scalar of tolerance. Default is 0.001.
#' @param detail logical; if TRUE, the details of estimation are printed.
#' @param ...  user can put a list of control parameters that would be passed to optim().
#'             E.g.: control=list(maxit=1e4, reltol=1e-9)
#' @return      A list containing:
#'                \itemize{
#'                  \item v: shape parameters
#'                  \item m: mean parameters
#'                  \item trans: transition probability c(rho_u, rho, delta)
#'                  \item iter, number of iterations used
#'                }
#'              Print the class: \code{res}
#' @examples
#' k <- 5  # number of transcripts from each group (see data ?uORF)
#' ii <- c()
#' for (i in 1:5){
#'   ii <- c(ii, ((i-1)*100+1):((i-1)*100+k))
#' }
#' print(ii)
#'
#' df <- uORF[ii]
#' X <- RNA <- list(); E <- c()
#' for (i in 1:length(df)){
#'   X[[i]] <-  df[[i]]$x
#'   RNA[[i]] <- df[[i]]$RNA
#'   E[i] <- df[[i]]$E
#' }
#'
#' init <- init_generate(df, cutoff=10, r=4)
#' res <- estimate(X, RNA, E, init, tol=0.001, detail=TRUE)
#' print(res)
#' print(df[[1]]$trans)
#'
#' m <- df[[1]]$m; v <- df[[1]]$v
#' m_hat <- res$m_hat
#' v_hat <- res$v_hat
#'
#' cbind(c(m[1:10],NA), c(m_hat[1:10],NA), c(m_hat[1:10],NA)-c(m[1:10],NA),
#'       m[11:21],      m_hat[11:21],      m_hat[11:21]-m[11:21])


estimate <- function(X, RNA, E, init, tol=0.001, detail=TRUE, ...){
  if (detail){
    require(crayon)
    cat(rep('-',49), '\n', sep='')
    fun_s_t <- Sys.time()
    cat('current time: ', toString(fun_s_t), '\n')
  }

  dig = -log10(tol)
  M <- 21
  num <- length(X)
  v <- matrix(init$v, 1, M)
  m <- matrix(init$m, 1, M)
  trans <- matrix(init$trans, 1, 3)

  if (detail){
    cat('Initial transition:',  trans[1,], '\n')
    cat('Initial v:')
    prmatrix(round(rbind(v[1, 2:10],v[1,12:20]), dig), rowlab=rep("",2), collab=rep("",10))
    cat('Initial m:')
    prmatrix(round(rbind(m[1, 2:10],m[1,12:20]), dig), rowlab=rep("",2), collab=rep("",10))
  }

  la <- lb <- L <- H <- list()
  iter <- 1
  while (1){
    if (detail){
      str_t <- Sys.time()
      cat('\n', rep('-',49), '\n', sep='')
      cat(rep('-',26), sprintf(' Entering iteration: %2d', iter), '\n', sep='')
    }

    beta <- v[iter,]/m[iter,]
    for (i in 1:num){
      la[[i]] <- forwardAlg(X[[i]], RNA[[i]], trans[iter,], v[iter,], beta, E[i])
      lb[[i]] <- backwardAlg(X[[i]], RNA[[i]], trans[iter,], v[iter,], beta, E[i])
      L[[i]]  <- computeL(la[[i]], lb[[i]])
      H[[i]]  <- computeH(X[[i]], RNA[[i]], trans[iter,], v[iter,], beta, E[i], la[[i]], lb[[i]])
    }

    if (detail){
      cat(rep('-',15), ' la, lb, L, H calculation finished', '\n', sep='')
    }

    curr_pars <- c(v[iter,], beta, trans[iter,])
    curr_Q <- qfun(curr_pars,X, E, L, H)

    #++++++++++++++++++++++++++++  update transition ++++++++++++++++++++++++++++#
    trans <- rbind(trans, update_trans(num, H))

    if (detail){
      cat(rep('-',9), sprintf(' Updating trans on iteration %2d finished', iter), '\n', sep='')
      print(round(unname(trans[iter+1,]), dig+2))
    }
    #+++++++++++++++++++++++++++++++  update v, m +++++++++++++++++++++++++++++++#
    vm <- c(v[iter,], m[iter,])
    newvm <- rep(0, 42)

    if (detail) cat(rep('-',24), ' Updating v, m for state:\n', sep='')

    for (k in 1:M){
      newvm[c(k,21+k)] <- update_mv(vm[c(k,21+k)], k, X, E, L, ...)[1:2]
      if (detail) cat(red$underline('',k))
    }
    if (detail) cat('\n')

    v <- rbind(v, newvm[1:M])
    m <- rbind(m, newvm[22:42])

    if (detail){
      cat(sprintf('Iter %d.   v:', iter))
      prmatrix(round(rbind(v[iter+1, 2:10],v[iter+1,12:20]), dig), rowlab=rep("",2), collab=rep("",10))
      cat(sprintf('Iter %d.   m:', iter) )
      prmatrix(round(rbind(m[iter+1, 2:10],m[iter+1,12:20]), dig), rowlab=rep("",2), collab=rep("",10))

      end_t <- Sys.time()
      cat(rep('-',36),
          sprintf(' Iteration %d took: ', iter),
          red(round(difftime(end_t, str_t, units = 'mins'), 1)),
          ' mins\n', sep='')
    }

    # converge?
    new_pars <- c(v[iter+1,], v[iter+1,]/m[iter+1,], trans[iter+1,])
    new_Q <- qfun(new_pars,X, E, L, H)

    if (detail){
      cat(sprintf('------ New Q = %.*f, Old Q = %.*f, increased by',
                  dig, -new_Q, dig, -curr_Q),
          red(round(curr_Q-new_Q, dig)), '\n')
    }

    if(new_Q > curr_Q){  # we compute Q with negative sign
      cat('Error in optim(): new_Q < curr_Q. \n')
      return(0)
    }

    ifelse(curr_Q-new_Q < tol, break, iter <- iter+1)
  }

  if (detail){
    cat(rep('*', 64),'\n', sep='')
    cat('Tolerance satisfied. Optimization finished','\n')
    fun_e_t <- Sys.time()
    cat('current time: ', toString(fun_e_t), '\n')
    cat(rep('-',36),
        ' Total time used: ',
        red(round(difftime(fun_e_t, fun_s_t, units = 'mins'), 1)),
        ' mins\n', sep='')
    cat(rep('*', 64),'\n', sep='')
  }

  res <- list(v_hat=v[iter+1,], m_hat=m[iter+1,], trans_hat=trans[iter+1,],
    v_history=v, m_history=m, trans_history=trans, iter=iter)
  class(res) <- c('estConverge', class(res))

  return(res)
}
shimlab/riboHMM2 documentation built on May 19, 2019, 6:23 p.m.