R/gibbs.R

Defines functions inferencia_gibbs_sampling amostrar_mb inferencia_gibbs_sampling_ev

inferencia_gibbs_sampling <- function(bn, Q, E, n = 1000) {
  if (missing(Q) && missing(E)) return(NULL)
  if (!missing(Q) && !missing(E)) {
    EV <- c(Q, E)
    RAZ <- E
    p1 <- inferencia_gibbs_sampling_ev(bn, EV, n)
    p2 <- inferencia_gibbs_sampling_ev(bn, RAZ, n)
    return(p1 / p2)
  } else if (missing(Q) | !missing(E)) {
    EV <- E
  } else {
    EV <- Q
  }
  return(inferencia_gibbs_sampling_ev(bn, EV, n))
}

amostrar_mb <- function(f, z, mbZ, mbZ_val) {
  Ev <- paste(paste(mbZ, sprintf('"%s"', mbZ_val), sep = ' == '), collapse = ' & ')
  ff <- dplyr::filter_(f, Ev)
  ff[[z]][LaplacesDemonCpp::rcat(1, ff$factor)]
}

inferencia_gibbs_sampling_ev <- function(bn, EV, n) {
  E_var <- sapply(stringr::str_match_all(as.character(EV), '(X[0-9]+) ?=='),
                  function(x) x[, 2])
  E_val <- sapply(stringr::str_match_all(as.character(EV), 'X[0-9]+ ?== "(.*)"'),
                  function(x) x[, 2])
  names(E_val) <- E_var
  E_val <- E_val[sort(names(E_val))]
  Nx <- 0
  o <- names(bn)
  X <- sapply(o, function(x) {
    if(x %in% E_var) E_val[[x]] else sample(bn[[x]]$cpt[[x]], 1)
  })
  produtos_mb <- lapply(o, function(z) {
    mbZ <- bnlearn:::mb.fitted(bn, z)
    lf <- lapply(bn[[z]]$children, function(x) bn[[x]]$cpt)
    lf[[length(lf) + 1]] <- bn[[z]]$cpt
    f <- product(lf)
    f
  })
  mbs <- lapply(o, function(z) bnlearn:::mb.fitted(bn, z))
  for (i in 1:n) {
    # if(runif(1) < 0.1) print(i)
    for (j in seq_along(X)) {
      X[o[j]] <- amostrar_mb(produtos_mb[[j]], o[j], mbs[[j]], X[mbs[[j]]])
    }
    Nx <- Nx + identical(X[E_var][sort(names(E_val))], E_val)
  }
  Nx / n
}
jtrecenti/ea2 documentation built on May 20, 2019, 3:17 a.m.