R/distance.R

Defines functions GetDistance FilterSplits GroupDistance InternalDistance group_distance internal_distance single_linkage_distance complete_linkage_distance average_linkage_distance centroid_linkage_distance ward_distance mahalanobis_distance internal_linkage_distance internal_mahalanobis_distance PlotGroupMeans plot_distance PlotDistance AverageDistance PlotAverageDistance PlotInternalDistance TotalDistance PlotTotalDistance group_mean standard_error average_distance total_distance total_linkage cell_total_linkage RunCramer

Documented in average_distance AverageDistance average_linkage_distance cell_total_linkage centroid_linkage_distance complete_linkage_distance FilterSplits GetDistance group_distance GroupDistance group_mean internal_distance InternalDistance internal_linkage_distance internal_mahalanobis_distance mahalanobis_distance PlotAverageDistance plot_distance PlotDistance PlotGroupMeans PlotInternalDistance PlotTotalDistance RunCramer single_linkage_distance standard_error total_distance TotalDistance total_linkage ward_distance

#####################################################
# Seurat Distance Methods
#####################################################


#' Get distance data from seurat object
#'
#' @description
#' Simple getter method to extract data frame of distances in seurat object.
#'
#' @details
#' This does not calcualte the distances for any categories or cells. This is
#' purely for indexing information from an object.
#'
#' @param seurat_object A seurat object. Contains the distance data.
#' @param group Category that the distances were calculated between
#' @param split_by Factor that the distances were split by
#' @param method The distance function that was used. Options are "default" or
#' "cluster" (GroupDistance()), "average" (AverageDistance()), "total"
#' (TotalDistance), "internal" (InternalDistance)
#'
#' @return Distance data frame.
#' @export
GetDistance <- function(seurat_object,
                               group = NULL,
                               split_by = NULL,
                               method = "default") {
  #  Correct name
  if (method == "average") {
    name = "average_distance"
  } else if (method == "total") {
    name = "total_distance"
  } else if (method == "default" | method == "cluster") {
    name = "cluster_distance"
  } else if (method == "internal") {
    name = "internal_distance"
  }

  #Choose the correct table
  if (method == "internal") {
    name <- paste(name, group, split_by, sep = "_")
  }  else {
    if (!is.null(group)) {
      name <- paste(name, group, sep = "_")
    }
    if (!is.null(split_by)) {
      name <- paste(name, split_by, sep = "_")
    }
  }


  # get the data
  data <- seurat_object@tools[[name]]

  if (is.null(data)) {
    stop(paste("There is no data stored at the location", name, sep = " "))
  }
  return(seurat_object@tools[[name]])
}


#' Filters out split factors that are not present in all of your groups.
#'
#' @description
#' If you attempt to run the distances on groups split by a second factor you
#' will get an error if you the split factors are not present in each group.
#' One can imagine if your groups were treatment groups and split was by cell
#' type. If your treatment groups did not contain all the cell types this will
#' throw an error. This function will remove any factors of the split that are
#' not present in all groups.
#'
#' @inheritParams GroupDistance
#'
#' @return The input seurat object with irrelevant cells filtered out.
#' @export
FilterSplits <- function(seurat_object,
                         group = "default",
                         split_by = NULL) {
    allowable_levels <- levels(factor(seurat_object@meta.data[[split_by]]))

    # Find all the factors that are present in splitby for each of the groups
    group_levels <- levels(as.factor(as.matrix(seurat_object[[group]])))

    # for group in dataset
    for (i in 1:length(group_levels)) {
        # find the splits that are present in that group
        cells <- seurat_object@meta.data[seurat_object@meta.data[[group]] == group_levels[i],]
        # what levels does their split_by have
        split_levels <- levels(factor(cells[[split_by]]))
        allowable_levels <- allowable_levels[(allowable_levels %in% split_levels)]
    }

    # Filter seurat objects that do not have allowable levels
    keep <- as.matrix(seurat_object[[split_by]]) %in% allowable_levels
    seurat_object[["keep_cells"]] <- keep
    seurat_object <- subset(seurat_object,
                            subset = keep_cells == TRUE)
    seurat_object[["keep_cells"]] <- NULL
    return(seurat_object)
}

#' Calculates the distance between Categories of single cells
#'
#' @description
#' Takes in an Seurat s4 object with categorically labelled cells and calculates
#' the 'distance' between the cells in each category through a choice of
#' different methods. In this method each category is defined as a cluster and
#' hierarchical clustering methods are used to quantify the distances between
#' the clusters.
#'
#' This includes multiple methods to measure distances
#' between the groups/cateories themselves(centroid vs average
#' distance) as well as more fundementally in terms of the distance between
#' the points (euclidean vs manhattan distsance).
#'
#' @param seurat_object A seurat object
#' @param group Seurat Categories or groups for which the distances between
#' are calculated. Cell wise data.
#' @param reduction Dimensionality reduction data to use
#' @param dims Which dimensions to use
#' @param method Cluster Distance methods to use. Options are "single",
#' "complete", "average", "centroid", "ward", "mahalanobis". Further explanation
#' for these methods is given in the details.
#' @param distance Point to point distance to use. Options are "euclidean",
#' "maximum", "manhattan", "canberra", "binary", "minkowski"
#' @param split_by Second seurat category to split the calculations over.
#' @param output Output type. Default is "seurat". Options are "seurat" (Seurat
#' S4 object), "table" (Table of distance data), "list" (List of tables by
#' split_by), "seurat_list" (List of Seurat objects). Seurat objects returned
#' have the distance data stored internally.
#'
#' @details
#' It is possible if calculating distances split by a factor to input a list of
#' Seurat S4 objects.
#'
#' Cluster distance methods available for method:
#'
#'
#' \describe{
#'   \item{single}{Shortest between any two points in each cluster}
#'   \item{compelte}{Longest distance between any two points in each cluster}
#'   \item{average}{The average distance between any two points in each cluster}
#'   \item{centroid}{The distance between the centroid of each cluster}
#'   \item{ward}{Distance is defined as the increase in variance if two clusters
#'     are merged}
#'   \item{mahalanobis}{Similar to Mahalanobis distance. Calculates distance
#'     between means of each cluster, weighted by the covariance matricies.}
#' }
#' Point distance methods available for distance:
#'
#'\describe{
#'   \item{euclidean}{Usual distance between the two vectors sqrt(sum((x_i - y_i)^2))}
#'   \item{maximum}{Maximum distance between two components of x and y}
#'   \item{manhattan}{Absolute distance between two vectors}
#'   \item{canberra}{sum(|x_i - y_i| / (|x_i| + |y_i|))}
#'   \item{binary}{The vectors are regarded as binary bits, so non-zero elements
#'     are ‘on’ and zero elements are ‘off’. The distance is the proportion of
#'     bits in which only one is on amongst those in which at least one is on.}
#'   \item{minkowski}{The p norm, the pth root of the sum of the pth powers of
#'   the differences of the components.}
#'}
#' @return Returns distance data based on output parameter.
#' @export
GroupDistance <- function(seurat_object,
                                  group,
                                  reduction = "pca",
                                  dims = 1:30,
                                  method,
                                  distance = "euclidean",
                                  split_by = NULL,
                                  output = "seurat") {
    # Does the object need to be split
    if (typeof(seurat_object) == "S4") {
        if (is.null(split_by)) {
          output.object <- group_distance(seurat_object,
                                         reduction = reduction,
                                          group = group,
                                          dims= dims,
                                          method = method,
                                          distance = distance,
                                          output = output)
          return(output.object)
        }
        seurat_list <- Seurat::SplitObject(seurat_object, split.by = split_by)

        # If the group labels are numeric, ensure order
        tryCatch({
          numeric.names <- as.numeric(names(seurat_list))
          seurat_list <- seurat_list[order(as.numeric(names(seurat_list)))]
          }, warning = function(c){
          message("Was not able to order the split_by")
        })
    } else {
      seurat_list <- seurat_object
    }

    result <- list()

    # Check output type
    if (output == "seurat list") {
      output.type <- "seurat"
    } else {
      output.type <- "table"
    }

    for (i in 1:length(seurat_list)) {
        result[[i]] <- group_distance(seurat_list[[i]],
                                        reduction = reduction,
                                        group = group,
                                        dims= dims,
                                        method = method,
                                        distance = distance,
                                        output = output.type)
    }

    # Format output types
    if (output == "list") {
      names(result) <- names(seurat_list)
    } else if (output == "table") {
      table_result <- data.frame()

      # Create a new column and attribute it to the list group
      for (i in 1:length(result)) {
        result[[i]]$split.by <- names(seurat_list)[i]
        table_result <- rbind(table_result, result[[i]])
      }
      # Make it the first column
      split.by.index <- length(colnames(table_result))
      result <- table_result[,c(split.by.index, 2:split.by.index-1)]
    } else if (output == "seurat list") {
      # don't do anything
    } else if (output == "seurat") {
      # Create a complete table
      table_result <- data.frame()

      # Create a new column and attribute it to the list group
      for (i in 1:length(result)) {
        result[[i]]$split.by <- names(seurat_list)[i]
        table_result <- rbind(table_result, result[[i]])
      }
      # Make it the first column
      split.by.index <- length(colnames(table_result))
      result <- table_result[,c(split.by.index, 2:split.by.index-1)]
      location <- paste("cluster_distance", group, split_by, sep = "_")
      seurat_object@tools[[location]] <- result
      result <- seurat_object
    }else {
        stop("Choose output = 'table', 'list' or 'seurat'")
    }
    return(result)
}

#' Calculates internal distance
#'
#'
#' @inheritParams GroupDistance
#' @param output Whether to output as a seurat object (with data stored within)
#' or as a table containing the information.
#'
#' @return Table or Seurat object with internal distance information
#'
#' @details
#' For more information about the implementation of this method see the other
#' methods in this function family.
#'
#' @family internal_distance
#' @export
InternalDistance <- function(seurat_object,
                             group= NULL,
                             reduction = "pca",
                             dims = 1:30,
                             method = "linkage",
                             distance = "euclidean",
                             split_by = NULL,
                             output = "seurat") {
  # Does the object need to be split
  if (typeof(seurat_object) == "S4") {
    if (is.null(split_by)) {
      output.object <- internal_distance(seurat_object,
                                      reduction = reduction,
                                      group = group,
                                      dims= dims,
                                      method = method,
                                      distance = distance,
                                      output = output)
      return(output.object)
    }
    seurat_list <- Seurat::SplitObject(seurat_object, split.by = split_by)

    # If the group labels are numeric, ensure order
    tryCatch({
      numeric.names <- as.numeric(names(seurat_list))
      seurat_list <- seurat_list[order(as.numeric(names(seurat_list)))]
    }, warning = function(c){
      message("Was not able to order the split_by")
    })
  } else {
    seurat_list <- seurat_object
  }

  result <- list()

  # Check output type
  if (output == "seurat list") {
    output.type <- "seurat"
  } else {
    output.type <- "table"
  }

  for (i in 1:length(seurat_list)) {
    result[[i]] <- internal_distance(seurat_list[[i]],
                                  reduction = reduction,
                                  group = group,
                                  dims= dims,
                                  method = method,
                                  distance = distance,
                                  output = output.type)
  }

  # Format output types
  if (output == "list") {
    names(result) <- names(seurat_list)
  } else if (output == "table") {
    table_result <- data.frame()

    # Create a new column and attribute it to the list group
    for (i in 1:length(result)) {
      result[[i]]$split.by <- names(seurat_list)[i]
      table_result <- rbind(table_result, result[[i]])
    }
    # Make it the first column
    split.by.index <- length(colnames(table_result))
    result <- table_result[,c(split.by.index, 2:split.by.index-1)]
  } else if (output == "seurat list") {
    # don't do anything
  } else if (output == "seurat") {
    # Create a complete table
    table_result <- data.frame()

    # Create a new column and attribute it to the list group
    for (i in 1:length(result)) {
      result[[i]]$split.by <- names(seurat_list)[i]
      table_result <- rbind(table_result, result[[i]])
    }
    # Make it the first column
    split.by.index <- length(colnames(table_result))
    result <- table_result[,c(split.by.index, 2:split.by.index-1)]
    location <- paste("internal_distance", group, split_by, sep = "_")
    seurat_object@tools[[location]] <- result
    result <- seurat_object
  }else {
    stop("Choose output = 'table', 'list' or 'seurat'")
  }
  return(result)
}

###################################################
# Distance Helper Functions
###################################################

#' Calculates the distance between groups in seurat object
#'
#' @inheritParams GroupDistance
#'
#' @description
#' Calculates the distance between the closet pair of points between two
#' clusters.
#'@param output Output type. Default is "seurat" (Seurat S4 object), "table"
#'(Table of distance data).
#'
#' @return Distance data
group_distance <- function(seurat_object,
                          group,
                          reduction = "pca",
                          dims = 1:30,
                          method,
                          distance = "euclidean",
                          output = "table") {
  # Ensure group is a factor
  seurat_object[[group]] <- factor(as.matrix(seurat_object[[group]]))

  # Correcting number of dimensions if too many were chosen
  if (length(dims) > ncol(seurat_object@reductions[[reduction]]@cell.embeddings)) {
    dims <- 1:ncol(seurat_object@reductions[[reduction]]@cell.embeddings)
  }

  # Number of groups to test
  group_count <- length(unique(seurat_object[[group]][,1]))

  # Correct method data type
  if (length(method) == 1){
    method <- c(method)
  }

  # Output data frame
  result <- data.frame(matrix(ncol = 2 + length(method), nrow = 0))
  column_names <- c("start", "destination")


  for (i in 1:group_count) {
    # Isolate rows belonging to group 1
    group_1_name <- unique(as.matrix(seurat_object[[group]][,1]))[i]
    group_1_key <- seurat_object[[group]][,1] == group_1_name
    group_1 <- seurat_object@reductions[[reduction]]@cell.embeddings[group_1_key,][,dims]

    for (j in 1:group_count) {
      if (j <= i) {
        next
      }
      # Isolate rows belonging to group 2
      group_2_name <- unique(as.matrix(seurat_object[[group]][,1]))[j]
      group_2_key <- seurat_object[[group]][,1] == group_2_name
      group_2 <- seurat_object@reductions[[reduction]]@cell.embeddings[group_2_key,][,dims]

      # New line vector
      new_line <- c(group_1_name, group_2_name)

      # Calculate the single linkage distance
      if ("single" %in% method) {
        new_line <- c(new_line, single_linkage_distance(group_1, group_2, distance))
      }

      # Calculate complete linkage
      if ("complete" %in% method) {
        new_line <- c(new_line, complete_linkage_distance(group_1, group_2, distance))
      }

      # Calculate average linkage
      if ("average" %in% method) {
        new_line <- c(new_line, average_linkage_distance(group_1, group_2, distance))
      }

      # Calculate centroid linkage
      if ("centroid" %in% method) {
        new_line <- c(new_line, centroid_linkage_distance(group_1, group_2, distance))
      }

      # Calculate wards method
      if ("ward" %in% method) {
        new_line <- c(new_line, ward_distance(group_1, group_2, distance))
      }

      # Calculate Mahalanobis method
      if ("mahalanobis" %in% method) {
        new_line <- c(new_line, mahalanobis_distance(group_1, group_2, distance))
      }
      result <- rbind(result, new_line)
    }
  }
  if ("single" %in% method) {
    column_names <- c(column_names, "single")
  }
  if ("complete" %in% method) {
    column_names <- c(column_names, "complete")
  }
  if ("average" %in% method) {
    column_names <- c(column_names, "average")
  }
  if ("centroid" %in% method) {
    column_names <- c(column_names, "centroid")
  }
  if ("ward" %in% method) {
    column_names <- c(column_names, "ward")
  }
  if ("mahalanobis" %in% method) {
    column_names <- c(column_names, "mahalanobis")
  }
  colnames(result) <- column_names

  # Output format

  if (output == "table") {
    return(result)
  } else if (output == "seurat") {
    seurat_object@tools[[paste("cluster_distance", group, sep = "_")]] <- result
    return (seurat_object)
  } else {
    stop("Please use the output parameters of either 'table', or 'seurat'.")
  }
}

#' Calculate the distance of all the points within a single cluster to the
#' centre of that cluster
#'
#' @description
#' Enables calculating internal mean, sd and se of the distance that
#' points in a cluster to the centre of a cluster.
#'
#' @details
#' This is an internal helper method for the
#' function \code{\link{InternalDistance}}.
#'
#' @seealso \code{\link{InternalDistance}}
#'
#' @inheritParams InternalDistance
#'
#' @family internal_distance
#'
#' @return table of mean, sd and se of internal cluster distances
#'
internal_distance <- function(seurat_object,
                              group = NULL,
                              reduction = "pca",
                              dims = 1:30,
                              method,
                              distance = "euclidean",
                              output = "table") {
  # Correcting number of dimensions if too many were chosen
  if (length(dims) > ncol(seurat_object@reductions[[reduction]]@cell.embeddings)) {
    dims <- 1:ncol(seurat_object@reductions[[reduction]]@cell.embeddings)
    message(paste("Too many dimensions were chosen now reduced it down to", dims))
  }

  # Correct method data type
  if (length(method) == 1){
    method <- c(method)
  }


  # If required split, based on group.
  if(is.null(group)) {
    # Output data frame
    result <- data.frame(matrix(ncol = 4, nrow = 0))
    colnames(result) <- c("method", "mean", "sd", "se")


    # Isolate all the cell embeddings
    embedding_matrix <- Embeddings(seurat_object, reduction = reduction)[,dims]

    # Create list of potential output data
    distance.data <- vector(mode = "list")

    valid_method <- 0
    if ("linkage" %in% method) {
      valid_method <- 1
      distance.data[["linkage"]] <- internal_linkage_distance(embedding_matrix,
                                             distance = distance,
                                             output = "raw")
    }
    if ("mahalanobis" %in% method) {
      valid_method <- 1
      distance.data[["mahalanobis"]] <- internal_mahalanobis_distance(embedding_matrix,
                                                 distance,
                                                 output = "raw")
    }
    if (valid_method == 0) {
      stop("Please use a valid internal distance method. Either 'linkage' or
           'mahalanobis'.")
    }
    # Combine the results into a vector which can be added to the list.
    for(i in 1:length(distance.data)) {
      newline <- data.frame(method = names(distance.data[i]),
                           mean = mean(distance.data[[i]]),
                           sd = stats::sd(distance.data[[i]]),
                           se = stats::sd(distance.data[[i]])/sqrt(length(distance.data[[i]])))
      # Add the newline to the data frame
      result <- rbind(result, newline)
    }
  } else {
    # Ensure group is a factor
    seurat_object[[group]] <- factor(as.matrix(seurat_object[[group]]))

    # Number of groups to test
    unique_groups <- unique(seurat_object[[group]][,1])

    # Output data frame
    result <- data.frame(matrix(ncol = 5, nrow = 0))
    colnames(result) <- c("group", "method", "mean", "sd", "se")

    for (i  in 1:length(unique_groups)) {

      # Get embeddings for that group.
      group_key <- as.matrix(seurat_object[[group]]) == unique_groups[i]
      group_embeddings <- Embeddings(seurat_object, reduction = reduction)[group_key,]

      # Create list of potential output data
      distance.data <- vector(mode = "list")

      valid_method <- 0
      if ("linkage" %in% method) {
        valid_method <- 1
        distance.data[["linkage"]] <- internal_linkage_distance(group_embeddings,
                                                                distance = distance,
                                                                output = "raw")
      }
      if ("mahalanobis" %in% method) {
        valid_method <- 1
        distance.data[["mahalanobis"]] <- internal_mahalanobis_distance(group_embeddings,
                                                                        distance,
                                                                        output = "raw")
      }
      if (valid_method == 0) {
        stop("Please use a valid internal distance method. Either 'linkage' or
           'mahalanobis'.")
      }

      # Combine the results into a vector which can be added to the list.
      for(j in 1:length(distance.data)) {
        # Add the group
        newline <- data.frame(group = unique_groups[i],
                              method = names(distance.data[j]),
                              mean = mean(distance.data[[j]]),
                              sd = stats::sd(distance.data[[j]]),
                              se = stats::sd(distance.data[[j]])/sqrt(length(distance.data[[j]])))

        # Add the newline to the data frame
        result <- rbind(result, newline)
      }
    }
  }
  # Outputs
  if (output == "table") {
    return(result)
  } else if (output == "seurat") {
    location <- paste("internal_distance", group, sep = "_")
    seurat_object@tools[[location]] <- result
    return (seurat_object)
  } else {
    stop("Please use the output parameters of either 'table', or 'seurat'.")
  }
}

#' Calculates the single linkage distance between two clusters
#'
#' @description
#' Uses data from two matricies, representing a cluster each and calculates
#' the single linkage distance between each cluster. The single linkage
#' distance is the shortest distance between any two points between the
#' two matricies.
#'
#' @param matrix_1 Data for cluster 1. Cell by reduced dimension components
#' @param matrix_2 Data for cluster 2 . Cell by reduced dimension components
#' @param distance Point to point distance method to use. e.g. 'euclidean'.
#'
#' @family cluster_distance
#'
#' @return Real positive number representing the distance
single_linkage_distance <- function(matrix_1,
                                    matrix_2,
                                    distance = "euclidean") {
    # smallest distance
    smallest_distance <- .Machine$double.xmax

    # Progress Bar
    total <- nrow(matrix_1)
    pb <- progress::progress_bar$new(format = ":what [:bar] :current/:total (:percent) eta: :eta",
                           total = total)
    pb$tick(0)

    for (i in 1:nrow(matrix_1)) {
        cell_1 <- matrix_1[i,]
        pb$tick(tokens = list(what = "single linkage"))
        for (j in 1:nrow(matrix_2)) {
            cell_2 <- matrix_2[j,]
            pair <- rbind(cell_1, cell_2)
            pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]
            if (pair_distance < smallest_distance) {
                smallest_distance <- pair_distance
            }
        }
    }
    return(smallest_distance)
}

#' Calculates the complete linkage distance between two clusters
#'
#' @description
#' Uses data from two matricies, representing a cluster each and calculates
#' the complete linkage distance between each cluster.
#'
#' The single linkage distance is the furthest distance between any two
#' points between the
#' two matricies.
#'
#' @inheritParams single_linkage_distance
#'
#' @family cluster_distance
#'
#' @return Real positive number representing the distance
complete_linkage_distance <- function(matrix_1,
                                      matrix_2,
                                      distance = "euclidean") {
    # largest distance
    largest_distance <- 0

    # Progress Bar
    total <- nrow(matrix_1)
    pb <- progress::progress_bar$new(format = ":what [:bar] :current/:total (:percent) eta: :eta",
                           total = total)
    pb$tick(0)

    for (i in 1:nrow(matrix_1)) {
        cell_1 <- matrix_1[i,]
        pb$tick(tokens = list(what = "complete linkage"))
        for (j in 1:nrow(matrix_2)) {
            cell_2 <- matrix_2[j,]
            pair <- rbind(cell_1, cell_2)
            pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]
            if (pair_distance > largest_distance) {
                largest_distance <- pair_distance
            }
        }
    }
    return(largest_distance)
}

#' Calculates the average linkage distance between two clusters
#'
#' @inheritParams single_linkage_distance
#' @param output Whether to output the mean of all the distances or a vector
#' of all the distances. Options are "mean" or "raw".
#'
#' @family cluster_distance
#'
#' @return Real positive number representing the distance. Or vector of all
#' distances
average_linkage_distance <- function(matrix_1,
                                     matrix_2,
                                     distance = "euclidean",
                                     output = "mean") {
    # Distances for cluster
    cluster_distance <- c()

    # progress bar
    total <- nrow(matrix_1)
    pb <- progress::progress_bar$new(format = ":what [:bar] :current/:total (:percent) eta: :eta",
                           total = total)
    pb$tick(0)

    for (i in 1:nrow(matrix_1)) {
        # Distances for each cell
        cell_distances <- c()

        cell_1 <- matrix_1[i,]

        # Progress bar
        pb$tick(tokens = list(what = "average linkage"))

        # loop through every cell in cluster 2
        for (j in 1:nrow(matrix_2)) {
            cell_2 <- matrix_2[j,]
            pair <- rbind(cell_1, cell_2)
            pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]
            cell_distances <- c(cell_distances, pair_distance)
        }
        cluster_distance <- c(cluster_distance, mean(cell_distances))
    }
    if (output == "mean") {
        return(mean(cluster_distance))
    } else if (output == "raw"){
        return(cluster_distance)
    } else {
        stop("Please choose output of 'mean' or 'raw'")
    }
}

#' Calculates the centroid linkage distance between two clusters
#'
#' @description
#' Uses data from two matricies, representing a cluster each and calculates
#' the centroid linkage distance between each cluster.
#'
#' The centroid linkage distance between the centroid of the two repspective
#' clusters.
#'
#' @inheritParams single_linkage_distance
#'
#' @family cluster_distance
#'
#' @return Real positive number representing the distance.
centroid_linkage_distance <- function(matrix_1,
                                      matrix_2,
                                      distance = "euclidean") {
    # Calculate column means
    mean_1 <- colMeans(matrix_1)
    mean_2 <- colMeans(matrix_2)

    #Calculate distance between two means
    pair <- rbind(mean_1, mean_2)
    pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]

    return(pair_distance)
}

#' Calculates the ward distance between two clusters
#' @description
#' Uses data from two matricies, representing a cluster each and calculates
#' the ward distance between each cluster.
#'
#' The ward distance is the increase in variance that results in merging
#' two clusters. The greater the variance cost of merging clusters, the
#' greater the difference between the two clusters.
#'
#' @inheritParams single_linkage_distance
#'
#' @return Real positive number representing the distance.
#'
#' @family cluster_distance
ward_distance <- function(matrix_1,
                          matrix_2,
                          distance = "euclidean") {
    # Number of cells per group
    n_1 <- nrow(matrix_1)
    n_2 <- nrow(matrix_2)

    # Cluster means
    mean_1 <- colMeans(matrix_1)
    mean_2 <- colMeans(matrix_2)

    # Mean distance
    pair <- rbind(mean_1, mean_2)
    pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]
    mean_variance <- pair_distance^2

    # Ward formula
    output <- (n_1 * n_2) / (n_1 + n_2) * mean_variance

    return(output)
}

#' Calculates the mahalanobis distance between two clusters
#'
#' @description
#' Uses data from two matricies, representing a cluster each and calculates
#' the mahalanobis distance between each cluster.
#'
#' The mahalanobis distance is the distance between the means of two
#' clusters, scaled by the covariance of the two groups.
#'
#' @inheritParams single_linkage_distance
#'
#' @return Real positive number representing the distance.
#'
#' @family cluster_distance
mahalanobis_distance <- function(matrix_1,
                                 matrix_2,
                                 distance = "euclidean") {
    # Covariance matricies
    cov_1 <- stats::cov(matrix_1)
    cov_2 <- stats::cov(matrix_2)
    S <- solve(cov_1 + cov_2)

    # Means
    mean_1 <- colMeans(matrix_1)
    mean_2 <- colMeans(matrix_2)
    mean_distance <- mean_1 - mean_2

    # Mahalanobis formula
    output <- t(mean_distance) %*% S %*% mean_distance
    return(output[1,1])
}

#' Internal distance points in a cluster and the cluster centroid.
#'
#' @description
#' Calculates the distance between every point in a cluster and the mean/centroid
#' of that cluster. Can either output an entire list of all the distances to
#' the centroid or a mean of the distances. Used to measure how dispersed
#' all the points in a cluster are.
#'
#' @inheritParams single_linkage_distance
#' @param output Can be either "mean" or "raw". mean will return the mean of all
#' the distances from the centroid. Raw will return a vector of all the distances
#' from every point to the centroid.
#'
#' @family internal_distance
#'
#' @export
#'
#' @return Mean or a vector of all the distsances to the centroid.
internal_linkage_distance <- function(matrix_1,
                                      distance = "euclidean",
                                      output = "mean") {
  # Calculate centroid
  mean_1 <- colMeans(matrix_1)

  # progress bar
  total <- nrow(matrix_1)
  pb <- progress::progress_bar$new(format = ":what [:bar] :current/:total (:percent) eta: :eta",
                                   total = total)
  pb$tick(0)

  cell_distances <- c()

  for (i in 1:nrow(matrix_1)) {
    # Distances for each cell
    cell_1 <- matrix_1[i,]

    # Progress bar
    pb$tick(tokens = list(what = "internal linkage"))

    pair <- rbind(cell_1, mean_1)
    pair_distance <- as.matrix(stats::dist(pair, method = distance))[1,2]
    cell_distances <- c(cell_distances, pair_distance)
  }

  # Outputs
  if (output == "mean") {
    return(mean(cell_distances))
  } else if (output == "raw") {
    return(cell_distances)
  } else {
    stop("Output parameter invalid. Please use either 'mean' or 'raw'")
  }
}

#' Mahalanobis distance of points in a cluster and the cluster centroid.
#'
#' @description
#' Calculates the distance between every point in a cluster and the mean/centroid
#' of that cluster. The mahalanobis distance can be interpreted as how many
#' standard deviations away is the point from the mean. The value also takes
#' into account dimensions where there is not much variability, resulting in a
#' larger distance value if points vary in the aforementioned dimensions.
#'
#' @details
#' Can either output an entire list of all the distances to
#' the centroid or a mean of the distances. Used to measure how dispersed
#' all the points in a cluster are.
#'
#' @inheritParams single_linkage_distance
#' @param output Can be either "mean" or "raw". mean will return the mean of all
#' the distances from the centroid. Raw will return a vector of all the distances
#' from every point to the centroid.
#'
#' @family internal_distance
#'
#' @import matlib
#'
#' @export
#' @return Mahalanobis distance
internal_mahalanobis_distance <- function(matrix_1,
                                      distance = "euclidean",
                                      output = "mean") {
  # Calculate centroid
  mean_1 <- colMeans(matrix_1)
  S <- matlib::inv(stats::cov(matrix_1))

  # progress bar
  total <- nrow(matrix_1)
  pb <- progress::progress_bar$new(format = ":what [:bar] :current/:total (:percent) eta: :eta",
                                   total = total)
  pb$tick(0)

  cell_distances <- c()
  for (i in 1:nrow(matrix_1)) {
    # Progress bar
    pb$tick(tokens = list(what = "internal linkage"))

    # Distances for each cell
    cell_1 <- matrix_1[i,]
    pair_distance <- as.matrix(cell_1 - mean_1)
    man_distance <- (t(pair_distance) %*% S %*% pair_distance)[1,1]
    cell_distances <- c(cell_distances, man_distance)
  }

  # Outputs
  if (output == "mean") {
    return(mean(cell_distances))
  } else if (output == "raw") {
    return(cell_distances)
  } else {
    stop("Output parameter invalid. Please use either 'mean' or 'raw'")
  }
}

###################################################
# Seurat Distance Visualisation Methods
###################################################

#' Plots the groups as well as their cluster means
#'
#' @param seurat_object Object or a list of Seurat objects containing data
#' @param group Seurat cell wise categories to calculate the pairwise distances between
#' @param reduction Dimensionality reduction data to use
#' @param split_by Categorical variable to split the data between
#' @param scales Force X and Y axis to a certain range
#'
#' @return ggplot dot plot of cells in reduced dimension space.
#' @export
#'
#' @family distance_plot
PlotGroupMeans <- function(seurat_object,
                           group,
                           reduction,
                           split_by = NULL,
                           scales = NULL) {
  # Isolate the cell data
  cell.data <-
    as.data.frame(seurat_object@reductions[[reduction]]@cell.embeddings[,1:2])
  colnames(cell.data) <- c("comp_1", "comp_2")

  # calculate the means of cell data
  means <- group_mean(seurat_object, group, reduction, dims = 1:2)
  colnames(means) <- c("comp_1", "comp_2")

  # Create a plot with the data
  p <- ggplot2::ggplot(cell.data, ggplot2::aes(x = cell.data$comp_1,
                                               y = cell.data$comp_2)) +
    ggplot2::theme_bw() +
    ggplot2::geom_point(alpha = 0.7,ggplot2::aes(color = as.matrix(seurat_object[[group]]))) +
    ggplot2::geom_point(data = means,size = 5, color = "black", pch=21,
                        ggplot2::aes(x = means$comp_1, y = means$comp_2, fill = rownames(means))) +
    ggplot2::labs(color = group, fill = "Group Means")

  if (!is.null(scales)) {
    p <- p +
      ggplot2::scale_x_continuous(limits=scales) +
      ggplot2::scale_y_continuous(limits=scales)
  }
  return(p)
}

#' Plots the distance between two groups using a number of different methods.
#'
#' @inheritParams PlotDistance
#'
#' @return ggplot of multiple methods
plot_distance <- function(seurat_object,
                          group,
                          group_start,
                          group_destination,
                          method = "all") {
  data <- GetDistance(seurat_object, group = group)
  if (is.null(data)) {
    stop("This data does not exist. Run GroupDistance() with the same parameters
         first before trying to plot them.")
  }

  # Select for start and destination
  data <- data[(data$start == group_start & data$destination == group_destination) |
                 (data$start == group_destination & data$destination == group_start),]
  if (nrow(data) == 0) {
    stop("Either the start or destination group were not found")
  }

  # Choose the specified methods
  if (!("all" %in% method)){
    data <- cbind(data[,1:2], data[,method])
  }

  # Add name if only one method
  if (length(method) ==  1 & method != "all") {
    colnames(data)[3] <- method
  }

  # Make the table long format for the methods
  data <- suppressWarnings(tidyr::gather(data,
                                         key = "method",
                                         value = "distance",
                                         3:ncol(data)))

  # Plotting
  p <- ggplot2::ggplot(data = data, ggplot2::aes(x = data$method,
                                                 y = data$distance)) +
    ggplot2::theme_bw() +
    ggplot2::geom_bar(stat = "identity", ggplot2::aes(fill = method)) +
    ggplot2::labs(x = "method", y = "distance")
  return(p)
}

#' Plots the distances next to each other
#'
#' @description
#' Create a bar plot of how the distances between two factors in categorical
#' variable differ across another categorical variable. Also possible to
#' select multiple distance measures and plot them simultaneously. Additionally
#' each method can be z score normalised/rescaled to have a mean of allowing
#' comparison across methods.
#'
#' @inheritParams GroupDistance
#' @param group_start Factor to plot the distance from
#' @param group_destination Factor to plot the distance to
#' @param rescale Boolean. z score normalise across distance methods.
#'
#' @return ggplot bar chart of distances
#' @export
#'
#' @family distance_plot
PlotDistance <- function(seurat_object,
                              group,
                              group_start,
                              group_destination,
                              split_by = NULL,
                              method = "all",
                              rescale = TRUE) {
  if (is.null(split_by)) {
    p <- plot_distance(seurat_object,
                  group,
                  group_start,
                  group_destination,
                  method = method)
    return(p)
  }
  # Extract the data
  data <- GetDistance(seurat_object, group = group, split_by = split_by)
  if (is.null(data)) {
    stop("Either the values for group or split_by do not exist. Make sure you
         have already run GroupDistance() with the same options")
  }

  # Make split.by factors
  data$split.by <- factor(data$split.by, levels <- data$split.by)

  # Select for start and destination
  data <- data[(data$start == group_start & data$destination == group_destination) |
                 (data$start == group_destination & data$destination == group_start),]
  if (nrow(data) == 0) {
    stop("Either the start or destination group were not found")
  }

  # Choose the specified methods
  if (!("all" %in% method)){
    data <- cbind(data[,1:3], data[,method])
  }

  # Add name if only one method
  if (length(method) ==  1 & method != "all") {
    colnames(data)[4] <- method
  }

  # Scale the different methods
  if(rescale == TRUE) {
    for (i in 4:ncol(data)) {
      data[,i] <- scale(as.numeric(as.character(data[,i]))) + 1
    }
    mean <- 1
  }

  # Make the table long format for the methods
  data <- suppressWarnings(tidyr::gather(data,
                                         key = "method",
                                         value = "distance",
                                         4:ncol(data)))

  # Mean if not scaled
  if (rescale == FALSE) {
    mean <- mean(as.numeric(data$distance))
  }

  # Plotting
  p <- ggplot2::ggplot(data = data, ggplot2::aes(x = data$split.by,
                                                 y = data$distance)) +
    ggplot2::theme_bw() +
    ggplot2::geom_bar(stat = "identity", ggplot2::aes(fill = method),
                      position = "dodge") +
    ggplot2::labs(x = split_by) +
    ggplot2::geom_hline(ggplot2::aes(yintercept= mean, linetype = "mean"),
                        colour= 'blue') +
    ggplot2::scale_linetype_manual(name = "Statistics", values = c(2, 2),
                          guide = ggplot2::guide_legend(override.aes = list(color = c("blue"))))
  return(p)
}


#' Calculates the average distance between groups for a list of seurat objects
#'
#' @description
#' The average distance is the average of the distance from every individual
#' cell in group 1 to every individual cell in group 2.
#'
#' @inheritParams GroupDistance
#' @param output Output type. Default is "seurat". Options are
#' "seurat" (Seurat S4 object), "list" (List of tables by split_by),
#' "seurat_list" (List of Seurat objects). Seurat objects returned have the
#' distance data stored internally.
#' @param parallel It is possible to run this process in parallel across
#' multiple computers using BiocParallel. TRUE will use and FALSE will not.
#' If this is set to TRUE then it will not be able to print out progress.
#'
#' @return Average distance data. Fundementally in the form of a vector of
#' distances
#' @export
#' @import BiocParallel
AverageDistance <- function(seurat_object,
                                 group,
                                 reduction = "pca",
                                 dims = 1:30,
                                 distance = "euclidian",
                                 split_by = NULL,
                                 output = "seurat",
                                parallel = FALSE) {
  # Does the object need to be split
  if (typeof(seurat_object) == "S4") {
    if (is.null(split_by)) {
      output.object <- average_distance(seurat_object = seurat_object,
                      group = group,
                      reduction = reduction,
                      dims= dims,
                      distance = distance,
                      output = output)
      return(output.object)
    }
    seurat_list <- Seurat::SplitObject(seurat_object, split.by = split_by)

    # If the group labels are numeric, ensure order
    tryCatch({
      numeric.names <- as.numeric(names(seurat_list))
      seurat_list <- seurat_list[order(as.numeric(names(seurat_list)))]
    }, warning = function(c){
      message("Was not able to order the split_by")
    })
  } else {
    seurat_list <- seurat_object
  }

  result <- list()

  # Check output type
  if (output == "seurat list") {
    output.type <- "seurat"
  } else {
    output.type <- "list"
  }
  if (output == "seurat") {
    if (typeof(seurat_object) != "S4") {
      stop("To output to seurat object, the input must be a seurat object")
    }
  }

  total.split <- length(seurat_list)

    message(paste("Calculating average distance for every factor in", split_by))


    # Run it with biocparallel
    if (parallel == TRUE) {
        message("Computing distances in parallel")
        result <- BiocParallel::bplapply(seurat_list, average_distance,
                                         group = group,
                                         reduction = reduction,
                                         dims = dims,
                                         distance = distance,
                                         output = output.type)
    } else if (parallel == FALSE) {
        for (i in 1:length(seurat_list)) {
          message(paste("Calculating for split:", names(seurat_list[i])))
          result[[i]] <- average_distance(seurat_list[[i]],
                                         group = group,
                                         reduction = reduction,
                                         dims= dims,
                                         distance = distance,
                                         output = output.type)
        }
    } else {
        warning("Please use either TRUE or FALSE for parallel")
        for (i in 1:length(seurat_list)) {
            message(paste("Calculating for split:", names(seurat_list[i])))
            result[[i]] <- average_distance(seurat_list[[i]],
                                            group = group,
                                            reduction = reduction,
                                            dims= dims,
                                            distance = distance,
                                            output = output.type)
        }
    }
  # Format output types
  if (output == "list") {
    names(result) <- names(seurat_list)
  }else if (output == "seurat list") {
    # don't do anything
  } else if (output == "seurat") {
    names(result) <- names(seurat_list)
    location <- paste("average_distance", group, split_by, sep = "_")
    seurat_object@tools[[location]] <- result
    result <- seurat_object
  }else {
    stop("Choose output = 'list', 'seurat list' or 'seurat'")
  }
  return(result)
}


#' Plots the average distance between two factors of a group
#'
#' @description
#' Can choose to identify the groups, split and method. The distance between
#' group 1 and group 2
#'
#' @details
#' Need to run AverageDistance() beforehand.
#'
#' Error bars can be: "sd" or "se"
#' Mode can be: "density" or "bar"
#'
#' @inheritParams PlotDistance
#' @param error_bars Method for plot error bars. Error bars can be NULL
#' (no error bars), "se" (standard error), "sd" (standard deviation).
#' @param mode Whether to plot a bar or density plot. Options are "bar"
#' (bar plot), "density" (density plot)
#' @param rotate_x The rotation of the x axis labels. Default is 90
#' (perpendicular) to the x axis. This is only relevant for bar plots.
#'
#' @return ggplot of distances
#' @export
#'
#' @family distance_plot
PlotAverageDistance <- function(seurat_object,
                                group,
                                group_start,
                                group_destination,
                                split_by = NULL,
                                error_bars = "sd",
                                mode = "density",
                                rotate_x = 90) {
  if (mode == "bar" & is.null(split_by)) {
    stop("If mode = bar then split_by can not be NULL. There needs to be an
         x axis category. ")
  }

  if (is.null(split_by)) {
    # Isolate the data for no split
    data <-  GetDistance(seurat_object, group = group, method = "average")
    data.1 <- data[[paste(group_start, group_destination, sep = "_")]]
    data.2 <- data[[paste(group_destination, group_start, sep = "_")]]
    if (length(data.1) > length(data.2)) {
      distance <- data.1
    } else {
      distance <- data.2
    }
    data <- as.data.frame(distance)
  } else {
    # Isolate the data for a split
    average.data <- GetDistance(seurat_object, group, split_by, method = "average")

    # split data frame
    data <- data.frame()

    for (i in 1:length(average.data)) {
      data.1 <- average.data[[i]][[paste(group_start, group_destination, sep = "_")]]
      data.2 <- average.data[[i]][[paste(group_destination, group_start, sep = "_")]]
      if (length(data.1) > length(data.2)) {
        distance <- data.1
      } else {
        distance <- data.2
      }
      # Assign a split to data
      data.split <- as.data.frame(distance)
      data.split$split <- names(average.data)[i]

      # Combine into data
      data <- rbind(data, data.split)
    }
  }

  # Density plot
  if (mode == "density" & is.null(split_by)) {
    p <- ggplot2::ggplot() +
      ggplot2::theme_bw() +
      ggplot2::geom_density(data = data, mapping = ggplot2::aes(x = distance))
  } else if (mode == "density") {
    p <- ggplot2::ggplot() +
      ggplot2::theme_bw() +
      ggplot2::geom_density(data = data, mapping = ggplot2::aes(x = distance,
                                              color = split,
                                              fill = split),
                   alpha = 0.2) +
      ggplot2::labs(fill = split_by, color = split_by)
  }

  # Bar plot
  if (mode == "bar" & !is.null(split_by)) {
    # Calculate the mean and error of each split
    if (is.null(error_bars)){
        mu <- plyr::ddply(data, "split", plyr::summarise,
                          distance=mean(distance))
    } else if (error_bars == "sd") {
      mu <- plyr::ddply(data, "split", plyr::summarise,
                  error = stats::sd(distance), distance=mean(distance))
    } else if (error_bars == "se") {
      mu <- plyr::ddply(data, "split", plyr::summarise,
                  error = stats::sd(distance)/sqrt(length(distance)), distance=mean(distance))

    } else {
      mu <- plyr::ddply(data, "split", plyr::summarise,
                  distance=mean(distance))
    }

    data$split <- factor(data$split, levels = names(average.data))

    p <- ggplot2::ggplot(data = mu, ggplot2::aes(x = split, y = distance)) +
      ggplot2::theme_bw() +
      ggplot2::geom_bar(stat = "identity", ggplot2::aes(fill = split)) +
      ggplot2::labs(fill = split_by, x = split_by)
    if(!is.null(error_bars)) {
      p <- p +
        ggplot2::geom_errorbar(data = mu, ggplot2::aes(ymin=mu$distance-mu$error,
                                              ymax=mu$distance+mu$error),
                               width= 0.2)
    }
    p <- p + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = rotate_x, hjust = 1))
  }
  return(p)
}

#' Plots the internal distances
#'
#' @description
#' Creates a bar plot of internal distances.
#'
#' @inheritParams PlotDistance
#'
#' @return Bar bar plot of distances
#' @export
PlotInternalDistance <- function(seurat_object,
                                 group = NULL,
                                 split_by = NULL,
                                 method = "linkage",
                                 error_bars = "sd",
                                 rotate_x = 90) {
  # Extract distance data
  distance.data <- GetDistance(seurat_object, group = group, split_by = split_by,
                               method = "internal")

  # Subset based on internal distance method
  distance.data <- distance.data[distance.data$method == method,]

  if(!is.null(group) & is.null(split_by)) {
    p <- ggplot2::ggplot(distance.data, ggplot2::aes(x = group,
                                                     y = mean,
                                                     fill = group)) +
      ggplot2::geom_bar(stat = "identity", position = "dodge") +
      ggplot2::labs(x = group, fill = group)
  }else if (is.null(group) & !is.null(split_by)) {
    p <- ggplot2::ggplot(distance.data, ggplot2::aes(x = split.by,
                                                     y = mean,
                                                     fill = split.by)) +
      ggplot2::geom_bar(stat = "identity", position = "dodge") +
      ggplot2::labs(x = group, fill = split_by)
  } else if (is.null(group) & is.null(group)) {
    p <- ggplot2::ggplot(distance.data, ggplot2::aes(x = method, y = mean)) +
      ggplot2::geom_bar(stat = "identity", position = "dodge")
  } else {
    p <- ggplot2::ggplot(distance.data, ggplot2::aes(x = split.by,
                                                     y = mean,
                                                     fill = group)) +
      ggplot2::geom_bar(stat = "identity", position = "dodge") +
      ggplot2::labs(x = split_by, fill = group)
  }

  # Error bars
  if(error_bars == "sd") {
  p <- p + ggplot2::geom_errorbar(width = 0.5,
                           position = ggplot2::position_dodge(width = 0.9),
                           ggplot2::aes(ymin = mean - sd, ymax = mean + sd))
  } else if (error_bars == "se") {
    p <- p + ggplot2::geom_errorbar(width = 0.5,
                                    position = ggplot2::position_dodge(width = 0.9),
                                    ggplot2::aes(ymin = mean - se, ymax = mean + se))
  }
  # Rotate axis labels
  p <- p + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = rotate_x, hjust = 1))
  return(p)
}


#' Calculates the total distance between groups for a list of seurat objects
#'
#' @description
#' The total distance is the distance between every individual cell in each
#' group. This includes every cell within it's own group.
#'
#' @details
#' The total distance calculates the distance to every other cell in
#' group 1 to every other cell in group 1 as well as from every cell in
#' group 1 to every other cell in group 2. Subsequently this creates a
#' distribution of distances for both group 1 and group 2. This means that
#' while this method cannot result in a single number distance result it is
#' useful for plotting the distances between the two groups.
#'
#' It is possible to change the
#' method to only measure the linkage for one cell instead of all the cells.
#' This could be chosen in order to speed up the calculation process. If
#' "one" is chosen then only one cell in group 1 will be used. This cell is
#' chosen because it is furthest away from the mean of group 2.
#'
#' @inheritParams AverageDistance
#' @param method The number of cells in group 1 to test. This is done because
#' calculating the total distance is often very computationally expensive.
#' And by reducing the number of cells in group one can give an approximation
#' of the final result without having to test all cells. Options are "one" (
#' calculate for one cell in group 1), "all" (calculate for all cells)
#'
#' @return List of data frames containing distances between groups per row
#' @export
#'
#' @family TotalDistance
TotalDistance <- function(seurat_object,
                               group,
                               reduction = "pca",
                               dims = 1:30,
                               distance = "euclidian",
                               split_by = NULL,
                               output = "seurat",
                               method = "all",
                                parallel = FALSE) {
  # Does the object need to be split
  if (typeof(seurat_object) == "S4") {
    if (is.null(split_by)) {
      output.object <- total_distance(seurat_object,
                              group = group,
                              reduction = reduction,
                              dims= dims,
                              distance = distance,
                              output = output,
                              method = method)
      return(output.object)
    }
    seurat_list <- Seurat::SplitObject(seurat_object, split.by = split_by)

    # If the group labels are numeric, ensure order
    tryCatch({
      numeric.names <- as.numeric(names(seurat_list))
      seurat_list <- seurat_list[order(as.numeric(names(seurat_list)))]
    }, warning = function(c){
      message("Was not able to order the split_by")
    })
  } else {
    seurat_list <- seurat_object
  }

  result <- list()

  # Check output type
  if (output == "seurat list") {
    output.type <- "seurat"
  } else {
    output.type <- "list"
  }
  if (output == "seurat") {
    if (typeof(seurat_object) != "S4") {
      stop("To output to seurat object, the input must be a seurat object")
    }
  }

  if (parallel == TRUE) {
      result <- BiocParallel::bplapply(seurat_list, average_distance,
                                       group = group,
                                       reduction = reduction,
                                       dims = dims,
                                       distance = distance,
                                       output = output.type)
    } else if (parallel == FALSE) {
        for (i in 1:length(seurat_list)) {
            message(paste("Calculating for split:", names(seurat_list)[i]))
            result[[i]] <- total_distance(seurat_list[[i]],
                                          group = group,
                                          reduction = reduction,
                                          dims= dims,
                                          distance = distance,
                                          output = output.type,
                                          method = method)
        }
    } else {
        stop("Please select either TRUE or FALSE for parallel")
    }
  # Format output types
  if (output == "list") {
    names(result) <- names(seurat_list)
  }else if (output == "seurat list") {
    # don't do anything
  } else if (output == "seurat") {
    names(result) <- names(seurat_list)
    location <- paste("total_distance", group, split_by, sep = "_")
    seurat_object@tools[[location]] <- result
    result <- seurat_object
  }else {
    stop("Choose output = 'list', 'seurat list' or 'seurat'")
  }
  return(result)
}

#' Plots the total distance between two factors of a group as a density plot
#'
#' @description
#' Plots the total distance data on a density plot. This should visualise the
#' distance between the two groups as well as their distributions.
#'
#' @details
#' Need to run TotalDistance beforehand.
#'
#' @inheritParams PlotAverageDistance
#'
#' @return ggplot density plot
#' @export
#'
#' @family distance_plot
#' @family total_distance
PlotTotalDistance <- function(seurat_object,
                              group,
                              group_start,
                              group_destination,
                              split_by = NULL) {
  # Isolate data
  data <- GetDistance(seurat_object, group = group,
              split_by = split_by, method = "total")

  if (is.null(split_by)) {
    data.1 <- data[[paste(group_start, group_destination, sep = "_")]]
    data.2 <- data[[paste(group_destination, group_start, sep = "_")]]
    if (length(data.1) > length(data.2)) {
      distance <- data.1
    } else {
      distance <- data.2
    }
    data <- as.data.frame(distance)

    p <- ggplot2::ggplot() +
      ggplot2::theme_bw() +
      ggplot2::geom_density(data = data, mapping = ggplot2::aes(x = distance,
                                              color = group,
                                              fill = group),
                   alpha = 0.2) +
      ggplot2::labs(fill = group, color = group)
    return(p)
  } else {

    # output list
    result <- list()

    for (i in 1:length(data)) {
      data.1 <- data[[i]][[paste(group_start, group_destination, sep = "_")]]
      data.2 <- data[[i]][[paste(group_destination, group_start, sep = "_")]]
      if (length(data.1) > length(data.2)) {
        distance <- data.1
      } else {
        distance <- data.2
      }
      result[[i]] <- ggplot2::ggplot() +
        ggplot2::theme_bw() +
        ggplot2::geom_density(data = distance,
                              mapping = ggplot2::aes(x = distance,
                                                color = group,
                                                fill = group),
                     alpha = 0.2) +
        ggplot2::labs(fill = group, color = group)

    }
    names(result) <- names(data)
    return(result)
  }
}


# Distance Visualisation Helper Methods  ------------------------------

#' Calculates the means of each factor in a categorical variable
#'
#' @description
#' In reduced dimension space, calculate the means of each group in a seurat
#' object
#'
#' @inheritParams GroupDistance
#'
#' @return Data frame containing means of groups per row
group_mean <- function(seurat_object,
                             group,
                             reduction = "pca",
                             dims = 1:30) {
  # Isolate data
  data <- seurat_object@reductions[[reduction]]@cell.embeddings
  group.factors <- unique(as.matrix(seurat_object[[group]]))

  # Correct dimensions if too large
  if (length(dims) > ncol(data)) {
    dims <- 1:ncol(data)
  }
  data<- data[,dims]

  # Create empty dataframe with correct dimensions
  output <- data.frame()

  for (i in 1:length(group.factors)) {
    this.data <- data[seurat_object[[group]] == group.factors[i],]
    output <- rbind(output, colMeans(this.data))
  }

  # Set names
  colnames(output) <- c(colnames(data)[dims])
  rownames(output) <- c(group.factors)

  return(output)
}

#' Calculates the standard error of a vector
#'
#' @param x Vector of numbers
#'
#' @return Standard error of the input vector
standard_error <- function(x) {
  return(stats::sd(x)/sqrt(length(x)))
}

#' Calculates the average distance between groups in seurat object
#'
#' @description
#' Calculates the distance between the closet pair of points between two
#' clusters.
#'
#' @inheritParams AverageDistance
#' @return Distances between groups
average_distance <- function(seurat_object,
                             group,
                             reduction = "pca",
                             dims = 1:30,
                             distance = "euclidian",
                             output = "seurat") {
  # Ensure group is a factor
  seurat_object[[group]] <- as.factor(as.matrix(seurat_object[[group]]))

  # Correcting number of dimensions if too many were chosen
  data <- seurat_object@reductions[[reduction]]@cell.embeddings
  if (length(dims) > ncol(data)) {
    dims <- 1:ncol(data)
  }
  data <- data[,dims]

  # Number of groups to test
  group_count <- length(unique(seurat_object[[group]][,1]))

  # Output data frame
  result <- list()

  for (i in 1:group_count) {
    # Isolate rows belonging to group 1
    group_1_name <- unique(as.matrix(seurat_object[[group]][,1]))[i]
    group_1_key <- seurat_object[[group]][,1] == group_1_name
    group_1 <- data[group_1_key,]

    for (j in 1:group_count) {
      if (j <= i) {
        next
      }
      # Isolate rows belonging to group 2
      group_2_name <- unique(as.matrix(seurat_object[[group]][,1]))[j]
      group_2_key <- seurat_object[[group]][,1] == group_2_name
      group_2 <- data[group_2_key,][,dims]

      # New line vector
      distance_name <- paste(group_1_name, group_2_name, sep = "_")

      # Calculate the single linkage distance
      distance.data <- average_linkage_distance(group_1,
                                                group_2,
                                                distance,
                                                output = "raw")
      result[[distance_name]] <- distance.data
    }
  }
  if (output == "list") {
    return(result)
  } else if (output == "seurat") {
    location <- paste("average_distance", group, sep = "_")
    seurat_object@tools[[location]] <- result
    return(seurat_object)
  } else {
    stop("Please select a valid output mode: 'list' or 'seurat'")
  }
}

#' Calculates the total distance between groups in seurat object
#'
#' @description
#' Calculates the distance between the every point in both clusters. Including
#' within their own clusters.
#'
#' @inheritParams TotalDistance
#'
#' @return List containing distances between groups per row
#'
#' @family total_distance
total_distance <- function(seurat_object,
                           group,
                           reduction = "pca",
                           dims = 1:30,
                           distance = "euclidian",
                           output = "seurat",
                           method = "all") {
  # Correct dimensions if too large
  data <- seurat_object@reductions[[reduction]]@cell.embeddings
  if (length(dims) > ncol(data)) {
    dims <- 1:ncol(data)
  }
  data <- data[,dims]

  # Number of groups to test
  unique_groups <- unique(as.matrix(seurat_object[[group]][,1]))

  # Split into groups
  seurat.groups <- Seurat::SplitObject(seurat_object, split.by = group)

  output.list <- list()

  pair.count <- 1
  for (i in 1:length(unique_groups)) {
    group_1_name <- unique_groups[i]
    group_1_key <- seurat_object[[group]][,1] == group_1_name
    group_1 <- data[group_1_key,]
    for (j in 1:length(seurat.groups)) {
      if (j <= i) {
        next
      }
      group_2_name <- unique_groups[j]
      group_2_key <- seurat_object[[group]][,1] == group_2_name
      group_2 <- data[group_2_key,]

      # Calculate the complete linkage between the two groups
      output.list[[pair.count]] <- total_linkage(group_1, group_2,
                                                 distance,
                                                 group_1_name, group_2_name,
                                                 method = method)
      names(output.list)[pair.count] <- paste(group_1_name, group_2_name,
                                              sep = "_")
      pair.count <- pair.count + 1
    }
  }

  if (output == "list") {
    return(output.list)
  } else if (output == "seurat") {
    location <- paste("total_distance", group, sep = "_")
    seurat_object@tools[[location]] <- output.list
    return(seurat_object)
  }
}

#' Calculates the Total linkage between two groups
#'
#' @description
#' Total linkage shows the distance between two groups. Will result in
#' information about the current group as well as the opposing group.
#'
#' @details
#' Total  linkage measures distance to every point in the dataset from
#' every point in one group of the dataset. It is possible to change the
#' method to only measure the linkage for one cell instead of all the cells.
#' This could be chosen in order to speed up the calculation process. If
#' "one" is chosen then only one cell in group 1 will be used. This cell is
#' chosen because it is furthest away from the mean of group 2.
#'
#' @inheritParams single_linkage_distance
#' @param name_1 name of matrix_1 group
#' @param name_2 name of matrix_2 group
#' @param method Cells in matrix_1 to test. "all" or "one". This is primarily
#' done to speed up the calculation process.
#'
#' @return Dataframe where column 1 is group label and column 2 is distance
#' information
#'
#' @family total_distance
total_linkage <- function(matrix_1,
                          matrix_2,
                          distance = "euclidean",
                          name_1 = NULL,
                          name_2 = NULL,
                          method = "all") {
  matrix_1 <- as.data.frame(matrix_1)
  matrix_2 <- as.data.frame(matrix_2)

  # Put labels on group 1
  matrix_1$group <- name_1
  matrix_1 <- matrix_1[, c(ncol(matrix_1), 2:ncol(matrix_1)-1)]

  # Put labels on group 2
  matrix_2$group <- name_2
  matrix_2 <- matrix_2[, c(ncol(matrix_2), 2:ncol(matrix_2)-1)]

  # Default group_1.test
  group_1.test <- matrix_1

  # Change size of group 1 if method is "one"
  if (method == "one") {
    # Mean of group 2
    group_2.mean <- as.data.frame(t(colMeans(matrix_2[,2:ncol(matrix_2)])))

    # Find the point with the greatest distance to group 2 mean
    largest.distance <- 0
    i <- 1
    for (i in 1:nrow(matrix_1)) {
      point.1 <- matrix_1[i, 2:ncol(matrix_1)]
      pair <- rbind(point.1, group_2.mean)
      difference <- as.matrix(stats::dist(pair, method = distance))[1,2]
      if (distance > largest.distance) {
        largest.distance <- difference
        largest.index <- i
      }
    }
    group_1.test <- matrix_1[largest.index,]
  }

  # output.data.frame
  output <- data.frame(matrix(ncol = 2, nrow = 0))
  colnames(output) <- c("group", "distance")

  # Cells that have already been measured. Do not repeat
  completed.cells <- c()

  # Progress bar
  total <- nrow(group_1.test)
  pb <- progress::progress_bar$new(format = "[:bar] :current/:total (:percent) eta: :eta",
                         total = total)
  pb$tick(0)

  for (i in 1:nrow(group_1.test)) {
    current.cell <- group_1.test[i,]
    completed.cells <- c(completed.cells, rownames(matrix_1)[i])
    group_1.include <- as.data.frame(matrix_1[ -which(rownames(matrix_1) %in%
                                                       completed.cells), ])


    test.against <- rbind(group_1.include, matrix_2)
    output <- rbind(output, cell_total_linkage(current.cell, test.against,
                                                  distance = distance))

    # Progress bar
    pb$tick(1)
  }
  return(output)
}

#' Calculates the distance from this cell to all other cells
#'
#' @description
#' Iterates over every cell in the other cells table and calculates the
#' the distance to all of those cells.
#'
#' @details
#' All the cells in the other group should be labelled by group as well. This
#' will allow the function to correctly assign the distances to group.
#'
#' @param cell_1 A data frame with one row for the one cell
#' @param other_cells  cells x dimensions data frame of destination cells
#' @inheritParams GroupDistance
#'
#' @return Data frame containing labelled distance to every cell in other cells
#' data frame
cell_total_linkage <- function(cell_1, other_cells, distance = "euclidean") {
  #transpose cell_1
  cell_1 <- cell_1[2:length(cell_1)]

  #Output.data.frame
  output.data.frame <- data.frame(matrix(ncol = 2, nrow = 0),
                                  stringsAsFactors = FALSE)
  output.data.frame[,2] <- as.numeric(output.data.frame[,2])

  # Iterate over every cell in other cells
  for (i in 1:nrow(other_cells)) {
    group <- other_cells[i,1]
    cell_2 <- other_cells[i,2:ncol(other_cells)]

    # Figure out how to calculate distance between two
    pair <- rbind(cell_1, cell_2)
    pair.distance <- as.matrix(stats::dist(pair, method = distance))[1,2]

    # distance <- euclidean_distance(cell_1, cell_2)
    new.row <- data.frame("group"= group, "distance" = pair.distance)
    output.data.frame <- rbind(output.data.frame, new.row)
  }
  return(output.data.frame)
}

########################################################
# Two sample test
########################################################

#' Cramer test on different groups in seurat object
#'
#' @description
#' Calculates whether groups within a seurat object are drawn from different
#' distributions.
#'
#' @details
#' Uses the Cramer package. Only works with two groups.
#'
#' @inheritParams GroupDistance
#'
#' @return List of cramer objects with information about the test
#' @export
RunCramer <- function(seurat_object,
                      group,
                      reduction = "pca",
                      dims = 1:30) {
  # Ensure group is a factor
  seurat_object[[group]] <- as.factor(as.matrix(seurat_object[[group]]))

  # Correcting number of dimensions if too many were chosen
  data <- seurat_object@reductions[[reduction]]@cell.embeddings
  if (length(dims) > ncol(data)) {
    dims <- 1:ncol(data)
  }
  data <- data[,dims]

  # Number of groups to test
  group_count <- length(unique(seurat_object[[group]][,1]))

  # Output data frame
  result <- list()

  for (i in 1:group_count) {
    # Isolate rows belonging to group 1
    group_1_name <- unique(as.matrix(seurat_object[[group]][,1]))[i]
    group_1_key <- seurat_object[[group]][,1] == group_1_name
    group_1 <- data[group_1_key,]

    for (j in 1:group_count) {
      if (j <= i) {
        next
      }
      # Isolate rows belonging to group 2
      group_2_name <- unique(as.matrix(seurat_object[[group]][,1]))[j]
      group_2_key <- seurat_object[[group]][,1] == group_2_name
      group_2 <- data[group_2_key,][,dims]

      result[[paste(group_1_name, group_2_name, sep = "_")]] <-
        cramer::cramer.test(group_1, group_2)
    }
  }
  return(result)
}
sbrn3/disscat documentation built on Dec. 12, 2019, 7:54 a.m.