R/sparsesolve.R

Defines functions sparsesolve pass_flow backsubstitute build_jt countfillin

Documented in sparsesolve

#' Solve a sparse linear system
#' @param A matrix
#' @param b vector
#' @param jt junction tree of the graph corresponding to A
#' @param forward order of elimination of the forward flow passing procedure
#' @param root of the junction tree (choice is arbitrary)
#' @return The solution to the linear problem \code{AX = b}
#' @references Paskin, M. and Lawrence D., \italic{Junction tree algorithms for solving sparse linear systems}, 2003.
#'
#' @export
sparsesolve <- function(A, b, jt = NULL, forward = NULL, root = NULL) {
  dimnames(A) <- list(1:nrow(A), 1:ncol(A))
  names(b) <- seq_along(b)
  # Define junction tree
  if (is.null(jt)) {
    g <- graph_from_adjacency_matrix(A != 0, mode = "undirected")
    jt <- build_jt(g)
  }
  tree <- jt$graph
  # Define root of tree (choice is arbitrary)
  if (is.null(root)) root <- floor(jt$N / 2)
  # Define flow order for forward flow passing
  if (is.null(forward)) {
    forward <- sapply(rev(all_simple_paths(tree, from = V(tree)[root])),
                      function(a) c(which(V(tree) %in% a[length(a)]),
                                    which(V(tree) %in% a[length(a) - 1]))
    )
  }
  # Define partial matrices and vectors
  matrix <- vector("list", length(V(tree)))
  vector <- vector("list", length(V(tree)))
  for (i in 1:length(V(tree))) {
    matrix[[i]] <- A[ach(jt$C[[i]]), ach(jt$C[[i]])]
    vector[[i]] <- b[ach(jt$C[[i]])]
    if (i > 1) {
      s <- intersect(jt$C[[i]], jt$C[[i - 1]])
      matrix[[i]][ach(s), ach(s)] <- 0
      vector[[i]][ach(s)] <- 0
    }
  }
  # Pass flows towards root
  for (ind in 1:ncol(forward)) {
    index_c <- forward[1, ind]
    index_d <- forward[2, ind]
    res <- pass_flow(index_c, index_d, matrix, vector, jt)
    matrix <- res$matrix
    vector <- res$vector
  }
  # Solve partial linear system at root
  sol <- solve(matrix[[root]], vector[[root]])
  # Backsubstitute
  backward <- forward[, rev(1:ncol(forward))]
  for (ind in 1:ncol(backward)) {
    index_c <- backward[1, ind]
    index_d <- backward[2, ind]
    temp <- c(sol, backsubstitute(index_c, index_d, matrix, vector, jt, sol))
    sol <- temp[order(as.numeric(names(temp)))]
  }
  sol
}
#' @export
pass_flow <- function(index_c, index_d, matrix, vector, jt) {
  c <- jt$C[[index_c]]
  d <- jt$C[[index_d]]
  u <- setdiff(c, d)
  s <- intersect(c, d)
  temp_matrix <- matrix[[index_c]][ach(s), ach(s)] -
    matrix[[index_c]][ach(s), ach(u), drop = FALSE] %*%
    solve(matrix[[index_c]][ach(u), ach(u), drop = FALSE], matrix[[index_c]][ach(u), ach(s), drop = FALSE])
  temp_vector <- as.vector(
    vector[[index_c]][ach(s)] -
      matrix[[index_c]][ach(s), ach(u), drop = FALSE] %*%
      solve(matrix[[index_c]][ach(u), ach(u), drop = FALSE], vector[[index_c]][ach(u)])
  )
  matrix[[index_d]][ach(s), ach(s)] <- matrix[[index_d]][ach(s), ach(s)] + temp_matrix
  vector[[index_d]][ach(s)] <- vector[[index_d]][ach(s)] + temp_vector
  return(list(matrix = matrix,
              vector = vector))
}
#' @export
backsubstitute <- function(index_c, index_d, matrix, vector, jt, sol) {
  c <- jt$C[[index_c]]
  d <- jt$C[[index_d]]
  u <- setdiff(c, d)
  s <- intersect(c, d)
  solve(matrix[[index_c]][ach(u), ach(u), drop = FALSE],
        as.vector(vector[[index_c]][ach(u), drop = FALSE] -
                    matrix[[index_c]][ach(u), ach(s), drop = FALSE] %*% sol[ach(s)]))
}
#' @export
build_jt <- function(g, verbose = FALSE) {
  # clique and separator sets
  N <- 0
  order <- NULL
  C <- gC <- S <- list()
  active <- rep(TRUE, vcount(g))
  # main loop
  while (sum(active) > 0) {
    # compute neighborhood
    vois <- neighborhood(g, order = 1)
    gvois <- graph.neighborhood(g, order = 1)
    # fillin
    fillin <- unlist(lapply(gvois, countfillin))
    # elim min fillin (among active vertices)
    elim <- which(active)[sort(fillin[active], index.return = TRUE)$ix[1]]
    # compute remove set from elim
    remove <- NULL
    for (v in vois[[elim]]) {
      # if vois[[v]] subset vois[[elim]]
      if (sum(is.element(vois[[v]], vois[[elim]])) == length(vois[[v]]))
        remove <- c(remove, v)
    }
    # add fillin
    for (v in vois[[elim]]) g[v, setdiff(vois[[elim]], v)] <- TRUE
    # disconnect and inactive remove set
    order <- c(order, remove)
    g[remove, ] <- FALSE
    active[remove] <- FALSE
    # update clique and separator set
    N <- N + 1
    C[[N]] <- vois[[elim]]
    gC[[N]] <- gvois[[elim]]
    S[[N]] <- setdiff(C[[N]], remove)
    if (verbose)
      print(paste0(N, " cliques, ", sum(active), " nodes left"))
  }
  # connect separators to cliques
  connect <- list()
  edges <- NULL
  for (i in 1:(N - 1)) {
    for (j in (i + 1):N) {
      if (sum(is.element(S[[i]], C[[j]])) == length(S[[i]])) {
        connect[[i]] <- j
        edges <- c(edges, c(i, j))
        break
      }
    }
    connect[[N]] <- numeric(0)
  }
  # build the JT
  jt <- graph(edges, directed = FALSE)
  if (!is.null(V(g)$name)) {
    V(jt)$name <- paste0("C", 1:length(C), "={", sapply(C, function(z) paste0(V(g)$name[as.numeric(z)], collapse = ",")),
                         "}")
  } else {
    V(jt)$name <- paste0("C", 1:length(C), "={", sapply(C, function(z) paste0(as.numeric(z), collapse = ",")), "}")
  }
  list(C = C, S = S, elim.order = order,
       connect = connect, graph = jt, N = N,
       treewidth = max(unlist(lapply(C, length))))
}
#' @export
countfillin <- function(g) {
  n <- vcount(g)
  m <- ecount(g)
  return(n * (n - 1)/2 - m)
}
#' @export
ach <- as.character
goepp/sparsesolve documentation built on May 30, 2019, 12:01 a.m.