R/EM.R

GEMfit <- function (net, cases, tol=sqrt(.Machine$double.eps), maxit=100,
                   Estepit=1, Mstepit=30,trace=FALSE,debugNo=maxit+1) {

  oldThreshold <- flog.threshold()
  ## Base case
  converged <- FALSE
  llike <- rep(NA,maxit+1)
  iter <- 1
  if (iter >= debugNo) flog.threshold(DEBUG)
  BuildAllTables(net)
  llike[iter] <- calcPnetLLike(net,cases)
  if (trace)
    flog.info("Iteration %d; Log Likelihood %e.",iter,llike[iter])

  while(!converged && iter <= maxit) {
    ## E-step
    calcExpTables(net, cases, Estepit=Estepit, tol=tol)
    #browser()

    ## M-step
    maxAllTableParams(net, Mstepit=Mstepit, tol=tol)
    #browser()
    ## Update parameters & convergence test
    iter <- iter + 1
    if (iter >= debugNo) flog.threshold(DEBUG)
    BuildAllTables(net)

    #browser()

    llike[iter] <- calcPnetLLike(net,cases)
    if (trace) flog.info("Iteration %d; Log likelihood %e.",iter,llike[iter])
    converged <- (abs(llike[iter]-llike[iter-1]) < tol)
  }
  flog.info("GEMfit %s after %d iterations.",
            ifelse(converged,"converged","did not converge"),iter)
  flog.threshold(oldThreshold)
  list(converged=converged,iter=iter,
       llikes=llike[1:iter])
}


### Build CPTs from parameters

PnodeBuildTable <- function (node) {
  UseMethod("PnodeBuildTable")
}
setGeneric("PnodeBuildTable")
  ## PnodeProbs(node) <- CPTtools::calcDPCTable(PnodeParentStates(node),PnodeStates(node),
  ##                        PnodeLnAlphas(node), PnodeBetas(node),
  ##                        PnodeRules(node),PnodeLink(node),
  ##                        PnodeLinkScale(node),PnodeParentTvals(node))
  ## PnodePriorWeight(node) <- GetPriorWeight(node)
  ## invisible(node)



calcPnetLLike <- function (net,cases){
  UseMethod("calcPnetLLike")
}
setGeneric("calcPnetLLike")

calcExpTables <- function (net, cases, Estepit=1, tol=sqrt(.Machine$double.eps)) {
  UseMethod("calcExpTables")
}
setGeneric("calcExpTables")


maxAllTableParams <- function (net, Mstepit=5,
                                       tol=sqrt(.Machine$double.eps),
                                       debug=FALSE) {
  Errs <- list()
  netnm <- PnetName(net)
  lapply(PnetPnodes(net),
         function (nd) {
           ndnm <- PnodeName(nd)
           flog.debug("Updating params for node %s in net %s.",ndnm,netnm)
           out <- flog.try(maxCPTParam(nd,Mstepit,tol),
                           context=sprintf("Updating params for node %s in net %s.",
                                           ndnm, netnm))
           if (is(out,'try-error')) {
             Errs <- c(Errs,out)
             if (debug) recover()
           }
         })
  if (length(Errs) >0L)
    stop("Errors encountered while updating parameters for ",netnm)
  invisible(net)
}
setGeneric("maxAllTableParams")


maxCPTParam <- function (node, Mstepit=5,
                                    tol=sqrt(.Machine$double.eps)) {
  ## Get the posterior pseudo-counts by multiplying each row of the
  ## node's CPT by its experience.
  ne <- PnodePostWeights(node)
  np <- as.CPA(PnodeProbs(node))
  npdim <- length(dim(np)) -1L
  if (npdim==0L || length(ne) == 1L) {
    counts <- np*ne
  } else {
    counts <- sweep(np,1L:npdim,ne,"*")
  }
  withCallingHandlers(
      est <- CPTtools::mapDPC(counts,ParentStates(node),NodeStates(node),
                    PnodeLnAlphas(node), PnodeBetas(node),
                    PnodeRules(node),PnodeLink(node),
                    PnodeLinkScale(node),PnodeQ(node),
                    PnodeParentTvals(node),
                    control=list(reltol=tol,maxit=Mstepit)
                    ),
      warning=muffler)
  PnodeLnAlphas(node) <- est$lnAlphas
  PnodeBetas(node) <- est$betas
  PnodeLinkScale(node) <- est$linkScale
  invisible(node)
}
setGeneric("maxCPTParam")
ralmond/Peanut documentation built on Sept. 19, 2023, 8:27 a.m.