R/operations.R

Defines functions order_bn assign_val bn_de bn_an product sum_out remove_barren

order_bn <- function(bn) {
  bnlearn:::schedule(bn, reverse = TRUE)
}

assign_val <- function(f, expr) {
  E_var <- sapply(stringr::str_match_all(as.character(expr), '(X[0-9]+) ?=='),
                  function(x) x[, 2])
  exprs <- as.character(expr)
  f_res <- f
  for (i in seq_along(E_var)) {
    if (E_var[i] %in% names(f_res)) {
      with(f_res, {
        ev <- eval(parse(text = exprs[i]))
        f_res <<- f_res[ev, ]
      })
    }
  }
  f_res
}

bn_de <- function(bn, X) {
  X_test <- unlist(lapply(bn[X], function(x) x$children))
  if (length(X_test) == 0) return(X) else return(unname(c(X, bn_de(bn, X_test))))
}

bn_an <- function(bn, X) {
  X_test <- unlist(lapply(bn[X], function(x) x$parents))
  if (length(X_test) == 0) return(X) else return(unname(c(X, bn_an(bn, X_test))))
}


product <- function(lf) {
  product_one <- function(f1, f2) {
    i <- intersect(names(f1), names(f2))
    i <- i[i != 'factor']
    if (length(i) > 0) {
      f <- dplyr::left_join(f1, f2, i)
      f$factor <- f$factor.x * f$factor.y
      f <- dplyr::select(f, -factor.x, -factor.y)
    } else {
      u <- setdiff(c(names(f1), names(f2)), 'factor')
      if (ncol(f1) == 1) {
        f2$factor <- f2$factor * f1$factor
        return(f2)
      } else if (ncol(f2) == 1) {
        f1$factor <- f1$factor * f2$factor
        return(f1)
      }
      f1$.id <- 1:nrow(f1)
      f2$.id <- (nrow(f1) + 1):(nrow(f1) + nrow(f2))
      f <- dplyr::select(dplyr::full_join(f1, f2, '.id'), -.id)
      f <- tidyr::expand_(f, dots = u)
      f <- dplyr::left_join(f, f1, names(f1)[!names(f1) %in% c('factor', '.id')])
      f <- dplyr::left_join(f, f2, names(f2)[!names(f2) %in% c('factor', '.id')])
      f$factor <- f$factor.x * f$factor.y
      f <- dplyr::select(f, -factor.x, -factor.y, -.id.x, -.id.y)
      f
    }
    f <- dplyr::arrange_(f, .dots = names(f)[-length(f)])
    f
  }
  res <- lf[[1]]
  if (length(lf) > 1) {
    for (i in 2:length(lf)) {
      res <- product_one(res, lf[[i]])
    }
  }
  res
}

sum_out <- function(f, S) {
  nm <- names(f)
  s0 <- S[S %in% nm]
  if (length(s0) == 0) return(f)
  s_res <- nm[!nm %in% c(S, 'factor')]
  fg <- dplyr::group_by_(f, .dots = s_res)
  fs <- dplyr::ungroup(dplyr::summarise(fg, factor = sum(factor)))
  fs
}

remove_barren <- function(bn, Q_var, E_var) {
  candidatos <- setdiff(names(bn), c(Q_var, E_var))
  tirar_id <- sapply(candidatos, function(x) length(bn[[x]]$children) == 0)
  if (sum(unlist(tirar_id)) == 0) return(bn)
  tirar <- candidatos[tirar_id]
  bn[tirar] <- NULL
  bn <- lapply(bn, function(node) {
    node$cpt <- sum_out(node$cpt, tirar)
    node$children <- setdiff(node$children, tirar)
    node
  })
  remove_barren(bn, Q_var, E_var)
}
jtrecenti/ea2 documentation built on May 20, 2019, 3:17 a.m.