R/importance.R

Defines functions amostrar_n_importance amostrar_node_importance inferencia_importance_sampling_ev inferencia_importance_sampling

###############################################################################
##################### IMPORTANCE SAMPLING #####################################
###############################################################################

amostrar_n_importance <- function(bn_q, bn, n, E_var) {
  o <- bnlearn:::schedule(bn_q)
  res <- list()
  pa <- NULL
  produto <- 1
  for (i in seq_along(o)) {
    pa <- res[bn_q[[o[i]]]$parents]
    if (length(pa) == 0) {
      v <- amostrar_node_importance(bn_q[[o[i]]], n, bn, E_var)
    } else {
      v <- amostrar_node_importance(bn_q[[o[i]]],
                                    dplyr::as_data_frame(pa), bn, E_var)
    }
    res[[o[i]]] <- v$val
    produto <- produto * v$prob
  }
  produto
}

amostrar_node_importance <- function(node, pa, bn, E_var) {
  if (node$node %in% E_var) {
    val <- head(dplyr::filter(node$cpt, factor == 1), 1)[[node$node]]
    if (ncol(node$cpt) == 2) {
      d_obs <- setNames(dplyr::data_frame(a = rep(val, pa), node$node))
      probs_obs <- dplyr::inner_join(d_obs, bn[[node$node]]$cpt, names(d_obs))
    } else {
      d_ev <- setNames(dplyr::data_frame(a = rep(val, nrow(pa))), node$node)
      d_obs <- dplyr::bind_cols(pa, d_ev)
      probs_obs <- dplyr::inner_join(d_obs, bn[[node$node]]$cpt, names(d_obs))
    }
    return(list(val = d_ev[[node$node]], prob = probs_obs$factor))
  } else {
    classes <- unique(node$cpt[[node$node]])
    if(ncol(node$cpt) == 2) {
      cl <- classes[LaplacesDemonCpp::rcat(pa, node$cpt$factor)]
    } else {
      probs <- tidyr::spread_(node$cpt, node$node, 'factor')
      d <- dplyr::select(dplyr::inner_join(pa, probs, names(pa)), one_of(classes))
      m <- as.matrix(d)
      cl <- classes[LaplacesDemonCpp::rcat(nrow(m), m)]
    }
    return(list(val = cl, prob = 1))
  }
}

inferencia_importance_sampling_ev <- function(bn, E, n = 100) {
  E_var <- sapply(stringr::str_match_all(as.character(E), '(X[0-9]+) ?=='),
                  function(x) x[, 2])
  bn_q <- bn
  class(bn_q) <- NULL
  bn_q <- lapply(bn, as.list)
  plyr::l_ply(seq_along(E_var), function(x) {
    bn_q[[E_var[x]]]$cpt$factor <<- with(bn_q[[E_var[x]]]$cpt, ifelse(eval(E[x]), 1, 0))
  })
  class(bn_q) <- class(bn)
  amostra <- amostrar_n_importance(bn_q, bn, n, E_var)
  mean(amostra)
}

inferencia_importance_sampling <- function(bn, Q, E, n = 100) {
  if (missing(Q) && missing(E)) return(NULL)
  if (!missing(Q) && !missing(E)) {
    EV <- c(Q, E)
    RAZ <- E
    p1 <- inferencia_importance_sampling_ev(bn, EV, n)
    p2 <- inferencia_importance_sampling_ev(bn, RAZ, n)
    return(p1 / p2)
  } else if (missing(Q) | !missing(E)) {
    EV <- E
  } else {
    EV <- Q
  }
  return(inferencia_importance_sampling_ev(bn, EV, n))
}
jtrecenti/ea2 documentation built on May 20, 2019, 3:17 a.m.