R/k_tcherry_step.R

Defines functions k_tcherry_step

Documented in k_tcherry_step

#' Determine a k'th order t-cherry tree from data
#'
#' @description Determine the structure of a k'th order t-cherry tree
#' from data based on a greedy stepwise approach.
#'
#' @param data The data the tree structure should be based on.
#' @param k The order of the t-cherry tree.
#' @param ... Additional arguments passed to \code{MIk}.
#'
#' @details Notice that for \eqn{k = 3} it is the same as using
#' \code{tcherry_step} and for \eqn{k = 2} it is the same as using
#' \code{ChowLiu}.
#'
#' The algorithm for constructing the t-cherry tree from
#' data is based on an atempt to minimize the Kullback-Leibler
#' divergence. The first cherry is chosen as the k variables with
#' highest mutual information. This is the preliminary t-cherry
#' tree. Then all possible new cherries are added stepwise to this
#' tree and the weight \deqn{\sum MI(clique) - \sum MI(separator)} is
#' calculated.
#' The first sum is over the cliques and the second over the
#' separators of the junction tree of the preliminary t-cherry tree.
#' The one with the highest weight is chosen as the new preliminary
#' t-cherry tree, and the procedure is repeated untill all variables
#' has been added.
#'
#' @return A list containing the following components:
#' \itemize{
#' \item \code{adj_matrix} The adjacency matrix for the k'th order
#' t-cherry tree.
#' \item \code{weight} The weight of the final k'th order t-cherry tree.
#' \item \code{cliques} A list containing the cliques of
#'  the k'th order t-cherry tree.
#' \item \code{separators} A list containing the separators of a
#' junction tree for the k'th order t-cherry tree.
#' \item \code{n_edges} The number of edges in the resulting graph.
#' }
#'
#' @author
#' Katrine Kirkeby, \email{enir_tak@@hotmail.com}
#'
#' Maria Knudsen, \email{mariaknudsen@@hotmail.dk}
#'
#' Ninna Vihrs, \email{ninnavihrs@@hotmail.dk}
#'
#' @seealso \code{\link{MIk}} for mutual
#' information of k variables.
#'
#' @examples
#' set.seed(43)
#' var1 <- c(sample(c(1, 2), 100, replace = TRUE))
#' var2 <- var1 + c(sample(c(1, 2), 100, replace = TRUE))
#' var3 <- var1 + c(sample(c(0, 1), 100, replace = TRUE,
#'                         prob = c(0.9, 0.1)))
#' var4 <- c(sample(c(1, 2), 100, replace = TRUE))
#' var5 <- var2 + var3
#' var6 <- var1 - var4 + c(sample(c(1, 2), 100, replace = TRUE))
#' var7 <- c(sample(c(1, 2), 100, replace = TRUE))
#'
#' data <- data.frame("var1" = as.character(var1),
#'                    "var2" = as.character(var2),
#'                    "var3" = as.character(var3),
#'                    "var4" = as.character(var4),
#'                    "var5" = as.character(var5),
#'                    "var6" = as.character(var6),
#'                    "var7" = as.character(var7))
#'
#' # smooth used in MIk
#' (tch <- k_tcherry_step(data, 3, smooth = 0.1))
#' @export

k_tcherry_step <- function(data, k, ...){

  if (any(is.na(data))){
    warning(paste("The data contains NA values.",
                  "Theese will be excluded from tables,",
                  "which may be problematic.",
                  "It is highly recommended to manually take",
                  "care of NA values before using the data as input.",
                  sep = " "))
  }

  if (! (is.data.frame(data) | is.matrix(data))) {
    stop("data must be a data frame or a matrix.")
  }

  data <- as.data.frame(data)

  if (! all(sapply(data, function(x){
    is.character(x) | is.factor(x)
  }
  ))){
    stop("Some columns are not characters or factors.")
  }

  if (length(k) != 1){
    stop("k must be a single positive integer.")
  }

  if (k %% 1 != 0 | k <= 1){
    stop("k must be a positive integer and at least 2.")
  }

  nodes <- colnames(data)
  n_var <- length(nodes)
  adj_matrix <- matrix(0, nrow = n_var, ncol = n_var,
                       dimnames = list(nodes, nodes))
  cliques <- as.list(rep(NA, n_var - (k - 1)))
  separators <- as.list(rep(NA, n_var - k))

  # Adding first cherry.
  poss_cliq <- utils::combn(nodes, k)
  poss_cliq <- split(poss_cliq, rep(1:ncol(poss_cliq),
                                    each = nrow(poss_cliq)))

  MI <- sapply(poss_cliq, MIk, data = data, ...)
  idx.max <- which.max(MI)

  tcherry_nodes <- poss_cliq[[idx.max]]
  cliques[[1]] <- tcherry_nodes
  nodes_remaining <- setdiff(nodes, tcherry_nodes)
  tcherry_hyperedges <- utils::combn(tcherry_nodes, k-1)
  tcherry_hyperedges <- split(tcherry_hyperedges,
                              rep(1:ncol(tcherry_hyperedges),
                                  each = nrow(tcherry_hyperedges)))

  adj_matrix[tcherry_nodes, tcherry_nodes] <- 1
  diag(adj_matrix[tcherry_nodes, tcherry_nodes]) <- 0

  weight <- max(MI)

  # Adding remaining cherries.
  idx.dat <- 1
  idx.list <- 1

  n_nodes_remaining_median <- floor((n_var - k) / 2 + 1)
  n_hyp_edges_median <- 1 + (k - 1) * ((n_var - n_nodes_remaining_median)
                                       - (k - 1))

  weight_cliq_sep <- MI_cliq <- MI_sep <- new_var <-
    rep(NA, n_nodes_remaining_median * n_hyp_edges_median)

  new_cliques_list <- new_seps_list <-
    as.list(weight_cliq_sep)

  dat_new_poss <- data.frame(new_cliq = I(new_cliques_list),
                             new_sep = I(new_seps_list),
                             new_var = new_var,
                             MI_cliq = MI_cliq,
                             MI_sep = MI_sep,
                             weight_increase = weight_cliq_sep)

  while(length(tcherry_nodes) != n_var){
  for (i in 1:length(tcherry_hyperedges)){
    for (var in nodes_remaining){
      new_sep <- tcherry_hyperedges[[i]]
      dat_new_poss$new_sep[[idx.dat]] <- new_sep
      new_cliq <- c(dat_new_poss$new_sep[[idx.dat]], var)
      dat_new_poss$new_cliq[[idx.dat]] <- new_cliq

      dat_new_poss$MI_sep[idx.dat] <- MIk(new_sep, data, ...)
      dat_new_poss$MI_cliq[idx.dat] <- MIk(new_cliq, data, ...)

      dat_new_poss$weight_increase[idx.dat] <-
        dat_new_poss$MI_cliq[idx.dat] - dat_new_poss$MI_sep[idx.dat]

      dat_new_poss$new_var[[idx.dat]] <- var
      idx.dat <- idx.dat + 1
    }
  }

  idx_max_weight <- which.max(dat_new_poss$weight_increase)
  weight <- weight + dat_new_poss$weight_increase[idx_max_weight]

  new_clique <- dat_new_poss$new_cliq[[idx_max_weight]]
  new_sep <- dat_new_poss$new_sep[[idx_max_weight]]
  new_var <- dat_new_poss$new_var[idx_max_weight]

  tcherry_nodes <- c(tcherry_nodes, new_var)
  nodes_remaining <- setdiff(nodes, tcherry_nodes)

  cliques[[idx.list + 1]] <- new_clique
  separators[[idx.list]] <- new_sep
  idx.list <- idx.list + 1

  new_hyper_edges <- utils::combn(new_clique, k - 1)
  new_hyper_edges <- split(new_hyper_edges,
                           rep(1:ncol(new_hyper_edges),
                               each = nrow(new_hyper_edges)))
  idx.new <- sapply(new_hyper_edges, function(e){
    length(setdiff(e, new_var)) != k - 1
  })
  tcherry_hyperedges <- new_hyper_edges[idx.new]

  adj_matrix[new_clique, new_clique] <- 1
  diag(adj_matrix[new_clique, new_clique]) <- 0

  idx.new.var <- dat_new_poss$new_var == new_var &
    !is.na(dat_new_poss$new_var)
  dat_new_poss[idx.new.var, ] <- NA
  ord <- order(dat_new_poss$weight_increase, decreasing = TRUE)
  dat_new_poss <- dat_new_poss[ord, ]

  idx.dat <- which(is.na(dat_new_poss$new_var))[1]
  }

  n_edges_graph <- sum(adj_matrix) / 2

  return(list("adj_matrix" = adj_matrix,
              "weight" = weight,
              "cliques" = cliques,
              "separators" = separators,
              "n_edges" = n_edges_graph))
}
nvihrs14/tcherry documentation built on Aug. 1, 2020, 6:25 p.m.