R/local_average.R

Defines functions local_average

Documented in local_average

#' Local network averages
#'
#' @param graph_id the graph to which to add the edges and vertices. Should
#'   be of type \code{chickenwire}.
#' @param vertex_ids a vector with vertex ids. When omitted it is assumed that
#'   the vertex values (\code{vertex_values}) are in the same order as used
#'   internally.
#' @param vertex_values a vector with vertex values. Can be numeric, factor or
#'   character values. 
#' @param vertex_weights a vector with vertex weights. Should be numeric and 
#'   have the same length as \code{vertex_values}. Optional.
#' @param alpha probability of continuing the random walk (1-restart 
#'   probability).
#' @param nstep_max maximum number of iterations.
#' @param nworkers number of threads to use during the calculation.
#' @param precision when updates in the local averages are smaller than 
#'   \code{precision} the algorithm is stopped.
#'
#' @return
#' A data.frame. The rows correspond to the vertices.
#'
#' @export
local_average <- function(graph_id, vertex_ids, vertex_values, vertex_weights = 1.0, alpha = 0.85, 
   nstep_max = 200, nworkers = available_cores(), precision = 1E-5) {

  stopifnot(methods::is(graph_id, "chickenwire"))
  stopifnot(is.integer(graph_id) && length(graph_id) == 1)

  # vertex_id
  order <- seq_along(vertex_values)
  if (!missing(vertex_ids)) {
    stopifnot(length(vertex_values) == length(vertex_ids))
    ids <- attr(graph_id, "vertex_ids")
    if (!is.null(ids)) {
      order <- match(vertex_ids, ids)
      if (anyNA(order)) stop("Unknown vertex ids.")
    } else {
      warning("graph_id doesn't store vertex ids. Therefore, provided argument vertex_ids is ignored.")
    }
  } else {
    vertex_ids <- order - 1L
    ids <- attr(graph_id, "vertex_ids")
    if (!is.null(ids)) vertex_ids <- ids
  }
  # vertex_weights
  stopifnot(is.numeric(vertex_weights) && length(vertex_weights) >= 1)
  stopifnot(length(vertex_weights == 1) || length(vertex_weights) == length(vertex_values))
  stopifnot(!any(is.na(vertex_weights)))
  # vertex_values
  value_name <- deparse(substitute(vertex_values))
  value_factor <- FALSE
  if (is.factor(vertex_values) || is.character(vertex_values)) {
    vertex_values <- as.factor(vertex_values)
    value_name <- levels(vertex_values)
    vertex_values <- as.integer(vertex_values) - 1
    value_factor <- TRUE
  } else if (is.logical(vertex_values)) {
    vertex_values <- 1.0 * vertex_values;
  }
  stopifnot(is.numeric(vertex_values) && length(vertex_values) == nvertices(graph_id))
  stopifnot(!any(is.na(vertex_values)))
  # alpha
  stopifnot(is.numeric(alpha) && length(alpha) == 1)
  stopifnot(alpha >= 0 && alpha <= 1)
  # nstep_max
  stopifnot(is.numeric(nstep_max) && length(nstep_max) == 1)
  stopifnot(nstep_max > 0)
  # precision
  stopifnot(is.numeric(precision) && length(precision) == 1)
  stopifnot(precision > 0 && precision <= 1)
  # random_walk
  if (value_factor) {
    res <- rcpp_local_average_cat(graph_id, vertex_values, vertex_weights, alpha, nworkers, nstep_max, precision)
    nstep <- attr(res, "nstep")
    res <- as.data.frame(res)
    # when the highest levels are missing from the data set these are not 
    # included in the results. Fix this
    if (ncol(res) < length(value_name)) 
      for (col in seq(ncol(res)+1L, length(value_name)))
        res[[col]] <- 0.0
    names(res) <- value_name
    res <- cbind(data.frame(vertex_id = vertex_ids), res[order, , drop = FALSE])
    attr(res, "nstep") <- nstep
  } else {
    res <- rcpp_local_average_cont(graph_id, vertex_values, vertex_weights, alpha, nworkers, nstep_max, precision)
    nstep <- attr(res, "nstep")
    res <- data.frame(vertex_id = vertex_ids, res = res[order])
    names(res)[2] <- value_name
    attr(res, "nstep") <- nstep
  }
  res
}
djvanderlaan/chickenwire-r documentation built on July 19, 2022, 1:16 a.m.