R/writeJagsModel.R

Defines functions writeJagsModel_dpois survreg_tau writeJagsModel_dnorm_survreg writeJagsModel_dnorm_default writeJagsModel_dnorm writeJagsModel_determ writeJagsModel_dcat writeJagsModel_dbern writeJagsModel_default writeJagsModel

Documented in writeJagsModel writeJagsModel_dbern writeJagsModel_dcat writeJagsModel_default writeJagsModel_determ writeJagsModel_dnorm writeJagsModel_dnorm_default writeJagsModel_dpois

#' @name writeJagsModel
#' 
#' @title Write a Node's JAGS Model
#' @description Constructs the JAGS code that designates the model for the 
#'   node conditioned on its parents.  The parameters for the model may 
#'   be user supplied or estimated from a given data set.
#'  
#' @param network A network of class HydeNetwork
#' @param node A node within \code{network}
#' @param node_str A character string giving the name of a node within \code{network}.
#'   This is usually generated within \code{writeJagsModel} and passed to 
#'   a specific method.
#' @param node_params A vector of parameters for the node.  Generated by 
#'   \code{writeJagsModel} and passed to a specific method.
#' 
#' @details The manipulations are performed on the \code{nodeParams} element
#'   of the \code{Hyde} network.  A string of JAGS code is returned suitable
#'   for inclusion in the Bayesian analysis.  
#'   
#'   The function will (eventually) travel through a serious of \code{if} 
#'   statements until it finds the right node type.  It will then match
#'   the appropriate arguments to the inputs based on user supplied values or
#'   estimating them from the data.
#'   
#' @author Jarrod Dalton and Benjamin Nutter
#' @seealso \code{\link{writeJagsFormula}}
#' 
#' @examples
#' \dontrun{
#' #* NOTE: writeJagsModel isn't an exported function
#' data(PE, package='HydeNet')
#' Net <- HydeNetwork(~ wells + 
#'                      pe | wells + 
#'                      d.dimer | pregnant*pe + 
#'                      angio | pe + 
#'                      treat | d.dimer*angio + 
#'                      death | pe*treat,
#'                      data = PE)
#' HydeNet:::writeJagsModel(Net, 'pe')
#' HydeNet:::writeJagsModel(Net, 'treat')
#' }
#' 


writeJagsModel <- function(network, node)
{
  node_str <- 
    if (is.character(node)) node 
  else as.character(substitute(node))
  
  node_params <- network[["nodeParams"]][[node_str]]
  
  params <- 
    expectedParameters_(network = network, 
                        node = node, 
                        returnVector = TRUE)
    
  names(node_params) <- names(params)

  switch(
    network[["nodeType"]][[node_str]],
    "dbern" =   writeJagsModel_dbern(network = network,
                                     node_str = node_str,
                                     node_params = node_params),
    "dcat" =     writeJagsModel_dcat(network = network,
                                     node_str = node_str,
                                     node_params = node_params),
    "dnorm" =   writeJagsModel_dnorm(network = network,
                                     node_str = node_str,
                                     node_params = node_params),
    "dpois" =   writeJagsModel_dpois(network = network,
                                     node_str = node_str,
                                     node_params = node_params),
    "determ" = writeJagsModel_determ(network = network,
                                     node_str = node_str,
                                     node_params = node_params),
    writeJagsModel_default(network = network,
                           node_str = node_str,
                           node_params = node_params)
  )
}

#' @rdname writeJagsModel

writeJagsModel_default <- function(network, node_str, node_params)
{
  if (any(node_params %in% c(fromData(), fromFormula())))
    stop("nodeType '", network[["nodeType"]][[node_str]], 
         "' does not currently support fromData() or fromFormula()")
  
  sprintf("%s ~ %s(%s)",
          node_str, 
          network[["nodeType"]][[node_str]],
          paste(node_params, collapse=", "))
}

#' @rdname writeJagsModel

writeJagsModel_dbern <- function(network, node_str, node_params)
{
  if (fromData() %in% node_params)
  {
    fit <- do.call(what = network[["nodeFitter"]][[node_str]],
                   args = c(list(formula = network[["nodeFormula"]][[node_str]],
                                 data = if (is.null(network[["nodeData"]][[node_str]])) 
                                   network[["data"]] 
                                 else 
                                   network[["nodeData"]][[node_str]]),
                            network[["nodeFitterArgs"]][[node_str]])
    )
  }
  
  if (node_params["p"] %in% c(fromData(), fromFormula()))
  {
    if (node_params["p"] == fromData())
    {
      node_params["p"] <- writeJagsFormula(fit = fit, 
                                           nodes = network[["nodes"]],
                                           bern = names(network$nodeType)[vapply(network$nodeType, function(x) x == "dbern", logical(1))])
    }
    else if (node_params["p"] == fromFormula())
    {
      node_params["p"] <- rToJags(network[["nodeFormula"]][[node_str]])
    }
    
    node_params["p"] <- as.character(as.formula(node_params[["p"]]))[-(1:2)]
  }
  
  sprintf("%s ~ %s(%s)",
          node_str,
          network[["nodeType"]][[node_str]],
          paste(node_params, collapse=", "))
}

#' @rdname writeJagsModel  

writeJagsModel_dcat <- function(network, node_str, node_params)
{
  if (!is.null(network[["nodeFitter"]][[node_str]]) &&
      network[["nodeFitter"]][[node_str]] == "cpt")
  {
    parents <- network[["parents"]][[node_str]]
    bern_parent <- vapply(X = parents, 
                          FUN = function(p) network[["nodeType"]][[p]] == "dbern",
                          FUN.VALUE = logical(1))
    parents[bern_parent] <- sprintf("(%s+1)", 
                                    parents[bern_parent])
    
    sprintf("pi.%s <- cpt.%s[%s, ]\n   %s ~ dcat(pi.%s)",
            node_str, 
            node_str, 
            paste0(parents, collapse=", "),
            node_str,
            node_str)
  } 
  else if (fromData() %in% node_params)
  {
    node_params["pi"] <- sprintf("pi.%s", 
                                 node_str)
    pi <- do.call(what = network[["nodeFitter"]][[node_str]],
                  args = list(formula = network[["nodeFormula"]][[node_str]],
                              data = if (is.null(network[["nodeData"]][[node_str]])) 
                                network[["data"]] 
                              else 
                                network[["nodeData"]][[node_str]]))
    pi <- writeJagsFormula(fit = pi, 
                           nodes = network$nodes,
                           bern = names(network$nodeType)[vapply(network$nodeType, function(x) x == "dbern", logical(1))])
    
    sprintf("%s \n%s ~ %s(%s)",
            pi,
            node_str, 
            network[["nodeType"]][[node_str]],
            paste(node_params, 
                  collapse = ", ")
    )
  }
  else
  {
    sprintf("%s \n   %s ~ %s(pi.%s)",
            network[["nodeParams"]][[node_str]]["pi"],
            node_str,
            network[["nodeType"]][[node_str]],
            node_str)
  }
}

#' @rdname writeJagsModel

writeJagsModel_determ <- function(network, node_str, node_params)
{
  if (fromFormula() %in% node_params)
  {
    define <- as.character(network[["nodeFormula"]][[node_str]])
    sprintf("%s <- %s",
            define[2], 
            define[3])
  }
  else 
  {
    sprintf("%s <- %s",
            node_str,
            network[["nodeParams"]][[node_str]][["define"]])
  }
}

#' @rdname writeJagsModel

writeJagsModel_dnorm <- function(network, node_str, node_params)
{
  use_default <- 
    if (is.null(network[["nodeFitter"]][[node_str]])) TRUE
    else network[["nodeFitter"]][[node_str]] != "survreg"
  if (use_default)
  {
    writeJagsModel_dnorm_default(network, node_str, node_params)
  }
  else
  {
    writeJagsModel_dnorm_survreg(network, node_str, node_params)
  }
}

#' @rdname writeJagsModel 

writeJagsModel_dnorm_default <- function(network, node_str, node_params)
{
  if (fromData() %in% node_params)
  {
    fit <- do.call(what = network[["nodeFitter"]][[node_str]],
                   args = c(list(formula = network[["nodeFormula"]][[node_str]],
                                 data = if (is.null(network[["nodeData"]][[node_str]])) 
                                   network[["data"]] 
                                 else 
                                   network[["nodeData"]][[node_str]]),
                            network[["nodeFitterArgs"]][[node_str]]))
  }
  
  if (node_params["mu"] %in% c(fromData(), fromFormula()))
  {
    if (node_params["mu"] == fromData())
    {
      node_params["mu"] <- writeJagsFormula(fit = fit, 
                                            nodes = network[["nodes"]],
                                            bern = names(network$nodeType)[vapply(network$nodeType, function(x) x == "dbern", logical(1))])
    }
    else if (node_params["mu"] == fromFormula())
    {
      node_params["mu"] <- rToJags(network[["nodeFormula"]][[node_str]])
    }
    
    
    node_params["mu"] <- 
      as.character(
        as.formula(
          node_params[["mu"]]
        )
      )[-(1:2)]
    
  }
  
  if (node_params["tau"] %in% c(fromData(), fromFormula()))
  {
    if (node_params["tau"] == fromData())
    {
      node_params["tau"] <- round(x = 1/summary(fit)[["sigma"]] ^ 2, 
                                  digits = getOption("Hyde_maxDigits"))
    }
    else if (node_params["tau"] == fromFormula())
      stop("parameter 'tau' can not be estimated from a formula.")
  }
  
  sprintf("%s ~ %s(%s)",
          node_str,
          network[["nodeType"]][[node_str]],
          paste(node_params, collapse = ", "))
}

writeJagsModel_dnorm_survreg <- function(network, node_str, node_params)
{
  if (fromData() %in% node_params)
  {
    fit <- do.call(what = network[["nodeFitter"]][[node_str]],
                   args = c(list(formula = network[["nodeFormula"]][[node_str]],
                                 data = if (is.null(network[["nodeData"]][[node_str]])) 
                                   network[["data"]] 
                                 else 
                                   network[["nodeData"]][[node_str]]),
                            network[["nodeFitterArgs"]][[node_str]]))
  }
  
  
  if (node_params["mu"] == fromData())
  {
    node_params["mu"] <- writeJagsFormula(fit = fit, 
                                          nodes = network[["nodes"]],
                                          bern = names(network$nodeType)[vapply(network$nodeType, function(x) x == "dbern", logical(1))])
  }
  else if (node_params["mu"] == fromFormula())
  {
    node_params["mu"] <- rToJags(network[["nodeFormula"]][[node_str]])
  }

  if (grepl("~", node_params["mu"]))
  {
    node_params["mu"] <- 
      as.character(
        as.formula(
          node_params[["mu"]]
        )
      )[-(1:2)]
  }
  
  if (node_params["tau"] %in% c(fromData(), fromFormula()))
  {
    if (node_params["tau"] == fromData())
    {
      
      surv_tau <- survreg_tau(fit = fit, 
                              node_str = node_str, 
                              parents = network[["parents"]][[node_str]]
      )
      node_params["tau"] <- surv_tau[["tau"]]
        
        
    }
    else if (node_params["tau"] == fromFormula())
      stop("parameter 'tau' can not be estimated from a formula.")
  }
  else surv_tau <- NULL
  
  sprintf("%s\n   %s ~ %s(%s,\n%s%s)",
          if (is.null(network[["nodePrelim"]][[node_str]]))
            surv_tau[["prelim"]]
          else 
            network[["nodePrelim"]][[node_str]],
          node_str,
          network[["nodeType"]][[node_str]],
          node_params[1],
          paste0(rep(" ", 
                     nchar(node_str) + nchar(network[["nodeType"]][[node_str]]) + 7),
                 collapse = ""),
          node_params[2]
  )
}

## Utility function: produces tau for the survreg prediction

survreg_tau <- function(fit, node_str, parents)
{
  vv <- fit$var[1:length(coef(fit)), 1:length(coef(fit))]
  vv_str <- vector("character", 
                   length = length(vv))
  
  mframe <- stats::model.frame(fit)[, -1, drop = FALSE]
  
  nms <- 
    sapply(names(mframe),
           function(x)
           {
             if (is.factor(mframe[[x]]))
             {
               if (nlevels(mframe[[x]]) > 2)
               {
                 paste0("(", x, " == ", 2:nlevels(mframe[[x]]), ")")
               }
               else 
               {
                 x
               }
             }
             else 
             {
               x
             }
           })
  
  for (i in 1:ncol(vv))
  {
    for (j in 1:nrow(vv))
    {
      vv_str[i + (j-1) * ncol(vv)] <- 
        sprintf("vv.%s[%d,%d] <- %s",
                node_str,
                j,
                i,
                round(vv[j,i],
                      digits = getOption("Hyde_maxDigits"))
        )
    }
  }
  vv_str <- paste0(vv_str, collapse = "; ")
  
  xmat <- 
    paste0("xmat.", node_str, "[1,", 1:length(coef(fit)), "] <- ",
           c(1, unlist(nms)), 
           collapse = "; ")
  
  xmat_prime <- 
    paste0("xmatprime.", node_str, "[", 1:length(coef(fit)), ",1] <- ",
           c(1, unlist(nms)),
           collapse = "; ")
  
  list(prelim = paste0(vv_str, "\n   ",
                       xmat, "\n   ",
                       xmat_prime, "   \n"),
       tau = paste0("1 / (xmat.", node_str, "[,] %*% vv.", node_str, "[,] %*% xmatprime.", node_str, "[,])")
  )
}


#' @rdname writeJagsModel

writeJagsModel_dpois <- function(network, node_str, node_params)
{
  if (fromData() %in% node_params)
  {
    fit <- do.call(what = network[["nodeFitter"]][[node_str]],
                   args = c(list(formula = network[["nodeFormula"]][[node_str]],
                                 data = 
                                   if (is.null(network[["nodeData"]][[node_str]])) 
                                     network[["data"]] 
                                 else 
                                   network[["nodeData"]][[node_str]]),
                            network[["nodeFitterArgs"]][[node_str]]))
  }
  
  if (node_params["lambda"] %in% c(fromData(), fromFormula()))
  {
    if (node_params["lambda"] == fromData())
    {
      node_params["lambda"] <- writeJagsFormula(fit = fit, 
                                                nodes = network$nodes,
                                                bern = names(network$nodeType)[vapply(network$nodeType, function(x) x == "dbern", logical(1))])
    }
    else if (node_params["lambda"] == fromFormula())
    {
      node_params["lambda"] <- rToJags(network[["nodeFormula"]][[node_str]])
    }
    
    node_params["lambda"] <- 
      as.character(
        as.formula(
          node_params[["lambda"]]
        )
      )[-(1:2)]
  }
  
  sprintf("%s ~ %s(%s)",
          node_str,
          network[["nodeType"]][[node_str]],
          paste0(node_params, collapse = ", "))
}

Try the HydeNet package in your browser

Any scripts or data that you put into this service are public.

HydeNet documentation built on July 8, 2020, 5:15 p.m.