GEMfit <- function (net, cases, tol=sqrt(.Machine$double.eps), maxit=100,
Estepit=1, Mstepit=30,trace=FALSE,debugNo=maxit+1) {
oldThreshold <- flog.threshold()
## Base case
converged <- FALSE
llike <- rep(NA,maxit+1)
iter <- 1
if (iter >= debugNo) flog.threshold(DEBUG)
BuildAllTables(net)
llike[iter] <- calcPnetLLike(net,cases)
if (trace)
flog.info("Iteration %d; Log Likelihood %e.",iter,llike[iter])
while(!converged && iter <= maxit) {
## E-step
calcExpTables(net, cases, Estepit=Estepit, tol=tol)
#browser()
## M-step
maxAllTableParams(net, Mstepit=Mstepit, tol=tol)
#browser()
## Update parameters & convergence test
iter <- iter + 1
if (iter >= debugNo) flog.threshold(DEBUG)
BuildAllTables(net)
#browser()
llike[iter] <- calcPnetLLike(net,cases)
if (trace) flog.info("Iteration %d; Log likelihood %e.",iter,llike[iter])
converged <- (abs(llike[iter]-llike[iter-1]) < tol)
}
flog.info("GEMfit %s after %d iterations.",
ifelse(converged,"converged","did not converge"),iter)
flog.threshold(oldThreshold)
list(converged=converged,iter=iter,
llikes=llike[1:iter])
}
### Build CPTs from parameters
PnodeBuildTable <- function (node) {
UseMethod("PnodeBuildTable")
}
setGeneric("PnodeBuildTable")
## PnodeProbs(node) <- CPTtools::calcDPCTable(PnodeParentStates(node),PnodeStates(node),
## PnodeLnAlphas(node), PnodeBetas(node),
## PnodeRules(node),PnodeLink(node),
## PnodeLinkScale(node),PnodeParentTvals(node))
## PnodePriorWeight(node) <- GetPriorWeight(node)
## invisible(node)
calcPnetLLike <- function (net,cases){
UseMethod("calcPnetLLike")
}
setGeneric("calcPnetLLike")
calcExpTables <- function (net, cases, Estepit=1, tol=sqrt(.Machine$double.eps)) {
UseMethod("calcExpTables")
}
setGeneric("calcExpTables")
maxAllTableParams <- function (net, Mstepit=5,
tol=sqrt(.Machine$double.eps),
debug=FALSE) {
Errs <- list()
netnm <- PnetName(net)
lapply(PnetPnodes(net),
function (nd) {
ndnm <- PnodeName(nd)
flog.debug("Updating params for node %s in net %s.",ndnm,netnm)
out <- flog.try(maxCPTParam(nd,Mstepit,tol),
context=sprintf("Updating params for node %s in net %s.",
ndnm, netnm))
if (is(out,'try-error')) {
Errs <- c(Errs,out)
if (debug) recover()
}
})
if (length(Errs) >0L)
stop("Errors encountered while updating parameters for ",netnm)
invisible(net)
}
setGeneric("maxAllTableParams")
maxCPTParam <- function (node, Mstepit=5,
tol=sqrt(.Machine$double.eps)) {
## Get the posterior pseudo-counts by multiplying each row of the
## node's CPT by its experience.
ne <- PnodePostWeights(node)
np <- as.CPA(PnodeProbs(node))
npdim <- length(dim(np)) -1L
if (npdim==0L || length(ne) == 1L) {
counts <- np*ne
} else {
counts <- sweep(np,1L:npdim,ne,"*")
}
withCallingHandlers(
est <- CPTtools::mapDPC(counts,ParentStates(node),NodeStates(node),
PnodeLnAlphas(node), PnodeBetas(node),
PnodeRules(node),PnodeLink(node),
PnodeLinkScale(node),PnodeQ(node),
PnodeParentTvals(node),
control=list(reltol=tol,maxit=Mstepit)
),
warning=muffler)
PnodeLnAlphas(node) <- est$lnAlphas
PnodeBetas(node) <- est$betas
PnodeLinkScale(node) <- est$linkScale
invisible(node)
}
setGeneric("maxCPTParam")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.