R/gmbn.R

Defines functions gmbn

Documented in gmbn

#' Create a Gaussian mixture Bayesian network
#'
#' This function creates a Gaussian mixture Bayesian network as an object of S3
#' class \code{gmbn}. A Bayesian network is a probabilistic graphical model that
#' represents the conditional dependencies and independencies between random
#' variables by a directed acyclic graph. It encodes a global joint distribution
#' over the nodes, which decomposes into a product of local conditional
#' distributions:
#' \deqn{p(X_1, \dots , X_n) = \prod_{i = 1}^n p(X_i | Pa(X_i))}
#' where \eqn{Pa(X_i)} is the set of parents of \eqn{X_i} in the graph. In a
#' Gaussian mixture Bayesian network, each local joint distribution over a node
#' and its parents is described by a Gaussian mixture model, which means that
#' the global distribution is a product of local conditional Gaussian mixture
#' models (Davies and Moore, 2000). The \code{gmbn} class can be extended to the
#' time factor by regarding the nodes as the state of the system at a given time
#' slice \eqn{t} (denoted by \eqn{X^{(t)}}) and allowing them to have parents at
#' previous time slices. This makes it possible to create a (\eqn{k + 1})-slice
#' temporal Bayesian network that encodes the transition distribution
#' \eqn{p(X^{(t)} | X^{(t - 1)}, \dots , X^{(t - k)})} (Hulst, 2006). Finally,
#' note that a Gaussian mixture Bayesian network can be created with functions
#' \code{\link{add_nodes}} (by passing \code{NULL} as argument \code{gmgm}) and
#' \code{\link{add_arcs}}, which allows to quickly initialize a \code{gmbn}
#' object that can be passed to a learning function.
#'
#' @param \dots Objects of class \code{gmm} describing the local joint
#' distributions over the nodes and their parents. Each \code{gmm} object must
#' be named after the node whose distribution it describes and contain variables
#' named after this node and its parents. Two types of parents are accepted:
#' other nodes (whose \code{gmm} objects must be defined) and instantiations of
#' nodes at previous time slices (if the created \code{gmbn} object is a
#' temporal Bayesian network). In the second case, the time lag must be added at
#' the end of the variable name after a period \code{.} (e.g. the instantiation
#' of a node \code{X} at time slice \eqn{t - 1} is represented by the variable
#' \code{X.1}).
#'
#' @return A list of class \code{gmbn} containing the \code{gmm} objects passed
#' as arguments.
#'
#' @references
#' Davies, S. and Moore, A. (2000). Mix-nets: Factored Mixtures of Gaussians in
#' Bayesian Networks with Mixed Continuous And Discrete Variables. \emph{In
#' Proceedings of the 16th Conference on Uncertainty in Artificial
#' Intelligence}, pages 168--175, Stanford, CA, USA.
#'
#' Hulst, J. (2006). \emph{Modeling physiological processes with dynamic
#' Bayesian networks}. Master's thesis, Delft University of Technology.
#'
#' @seealso \code{\link{gmdbn}}, \code{\link{gmm}}
#'
#' @examples
#' data(data_body)
#' gmbn_1 <- gmbn(
#'   AGE = split_comp(add_var(NULL, data_body[, "AGE"]), n_sub = 3),
#'   FAT = split_comp(add_var(NULL,
#'                            data_body[, c("FAT", "GENDER", "HEIGHT", "WEIGHT")]),
#'                    n_sub = 2),
#'   GENDER = split_comp(add_var(NULL, data_body[, "GENDER"]), n_sub = 2),
#'   GLYCO = split_comp(add_var(NULL, data_body[, c("GLYCO", "AGE", "WAIST")]),
#'                      n_sub = 2),
#'   HEIGHT = split_comp(add_var(NULL, data_body[, c("HEIGHT", "GENDER")])),
#'   WAIST = split_comp(add_var(NULL,
#'                              data_body[, c("WAIST", "AGE", "FAT", "HEIGHT",
#'                                            "WEIGHT")]),
#'                      n_sub = 3),
#'   WEIGHT = split_comp(add_var(NULL, data_body[, c("WEIGHT", "HEIGHT")]),
#'                       n_sub = 2)
#' )
#'
#' library(dplyr)
#' data(data_air)
#' data <- data_air %>%
#'   group_by(DATE) %>%
#'   mutate(NO2.1 = lag(NO2), O3.1 = lag(O3), TEMP.1 = lag(TEMP),
#'          WIND.1 = lag(WIND)) %>%
#'   ungroup()
#' gmbn_2 <- gmbn(
#'   NO2 = split_comp(add_var(NULL, data[, c("NO2", "NO2.1", "WIND")]), n_sub = 3),
#'   O3 = split_comp(add_var(NULL,
#'                           data[, c("O3", "NO2", "NO2.1", "O3.1", "TEMP",
#'                                    "TEMP.1")]),
#'                   n_sub = 3),
#'   TEMP = split_comp(add_var(NULL, data[, c("TEMP", "TEMP.1")]), n_sub = 3),
#'   WIND = split_comp(add_var(NULL, data[, c("WIND", "WIND.1")]), n_sub = 3)
#' )
#'
#' @export

gmbn <- function(...) {
  gmbn <- list(...)

  if (length(gmbn) == 0) {
    "no argument is passed" %>%
      stop()
  }

  nodes <- gmbn %>%
    names()

  if (is.null(nodes)) {
    "the arguments have no names" %>%
      stop()
  }

  if (any(duplicated(nodes))) {
    "arguments have the same name" %>%
      stop()
  }

  if (any(!str_detect(nodes,
                      "^(\\.([A-Za-z_\\.]|$)|[A-Za-z])[A-Za-z0-9_\\.]*$") |
          str_detect(nodes, "\\.[1-9][0-9]*$"))) {
    "argument names are invalid node names" %>%
      stop()
  }

  res_gmbn <- gmbn[sort(nodes)] %>%
    imap(function(gmm, node) {
      if (!inherits(gmm, "gmm")) {
        node %>%
          str_c(" is not of class \"gmm\"") %>%
          stop()
      }

      var_gmm <- gmm$mu %>%
        rownames()

      if (!(node %in% var_gmm)) {
        node %>%
          str_c(" has no variable \"", node, "\"") %>%
          stop()
      }

      var_gmm <- var_gmm %>%
        setdiff(node)
      arcs <- var_gmm %>%
        str_split_fixed("\\.(?=[1-9][0-9]*$)", 2)
      colnames(arcs) <- c("from", "lag")
      arcs <- arcs %>%
        as_tibble() %>%
        mutate(lag = replace_na(as.numeric(lag), 0),
               order = seq_len(n())) %>%
        arrange(from, lag)

      if (any(!(arcs$from %in% nodes))) {
        "variables of " %>%
          str_c(node, " are not related to defined nodes") %>%
          stop()
      }

      gmm <- gmm %>%
        reorder(c(node, var_gmm[arcs$order]))
      arcs_0 <- arcs %>%
        filter(lag == 0) %>%
        mutate(to = node) %>%
        select(from, to)
      list(gmm = gmm, arcs_0 = arcs_0) %>%
        return()
    }) %>%
    transpose()
  arcs_0 <- res_gmbn$arcs_0 %>%
    bind_rows()
  n_arcs_0 <- arcs_0 %>%
    nrow()

  while (n_arcs_0 > 0) {
    n_arcs_0_old <- n_arcs_0
    arcs_0 <- arcs_0 %>%
      filter(from %in% to, to %in% from)
    n_arcs_0 <- arcs_0 %>%
      nrow()

    if (n_arcs_0 == n_arcs_0_old) {
      "the defined structure has cycles" %>%
        stop()
    }
  }

  gmbn <- res_gmbn$gmm
  class(gmbn) <- "gmbn"
  gmbn %>%
    return()
}

Try the gmgm package in your browser

Any scripts or data that you put into this service are public.

gmgm documentation built on Sept. 9, 2022, 1:07 a.m.