R/utils.R

Defines functions calc_distance assess_rank_bias make_comb_ref build_atlas check_raw_counts append_genes plot_rank_bias query_rank_bias find_rank_bias get_unique_column file_marker_parse pos_neg_marker reverse_marker_matrix pos_neg_select marker_select ref_marker_select downsample_matrix plot_pathway_gsea gmt_to_list feature_select_PCA ref_feature_select overcluster_test insert_meta_object parse_loc_object object_loc_lookup clustify_nudge.Seurat clustify_nudge.default clustify_nudge gene_pct_markerm gene_pct cor_to_call_topn assign_ident calculate_pathway_gsea get_common_elements get_best_str get_best_match_matrix percent_clusters average_clusters overcluster is_pkg_available

Documented in append_genes assess_rank_bias assign_ident average_clusters build_atlas calc_distance calculate_pathway_gsea check_raw_counts clustify_nudge clustify_nudge.default clustify_nudge.Seurat cor_to_call_topn downsample_matrix feature_select_PCA file_marker_parse find_rank_bias gene_pct gene_pct_markerm get_best_match_matrix get_best_str get_common_elements get_unique_column gmt_to_list insert_meta_object make_comb_ref marker_select object_loc_lookup overcluster overcluster_test parse_loc_object percent_clusters plot_pathway_gsea plot_rank_bias pos_neg_marker pos_neg_select query_rank_bias ref_feature_select ref_marker_select reverse_marker_matrix

#' Check package is installed
#' @param pkg package to query
#' @return logical(1) indicating if package is available.
#' @noRd
is_pkg_available <- function(
  pkg,
  action = c("none", "message", "warn", "error"),
  msg = ""
) {
  has_pkg <- requireNamespace(pkg, quietly = TRUE)
  action <- match.arg(action)

  if (!has_pkg) {
    switch(
      action,
      message = message(
        pkg,
        " not installed ",
        msg
      ),
      warn = warning(pkg, " not installed ", msg, call. = FALSE),
      error = stop(
        pkg,
        " not installed and is required for this function ",
        msg,
        call. = FALSE
      ),
    )
  }
  has_pkg
}


#' Overcluster by kmeans per cluster
#'
#' @param mat expression matrix
#' @param cluster_id list of ids per cluster
#' @param power decides the number of clusters for kmeans
#' @return new cluster_id list of more clusters
#' @examples
#' res <- overcluster(
#'     mat = pbmc_matrix_small,
#'     cluster_id = split(colnames(pbmc_matrix_small), pbmc_meta$classified)
#' )
#' length(res)
#' @export
overcluster <- function(
  mat,
  cluster_id,
  power = 0.15
) {
  mat <- as.matrix(mat)
  new_ids <- list()
  for (name in names(cluster_id)) {
    ids <- cluster_id[[name]]
    if (length(ids) > 1) {
      new_clusters <-
        stats::kmeans(t(mat[, ids]), centers = as.integer(length(ids)^power))
      new_ids1 <-
        split(
          names(new_clusters$cluster),
          new_clusters$cluster
        )
      names(new_ids1) <-
        stringr::str_c(name, names(new_ids1), sep = "_")
      new_ids <- append(new_ids, new_ids1)
    } else {
      new_ids <- append(new_ids, cluster_id[name])
    }
  }
  new_ids
}

#' Average expression values per cluster
#'
#' @param mat expression matrix
#' @param metadata data.frame or vector containing cluster assignments per cell.
#' Order must match column order in supplied matrix. If a data.frame
#' provide the cluster_col parameters.
#' @param if_log input data is natural log,
#' averaging will be done on unlogged data
#' @param cluster_col column in metadata with cluster number
#' @param cell_col if provided, will reorder matrix first
#' @param low_threshold option to remove clusters with too few cells
#' @param method whether to take mean (default), median, 10% truncated mean, or trimean, max, min
#' @param output_log whether to report log results
#' @param subclusterpower whether to get multiple averages per original cluster
#' @param cut_n set on a limit of genes as expressed, lower ranked genes
#' are set to 0, considered unexpressed
#' @return average expression matrix, with genes for row names, and clusters
#'  for column names
#' @examples
#' mat <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = FALSE
#' )
#' mat[1:3, 1:3]
#' @importFrom matrixStats rowMaxs rowMedians colRanks
#' @export
average_clusters <- function(
  mat,
  metadata,
  cluster_col = "cluster",
  if_log = TRUE,
  cell_col = NULL,
  low_threshold = 0,
  method = "mean",
  output_log = TRUE,
  subclusterpower = 0,
  cut_n = NULL
) {
  cluster_info <- metadata
  if (!(is.null(cell_col))) {
    if (!(all(colnames(mat) == cluster_info[[cell_col]]))) {
      mat <- mat[, cluster_info[[cell_col]]]
    }
  }

  if (is.null(colnames(mat))) {
    stop(
      "The input matrix does not have colnames.\n",
      "Check colnames() of input object"
    )
  }
  if (is.vector(cluster_info)) {
    if (ncol(mat) != length(cluster_info)) {
      stop(
        "vector of cluster assignments does not match the number of columns in the matrix",
        call. = FALSE
      )
    }
    cluster_ids <- split(colnames(mat), cluster_info)
  } else if (is.data.frame(cluster_info) & !is.null(cluster_col)) {
    if (
      !is.null(cluster_col) &&
        !(cluster_col %in% colnames(metadata))
    ) {
      stop("given `cluster_col` is not a column in `metadata`", call. = FALSE)
    }

    cluster_info_temp <- cluster_info[[cluster_col]]
    if (is.factor(cluster_info_temp)) {
      cluster_info_temp <- droplevels(cluster_info_temp)
    }
    cluster_ids <- split(colnames(mat), cluster_info_temp)
  } else if (is.factor(cluster_info)) {
    cluster_info <- as.character(cluster_info)
    if (ncol(mat) != length(cluster_info)) {
      stop(
        "vector of cluster assignments does not match the number of columns in the matrix",
        call. = FALSE
      )
    }
    cluster_ids <- split(colnames(mat), cluster_info)
  } else {
    stop(
      "metadata not formatted correctly,
         supply either a vector or a dataframe",
      call. = FALSE
    )
  }

  if (subclusterpower > 0) {
    cluster_ids <-
      overcluster(mat, cluster_ids, power = subclusterpower)
  }

  if (method == "mean") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        if (if_log) {
          mat_data <- expm1(mat[, cell_ids, drop = FALSE])
        } else {
          mat_data <- mat[, cell_ids, drop = FALSE]
        }
        res <- Matrix::rowMeans(mat_data, na.rm = TRUE)
        if (output_log) {
          res <- log1p(res)
        }
        res
      }
    )
  } else if (method == "median") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        mat_data <- mat[, cell_ids, drop = FALSE]
        # mat_data[mat_data == 0] <- NA
        res <- matrixStats::rowMedians(as.matrix(mat_data), na.rm = TRUE)
        res[is.na(res)] <- 0
        names(res) <- rownames(mat_data)
        res
      }
    )
  } else if (method == "trimean") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        mat_data <- mat[, cell_ids, drop = FALSE]
        # mat_data[mat_data == 0] <- NA
        res1 <- matrixStats::rowQuantiles(
          as.matrix(mat_data),
          probs = 0.25,
          na.rm = TRUE
        )
        res2 <- matrixStats::rowQuantiles(
          as.matrix(mat_data),
          probs = 0.5,
          na.rm = TRUE
        )
        res3 <- matrixStats::rowQuantiles(
          as.matrix(mat_data),
          probs = 0.75,
          na.rm = TRUE
        )
        res <- 0.5 * res2 + 0.25 * res1 + 0.25 * res3
        res[is.na(res)] <- 0
        names(res) <- rownames(mat_data)
        res
      }
    )
  } else if (method == "truncate") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        mat_data <- mat[, cell_ids, drop = FALSE]
        # mat_data[mat_data == 0] <- NA
        res <- apply(mat_data, 1, function(x) mean(x, trim = 0.1, na.rm = TRUE))
        colnames(res) <- names(cell_ids)
        res
      }
    )
  } else if (method == "min") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        mat_data <- mat[, cell_ids, drop = FALSE]
        # mat_data[mat_data == 0] <- NA
        res <- matrixStats::rowMins(as.matrix(mat_data), na.rm = TRUE)
        res[is.na(res)] <- 0
        names(res) <- rownames(mat_data)
        res
      }
    )
  } else if (method == "max") {
    out <- lapply(
      cluster_ids,
      function(cell_ids) {
        if (!all(cell_ids %in% colnames(mat))) {
          stop("cell ids not found in input matrix", call. = FALSE)
        }
        mat_data <- mat[, cell_ids, drop = FALSE]
        # mat_data[mat_data == 0] <- NA
        res <- matrixStats::rowMaxs(as.matrix(mat_data), na.rm = TRUE)
        res[is.na(res)] <- 0
        names(res) <- rownames(mat_data)
        res
      }
    )
  }

  out <- do.call(cbind, out)
  if (low_threshold > 0) {
    fil <- vapply(cluster_ids, FUN = length, FUN.VALUE = numeric(1)) >=
      low_threshold
    if (!all(as.vector(fil))) {
      message(
        "The following clusters have less than ",
        low_threshold,
        " cells for this analysis: ",
        paste(colnames(out)[!as.vector(fil)], collapse = ", "),
        ". They are excluded."
      )
    }
    out <- out[, as.vector(fil)]
  } else {
    fil <- vapply(cluster_ids, FUN = length, FUN.VALUE = numeric(1)) >= 10
    if (!all(as.vector(fil))) {
      message(
        "The following clusters have less than ",
        10,
        " cells for this analysis: ",
        paste(colnames(out)[!as.vector(fil)], collapse = ", "),
        ". Classification is likely inaccurate."
      )
    }
  }
  if (!(is.null(cut_n))) {
    expr_mat <- out
    expr_df <- as.matrix(expr_mat)
    df_temp <- t(matrixStats::colRanks(-expr_df, ties.method = "average"))
    rownames(df_temp) <- rownames(expr_mat)
    colnames(df_temp) <- colnames(expr_mat)
    expr_mat[df_temp > cut_n] <- 0
    out <- expr_mat
  }

  return(out)
}

#' Percentage detected per cluster
#'
#' @param mat expression matrix
#' @param metadata data.frame with cells
#' @param cluster_col column in metadata with cluster number
#' @param cut_num binary cutoff for detection
#' @return matrix of numeric values, with genes for row names,
#' and clusters for column names
percent_clusters <- function(
  mat,
  metadata,
  cluster_col = "cluster",
  cut_num = 0.5
) {
  cluster_info <- metadata
  mat[mat >= cut_num] <- 1
  mat[mat <= cut_num] <- 0

  average_clusters(mat, cluster_info, if_log = FALSE, cluster_col = cluster_col)
}

#' Function to make best call from correlation matrix
#'
#' @param cor_mat correlation matrix
#' @return matrix of 1s and 0s
get_best_match_matrix <- function(cor_mat) {
  cor_mat <- as.matrix(cor_mat)
  best_mat <-
    as.data.frame(cor_mat - matrixStats::rowMaxs(as.matrix(cor_mat)))
  best_mat[best_mat == 0] <- "1"
  best_mat[best_mat != "1"] <- "0"

  return(best_mat)
}

#' Function to make call and attach score
#'
#' @param name name of row to query
#' @param best_mat binarized call matrix
#' @param cor_mat correlation matrix
#' @param carry_cor whether the correlation score gets reported
#' @return string with ident call and possibly cor value
get_best_str <- function(
  name,
  best_mat,
  cor_mat,
  carry_cor = TRUE
) {
  if (sum(as.numeric(best_mat[name, ])) > 0) {
    best.names <- colnames(best_mat)[which(best_mat[name, ] == 1)]
    best.cor <-
      round(cor_mat[name, which(best_mat[name, ] == 1)], 2)
    for (i in seq_len(length(best.cor))) {
      if (i == 1) {
        str <- paste0(
          best.names[i],
          " (",
          best.cor[i],
          ")"
        )
      } else {
        str <- paste0(
          str,
          "; ",
          best.names[i],
          " (",
          best.cor[i],
          ")"
        )
      }
    }
  } else {
    str <- "?"
  }

  if (carry_cor == FALSE) {
    str <- gsub(" \\(.*\\)", "", str)
  }
  return(str)
}

#' Find entries shared in all vectors
#' @description return entries found in all supplied vectors.
#'  If the vector supplied is NULL or NA, then it will be excluded
#'  from the comparison.
#' @param ... vectors
#' @return vector of shared elements
get_common_elements <- function(...) {
  vecs <- list(...)
  # drop NULL elements of list
  vecs <- vecs[!vapply(vecs, is.null, FUN.VALUE = logical(1))]
  # drop NA elements of list (NA values OK in a vector)
  vecs <- vecs[!is.na(vecs)]

  Reduce(intersect, vecs)
}

#' Convert expression matrix to GSEA pathway scores
#' (would take a similar place in workflow before average_clusters/binarize)
#'
#' @param mat expression matrix
#' @param pathway_list a list of vectors, each named for a specific pathway,
#' or dataframe
#' @param n_perm Number of permutation for fgsea function. Defaults to 1000.
#' @param scale convert expr_mat into zscores prior to running GSEA?,
#' default = FALSE
#' @param no_warnings suppress warnings from gsea ties
#' @return matrix of GSEA NES values, cell types as row names,
#' pathways as column names
#' @examples
#' gl <- list(
#'     "n" = c("PPBP", "LYZ", "S100A9"),
#'     "a" = c("IGLL5", "GNLY", "FTL")
#' )
#'
#' pbmc_avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified"
#' )
#'
#' calculate_pathway_gsea(
#'     mat = pbmc_avg,
#'     pathway_list = gl
#' )
#' @export
calculate_pathway_gsea <- function(
  mat,
  pathway_list,
  n_perm = 1000,
  scale = TRUE,
  no_warnings = TRUE
) {
  # pathway_list can be user defined or
  # `my_pathways <- fgsea::reactomePathways(rownames(pbmc4k_matrix))`
  out <- lapply(
    names(pathway_list),
    function(y) {
      marker_list <- list()
      marker_list[[1]] <- pathway_list[[y]]
      names(marker_list) <- y
      v1 <- marker_list
      temp <- run_gsea(
        mat,
        v1,
        n_perm = n_perm,
        scale = scale,
        per_cell = TRUE,
        no_warnings = no_warnings
      )
      temp <- temp[, 3, drop = FALSE]
    }
  )
  res <- do.call(cbind, out)
  colnames(res) <- names(pathway_list)
  res
}

#' manually change idents as needed
#'
#' @param metadata column of ident
#' @param cluster_col column in metadata containing cluster info
#' @param ident_col column in metadata containing identity assignment
#' @param clusters names of clusters to change, string or
#'  vector of strings
#' @param idents new idents to assign, must be length of 1 or
#' same as clusters
#' @return new dataframe of metadata
assign_ident <- function(
  metadata,
  cluster_col = "cluster",
  ident_col = "type",
  clusters,
  idents
) {
  if (!is.vector(clusters) | !is.vector(idents)) {
    stop("unsupported clusters or idents", call. = FALSE)
  } else {
    if (length(idents) == 1) {
      idents <- rep(idents, length(clusters))
    } else if (length(idents) != length(clusters)) {
      stop("unsupported lengths pairs of clusters and idents", call. = FALSE)
    }
  }

  for (n in seq_len(length(clusters))) {
    mindex <- metadata[[cluster_col]] == clusters[n]
    metadata[mindex, ident_col] <- idents[n]
  }
  metadata
}

#' get top calls for each cluster
#'
#' @param cor_mat input similarity matrix
#' @param metadata input metadata with tsne or umap coordinates
#' and cluster ids
#' @param col metadata column, can be cluster or cellid
#' @param collapse_to_cluster if a column name is provided,
#' takes the most frequent call of entire cluster to color in plot
#' @param threshold minimum correlation coefficent cutoff for calling clusters
#' @param topn number of calls for each cluster
#' @return dataframe of cluster, new potential ident, and r info
#' @examples
#' res <- clustify(
#'     input = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     ref_mat = cbmc_ref,
#'     query_genes = pbmc_vargenes,
#'     cluster_col = "classified"
#' )
#'
#' cor_to_call_topn(
#'     cor_mat = res,
#'     metadata = pbmc_meta,
#'     col = "classified",
#'     collapse_to_cluster = FALSE,
#'     threshold = 0.5
#' )
#' @export
cor_to_call_topn <- function(
  cor_mat,
  metadata = NULL,
  col = "cluster",
  collapse_to_cluster = FALSE,
  threshold = 0,
  topn = 2
) {
  correlation_matrix <- cor_mat
  df_temp <- tibble::as_tibble(correlation_matrix, rownames = col)
  df_temp <-
    tidyr::gather(
      df_temp,
      key = !!dplyr::sym("type"),
      value = !!dplyr::sym("r"),
      -!!col
    )
  df_temp[["type"]][df_temp$r < threshold] <-
    paste0("r<", threshold, ", unassigned")
  df_temp <-
    dplyr::top_n(
      dplyr::group_by_at(df_temp, 1),
      topn,
      !!dplyr::sym("r")
    )
  df_temp_full <- df_temp

  if (collapse_to_cluster != FALSE) {
    if (!(col %in% colnames(metadata))) {
      metadata <- tibble::as_tibble(metadata, rownames = col)
    }
    df_temp_full <-
      dplyr::left_join(df_temp_full, metadata, by = col)
    df_temp_full[, "type2"] <- df_temp_full[[collapse_to_cluster]]
    df_temp_full2 <-
      dplyr::group_by(
        df_temp_full,
        !!dplyr::sym("type"),
        !!dplyr::sym("type2")
      )
    df_temp_full2 <-
      dplyr::summarize(df_temp_full2, sum = sum(!!dplyr::sym("r")), n = n())
    df_temp_full2 <-
      dplyr::group_by(df_temp_full2, !!dplyr::sym("type2"))
    df_temp_full2 <-
      dplyr::arrange(df_temp_full2, desc(n), desc(sum))
    df_temp_full2 <-
      dplyr::filter(
        df_temp_full2,
        !!dplyr::sym("type") !=
          paste0(
            "r<",
            threshold,
            ", unassigned"
          )
      )
    df_temp_full2 <- dplyr::slice(df_temp_full2, seq_len(topn))
    df_temp_full2 <-
      dplyr::right_join(
        df_temp_full2,
        dplyr::select(
          df_temp_full,
          -c(
            !!dplyr::sym("type"),
            !!dplyr::sym("r")
          )
        ),
        by = stats::setNames(collapse_to_cluster, "type2"),
        relationship = "many-to-many"
      )
    df_temp_full <-
      dplyr::mutate(
        df_temp_full2,
        type = tidyr::replace_na(
          !!dplyr::sym("type"),
          paste0("r<", threshold, ", unassigned")
        )
      )
    df_temp_full <- dplyr::group_by(
      df_temp_full,
      !!dplyr::sym(col)
    )
    df_temp_full <-
      dplyr::distinct(
        df_temp_full,
        !!dplyr::sym("type"),
        !!dplyr::sym("type2"),
        .keep_all = TRUE
      )
    dplyr::arrange(df_temp_full, desc(n), desc(sum), .by_group = TRUE)
  } else {
    df_temp_full <- dplyr::group_by(
      df_temp_full,
      !!dplyr::sym(col)
    )
    dplyr::arrange(df_temp_full, desc(!!dplyr::sym("r")), .by_group = TRUE)
  }
}

#' pct of cells in each cluster that express genelist
#'
#' @param matrix expression matrix
#' @param genelist vector of marker genes for one identity
#' @param clusters vector of cluster identities
#' @param returning whether to return mean, min,
#' or max of the gene pct in the gene list
#' @return vector of numeric values
gene_pct <- function(
  matrix,
  genelist,
  clusters,
  returning = "mean"
) {
  genelist <- intersect(genelist, rownames(matrix))
  if (is.factor(clusters)) {
    clusters <-
      factor(clusters, levels = c(levels(clusters), "orig.NA"))
  }
  clusters[is.na(clusters)] <- "orig.NA"
  unique_clusters <- unique(clusters)

  if (returning == "mean") {
    vapply(
      unique_clusters,
      function(x) {
        celllist <- clusters == x
        tmp <- matrix[genelist, celllist, drop = FALSE]
        tmp[tmp > 0] <- 1
        mean(Matrix::rowSums(tmp) / ncol(tmp))
      },
      FUN.VALUE = numeric(1)
    )
  } else if (returning == "min") {
    vapply(
      unique_clusters,
      function(x) {
        celllist <- clusters == x
        tmp <- matrix[genelist, celllist, drop = FALSE]
        tmp[tmp > 0] <- 1
        min(Matrix::rowSums(tmp) / ncol(tmp))
      },
      FUN.VALUE = numeric(1)
    )
  } else if (returning == "max") {
    vapply(
      unique_clusters,
      function(x) {
        celllist <- clusters == x
        tmp <- matrix[genelist, celllist, drop = FALSE]
        tmp[tmp > 0] <- 1
        max(Matrix::rowSums(tmp) / ncol(tmp))
      },
      FUN.VALUE = numeric(1)
    )
  }
}

#' pct of cells in every cluster that express a series of genelists
#'
#' @param matrix expression matrix
#' @param marker_m matrixized markers
#' @param metadata data.frame or vector containing cluster
#' assignments per cell.
#' Order must match column order in supplied matrix. If a data.frame
#' provide the cluster_col parameters.
#' @param cluster_col column in metadata with cluster number
#' @param norm whether and how the results are normalized
#' @return matrix of numeric values, clusters from mat as row names,
#'  cell types from marker_m as column names
#' @examples
#' gene_pct_markerm(
#'     matrix = pbmc_matrix_small,
#'     marker_m = cbmc_m,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified"
#' )
#' @export
gene_pct_markerm <- function(
  matrix,
  marker_m,
  metadata,
  cluster_col = NULL,
  norm = NULL
) {
  cluster_info <- metadata
  if (is.vector(cluster_info)) {
  } else if (is.data.frame(cluster_info) & !is.null(cluster_col)) {
    cluster_info <- cluster_info[[cluster_col]]
  } else {
    stop(
      "metadata not formatted correctly,
         supply either a  vector or a dataframe",
      call. = FALSE
    )
  }

  # coerce factors in character
  if (is.factor(cluster_info)) {
    cluster_info <- as.character(cluster_info)
  }

  if (!is.data.frame(marker_m)) {
    marker_m <- as.data.frame(marker_m)
  }

  out <- vapply(
    colnames(marker_m),
    function(x) {
      gene_pct(
        matrix,
        marker_m[[x]],
        cluster_info
      )
    },
    FUN.VALUE = numeric(length(unique(cluster_info)))
  )

  if (!(is.null(norm))) {
    if (norm == "divide") {
      out <- sweep(out, 2, apply(out, 2, max), "/")
    } else if (norm == "diff") {
      out <- sweep(out, 2, apply(out, 2, max), "-")
    } else {
      out <- sweep(out, 2, apply(out, 2, max) * norm)
      out[out < 0] <- 0
      out[out > 0] <- 1
    }
  }

  # edge cases where all markers can't be found for a cluster
  out[is.na(out)] <- 0
  out
}

#' Combined function to compare scRNA-seq data to
#'  bulk RNA-seq data and marker list
#'
#' @examples
#'
#' # Seurat
#' so <- so_pbmc()
#' clustify_nudge(
#'     input = so,
#'     ref_mat = cbmc_ref,
#'     marker = cbmc_m,
#'     cluster_col = "seurat_clusters",
#'     threshold = 0.8,
#'     obj_out = FALSE,
#'     mode = "pct",
#'     dr = "umap"
#' )
#'
#' # Matrix
#' clustify_nudge(
#'     input = pbmc_matrix_small,
#'     ref_mat = cbmc_ref,
#'     metadata = pbmc_meta,
#'     marker = as.matrix(cbmc_m),
#'     query_genes = pbmc_vargenes,
#'     cluster_col = "classified",
#'     threshold = 0.8,
#'     call = FALSE,
#'     marker_inmatrix = FALSE,
#'     mode = "pct"
#' )
#' @export
clustify_nudge <- function(input, ...) {
  UseMethod("clustify_nudge", input)
}

#' @rdname clustify_nudge
#' @param input express matrix or object
#' @param ref_mat reference expression matrix
#' @param metadata cell cluster assignments, supplied as a vector
#' or data.frame. If
#' data.frame is supplied then `cluster_col` needs to be set.
#' @param marker matrix of markers
#' @param query_genes A vector of genes of interest to compare.
#' If NULL, then common genes between
#' the expr_mat and ref_mat will be used for comparision.
#' @param cluster_col column in metadata that contains cluster ids per cell.
#'  Will default to first
#' column of metadata if not supplied.
#' Not required if running correlation per cell.
#' @param compute_method method(s) for computing similarity scores
#' @param weight relative weight for the gene list scores,
#' when added to correlation score
#' @param dr stored dimension reduction
#' @param ... passed to matrixize_markers
#' @param norm whether and how the results are normalized
#' @param call make call or just return score matrix
#' @param marker_inmatrix whether markers genes are already
#'  in preprocessed matrix form
#' @param mode use marker expression pct or ranked cor score for nudging
#' @param obj_out whether to output object instead of cor matrix
#' @param seurat_out output cor matrix or called seurat object (deprecated, use obj_out)
#' @param rename_prefix prefix to add to type and r column names
#' @param lookuptable if not supplied, will look in built-in
#' table for object parsing
#' @param threshold identity calling minimum score threshold,
#'  only used when obj_out = T

#' @return single cell object, or matrix of numeric values,
#'  clusters from input as row names, cell types from ref_mat as column names
#' @export
clustify_nudge.default <- function(
  input,
  ref_mat,
  marker,
  metadata = NULL,
  cluster_col = NULL,
  query_genes = NULL,
  compute_method = "spearman",
  weight = 1,
  threshold = -Inf,
  dr = "umap",
  norm = "diff",
  call = TRUE,
  marker_inmatrix = TRUE,
  mode = "rank",
  obj_out = FALSE,
  seurat_out = obj_out,
  rename_prefix = NULL,
  lookuptable = NULL,
  ...
) {
  if (marker_inmatrix != TRUE) {
    marker <- matrixize_markers(
      marker,
      ...
    )
  }

  if (!inherits(input, c("matrix", "Matrix", "data.frame"))) {
    input_original <- input
    temp <- parse_loc_object(
      input,
      type = class(input),
      expr_loc = NULL,
      meta_loc = NULL,
      var_loc = NULL,
      cluster_col = cluster_col,
      lookuptable = lookuptable
    )

    if (!(is.null(temp[["expr"]]))) {
      message("recognized object type - ", class(input))
    }

    input <- temp[["expr"]]
    metadata <- temp[["meta"]]
    if (is.null(query_genes)) {
      query_genes <- temp[["var"]]
    }
    if (is.null(cluster_col)) {
      cluster_col <- temp[["col"]]
    }
  }

  resa <- clustify(
    input = input,
    ref_mat = ref_mat,
    metadata = metadata,
    cluster_col = cluster_col,
    query_genes = query_genes,
    obj_out = FALSE,
    per_cell = FALSE
  )

  if (mode == "pct") {
    resb <- gene_pct_markerm(
      input,
      marker,
      metadata,
      cluster_col = cluster_col,
      norm = norm
    )
  } else if (mode == "rank") {
    if (ncol(marker) > 1 && is.character(marker[1, 1])) {
      marker <- pos_neg_marker(marker)
    }
    resb <- pos_neg_select(
      input,
      marker,
      metadata,
      cluster_col = cluster_col,
      cutoff_score = NULL
    )
    empty_vec <- setdiff(colnames(resa), colnames(resb))
    empty_mat <-
      matrix(
        0,
        nrow = nrow(resb),
        ncol = length(empty_vec),
        dimnames = list(rownames(resb), empty_vec)
      )
    resb <- cbind(resb, empty_mat)
  }

  res <- resa[order(rownames(resa)), order(colnames(resa))] +
    resb[order(rownames(resb)), order(colnames(resb))] * weight
  obj_out <- seurat_out
  if (
    obj_out &&
      !inherits(input_original, c("matrix", "Matrix", "data.frame"))
  ) {
    df_temp <- cor_to_call(
      res,
      metadata = metadata,
      cluster_col = cluster_col,
      threshold = threshold
    )

    df_temp_full <- call_to_metadata(
      df_temp,
      metadata = metadata,
      cluster_col = cluster_col,
      per_cell = FALSE,
      rename_prefix = rename_prefix
    )

    out <- insert_meta_object(
      input_original,
      df_temp_full,
      lookuptable = lookuptable
    )

    return(out)
  } else {
    if (call == TRUE) {
      df_temp <- cor_to_call(res, threshold = threshold)
      colnames(df_temp) <- c(cluster_col, "type", "score")
      return(df_temp)
    } else {
      return(res)
    }
  }
}

#' @rdname clustify_nudge
#' @export
clustify_nudge.Seurat <- function(
  input,
  ref_mat,
  marker,
  cluster_col = NULL,
  query_genes = NULL,
  compute_method = "spearman",
  weight = 1,
  obj_out = TRUE,
  seurat_out = obj_out,
  threshold = -Inf,
  dr = "umap",
  norm = "diff",
  marker_inmatrix = TRUE,
  mode = "rank",
  rename_prefix = NULL,
  ...
) {
  if (marker_inmatrix != TRUE) {
    marker <- matrixize_markers(
      marker,
      ...
    )
  }
  resa <- clustify(
    input = input,
    ref_mat = ref_mat,
    cluster_col = cluster_col,
    query_genes = query_genes,
    obj_out = FALSE,
    per_cell = FALSE,
    dr = dr
  )

  if (mode == "pct") {
    resb <- gene_pct_markerm(
      object_data(input, "data"),
      marker,
      object_data(input, "meta.data"),
      cluster_col = cluster_col,
      norm = norm
    )
  } else if (mode == "rank") {
    if (ncol(marker) > 1 && is.character(marker[1, 1])) {
      marker <- pos_neg_marker(marker)
    }
    resb <- pos_neg_select(
      object_data(input, "data"),
      marker,
      object_data(input, "meta.data"),
      cluster_col = cluster_col,
      cutoff_score = NULL
    )
    empty_vec <- setdiff(colnames(resa), colnames(resb))
    empty_mat <-
      matrix(
        0,
        nrow = nrow(resb),
        ncol = length(empty_vec),
        dimnames = list(rownames(resb), empty_vec)
      )
    resb <- cbind(resb, empty_mat)
  }

  res <- resa[order(rownames(resa)), order(colnames(resa))] +
    resb[order(rownames(resb)), order(colnames(resb))] * weight
  obj_out <- seurat_out
  if (!obj_out) {
    res
  } else {
    df_temp <- cor_to_call(
      res,
      metadata = object_data(input, "meta.data"),
      cluster_col = cluster_col,
      threshold = threshold
    )

    df_temp_full <- call_to_metadata(
      df_temp,
      metadata = object_data(input, "meta.data"),
      cluster_col = cluster_col,
      per_cell = FALSE,
      rename_prefix = rename_prefix
    )

    if ("SeuratObject" %in% loadedNamespaces()) {
      input <- write_meta(input, df_temp_full)
      return(input)
    } else {
      message("seurat not loaded, returning cor_mat instead")
      return(res)
    }
    input
  }
}
#' lookup table for single cell object structures
#' @importFrom SummarizedExperiment colData<-
#' @returns A list populated with standardized functions to
#' access relevant data structures in multiple single cell
#' data formats.
object_loc_lookup <- function() {
  l <- list()

  l$SingleCellExperiment <- c(
    expr = function(x) object_data(x, "data"),
    meta = function(x) object_data(x, "meta.data"),
    add_meta = function(x, md) {
      colData(x) <- md
      x
    },
    var = NULL,
    col = "cell_type1"
  )

  l$Seurat <- c(
    expr = function(x) object_data(x, "data"),
    meta = function(x) object_data(x, "meta.data"),
    add_meta = function(x, md) {
      x@meta.data <- md
      x
    },
    var = function(x) object_data(x, "var.genes"),
    col = "RNA_snn_res.1"
  )

  l$URD <- c(
    expr = function(x) x@logupx.data,
    meta = function(x) x@meta,
    add_meta = function(x, md) {
      x@meta <- md
      x
    },
    var = function(x) x@var.genes,
    col = "cluster"
  )

  l$FunctionalSingleCellExperiment <- c(
    expr = function(x) x@ExperimentList$rnaseq@assays$data$logcounts,
    meta = function(x) x@ExperimentList$rnaseq@colData,
    add_meta = function(x, md) {
      x@ExperimentList$rnaseq@colData <- md
      x
    },
    var = NULL,
    col = "leiden_cluster"
  )

  l$CellDataSet <- c(
    expr = function(x) {
      do.call(
        function(x) {
          row.names(x) <- x@featureData@data$gene_short_name
          return(x)
        },
        list(x@assayData$exprs)
      )
    },
    meta = function(x) as.data.frame(x@phenoData@data),
    add_meta = function(x, md) {
      x@phenoData@data <- md
      x
    },
    var = function(x)
      as.character(x@featureData@data$gene_short_name[
        x@featureData@data$use_for_ordering == TRUE
      ]),
    col = "Main_Cluster"
  )
  l
}

#' more flexible parsing of single cell objects
#'
#' @param input input object
#' @param type look up predefined slots/loc
#' @param expr_loc function that extracts expression matrix
#' @param meta_loc function that extracts metadata
#' @param var_loc function that extracts variable genes
#' @param cluster_col column of clustering from metadata
#' @param lookuptable if not supplied, will use object_loc_lookup() for parsing.
#' @return list of expression, metadata, vargenes, cluster_col info from object
#' @examples
#' so <- so_pbmc()
#' obj <- parse_loc_object(so)
#' length(obj)
#' @export
parse_loc_object <- function(
  input,
  type = class(input),
  expr_loc = NULL,
  meta_loc = NULL,
  var_loc = NULL,
  cluster_col = NULL,
  lookuptable = NULL
) {
  if (!type %in% c("SingleCellExperiment", "Seurat")) {
    warning(
      "Support for ",
      type,
      " objects is deprecated ",
      "and will be removed from clustifyr in the next version"
    )
  }

  if (is.null(lookuptable)) {
    lookup <- object_loc_lookup()
  } else {
    warning(
      "Support for supplying custom objects is deprecated ",
      "and will be removed from clustifyr in the next version"
    )
    lookup <- lookuptable
  }

  if (type %in% names(lookup)) {
    parsed <- list(
      expr = lookup[[type]]$expr(input),
      meta = as.data.frame(lookup[[type]]$meta(input)),
      var = lookup[[type]]$var(input),
      col = lookup[[type]]$col
    )
  } else {
    parsed <- list(NULL, NULL, NULL, NULL)
  }

  names(parsed) <- c("expr", "meta", "var", "col")

  if (!(is.null(expr_loc))) {
    parsed[["expr"]] <- expr_loc(input)
  }

  if (!(is.null(meta_loc))) {
    parsed[["meta"]] <- as.data.frame(meta_loc(input))
  }

  if (!(is.null(var_loc))) {
    parsed[["var"]] <- var_loc(input)
  }

  if (!(is.null(cluster_col))) {
    parsed[["col"]] <- cluster_col
  }

  parsed
}

#' more flexible metadata update of single cell objects
#'
#' @param input input object
#' @param new_meta new metadata table to insert back into object
#' @param type look up predefined slots/loc
#' @param meta_loc metadata location
#' @param lookuptable if not supplied,
#' will look in built-in table for object parsing
#' @return new object with new metadata inserted
#' @examples
#' so <- so_pbmc()
#' insert_meta_object(so, seurat_meta(so, dr = "umap"))
#' @export
insert_meta_object <- function(
  input,
  new_meta,
  type = class(input),
  meta_loc = NULL,
  lookuptable = NULL
) {
  if (is.null(lookuptable)) {
    lookup <- object_loc_lookup()
  } else {
    lookup <- lookuptable
  }

  if (!type %in% names(lookup)) {
    stop("unrecognized object type", call. = FALSE)
  } else {
    input <- lookup[[type]]$add_meta(input, new_meta)
    return(input)
  }
}

#' compare clustering parameters and classification outcomes
#'
#' @param expr expression matrix
#' @param metadata metadata including cluster info and
#' dimension reduction plotting
#' @param ref_mat reference matrix
#' @param cluster_col column of clustering from metadata
#' @param x_col column of metadata for x axis plotting
#' @param y_col column of metadata for y axis plotting
#' @param n expand n-fold for over/under clustering
#' @param ngenes number of genes to use for feature selection,
#' use all genes if NULL
#' @param query_genes vector, otherwise genes with be recalculated
#' @param do_label whether to label each cluster at median center
#' @param do_legend whether to draw legend
#' @param newclustering use kmeans if NULL on dr
#' or col name for second column of clustering
#' @param threshold type calling threshold
#' @param combine if TRUE return a single plot with combined panels, if
#' FALSE return list of plots (default: TRUE)
#' @return faceted ggplot object
#' @examples
#' set.seed(42)
#' overcluster_test(
#'     expr = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     ref_mat = cbmc_ref,
#'     cluster_col = "classified",
#'     x_col = "UMAP_1",
#'     y_col = "UMAP_2"
#' )
#' @export
overcluster_test <- function(
  expr,
  metadata,
  ref_mat,
  cluster_col,
  x_col = "UMAP_1",
  y_col = "UMAP_2",
  n = 5,
  ngenes = NULL,
  query_genes = NULL,
  threshold = 0,
  do_label = TRUE,
  do_legend = FALSE,
  newclustering = NULL,
  combine = TRUE
) {
  if (is.null(newclustering)) {
    metadata$new_clusters <-
      as.character(
        stats::kmeans(
          metadata[, c(x_col, y_col)],
          centers = n * length(unique(metadata[[cluster_col]]))
        )$clust
      )
  } else {
    metadata$new_clusters <- metadata[[newclustering]]
    n <- length(unique(metadata[[newclustering]])) /
      length(unique(metadata[[cluster_col]]))
  }

  if (is.null(query_genes)) {
    if (is.null(ngenes)) {
      genes <- rownames(expr)
    } else {
      genes <- ref_feature_select(expr, ngenes)
    }
  } else {
    genes <- query_genes
  }
  res1 <- clustify(
    expr,
    ref_mat,
    metadata,
    query_genes = genes,
    cluster_col = cluster_col,
    obj_out = FALSE
  )
  res2 <- clustify(
    expr,
    ref_mat,
    metadata,
    query_genes = genes,
    cluster_col = "new_clusters",
    obj_out = FALSE
  )
  o1 <- plot_dims(
    metadata,
    feature = cluster_col,
    x = x_col,
    y = y_col,
    do_label = FALSE,
    do_legend = FALSE
  )
  o2 <- plot_dims(
    metadata,
    feature = "new_clusters",
    x = x_col,
    y = y_col,
    do_label = FALSE,
    do_legend = FALSE
  )
  p1 <- plot_best_call(
    res1,
    metadata,
    cluster_col,
    threshold = threshold,
    do_label = do_label,
    do_legend = do_legend,
    x = x_col,
    y = y_col
  )
  p2 <- plot_best_call(
    res2,
    metadata,
    "new_clusters",
    threshold = threshold,
    do_label = do_label,
    do_legend = do_legend,
    x = x_col,
    y = y_col
  )
  n_orig_clusters <- length(unique(metadata[[cluster_col]]))
  n_new_clusters <- n * length(unique(metadata[[cluster_col]]))

  if (combine) {
    g <- suppressWarnings(cowplot::plot_grid(
      o1,
      o2,
      p1,
      p2,
      labels = c(
        n_orig_clusters,
        n_new_clusters
      )
    ))
  } else {
    g <- list(
      original_clusters = o1,
      new_clusters = o2,
      original_cell_types = p1,
      new_cell_types = p2
    )
  }

  return(g)
}

#' feature select from reference matrix
#'
#' @param mat reference matrix
#' @param n number of genes to return
#' @param mode the method of selecting features
#' @param rm.lowvar whether to remove lower variation genes first
#' @return vector of genes
#' @examples
#' pbmc_avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified"
#' )
#'
#' ref_feature_select(
#'     mat = pbmc_avg[1:100, ],
#'     n = 5
#' )
#' @export
ref_feature_select <- function(
  mat,
  n = 3000,
  mode = "var",
  rm.lowvar = TRUE
) {
  if (rm.lowvar == TRUE) {
    if (!(is.matrix(mat))) {
      mat <- as.matrix(mat)
    }
    v <- matrixStats::rowVars(mat)
    names(v) <- rownames(mat)
    v2 <- v[order(-v)][seq_len(length(v) / 2)]
    mat <- mat[names(v2)[!is.na(names(v2))], ]
  }

  if (mode == "cor") {
    cor_mat <- cor(t(as.matrix(mat)), method = "spearman")
    diag(cor_mat) <- rep(0, times = nrow(cor_mat))
    cor_mat <- abs(cor_mat)
    score <- matrixStats::rowMaxs(cor_mat, na.rm = TRUE)
    names(score) <- rownames(cor_mat)
    score <- score[order(-score)]
    cor_genes <- names(score[seq_len(n)])
  } else if (mode == "var") {
    cor_genes <- names(v2[seq_len(n)])
  }
  cor_genes
}

#' Returns a list of variable genes based on PCA
#'
#' @description  Extract genes, i.e. "features", based on the top
#' loadings of principal components
#' formed from the bulk expression data set
#'
#' @param mat Expression matrix. Rownames are genes,
#' colnames are single cell cluster name, and
#' values are average single cell expression (log transformed).
#' @param pcs Precalculated pcs if available, will skip over processing on mat.
#' @param n_pcs Number of PCs to selected gene loadings from.
#' See the explore_PCA_corr.Rmd vignette for details.
#' @param percentile Select the percentile of absolute values of
#' PCA loadings to select genes from. E.g. 0.999 would select the
#' top point 1 percent of genes with the largest loadings.
#' @param if_log whether the data is already log transformed
#' @return vector of genes
#' @examples
#' feature_select_PCA(
#'     cbmc_ref,
#'     if_log = FALSE
#' )
#' @export
feature_select_PCA <- function(
  mat = NULL,
  pcs = NULL,
  n_pcs = 10,
  percentile = 0.99,
  if_log = TRUE
) {
  if (if_log == FALSE) {
    mat <- log(mat + 1)
  }

  # Get the PCs
  if (is.null(pcs)) {
    pca <- prcomp(t(as.matrix(mat)))$rotation
  } else {
    pca <- pcs
  }

  # For the given number PCs, select the genes with the largest loadings
  genes <- c()
  for (i in seq_len(n_pcs)) {
    cutoff <- quantile(abs(pca[, i]), probs = percentile)
    genes <- c(genes, rownames(pca[abs(pca[, i]) >= cutoff, ]))
  }

  return(genes)
}

#' convert gmt format of pathways to list of vectors
#'
#' @param path gmt file path
#' @param cutoff remove pathways with less genes than this cutoff
#' @param sep sep used in file to split path and genes
#' @return list of genes in each pathway
#' @examples
#' gmt_file <- system.file(
#'     "extdata",
#'     "c2.cp.reactome.v6.2.symbols.gmt.gz",
#'     package = "clustifyr"
#' )
#'
#' gene.lists <- gmt_to_list(path = gmt_file)
#' length(gene.lists)
#' @importFrom utils read.csv
#' @export
gmt_to_list <- function(
  path,
  cutoff = 0,
  sep = "\thttp://www.broadinstitute.org/gsea/msigdb/cards/.*?\t"
) {
  df <- read.csv(path, sep = ",", header = FALSE, col.names = "V1")
  df <- tidyr::separate(
    df,
    !!dplyr::sym("V1"),
    sep = sep,
    into = c("path", "genes")
  )
  pathways <- stringr::str_split(
    df$genes,
    "\t"
  )
  names(pathways) <- stringr::str_replace(
    df$path,
    "REACTOME_",
    ""
  )
  if (cutoff > 0) {
    ids <- vapply(
      pathways,
      function(i) {
        length(i) < cutoff
      },
      FUN.VALUE = logical(1)
    )
    pathways <- pathways[!ids]
  }
  return(pathways)
}

#' plot GSEA pathway scores as heatmap,
#'  returns a list containing results and plot.
#'
#' @param mat expression matrix
#' @param pathway_list a list of vectors, each named for a specific pathway,
#' or dataframe
#' @param n_perm Number of permutation for fgsea function. Defaults to 1000.
#' @param scale convert expr_mat into zscores prior to running GSEA?,
#'  default = TRUE
#' @param topn number of top pathways to plot
#' @param returning to return "both" list and plot, or either one
#' @return list of matrix and plot, or just plot, matrix of GSEA NES values,
#' cell types as row names, pathways as column names
#' @examples
#' gl <- list(
#'     "n" = c("PPBP", "LYZ", "S100A9"),
#'     "a" = c("IGLL5", "GNLY", "FTL")
#' )
#'
#' pbmc_avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified"
#' )
#'
#' plot_pathway_gsea(
#'     pbmc_avg,
#'     gl,
#'     5
#' )
#' @export
plot_pathway_gsea <- function(
  mat,
  pathway_list,
  n_perm = 1000,
  scale = TRUE,
  topn = 5,
  returning = "both"
) {
  res <- calculate_pathway_gsea(mat, pathway_list, n_perm, scale = scale)
  coltopn <-
    unique(cor_to_call_topn(res, topn = topn, threshold = -Inf)$type)
  res[is.na(res)] <- 0

  g <- suppressWarnings(ComplexHeatmap::Heatmap(
    res[, coltopn],
    column_names_gp = grid::gpar(fontsize = 6)
  ))

  if (returning == "both") {
    return(list(res, g))
  } else if (returning == "plot") {
    return(g)
  } else {
    return(res)
  }
}

#' downsample matrix by cluster or completely random
#'
#' @param mat expression matrix
#' @param n number per cluster or fraction to keep
#' @param keep_cluster_proportions whether to subsample
#' @param metadata data.frame or
#' vector containing cluster assignments per cell.
#' Order must match column order in supplied matrix. If a data.frame
#' provide the cluster_col parameters.
#' @param cluster_col column in metadata with cluster number
#' @return new smaller mat with less cell_id columns
#' @examples
#' set.seed(42)
#' mat <- downsample_matrix(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta$classified,
#'     n = 10,
#'     keep_cluster_proportions = TRUE
#' )
#' mat[1:3, 1:3]
#' @export
downsample_matrix <- function(
  mat,
  n = 1,
  keep_cluster_proportions = TRUE,
  metadata = NULL,
  cluster_col = "cluster"
) {
  cluster_info <- metadata
  if (keep_cluster_proportions == FALSE) {
    cluster_ids <- colnames(mat)
    if (n < 1) {
      n <- as.integer(ncol(mat) * n)
    }
    cluster_ids_new <- sample(cluster_ids, n)
  } else {
    if (is.vector(cluster_info)) {
      cluster_ids <- split(colnames(mat), cluster_info)
    } else if (
      is.data.frame(cluster_info) &
        !is.null(cluster_col)
    ) {
      cluster_ids <- split(colnames(mat), cluster_info[[cluster_col]])
    } else if (is.factor(cluster_info)) {
      cluster_info <- as.character(cluster_info)
      cluster_ids <- split(colnames(mat), cluster_info)
    } else {
      stop(
        "metadata not formatted correctly,
         supply either a  vector or a dataframe",
        call. = FALSE
      )
    }
    if (n < 1) {
      n2 <- vapply(
        cluster_ids,
        function(x) {
          as.integer(length(x) * n)
        },
        FUN.VALUE = numeric(1)
      )
      n <- n2
    }
    cluster_ids_new <-
      mapply(sample, cluster_ids, n, SIMPLIFY = FALSE)
  }
  return(mat[, unlist(cluster_ids_new)])
}

#' marker selection from reference matrix
#'
#' @param mat reference matrix
#' @param cut an expression minimum cutoff
#' @param arrange whether to arrange (lower means better)
#' @param compto compare max expression to the value of next 1 or more
#' @return dataframe, with gene, cluster, ratio columns
#' @examples
#' ref_marker_select(
#'     cbmc_ref,
#'     cut = 2
#' )
#' @export
ref_marker_select <-
  function(
    mat,
    cut = 0.5,
    arrange = TRUE,
    compto = 1
  ) {
    mat <- mat[!is.na(rownames(mat)), ]
    mat <- mat[Matrix::rowSums(mat) != 0, ]
    ref_cols <- colnames(mat)
    res <-
      apply(mat, 1, marker_select, ref_cols, cut, compto = compto)
    if (is.list(res)) {
      res <- res[!vapply(res, is.null, FUN.VALUE = logical(1))]
    }
    resdf <- t(as.data.frame(res, stringsAsFactors = FALSE))
    if (tibble::has_rownames(as.data.frame(resdf, stringsAsFactors = FALSE))) {
      resdf <- tibble::remove_rownames(as.data.frame(
        resdf,
        stringsAsFactors = FALSE
      ))
    }
    resdf <- tibble::rownames_to_column(
      resdf,
      "gene"
    )
    colnames(resdf) <- c("gene", "cluster", "ratio")
    resdf <-
      dplyr::mutate(resdf, ratio = as.numeric(!!dplyr::sym("ratio")))
    if (arrange == TRUE) {
      resdf <- dplyr::group_by(resdf, cluster)
      resdf <-
        dplyr::arrange(resdf, !!dplyr::sym("ratio"), .by_group = TRUE)
      resdf <- dplyr::ungroup(resdf)
    }
    resdf
  }

#' decide for one gene whether it is a marker for a certain cell type
#' @param row1 a numeric vector of expression values (row)
#' @param cols a vector of cell types (column)
#' @param cut an expression minimum cutoff
#' @param compto compare max expression to the value of next 1 or more
#' @return vector of cluster name and ratio value
#' @examples
#' pbmc_avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = FALSE
#' )
#'
#' marker_select(
#'     row1 = pbmc_avg["PPBP", ],
#'     cols = names(pbmc_avg["PPBP", ])
#' )
#' @export
marker_select <- function(
  row1,
  cols,
  cut = 1,
  compto = 1
) {
  row_sorted <- sort(row1, decreasing = TRUE)
  col_sorted <- names(row_sorted)
  num_sorted <- unname(row_sorted)
  if (num_sorted[1] >= cut) {
    return(c(col_sorted[1], (num_sorted[1 + compto] / num_sorted[1])))
  }
}

#' adapt clustify to tweak score for pos and neg markers
#' @param input single-cell expression matrix
#' @param metadata cell cluster assignments,
#' supplied as a vector or data.frame. If
#' data.frame is supplied then `cluster_col` needs to be set.
#'  Not required if running correlation per cell.
#' @param ref_mat reference expression matrix with positive and
#' negative markers(set expression at 0)
#' @param cluster_col column in metadata that contains cluster ids per cell.
#' Will default to first
#' column of metadata if not supplied.
#' Not required if running correlation per cell.
#' @param cutoff_n expression cutoff where genes ranked below n are
#'  considered non-expressing
#' @param cutoff_score positive score lower than this cutoff will be
#' considered as 0 to not influence scores
#' @return matrix of numeric values, clusters from input as row names,
#'  cell types from ref_mat as column names
#' @examples
#' pn_ref <- data.frame(
#'     "Myeloid" = c(1, 0.01, 0),
#'     row.names = c("CD74", "clustifyr0", "CD79A")
#' )
#'
#' pos_neg_select(
#'     input = pbmc_matrix_small,
#'     ref_mat = pn_ref,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     cutoff_score = 0.8
#' )
#' @export
pos_neg_select <- function(
  input,
  ref_mat,
  metadata,
  cluster_col = "cluster",
  cutoff_n = 0,
  cutoff_score = 0.5
) {
  suppressWarnings(
    res <- clustify(
      rbind(input, "clustifyr0" = 0.01),
      ref_mat,
      metadata,
      cluster_col = cluster_col,
      per_cell = TRUE,
      verbose = TRUE,
      query_genes = rownames(ref_mat)
    )
  )
  res[is.na(res)] <- 0

  suppressWarnings(
    res2 <- average_clusters(
      t(res),
      metadata,
      cluster_col = cluster_col,
      if_log = FALSE,
      output_log = FALSE
    )
  )
  res2 <- t(res2)

  if (!(is.null(cutoff_score))) {
    res2 <- apply(res2, 2, function(x) {
      maxr <- max(x)
      if (maxr > 0.1) {
        x[x > 0 & x < cutoff_score * maxr] <- 0
      }
      x
    })
  }

  res2
}

#' generate negative markers from a list of exclusive positive markers
#' @param mat matrix or dataframe of markers
#' @return matrix of gene names
#' @examples
#' reverse_marker_matrix(cbmc_m)
#' @export
reverse_marker_matrix <- function(mat) {
  full_vec <- as.vector(t(mat))
  mat_rev <- apply(mat, 2, function(x) {
    full_vec[!(full_vec %in% x)]
  })
  as.data.frame(mat_rev)
}

#' generate pos and negative marker expression matrix from a
#' list/dataframe of positive markers
#' @param mat matrix or dataframe of markers
#' @return matrix of gene expression
#' @examples
#' m1 <- pos_neg_marker(cbmc_m)
#' @export
pos_neg_marker <- function(mat) {
  if (is.data.frame(mat)) {
    mat <- as.list(mat)
  } else if (is.matrix(mat)) {
    mat <- as.list(as.data.frame(mat, stringsAsFactors = FALSE))
  } else if (!is.list(mat)) {
    stop(
      "unsupported marker format,
             must be dataframe, matrix, or list",
      call. = FALSE
    )
  }
  genelist <- mat
  typenames <- names(genelist)

  g2 <- lapply(genelist, function(x) {
    data.frame(gene = x, stringsAsFactors = FALSE)
  })

  g2 <- dplyr::bind_rows(g2, .id = "type")
  g2 <- dplyr::mutate(g2, expression = 1)
  g2 <- tidyr::spread(g2, key = "type", value = "expression")
  if (tibble::has_rownames(g2)) {
    g2 <- tibble::remove_rownames(g2)
  }
  g2 <- tibble::column_to_rownames(g2, "gene")
  g2[is.na(g2)] <- 0
  g2
}
#' takes files with positive and negative markers, as described in garnett,
#' and returns list of markers
#' @param filename txt file to load
#' @return list of positive and negative gene markers
#' @examples
#' marker_file <- system.file(
#'     "extdata",
#'     "hsPBMC_markers.txt",
#'     package = "clustifyr"
#' )
#'
#' file_marker_parse(marker_file)
#' @export
file_marker_parse <- function(filename) {
  lines <- readLines(filename)
  count <- 0
  ident_names <- c()
  ident_pos <- c()
  ident_neg <- c()
  for (line in lines) {
    tag <- substr(line, 1, 1)
    if (tag == ">") {
      count <- count + 1
      ident_names[count] <- substr(line, 2, nchar(line))
    } else if (tag == "e") {
      ident_pos[count] <-
        strsplit(substr(line, 12, nchar(line)), split = ", ")
    } else if (tag == "n") {
      ident_neg[count] <-
        strsplit(substr(line, 16, nchar(line)), split = ", ")
    }
  }

  if (!(is.null(ident_neg))) {
    names(ident_neg) <- ident_names
  }
  if (!(is.null(ident_pos))) {
    names(ident_pos) <- ident_names
  }
  list("pos" = ident_pos, "neg" = ident_neg)
}

#' Generate a unique column id for a dataframe
#' @param df dataframe with column names
#' @param id desired id if unique
#' @return character
get_unique_column <- function(df, id = NULL) {
  if (!is.null(id)) {
    out_id <- id
  } else {
    out_id <- "x"
  }

  res <- ifelse(
    out_id %in% colnames(df),
    make.unique(c(
      colnames(df),
      out_id
    ))[length(c(
      colnames(df),
      out_id
    ))],
    out_id
  )

  res
}

#' Find rank bias
#' @param avg_mat average expression matrix
#' @param ref_mat reference expression matrix
#' @param query_genes original vector of genes used to clustify
#' @return list of matrix of rank diff values
#' @examples
#' avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = FALSE
#' )
#'
#' rankdiff <- find_rank_bias(
#'     avg,
#'     cbmc_ref,
#'     query_genes = pbmc_vargenes
#' )
#' @export
find_rank_bias <- function(
  avg_mat,
  ref_mat,
  query_genes = NULL
) {
  # genes shared between matrix and ref
  if (is.null(query_genes)) {
    query_genes <- intersect(
      rownames(avg_mat),
      rownames(ref_mat)
    )
  } else {
    query_genes <- intersect(
      query_genes,
      intersect(
        rownames(avg_mat),
        rownames(ref_mat)
      )
    )
  }

  # rank average expression matrix
  r2 <- t(matrixStats::colRanks(
    -avg_mat[query_genes, ],
    ties.method = "average"
  ))
  rownames(r2) <- query_genes
  colnames(r2) <- colnames(avg_mat)

  # rank ref matrix
  r1 <- t(matrixStats::colRanks(
    -ref_mat[query_genes, ],
    ties.method = "average"
  ))
  rownames(r1) <- query_genes
  colnames(r1) <- colnames(ref_mat)

  # actual diff calculations
  rdiff <- lapply(
    rownames(r1),
    function(x) {
      res <- outer(r2[x, ], r1[x, ], FUN = "-")
      # rownames(res) <- colnames(r1)
      # colnames(res) <- colnames(r2)
      res
    }
  )
  names(rdiff) <- rownames(r1)

  rdiff
}

#' Query rank bias results
#' @param bias_list list of rank diff matrix between cluster and reference cell types
#' @param id_mat name of cluster from average cluster matrix
#' @param id_ref name of cell type in reference matrix
#' @return data.frame rank diff values
#' @examples
#' avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = FALSE
#' )
#'
#' rankdiff <- find_rank_bias(
#'     avg,
#'     cbmc_ref,
#'     query_genes = pbmc_vargenes
#' )
#'
#' qres <- query_rank_bias(
#'     rankdiff,
#'     "CD14+ Mono",
#'     "CD14+ Mono"
#' )
#' @export
query_rank_bias <- function(
  bias_list,
  id_mat,
  id_ref
) {
  res <- lapply(bias_list, function(x) {
    x[id_mat, id_ref]
  })
  resdf <- data.frame(unlist(res))
  colnames(resdf) <- paste0(id_mat, "_vs_ ", id_ref)
  tibble::rownames_to_column(resdf, "gene")
}

#' Query rank bias results
#' @param bias_df data.frame of rank diff matrix between cluster and reference cell types
#' @param organism for GO term analysis, organism name: human - 'hsapiens', mouse - 'mmusculus'
#' @return ggplot object of distribution and annotated GO terms
#' @examples
#' \dontrun{
#' avg <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = FALSE
#' )
#'
#' rankdiff <- find_rank_bias(
#'     avg,
#'     cbmc_ref,
#'     query_genes = pbmc_vargenes
#' )
#'
#' qres <- query_rank_bias(
#'     rankdiff,
#'     "CD14+ Mono",
#'     "CD14+ Mono"
#' )
#'
#' g <- plot_rank_bias(
#'     qres
#' )
#' }
#' @export
plot_rank_bias <- function(
  bias_df,
  organism = "hsapiens"
) {
  genes_all <- stats::setNames(bias_df[[2]], bias_df[[1]])
  genes_high <- bias_df[genes_all >= (length(genes_all) * 0.33), ]
  genes_low <- bias_df[genes_all <= -(length(genes_all) * 0.33), ]
  if (nrow(genes_high) == 0) {
    go_high <- ""
  } else {
    res_high <- suppressMessages(gprofiler2::gost(
      query = genes_high$gene,
      organism = "hsapiens",
      sources = "GO:BP",
      correction_method = "fdr",
      evcodes = TRUE
    ))
    if (is.null(res_high)) {
      go_high <- ""
    } else {
      go_high <- paste0(
        dplyr::slice(
          dplyr::filter(res_high[[1]], intersection_size > 1),
          seq_len(10)
        )$term_name,
        collapse = "\n"
      )
    }
  }
  if (nrow(genes_low) == 0) {
    go_low <- ""
  } else {
    res_low <- suppressMessages(gprofiler2::gost(
      query = genes_low$gene,
      organism = "hsapiens",
      sources = "GO:BP",
      correction_method = "fdr",
      evcodes = TRUE
    ))
    if (is.null(res_low)) {
      go_low <- ""
    } else {
      go_low <- paste0(
        dplyr::slice(
          dplyr::filter(res_low[[1]], intersection_size > 1),
          seq_len(10)
        )$term_name,
        collapse = "\n"
      )
    }
  }

  col <- colnames(bias_df)[2]
  g <- ggplot2::ggplot(bias_df, ggplot2::aes(!!dplyr::sym(col))) +
    ggplot2::geom_bar() +
    ggplot2::geom_bar(data = genes_high, color = "red", fill = "red") +
    ggplot2::geom_bar(data = genes_low, color = "blue", fill = "blue") +
    cowplot::theme_cowplot() +
    ggplot2::theme(
      axis.line.y = ggplot2::element_blank(),
      axis.title.y = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      axis.ticks.y = ggplot2::element_blank()
    ) +
    ggplot2::annotate(
      "text",
      x = max(bias_df[[2]]) * 1.1,
      y = nrow(bias_df) / 70,
      label = go_high,
      color = "red",
      size = 2,
      hjust = 0
    ) +
    ggplot2::annotate(
      "text",
      x = min(bias_df[[2]]) * 1.1,
      y = nrow(bias_df) / 70,
      label = go_low,
      color = "blue",
      size = 2,
      hjust = 1
    ) +
    ggplot2::coord_cartesian(
      clip = "off",
      xlim = c(
        -max(abs(bias_df[[2]])) * 3,
        max(abs(bias_df[[2]])) * 3
      ),
      ylim = c(0, nrow(bias_df) / 50)
    )
}


#' Given a reference matrix and a list of genes, take the union of
#' all genes in vector and genes in reference matrix
#' and insert zero counts for all remaining genes.
#' @param gene_vector char vector with gene names
#' @param ref_matrix Reference matrix containing cell types vs.
#' gene expression values
#' @return Reference matrix with union of all genes
#' @examples
#' mat <- append_genes(
#'     gene_vector = human_genes_10x,
#'     ref_matrix = cbmc_ref
#' )
#' @export
append_genes <- function(gene_vector, ref_matrix) {
  missing_rows <- setdiff(gene_vector, rownames(ref_matrix))

  zeroExpressionMatrix <- matrix(
    0,
    nrow = length(missing_rows),
    ncol = ncol(ref_matrix)
  )

  rownames(zeroExpressionMatrix) <- missing_rows
  colnames(zeroExpressionMatrix) <- colnames(ref_matrix)

  full_matrix <- rbind(ref_matrix, zeroExpressionMatrix)
  full_matrix <- full_matrix[gene_vector, ]
  full_matrix
}

#' Given a count matrix, determine if the matrix has been either
#' log-normalized, normalized, or contains raw counts
#' @param counts_matrix Count matrix containing scRNA-seq read data
#' @param max_log_value Static value to determine if a matrix is normalized
#' @return String either raw counts, log-normalized or normalized
#' @examples
#' check_raw_counts(pbmc_matrix_small)
#' @export
check_raw_counts <- function(counts_matrix, max_log_value = 50) {
  if (is(counts_matrix, "sparseMatrix")) {
    counts_matrix <- as.matrix(counts_matrix)
  }
  if (!is.matrix(counts_matrix)) {
    counts_matrix <- as.matrix(counts_matrix)
  }
  if (is.integer(counts_matrix)) {
    return("raw counts")
  } else if (is.double(counts_matrix)) {
    if (all(counts_matrix == floor(counts_matrix))) {
      return("raw counts")
    }
    if (max(counts_matrix) > max_log_value) {
      return("normalized")
    } else if (min(counts_matrix) < 0) {
      stop("negative values detected, likely scaled data")
    } else {
      return("log-normalized")
    }
  } else {
    stop("unknown matrix format: ", typeof(counts_matrix))
  }
}

#' Function to combine records into single atlas
#'
#' @param matrix_fns character vector of paths to study matrices stored as .rds files.
#' If a named character vector, then the name will be added as a suffix to the cell type
#' name in the final matrix. If it is not named, then the filename will be used (without .rds)
#' @param genes_fn text file with a single column containing genes and the ordering desired
#' in the output matrix
#' @param matrix_objs Checks to see whether .rds files will be read or R objects in a
#' local environment. A list of environmental objects can be passed to
#' matrx_objs, and that names will be used, otherwise defaults to numbers
#' @param output_fn output filename for .rds file. If NULL the matrix will be returned instead of
#' saving
#' @return Combined matrix with all genes given
#' @examples
#' pbmc_ref_matrix <- average_clusters(
#'     mat = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     cluster_col = "classified",
#'     if_log = TRUE # whether the expression matrix is already log transformed
#' )
#' references_to_combine <- list(pbmc_ref_matrix, cbmc_ref)
#' atlas <- build_atlas(NULL, human_genes_10x, references_to_combine, NULL)
#' @export
build_atlas <- function(
  matrix_fns = NULL,
  genes_fn,
  matrix_objs = NULL,
  output_fn = NULL
) {
  genesVector <- genes_fn
  if (is.null(matrix_objs) && !is.null(matrix_fns)) {
    ref_mats <- lapply(matrix_fns, readRDS)
    if (is.null(names(matrix_fns))) {
      names(ref_mats) <- stringr::str_remove(basename(matrix_fns), ".rds$")
    } else {
      names(ref_mats) <- names(matrix_fns)
    }
  } else if (is.null(matrix_fns) && !is.null(matrix_objs)) {
    ref_mats <- matrix_objs
    if (is.null(names(matrix_objs))) {
      names(ref_mats) <- seq_len(length(matrix_objs))
    }
  }
  new_mats <- list()
  for (i in seq_along(ref_mats)) {
    # standardize genes in matrix
    mat <- append_genes(
      gene_vector = genesVector,
      ref_matrix = as.matrix(ref_mats[[i]])
    )
    # get study name
    mat_name <- names(ref_mats)[i]

    # append study name to cell type names
    new_cols <- paste0(
      colnames(mat),
      " (",
      mat_name,
      ")"
    )
    colnames(mat) <- new_cols

    # assign to list
    new_mats[[i]] <- mat
  }

  # cbind a list of matrices
  atlas <- do.call(cbind, new_mats)

  if (!is.null(output_fn)) {
    saveRDS(atlas, output_fn)
  } else {
    return(atlas)
  }
}

#' make combination ref matrix to assess intermixing
#'
#' @param ref_mat reference expression matrix
#' @param if_log whether input data is natural
#' @param sep separator for name combinations
#' @return expression matrix
#' @examples
#' ref <- make_comb_ref(
#'     cbmc_ref,
#'     sep = "_+_"
#' )
#' ref[1:3, 1:3]
#' @export
make_comb_ref <- function(ref_mat, if_log = TRUE, sep = "_and_") {
  if (if_log == TRUE) {
    ref_mat <- expm1(ref_mat)
  }
  combs <-
    utils::combn(
      x = colnames(ref_mat),
      m = 2,
      simplify = FALSE
    )
  comb_mat <-
    vapply(
      combs,
      FUN = function(x) {
        Matrix::rowMeans(ref_mat[, unlist(x)])
      },
      FUN.VALUE = numeric(nrow(ref_mat))
    )
  colnames(comb_mat) <-
    vapply(
      combs,
      FUN = function(x) {
        stringr::str_c(unlist(x), collapse = sep)
      },
      FUN.VALUE = character(1)
    )
  new_mat <- cbind(ref_mat, comb_mat)
  if (if_log == TRUE) {
    new_mat <- log1p(new_mat)
  }
  new_mat
}

#' Find rank bias
#' @param avg_mat average expression matrix
#' @param ref_mat reference expression matrix
#' @param query_genes original vector of genes used to clustify
#' @param res dataframe of idents, such as output of cor_to_call
#' @param organism for GO term analysis, organism name: human - 'hsapiens', mouse - 'mmusculus'
#' @param plot_name name for saved pdf, if NULL then no file is written (default)
#' @param rds_name name for saved rds of rank_diff, if NULL then no file is written (default)
#' @param expand_unassigned test all ref clusters for unassigned results
#' @return pdf of ggplot object
#' @examples
#' \dontrun{
#' avg <- average_clusters(
#'     pbmc_matrix_small,
#'     pbmc_meta$seurat_clusters
#' )
#' res <- clustify(
#'     input = pbmc_matrix_small,
#'     metadata = pbmc_meta,
#'     ref_mat = cbmc_ref,
#'     query_genes = pbmc_vargenes,
#'     cluster_col = "seurat_clusters"
#' )
#' top_call <- cor_to_call(
#'     res,
#'     metadata = pbmc_meta,
#'     cluster_col = "seurat_clusters",
#'     collapse_to_cluster = FALSE,
#'     threshold = 0.8
#' )
#' res_rank <- assess_rank_bias(
#'     avg,
#'     cbmc_ref,
#'     res = top_call
#' )
#' }
#' @export
assess_rank_bias <- function(
  avg_mat,
  ref_mat,
  query_genes = NULL,
  res,
  organism,
  plot_name = NULL,
  rds_name = NULL,
  expand_unassigned = FALSE
) {
  rankdiff <- find_rank_bias(
    avg_mat,
    ref_mat,
    query_genes = query_genes
  )
  rbiases <- list()
  for (i in seq_len(nrow(res))) {
    id <- res[[1]][i]
    ct <- res[[2]][i]
    if (ct == "unassigned") {
      if (expand_unassigned) {
        message("checking unassigned types against every ref type")
        rb <- lapply(colnames(ref_mat), function(x) {
          query_rank_bias(
            rankdiff,
            id,
            x
          )
        })
        rbiases[i] <- list(NULL)
        rbiases <- append(rbiases, rb)
      } else {
        rbiases[i] <- list(NULL)
      }
    } else {
      rb <- query_rank_bias(
        rankdiff,
        id,
        ct
      )
      rbiases <- append(rbiases, list(rb))
    }
  }

  if (!(is.null(rds_name))) {
    saveRDS(rbiases, paste0(rds_name, ".rds"))
  }
  message("Using gprofiler2 for GO analyses (internet connection required)")
  plts <- lapply(rbiases, function(x) {
    if (is.null(x)) {
      return(NULL)
    } else {
      plot_rank_bias(x, organism = organism)
    }
  })
  plts <- plts[!unlist(lapply(plts, function(x) is.null(x)))]
  if (!(is.null(plot_name))) {
    p <- cowplot::plot_grid(plotlist = plts, ncol = 1)
    ggplot2::ggsave(
      paste0(plot_name, ".pdf"),
      p,
      width = 6,
      height = 4 * length(rbiases),
      limitsize = FALSE
    )
  }
  plts
}

#' Distance calculations for spatial coord
#' @param coord dataframe or matrix of spatial coordinates, cell barcode as rownames
#' @param metadata data.frame or vector containing cluster assignments per cell.
#' Order must match column order in supplied matrix. If a data.frame
#' provide the cluster_col parameters.
#' @param cluster_col column in metadata with cluster number
#' @param collapse_to_cluster instead of reporting min distance to cluster per cell, summarize to cluster level
#' @return min distance matrix
#' @examples
#' cbs <- paste0("cb_", 1:100)
#'
#' spatial_coords <- data.frame(
#'     row.names = cbs,
#'     X = runif(100),
#'     Y = runif(100)
#' )
#' group_ids <- sample(c("A", "B"), 100, replace = TRUE)
#' dist_res <- calc_distance(
#'     spatial_coords,
#'     group_ids
#' )
#' @export
calc_distance <- function(
  coord,
  metadata,
  cluster_col = "cluster",
  collapse_to_cluster = FALSE
) {
  distm <- as.matrix(stats::dist(coord))
  res <- average_clusters(
    distm,
    metadata,
    cluster_col,
    if_log = FALSE,
    output_log = FALSE,
    method = "min"
  )
  if (collapse_to_cluster) {
    res2 <- average_clusters(
      t(res),
      metadata,
      cluster_col,
      if_log = FALSE,
      output_log = FALSE,
      method = "min"
    )
    res2
  } else {
    res
  }
}
rnabioco/clustifyR documentation built on June 13, 2025, 1:42 p.m.