R/sum_product.R

Defines functions scope factor_graph inferencia_sum_product

scope <- function(node) {
  nm <- names(node$cpt)
  nm[nm != 'factor']
}

factor_graph <- function(bn) {
  scopes <- lapply(bn, scope)
  edges <- sapply(seq_along(scopes), function(i) {
    phi <- gsub('X', 'phi', names(scopes)[i])
    m <- matrix(c(rep(phi, length(scopes[[i]])), scopes[[i]]), ncol = 2)
    rbind(m, m[, c(2, 1)])
  })
  edges <- do.call(rbind, edges)
  # nodes <- DiagrammeR::create_nodes(nodes = unique(as.character(edges)))
  # ed <- DiagrammeR::create_edges(from = edges[, 1], to = edges[, 2])
  # gr <- DiagrammeR::create_graph(nodes_df = nodes, edges_df = ed)
  # DiagrammeR::grViz(gr$dot_code)
  d <- setNames(as.data.frame(edges, stringsAsFactors = FALSE), c('from', 'to'))
  d <- dplyr::tbl_df(d)
  d <- dplyr::mutate(d, type = ifelse(stringr::str_detect(from, 'phi'),
                                      'phi_to_X', 'X_to_phi'))
  dplyr::arrange(d, type, from, to)
}

inferencia_sum_product <- function(bn, E, n = 1000) {
  E_var <- sapply(stringr::str_match_all(as.character(E), '(X[0-9]+) ?=='),
                  function(x) x[, 2])
  bn <- remove_barren(bn, E_var, E_var)
  fgraph <- factor_graph(bn)
  fgraph <- dplyr::group_by(fgraph, from, to, type)
  fgraph <- dplyr::do(fgraph, psi = {
    if(.$type == 'phi_to_X') x <- .$to else x <- .$from
    s <- unique(bn[[x]]$cpt[[x]])
    setNames(dplyr::data_frame(x = s, f = 1 / length(s)), c(x, 'factor'))
  })
  fgraph <- dplyr::ungroup(fgraph)
  for(k in 1:n) {
    print(k)
    for (i in seq_len(nrow(fgraph))) {
      if (fgraph[i, ][['type']] == 'phi_to_X') {
        f <- dplyr::filter(fgraph, to == fgraph[i, ][['from']], from != fgraph[i, ][['to']])
        if(nrow(f) != 0) {
          f <- f$psi
          f$phi <- bn[[gsub('phi', 'X', fgraph[i, ][['from']])]]$cpt
          pr <- product(f)
          rm_var <- names(pr)[!names(pr) %in% c(fgraph[i, ][['to']], 'factor')]
          fgraph[i, ][['psi']][[1]] <- sum_out(pr, rm_var)
        }
      } else {
        f <- dplyr::filter(fgraph, to == fgraph[i, ][['from']], from != fgraph[i, ][['to']])
        if(nrow(f) != 0) {
          fgraph[i, ][['psi']][[1]] <- product(f$psi)
        }
      }
      soma <- sum(fgraph[i, ][['psi']][[1]]$factor)
      fgraph[i, ][['psi']][[1]]$factor <- fgraph[i, ][['psi']][[1]]$factor / soma
    }
  }
  f <- product(dplyr::filter(fgraph, to == E_var)$psi)
  f
}
jtrecenti/ea2 documentation built on May 20, 2019, 3:17 a.m.