R/signatures.R

Defines functions find_best_km Ftest pamr samr one_vs_others ttest compare_to_subgroup knee_finder elbow_finder

Documented in find_best_km

# == title
# Get signature rows
#
# == param
# -object A `ConsensusPartition-class` object.
# -k Number of subgroups.
# -col Colors for the main heatmap.
# -silhouette_cutoff Cutoff for silhouette scores. Samples with values 
#        less than it are not used for finding signature rows. For selecting a 
#        proper silhouette cutoff, please refer to https://www.stat.berkeley.edu/~s133/Cluster2a.html#tth_tAb1.
# -fdr_cutoff Cutoff for FDR of the difference test between subgroups.
# -top_signatures Top signatures with most significant fdr. Note since fdr might be same for multiple rows,
#          the final number of signatures might not be exactly the same as the one that has been set.
# -group_diff Cutoff for the maximal difference between group means.
# -scale_rows Whether apply row scaling when making the heatmap.
# -.scale_mean Internally used.
# -.scale_sd Internally used.
# -row_km Number of groups for performing k-means clustering on rows. By default it is automatically selected.
# -diff_method Methods to get rows which are significantly different between subgroups, see 'Details' section.
# -anno A data frame of annotations for the original matrix columns. 
#       By default it uses the annotations specified in `consensus_partition` or `run_all_consensus_partition_methods`.
# -anno_col A list of colors (color is defined as a named vector) for the annotations. If ``anno`` is a data frame,
#       ``anno_col`` should be a named list where names correspond to the column names in ``anno``.
# -internal Used internally.
# -show_row_dend Whether show row dendrogram.
# -show_column_names Whether show column names in the heatmap.
# -column_names_gp Graphics parameters for column names.
# -use_raster Internally used.
# -plot Whether to make the plot.
# -verbose Whether to print messages.
# -seed Random seed.
# -left_annotation Annotation put on the left of the heatmap. It should be a `ComplexHeatmap::HeatmapAnnotation-class` object. 
#              The number of items should be the same as the number of the original matrix rows. The subsetting to the significant 
#              rows are automatically performed on the annotation object.
# -right_annotation Annotation put on the right of the heatmap. Same format as ``left_annotation``.
# -simplify Only used internally.
# -prefix Only used internally.
# -enforce The analysis is cached by default, so that the analysis with the same input will be automatically extracted
#     without rerunning them. Set ``enforce`` to ``TRUE`` to enforce the funtion to re-perform the analysis.
# -hash Userd internally.
# -from_hc Is the `ConsensusPartition-class` object a node of a `HierarchicalPartition` object?
# -... Other arguments.
# 
# == details 
# Basically the function applies statistical test for the difference in subgroups for every
# row. There are following methods which test significance of the difference:
#
# -ttest First it looks for the subgroup with highest mean value, compare to each of the 
#        other subgroups with t-test and take the maximum p-value. Second it looks
#        for the subgroup with lowest mean value, compare to each of the other subgroups
#        again with t-test and take the maximum p-values. Later for these two list of p-values
#        take the minimal p-value as the final p-value. 
# -samr/pamr use SAM (from samr package)/PAM (from pamr package) method to find significantly different rows between subgroups.
# -Ftest use F-test to find significantly different rows between subgroups.
# -one_vs_others For each subgroup i in each row, it uses t-test to compare samples in current 
#        subgroup to all other samples, denoted as p_i. The p-value for current row is selected as min(p_i).
# -uniquely_high_in_one_group The signatures are defined as, if they are uniquely up-regulated in subgroup A, then it must fit following criterions:
#          1. in a two-group t-test of A ~ other_merged_groups, the statistic must be > 0 (high in group A) and p-value must be significant, 
#          and 2. for other groups (excluding A), t-test in every pair of groups should not be significant.
#
# ``diff_method`` can also be a self-defined function. The function needs two arguments which are the matrix for the analysis
# and the predicted classes. The function should returns a vector of FDR from the difference test.
#
# == return 
# A data frame with more than two columns:
#
# -``which_row``: row index corresponding to the original matrix.
# -``fdr``: the FDR.
# -``km``: the k-means groups if ``row_km`` is set.
# -other_columns: the mean value (depending rows are scaled or not) in each subgroup.
# 
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
# == example
# data(golub_cola)
# res = golub_cola["ATC", "skmeans"]
# tb = get_signatures(res, k = 3)
# head(tb)
# get_signatures(res, k = 3, top_signatures = 100)
setMethod(f = "get_signatures",
	signature = "ConsensusPartition",
	definition = function(object, k,
	col = if(scale_rows) c("green", "white", "red") else c("blue", "white", "red"),
	silhouette_cutoff = 0.5, 
	fdr_cutoff = cola_opt$fdr_cutoff, 
	top_signatures = NULL,
	group_diff = cola_opt$group_diff,
	scale_rows = object@scale_rows, .scale_mean = NULL, .scale_sd = NULL,
	row_km = NULL,
	diff_method = c("Ftest", "ttest", "samr", "pamr", "one_vs_others", "uniquely_high_in_one_group"),
	anno = get_anno(object), 
	anno_col = get_anno_col(object),
	internal = FALSE,
	show_row_dend = FALSE,
	show_column_names = FALSE, 
	column_names_gp = gpar(fontsize = 8),
	use_raster = TRUE,
	plot = TRUE, verbose = TRUE, seed = 888,
	left_annotation = NULL, right_annotation = NULL,
	simplify = FALSE, prefix = "", enforce = FALSE, hash = NULL, from_hc = FALSE,
	...) {

	if(from_hc) {
		group_diff = object@.env$group_diff
		fdr_cutoff = object@.env$fdr_cutoff
		.scale_mean = object@.env$global_row_mean
		.scale_sd = object@.env$global_row_sd
		k = attr(object, "best_k")
	} else {
		if(missing(k)) stop_wrap("k needs to be provided.")
	}

	dotdot = list(...)
	p_cutoff = NULL
	if("p_cutoff" %in% names(dotdot)) {
		p_cutoff = dotdot$p_cutoff
	}
	from_down_sampling = FALSE
	if(inherits(object, "DownSamplingConsensusPartition")) {
		from_down_sampling = TRUE
	}
	
	class_df = get_classes(object, k)
	class_ids = class_df$class

	if(is.null(p_cutoff)) {
		data = get_matrix(object, include_all_rows = TRUE)
	} else {
		data = object@.env$data[, object@full_column_index, drop = FALSE]
	}

	if(!from_down_sampling) {
		l = class_df$silhouette >= silhouette_cutoff
	} else {
		l = class_df$p <= p_cutoff
	}
	data2 = data[, l, drop = FALSE]
	class = class_df$class[l]
	column_used_index = which(l)
	tb = table(class)
	l = as.character(class) %in% names(which(tb <= 1))
	data2 = data2[, !l, drop = FALSE]
	class = class[!l]
	column_used_index = column_used_index[!l]
	column_used_logical = rep(FALSE, ncol(data))
	column_used_logical[column_used_index] = TRUE
	has_ambiguous = sum(!column_used_logical)
	n_sample_used = length(class)

	if(!from_down_sampling) {
		if(verbose) qqcat("@{prefix}* @{n_sample_used}/@{nrow(class_df)} samples (in @{length(unique(class))} classes) remain after filtering by silhouette (>= @{silhouette_cutoff}).\n")
	} else {
		if(verbose) qqcat("@{prefix}* @{n_sample_used}/@{nrow(class_df)} samples (in @{length(unique(class))} classes) remain after filtering by p-value (<= @{p_cutoff}).\n")
	}
	
	tb = table(class)
	if(sum(tb > 1) <= 1) {
		if(plot) {
			grid.newpage()
			fontsize = convertUnit(unit(0.1, "npc"), "char", valueOnly = TRUE)*get.gpar("fontsize")$fontsize
			grid.text("not enough samples", gp = gpar(fontsize = fontsize))
		}
		if(verbose) qqcat("@{prefix}* Not enough samples.\n")
		return(invisible(data.frame(which_row = integer(0))))
	}
	if(length(unique(class)) <= 1) {
		if(plot) {
			grid.newpage()
			fontsize = convertUnit(unit(0.1, "npc"), "char", valueOnly = TRUE)*get.gpar("fontsize")$fontsize
			grid.text("not enough classes", gp = gpar(fontsize = fontsize))
		}
		if(verbose) qqcat("@{prefix}* Not enough classes.\n")
		return(invisible(data.frame(which_row = integer(0))))
	}

	do_row_clustering = TRUE
	if(inherits(diff_method, "function")) {
		if(verbose) qqcat("@{prefix}* calculate row difference between subgroups by user-defined function.\n")
		diff_method_fun = diff_method
		diff_method = digest(diff_method)
	} else {
		diff_method = match.arg(diff_method)
	}

	if(!is.null(top_signatures)) {
		if(diff_method %in% c("samr", "pamr")) {
			if(verbose) qqcat("`top_signatures` is ignored when `diff_method` is set to samr/pamr.\n")
		}
	}

	if(is.null(hash)) {
		hash = digest(list(used_samples = which(l), 
			               class = class,
			               n_group = k, 
			               diff_method = diff_method,
			               .scale_mean = .scale_mean,
			               .scale_sd = .scale_sd,
			               column_index = object@column_index,
			               group_diff = group_diff,
			               fdr_cutoff = fdr_cutoff,
			               top_signatures = top_signatures,
			               seed = seed),
						algo = "md5")
	} else {
		if(is.null(object@.env[[paste0("signature_fdr_", hash)]])) {
			if(verbose) qqcat("@{prefix} Warning: cannot find cache with hash: @{hash}, generate a new hash.")
			hash = digest(list(used_samples = which(l), 
		               class = class,
		               n_group = k, 
		               diff_method = diff_method,
		               .scale_mean = .scale_mean,
		               .scale_sd = .scale_sd,
		               column_index = object@column_index,
		               group_diff = group_diff,
		               fdr_cutoff = fdr_cutoff,
		               top_signatures = top_signatures,
		               seed = seed),
					algo = "md5")
		}
	}
	
	nm = paste0("signature_fdr_", hash)
	if(verbose) qqcat("@{prefix}* cache hash: @{hash} (seed @{seed}).\n")

	if(enforce) object@.env[[nm]] = NULL

	find_signature = TRUE
	if(!is.null(object@.env[[nm]])) {
		if(diff_method == "samr") {
			if(object@.env[[nm]]$diff_method == "samr" && 
			   object@.env[[nm]]$n_sample_used == n_sample_used && 
			   abs(object@.env[[nm]]$fdr_cutoff - fdr_cutoff) < 1e-10) {
				fdr = object@.env[[nm]]$fdr
				find_signature = FALSE
			}
		} else if(diff_method == "pamr") {
			if(object@.env[[nm]]$diff_method == "pamr" && 
			   object@.env[[nm]]$n_sample_used == n_sample_used && 
			   abs(object@.env[[nm]]$fdr_cutoff - fdr_cutoff) < 1e-10) {
				fdr = object@.env[[nm]]$fdr
				find_signature = FALSE
			}
		} else {
			if(object@.env[[nm]]$diff_method == diff_method &&
			   object@.env[[nm]]$n_sample_used == n_sample_used) {
				fdr = object@.env[[nm]]$fdr
				find_signature = FALSE
			}
		}
	}

	if(verbose) qqcat("@{prefix}* calculating row difference between subgroups by @{diff_method}.\n")
	if(find_signature) {
		if(diff_method == "ttest") {
			fdr = ttest(data2, class)
		} else if(diff_method == "samr") {
			fdr = samr(data2, class, fdr.output = fdr_cutoff)
		} else if(diff_method == "Ftest") {
			fdr = Ftest(data2, class)
		} else if(diff_method == "pamr") {
			fdr = pamr(data2, class, fdr.ouput = fdr_cutoff)
		} else if(diff_method == "one_vs_others") {
			fdr = one_vs_others(data2, class)
		} else {
			fdr = diff_method_fun(data2, class)
		}
	} else {
		if(verbose) qqcat("@{prefix}  - row difference is extracted from cache.\n")
	}

	if(!is.null(top_signatures)) {
		fdr_cutoff = fdr[order(fdr)[min(length(fdr), top_signatures)]]
	}

	if(scale_rows && !is.null(object@.env[[nm]]$row_order_scaled)) {
		row_order = object@.env[[nm]]$row_order_scaled
		if(verbose) qqcat("@{prefix}  - row order for the scaled matrix is extracted from cache.\n")
		do_row_clustering = FALSE
	} else if(!scale_rows && !is.null(object@.env[[nm]]$row_order_unscaled)) {
		row_order = object@.env[[nm]]$row_order_unscaled
		if(verbose) qqcat("@{prefix}  - row order for the unscaled matrix is extracted from cache.\n")
		do_row_clustering = FALSE
	}

	object@.env[[nm]]$diff_method = diff_method
	object@.env[[nm]]$fdr_cutoff = fdr_cutoff
	object@.env[[nm]]$fdr = fdr
	object@.env[[nm]]$n_sample_used = n_sample_used
	object@.env[[nm]]$group_diff = group_diff
	object@.env[[nm]]$.scale_mean = .scale_mean
	object@.env[[nm]]$.scale_sd = .scale_sd

	# filter by fdr
	fdr[is.na(fdr)] = 1
	p_value = attr(fdr, "p_value")

	l_fdr = fdr <= fdr_cutoff
	mat = data[l_fdr, , drop = FALSE]
	fdr2 = fdr[l_fdr]
	if(!is.null(p_value)) {
		p_value = p_value[l_fdr]
	}
	if(!is.null(.scale_mean)) {
		.scale_mean = .scale_mean[l_fdr]
	}
	if(!is.null(.scale_sd)) {
		.scale_sd = .scale_sd[l_fdr]
	}
	
	if(!is.null(left_annotation)) left_annotation = left_annotation[l_fdr, ]
	if(!is.null(right_annotation)) right_annotation = right_annotation[l_fdr, ]

	returned_df = data.frame(which_row = which(l_fdr), fdr = fdr2)
	if(!is.null(p_value)) {
		returned_df$p_value = p_value
	}
	rownames(returned_df) = rownames(object)[l_fdr]
	attr(returned_df, "sample_used") = column_used_logical
	attr(returned_df, "hash") = hash

	# filter by group_diff
	mat1 = mat[, column_used_logical, drop = FALSE]
	if(nrow(mat) == 1) {
		group_mean = rbind(tapply(mat1, class, mean))
	} else {
		group_mean = do.call("cbind", tapply(seq_len(ncol(mat1)), class, function(ind) {
			rowMeans(mat1[, ind, drop = FALSE])
		}))
	}
	colnames(group_mean) = paste0("mean_", colnames(group_mean))
	returned_df = cbind(returned_df, group_mean)
	returned_df$group_diff = apply(group_mean, 1, function(x) max(x) - min(x))

	if(scale_rows) {
		mfoo = mat[, column_used_logical, drop = FALSE]
		if(!is.null(.scale_mean) && !is.null(.scale_sd)) {
			mat1_scaled = (mfoo - .scale_mean)/.scale_sd
		} else {
			mat1_scaled = (mfoo - rowMeans(mfoo))/rowSds(mfoo)
		}
		if(nrow(mat) == 1) {
			group_mean_scaled = rbind(tapply(mat1_scaled, class, mean))
		} else {
			group_mean_scaled = do.call("cbind", tapply(seq_len(ncol(mat1_scaled)), class, function(ind) {
				rowMeans(mat1_scaled[, ind, drop = FALSE])
			}))
		}
		colnames(group_mean_scaled) = paste0("scaled_mean_", colnames(group_mean_scaled))
		returned_df = cbind(returned_df, group_mean_scaled)
		returned_df$group_diff_scaled = apply(group_mean_scaled, 1, function(x) max(x) - min(x))
	}

	if(group_diff > 0) {
		if(scale_rows) {
			l_diff = returned_df$group_diff_scaled >= group_diff
		} else {
			l_diff = returned_df$group_diff >= group_diff
		}
		mat = mat[l_diff, , drop = FALSE]
		mat1 = mat1[l_diff, , drop = FALSE]
		returned_df = returned_df[l_diff, , drop = FALSE]
	}

	returned_obj = returned_df
	rownames(returned_obj) = NULL

	attr(returned_obj, "sample_used") = column_used_logical
	attr(returned_obj, "hash") = hash

	if(simplify && !plot) {
		return(returned_df)
	}

	## add k-means
	row_km_fit = NULL
	if(!internal) {
		if(nrow(mat1) > 10) {
			do_kmeans = TRUE
			if(scale_rows) {
				mat_for_km = t(scale(t(mat1)))
				row_km_fit = object@.env[[nm]]$row_km_fit_scaled
			} else {
				mat_for_km = mat1
				row_km_fit = object@.env[[nm]]$row_km_fit_unscaled
			}

			if(nrow(mat_for_km) > 5000) {
				set.seed(seed)
				mat_for_km2 = mat_for_km[sample(nrow(mat_for_km), 5000), , drop = FALSE]
			} else {
				mat_for_km2 = mat_for_km
			}

			if(!is.null(row_km_fit)) {
				if(is.null(row_km) || identical(as.integer(row_km), length(row_km_fit$size))) {
					returned_obj$km = apply(pdist(row_km_fit$centers, mat_for_km, as.integer(1)), 2, which.min)
					do_kmeans = FALSE
					if(verbose) qqcat("@{prefix}* use k-means partition that are already calculated in previous runs.\n")
				}
			}
			if(do_kmeans) {
				set.seed(seed)
				if(is.null(row_km)) {
					row_km = guess_best_km(mat_for_km2)
					if(length(unique(class)) == 1) row_km = 1
					if(length(unique(class)) == 2) row_km = min(row_km, 2)
				}
				if(row_km > 1) {
					row_km_fit = kmeans(mat_for_km2, centers = row_km)
					returned_obj$km = apply(pdist(row_km_fit$centers, mat_for_km, as.integer(1)), 2, which.min)
					if(scale_rows) {
						object@.env[[nm]]$row_km_fit_scaled = row_km_fit
					} else {
						object@.env[[nm]]$row_km_fit_unscaled = row_km_fit
					}
				}
				if(verbose) qqcat("@{prefix}* split rows into @{row_km} groups by k-means clustering.\n")
			}
		}
	}

	if(verbose) {
		if(is.null(top_signatures)) {
			qqcat("@{prefix}* @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) under fdr < @{fdr_cutoff}, group_diff > @{group_diff}.\n")
		} else {
			qqcat("@{prefix}* @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) with most significant fdr, group_diff > @{group_diff}.\n")
		}
	}

	if(nrow(mat) == 0) {
		if(plot) {
			grid.newpage()
			fontsize = convertUnit(unit(0.1, "npc"), "char", valueOnly = TRUE)*get.gpar("fontsize")$fontsize
			grid.text("no sigatures", gp = gpar(fontsize = fontsize))
		}
		return(invisible(data.frame(which_row = integer(0))))
	}

	if(!plot) {
		return(invisible(returned_obj))
	}

	set.seed(seed)
	more_than_2k = FALSE
	if(!is.null(object@.env[[nm]]$row_index)) {
		if(verbose) qqcat("@{prefix}  - use the signatures that are already generated in previous runs.\n")
		row_index = object@.env[[nm]]$row_index
		mat1 = mat[row_index, column_used_logical, drop = FALSE]
		mat2 = mat[row_index, !column_used_logical, drop = FALSE]
		more_than_2k = TRUE
		if(!is.null(left_annotation)) left_annotation = left_annotation[row_index, ]
		if(!is.null(right_annotation)) right_annotation = right_annotation[row_index, ]
	} else if(nrow(mat) > 2000) {
		more_than_2k = TRUE
		row_index = sample(1:nrow(mat), 2000)
		object@.env[[nm]]$row_index = row_index
		# mat1 = mat[order(fdr2)[1:top_k_genes], column_used_logical, drop = FALSE]
		# mat2 = mat[order(fdr2)[1:top_k_genes], !column_used_logical, drop = FALSE]
		mat1 = mat[row_index, column_used_logical, drop = FALSE]
		mat2 = mat[row_index, !column_used_logical, drop = FALSE]
		# group2 = group2[order(fdr2)[1:top_k_genes]]
		if(verbose) qqcat("@{prefix}  - randomly sample 2000 signatures.\n")
		if(!is.null(left_annotation)) left_annotation = left_annotation[row_index, ]
		if(!is.null(right_annotation)) right_annotation = right_annotation[row_index, ]
	} else {
		row_index = seq_len(nrow(mat))
		mat1 = mat[, column_used_logical, drop = FALSE]
		mat2 = mat[, !column_used_logical, drop = FALSE]
		
	}
	base_mean = rowMeans(mat1)
	if(nrow(mat) == 1) {
		group_mean = matrix(tapply(mat1, class, mean), nrow = 1)
	} else {
		group_mean = do.call("cbind", tapply(seq_len(ncol(mat1)), class, function(ind) {
			rowMeans(mat1[, ind, drop = FALSE])
		}))
	}
	rel_diff = (rowMaxs(group_mean) - rowMins(group_mean))/base_mean/2

	if(is.null(anno)) {
		bottom_anno1 = NULL
	} else {
		if(is.atomic(anno)) {
			anno_nm = deparse(substitute(anno))
			anno = data.frame(anno)
			colnames(anno) = anno_nm
			if(!is.null(anno_col)) {
				if(is.atomic(anno_col)) {
					anno_col = list(anno_col)
					names(anno_col) = anno_nm
				}
			}
		} else if(ncol(anno) == 1) {
			if(!is.null(anno_col)) {
				if(is.atomic(anno_col)) {
					anno_col = list(anno_col)
					names(anno_col) = colnames(anno)
				}
			}
		}

		if(is.null(anno_col)) {
			bottom_anno1 = HeatmapAnnotation(df = anno[column_used_logical, , drop = FALSE],
				show_annotation_name = !has_ambiguous & !internal, annotation_name_side = "right")
		} else {
			bottom_anno1 = HeatmapAnnotation(df = anno[column_used_logical, , drop = FALSE], col = anno_col,
				show_annotation_name = !has_ambiguous & !internal, annotation_name_side = "right")
		}
	}

	if(scale_rows) {
		scaled_mean = base_mean
		scaled_sd = rowSds(mat1)
		scaled_mat1 = t(scale(t(mat1)))
		scaled_mat2 = mat2
		if(has_ambiguous) {
			for(i in seq_len(nrow(mat2))) {
				scaled_mat2[i, ] = (scaled_mat2[i, ] - scaled_mean[i])/scaled_sd[i]
			}
		}

		use_mat1 = scaled_mat1
		use_mat2 = scaled_mat2
		use_mat1[is.infinite(use_mat1)] = 0
		use_mat1[is.na(use_mat1)] = 0
		use_mat2[is.infinite(use_mat2)] = 0
		use_mat2[is.na(use_mat2)] = 0
		mat_range = quantile(abs(scaled_mat1), 0.95, na.rm = TRUE)
		col_fun = colorRamp2(c(-mat_range, 0, mat_range), col)
		heatmap_name = "z-score"
	} else {
		use_mat1 = mat1
		use_mat2 = mat2
		mat_range = quantile(mat1, c(0.05, 0.95))
		col_fun = colorRamp2(c(mat_range[1], mean(mat_range), mat_range[2]), col)
		heatmap_name = "Value"
	}

	if(has_ambiguous) {
		class2 = class_df$class[!column_used_logical]

		if(is.null(anno)) {
			bottom_anno2 = NULL
		} else {
			anno_col = lapply(bottom_anno1@anno_list, function(anno) {
				if(is.null(anno@color_mapping)) {
					return(NULL)
				} else {
					if(anno@color_mapping@type == "discrete") {
						if(anno@name %in% names(object@anno_col)) {
							object@anno_col[[anno@name]]
						} else {
							anno@color_mapping@colors
						}
					} else {
						anno@color_mapping@col_fun
					}
				}
			})
			names(anno_col) = names(bottom_anno1@anno_list)
			anno_col = anno_col[!sapply(anno_col, is.null)]

			if(!is.null(object@anno_col)) {
				nmd = setdiff(names(object@anno_col), names(anno_col))
				if(length(nmd)) {
					anno_col[nmd] = object@anno_col[nmd]
				}
			}

			bottom_anno2 = HeatmapAnnotation(df = anno[!column_used_logical, , drop = FALSE], col = anno_col,
				show_annotation_name = !internal, annotation_name_side = "right")	
		}
	}
	if(is.null(p_cutoff)) {
		silhouette_range = range(class_df$silhouette)
		silhouette_range[2] = 1
	} else {
		p_range = c(0, 1)
	}

	if(verbose) qqcat("@{prefix}* making heatmaps for signatures.\n")

	row_split = NULL
	if(!internal) {
		if(scale_rows) {
			row_km_fit = object@.env[[nm]]$row_km_fit_scaled
		} else {
			row_km_fit = object@.env[[nm]]$row_km_fit_unscaled
		}
		if(!is.null(row_km_fit)) {
			row_split = factor(returned_obj$km[row_index], levels = sort(unique(returned_obj$km[row_index])))
		}
	}

	# group2 = factor(group2, levels = sort(unique(group2)))
	# ht_list = Heatmap(group2, name = "Group", show_row_names = FALSE, width = unit(5, "mm"), col = cola_opt$color_set_2)
	ht_list = NULL

	membership_mat = get_membership(object, k)
	prop_col_fun = colorRamp2(c(0, 1), c("white", "red"))
	
	if(from_down_sampling) {
		class_df_logp = class_df$p
		class_df_logp[class_df_logp == 0] = 0.001
		class_df_logp = -log10(class_df_logp)
		p_range = range(class_df_logp)
		p_range[1] = 0
	}

	if(internal) {
		ha1 = HeatmapAnnotation(Prob = membership_mat[column_used_logical, ],
				Class = class_df$class[column_used_logical],
				col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
				show_annotation_name = !has_ambiguous & !internal,
				annotation_name_side = "right",
				show_legend = TRUE)
	} else {
		if(simplify) {
			if(!from_down_sampling) {
				ha1 = HeatmapAnnotation(
					Class = class_df$class[column_used_logical],
					silhouette = anno_barplot(class_df$silhouette[column_used_logical], ylim = silhouette_range,
						gp = gpar(fill = ifelse(class_df$silhouette[column_used_logical] >= silhouette_cutoff, "black", "#EEEEEE"),
							      col = NA),
						bar_width = 1, baseline = 0, axis = !has_ambiguous, axis_param = list(side= "right"),
						height = unit(15, "mm")),
					col = list(Class = cola_opt$color_set_2),
					show_annotation_name = !has_ambiguous & !internal,
					annotation_name_side = "right",
					show_legend = TRUE)
			} else {
				ha1 = HeatmapAnnotation(
					Class = class_df$class[column_used_logical],
					p_prediction = anno_barplot(class_df_logp[column_used_logical], ylim = p_range,
						gp = gpar(fill = ifelse(class_df_logp[column_used_logical] >= -log10(p_cutoff), "black", "#EEEEEE"),
							      col = NA),
						bar_width = 1, baseline = 0, axis = !has_ambiguous, axis_param = list(side= "right"),
						height = unit(15, "mm")),
					col = list(Class = cola_opt$color_set_2),
					show_annotation_name = !has_ambiguous & !internal,
					annotation_name_side = "right",
					show_legend = TRUE)
			}
		} else {
			if(!from_down_sampling) {
				ha1 = HeatmapAnnotation(Prob = membership_mat[column_used_logical, ],
					Class = class_df$class[column_used_logical],
					silhouette = anno_barplot(class_df$silhouette[column_used_logical], ylim = silhouette_range,
						gp = gpar(fill = ifelse(class_df$silhouette[column_used_logical] >= silhouette_cutoff, "black", "#EEEEEE"),
							      col = NA),
						bar_width = 1, baseline = 0, axis = !has_ambiguous, axis_param = list(side= "right"),
						height = unit(15, "mm")),
					col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
					show_annotation_name = !has_ambiguous & !internal & c(TRUE, TRUE, FALSE),
					annotation_name_side = "right",
					show_legend = TRUE)
			} else {
				ha1 = HeatmapAnnotation(
					Class = class_df$class[column_used_logical],
					p_prediction = anno_barplot(class_df_logp[column_used_logical], ylim = p_range,
						gp = gpar(fill = ifelse(class_df_logp[column_used_logical] >= -log10(p_cutoff), "black", "#EEEEEE"),
							      col = NA),
						bar_width = 1, baseline = 0, axis = !has_ambiguous, axis_param = list(side= "right"),
						height = unit(15, "mm")),
					col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
					show_annotation_name = !has_ambiguous & !internal & c(TRUE, FALSE),
					annotation_name_side = "right",
					show_legend = TRUE)
			}
		}
	}
	ht_list = ht_list + Heatmap(use_mat1, name = heatmap_name, col = col_fun,
		top_annotation = ha1, row_split = row_split,
		cluster_columns = TRUE, cluster_column_slices = FALSE, cluster_row_slices = FALSE,
		column_split = factor(class_df$class[column_used_logical], levels = sort(unique(class_df$class[column_used_logical]))), 
		show_column_dend = FALSE,
		show_row_names = FALSE, show_row_dend = show_row_dend, column_title = {if(internal) NULL else qq("@{ncol(use_mat1)} confident samples")},
		use_raster = use_raster, raster_by_magick = requireNamespace("magick", quietly = TRUE),
		bottom_annotation = bottom_anno1, show_column_names = show_column_names, column_names_gp = column_names_gp,
		left_annotation = left_annotation, right_annotation = {if(has_ambiguous) NULL else right_annotation})
 	
	all_value_positive = !any(data < 0)
 	if(scale_rows && all_value_positive && !simplify) {
		ht_list = ht_list + Heatmap(base_mean, show_row_names = FALSE, name = "base_mean", width = unit(5, "mm"), show_column_names = !internal, column_names_gp = column_names_gp) +
			Heatmap(rel_diff, col = colorRamp2(c(0, 0.5, 1), c("blue", "white", "red")), 
				show_row_names = FALSE, show_column_names = !internal, column_names_gp = column_names_gp, name = "rel_diff", width = unit(5, "mm"))
	}

	if(has_ambiguous) {
		if(internal) {
			ha2 = HeatmapAnnotation(Prob = membership_mat[!column_used_logical, ,drop = FALSE],
				Class = class_df$class[!column_used_logical],
				col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
				show_annotation_name = !internal,
				annotation_name_side = "right",
				show_legend = FALSE)
		} else {
			if(simplify) {
				if(!from_down_sampling) {
					ha2 = HeatmapAnnotation(
						Class = class_df$class[!column_used_logical],
						silhouette2 = anno_barplot(class_df$silhouette[!column_used_logical], ylim = silhouette_range,
							gp = gpar(fill = ifelse(class_df$silhouette[!column_used_logical] >= silhouette_cutoff, "grey", "grey"),
							      col = ifelse(class_df$silhouette[!column_used_logical] >= silhouette_cutoff, "black", NA)),
							bar_width = 1, baseline = 0, axis = TRUE, axis_param = list(side = "right"),
							height = unit(15, "mm")), 
						col = list(Class = cola_opt$color_set_2),
						show_annotation_name = c(TRUE, FALSE) & !internal,
						annotation_name_side = "right",
						show_legend = FALSE)
				} else {
					ha2 = HeatmapAnnotation(
						Class = class_df$class[!column_used_logical],
						p_prediction2 = anno_barplot(class_df_logp[!column_used_logical], ylim = p_range,
							gp = gpar(fill = ifelse(class_df_logp[!column_used_logical] >= -log10(p_cutoff), "grey", "grey"),
							      col = ifelse(class_df_logp[!column_used_logical] <= -log10(p_cutoff), "black", NA)),
							bar_width = 1, baseline = 0, axis = TRUE, axis_param = list(side = "right"),
							height = unit(15, "mm")), 
						col = list(Class = cola_opt$color_set_2),
						show_annotation_name = c(TRUE, FALSE) & !internal,
						annotation_name_side = "right",
						show_legend = FALSE)
				}
			} else {
				if(!from_down_sampling) {
					ha2 = HeatmapAnnotation(Prob = membership_mat[!column_used_logical, ,drop = FALSE],
						Class = class_df$class[!column_used_logical],
						silhouette2 = anno_barplot(class_df$silhouette[!column_used_logical], ylim = silhouette_range,
							gp = gpar(fill = ifelse(class_df$silhouette[!column_used_logical] >= silhouette_cutoff, "grey", "grey"),
							      col = ifelse(class_df$silhouette[!column_used_logical] >= silhouette_cutoff, "black", NA)),
							bar_width = 1, baseline = 0, axis = TRUE, axis_param = list(side = "right"),
							height = unit(15, "mm")), 
						col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
						show_annotation_name = c(TRUE, TRUE, FALSE) & !internal,
						annotation_name_side = "right",
						show_legend = FALSE)
				} else {
					ha2 = HeatmapAnnotation(
						Class = class_df$class[!column_used_logical],
						p_prediction2 = anno_barplot(class_df_logp[!column_used_logical], ylim = p_range,
							gp = gpar(fill = ifelse(class_df_logp[!column_used_logical] >= -log10(p_cutoff), "grey", "grey"),
							      col = ifelse(class_df_logp[!column_used_logical] >= -log10(p_cutoff), "black", NA)),
							bar_width = 1, baseline = 0, axis = TRUE, axis_param = list(side = "right"),
							height = unit(15, "mm")), 
						col = list(Class = cola_opt$color_set_2, Prob = prop_col_fun),
						show_annotation_name = c(TRUE, FALSE) & !internal,
						annotation_name_side = "right",
						show_legend = FALSE)
				}
			}
			
		}
		ht_list = ht_list + Heatmap(use_mat2, name = paste0(heatmap_name, 2), col = col_fun,
			top_annotation = ha2,
			cluster_columns = TRUE, show_column_dend = FALSE,
			show_row_names = FALSE, show_row_dend = FALSE, show_heatmap_legend = FALSE,
			use_raster = use_raster, raster_by_magick = requireNamespace("magick", quietly = TRUE),
			bottom_annotation = bottom_anno2, show_column_names = show_column_names, column_names_gp = column_names_gp,
			right_annotation = right_annotation)
	}

	if(has_ambiguous) {
		lgd = Legend(title = "Status (barplots)", labels = c("confident", "ambiguous"), legend_gp = gpar(fill = c("black", "grey")))
		heatmap_legend_list = list(lgd)
	} else {
		heatmap_legend_list = NULL
	}

	if(do_row_clustering) {
		if(is.null(top_signatures)) {
			column_title = ifelse(internal, "", qq("@{k} subgroups, @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) with fdr < @{fdr_cutoff}@{ifelse(group_diff > 0, paste0(', group_diff > ', group_diff), '')}"))
		} else {
			column_title = ifelse(internal, "", qq("@{k} subgroups, @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) with most significant fdr@{ifelse(group_diff > 0, paste0(', group_diff > ', group_diff), '')}"))
		}
		ht_list = draw(ht_list, main_heatmap = heatmap_name, column_title = column_title,
			show_heatmap_legend = !internal, show_annotation_legend = !internal,
			heatmap_legend_list = heatmap_legend_list
		)
		
		row_order = row_order(ht_list)
		if(!is.list(row_order)) row_order = list(row_order)
		if(scale_rows) {
			object@.env[[nm]]$row_order_scaled = do.call("c", row_order)
		} else {
			object@.env[[nm]]$row_order_unscaled = do.call("c", row_order)
		}
		
	} else {
		if(verbose) qqcat("@{prefix}  - use row order from cache.\n")
		if(is.null(top_signatures)) {
			column_title = ifelse(internal, "", qq("@{k} subgroups, @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) with fdr < @{fdr_cutoff}@{ifelse(group_diff > 0, paste0(', group_diff > ', group_diff), '')}"))
		} else {
			column_title = ifelse(internal, "", qq("@{k} subgroups, @{nrow(mat)} signatures (@{sprintf('%.1f',nrow(mat)/nrow(object)*100)}%) with most significant fdr@{ifelse(group_diff > 0, paste0(', group_diff > ', group_diff), '')}"))
		}
		draw(ht_list, main_heatmap = heatmap_name, column_title = column_title,
			show_heatmap_legend = !internal, show_annotation_legend = !internal,
			cluster_rows = FALSE, row_order = row_order, heatmap_legend_list = heatmap_legend_list,
			row_title = {if(length(unique(row_split)) <= 1) NULL else qq("k-means with @{length(unique(row_split))} groups")}
		)
	}
	# the cutoff
	# https://www.stat.berkeley.edu/~s133/Cluster2a.html
	if(!internal) {
		if(is.null(p_cutoff)) {
			if(!has_ambiguous) {
				decorate_annotation("silhouette",  slice = length(unique(class_df$class[column_used_logical])), {
					grid.rect(gp = gpar(fill = "transparent"))
					grid.lines(c(0, 1), unit(c(silhouette_cutoff, silhouette_cutoff), "native"), gp = gpar(lty = 2, col = "#CCCCCC"))
					if(!has_ambiguous) grid.text("Silhouette\nscore", x = unit(1, "npc") + unit(5, "mm"), just = "left", gp = gpar(fontsize = 10))
				})
			}
			if(has_ambiguous) {
				decorate_annotation("silhouette2", {
					grid.rect(gp = gpar(fill = "transparent"))
					grid.lines(c(0, 1), unit(c(silhouette_cutoff, silhouette_cutoff), "native"), gp = gpar(lty = 2, col = "#CCCCCC"))
					if(has_ambiguous) grid.text("Silhouette\nscore", x = unit(1, "npc") + unit(5, "mm"), just = "left", gp = gpar(fontsize = 10))
				})
			}
		} else {
			if(!has_ambiguous) {
				decorate_annotation("p_prediction",  slice = length(unique(class_df$class[column_used_logical])), {
					grid.rect(gp = gpar(fill = "transparent"))
					grid.lines(c(0, 1), unit(-log10(c(p_cutoff, p_cutoff)), "native"), gp = gpar(lty = 2, col = "#CCCCCC"))
					if(!has_ambiguous) grid.text("-log10(prediction\n    p-value)", x = unit(1, "npc") + unit(5, "mm"), just = "left", gp = gpar(fontsize = 10))
				})
			}
			if(has_ambiguous) {
				decorate_annotation("p_prediction2", {
					grid.rect(gp = gpar(fill = "transparent"))
					grid.lines(c(0, 1), unit(-log10(c(p_cutoff, p_cutoff)), "native"), gp = gpar(lty = 2, col = "#CCCCCC"))
					if(has_ambiguous) grid.text("-log10(prediction\n    p-value)", x = unit(1, "npc") + unit(5, "mm"), just = "left", gp = gpar(fontsize = 10))
				})
			}
		}
	}

	return(invisible(returned_obj))
})

# https://stackoverflow.com/questions/2018178/finding-the-best-trade-off-point-on-a-curve
elbow_finder <- function(x_values, y_values) {
  # Max values to create line
  max_x_x <- max(x_values)
  max_x_y <- y_values[which.max(x_values)]
  max_y_y <- max(y_values)
  max_y_x <- x_values[which.max(y_values)]
  max_df <- data.frame(x = c(max_y_x, max_x_x), y = c(max_y_y, max_x_y))

  # Creating straight line between the max values
  fit <- lm(max_df$y ~ max_df$x)

  # Distance from point to line
  distances <- c()
  for(i in 1:length(x_values)) {
    distances <- c(distances, abs(coef(fit)[2]*x_values[i] - y_values[i] + coef(fit)[1]) / sqrt(coef(fit)[2]^2 + 1^2))
  }

  # Max distance point
  x_max_dist <- x_values[which.max(distances)]
  y_max_dist <- y_values[which.max(distances)]

  return(c(x_max_dist, y_max_dist))
}

# https://raghavan.usc.edu//papers/kneedle-simplex11.pdf
knee_finder = function(x, y) {
	n = length(x)
	a = (y[n] - y[1])/(x[n] - x[1])
	b = y[1] - a*x[1]
	d = a*x + b - y
	x[which.max(d)]
}

compare_to_subgroup = function(mat, class, which = "highest") {

	check_pkg("genefilter", bioc = TRUE)
	
	od = order(class)
	class = class[od]
	mat = mat[, od, drop = FALSE]
	class = as.numeric(factor(class))

	group = apply(mat, 1, function(x) {
		group_mean = tapply(x, class, mean)
		if(which == "highest") {
			which.max(group_mean)
		} else {
			which.min(group_mean)
		}
	})

	# class and subgroup_index are all numeric
	compare_to_one_subgroup = function(mat, class, subgroup_index) {
		
		oc = setdiff(class, subgroup_index)
		pmat = matrix(NA, nrow = nrow(mat), ncol = length(oc))
		for(i in seq_along(oc)) {
			l = class == subgroup_index | class == oc[i]
			m2 = mat[, l, drop = FALSE]
			fa = factor(class[l])
			pmat[, i] = genefilter::rowttests(m2, fa)[, "p.value"]
		}
		rowMaxs(pmat, na.rm = TRUE)
	}

	p = tapply(seq_len(nrow(mat)), group, function(ind) {
		compare_to_one_subgroup(mat[ind, , drop = FALSE], class, group[ind][1])
	})
	p2 = numeric(length(group))
	for(i in seq_along(p)) {
		p2[group == i] = p[[i]]
	}

	return(p2)
}

ttest = function(mat, class) {
	p1 = compare_to_subgroup(mat, class, "highest")
	p2 = compare_to_subgroup(mat, class, "lowest")
	p = pmin(p1, p2, na.rm = TRUE)
	fdr = p.adjust(p, method = "BH")
	fdr[is.na(fdr)] = Inf
	fdr[is.infinite(fdr)] = Inf
	attr(fdr, "p_value") = p
	fdr
}

one_vs_others = function(mat, class) {
	check_pkg("genefilter", bioc = TRUE)
	
	le = unique(class)
	dfl = list()
	for(x in le) {
		fa = as.vector(class)
		fa[class == x] = "a"
		fa[class != x] = "b"
		fa = factor(fa, levels = c("a", "b"))
		dfl[[x]] = genefilter::rowttests(mat, fa)[, "p.value"]
	}
	df = do.call("cbind", dfl)
	p = rowMins(df)
	fdr = p.adjust(p, "BH")
	attr(fdr, "p_value") = p
	fdr
}

samr = function(mat, class, ...) {
	check_pkg("samr", bioc = FALSE)
	on.exit(if(sink.number()) sink(NULL))
	class = as.numeric(factor(class))
	n_class = length(unique(class))
	
	tempf = tempfile()
	sink(tempf)
	if(n_class == 2) {
		samfit = samr::SAM(mat, class, resp.type = "Two class unpaired", nperms = 1000, ...)
	} else {
		samfit = samr::SAM(mat, class, resp.type = "Multiclass", nperms = 1000, ...)
	}
	sink(NULL)
	file.remove(tempf)

	sig_index = NULL
	if(!is.null(samfit$siggenes.table$genes.up)) {
		id = samfit$siggenes.table$genes.up
		if(is.null(dim(id))) id = matrix(id, nrow = 1)
		sig_index = c(sig_index, id[, 2])
	}
	if(!is.null(samfit$siggenes.table$genes.lo)) {
		id = samfit$siggenes.table$genes.lo
		if(is.null(dim(id))) id = matrix(id, nrow = 1)
		sig_index = c(sig_index, id[, 2])
	}
	sig_index = as.numeric(sig_index)
	fdr = rep(1, nrow(mat))
	fdr[sig_index] = 0

	return(fdr)
}

pamr = function(mat, class, fdr.cutoff = 0.1, ...) {
	check_pkg("pamr", bioc = FALSE)

	on.exit(if(sink.number()) sink(NULL))

	class = as.numeric(factor(class))
	
	tempf = tempfile()
	sink(tempf)
	mydata <- list(x=mat, y=class, geneid = rownames(mat))
	mydata.fit <- pamr::pamr.train(mydata)
	mydata.cv <- pamr::pamr.cv(mydata.fit, mydata)
	mydata.fdr <- pamr::pamr.fdr(mydata.fit, mydata)
	threshold = min(mydata.fdr$results[mydata.fdr$results[,"Median FDR"] < fdr.cutoff, "Threshold"])
	mydata.genelist <- pamr::pamr.listgenes(mydata.fit, mydata, threshold = threshold, fitcv=mydata.cv)
	sink(NULL)
	file.remove(tempf)

	fdr = rep(1, nrow(mat))
	fdr[rownames(mat) %in% mydata.genelist[,"id"]] = 0
	
	return(fdr)
}


Ftest = function(mat, class) {

	check_pkg("genefilter", bioc = TRUE)
	rownames(mat) = NULL

	p = genefilter::rowFtests(mat, factor(class))[, "p.value"]
	fdr = p.adjust(p, "BH")
	fdr[is.na(fdr)] = Inf
	attr(fdr, "p_value") = p
	return(fdr)

}

uniquely_high_in_one_group = function (mat, class) {

    le = unique(class)
    dfl = list()
    for (x in le) {
        fa = as.vector(class)
        fa[class == x] = "a"
        fa[class != x] = "b"
        fa = factor(fa, levels = c("a", "b"))
        df = genefilter::rowttests(mat, fa)
        l = df[, "statistic"] < 0; l[is.na(l)] = FALSE
        df[l, "p.value"] = NA
        
        # then for the remaining subgroups, they show no difference
        if(length(le) > 2) {
        	l = class != x
        	p1 = compare_to_subgroup(mat[, l, drop = FALSE], class[l], "highest")
		    p2 = compare_to_subgroup(mat[, l, drop = FALSE], class[l], "lowest")
		    p = pmin(p1, p2, na.rm = TRUE)

		    df[p < 0.05, "p.value"] = NA
        }

        dfl[[x]] = df[, "p.value"]

    }
    df = do.call("cbind", dfl)
    p = rowMins(df, na.rm = TRUE)
    fdr = p.adjust(p, "BH")
    attr(fdr, "p_value") = p
    fdr
}

# test_row_diff_fun = function(fun, fdr_cutoff = 0.1) {
# 	set.seed(100)
# 	x = matrix(rnorm(1000 * 20), ncol = 20)
# 	rownames(x) = rep(paste0("gene1", 1:1000))
# 	dd = sample(1:1000, size = 100)
# 	u = matrix(2 * rnorm(100), ncol = 10, nrow = 100)
# 	x[dd, 11:20] = x[dd, 11:20] + u
# 	row_diff = rep("no", 1000)
# 	row_diff[dd] = "yes"
# 	y = c(rep(1, 10), rep(2, 10))
# 	fdr = fun(x, y)

# 	ht = Heatmap(x, top_annotation = HeatmapAnnotation(foo = as.character(y), col = list(foo = c("1" = "blue", "2" = "red"))), show_row_names = FALSE) +
# 	Heatmap(row_diff, name = "diff", col = c("yes" = "red", "no" = "white"), width = unit(5, "mm")) +
# 	Heatmap(fdr, name = "fdr", width = unit(5, "mm"), show_row_names = FALSE)
# 	draw(ht, split = fdr < fdr_cutoff)
# }




# title
# Density for the signatures
#
# == param
# -object A `ConsensusPartition-class` object. 
# -k number of partitions
# -... pass to `get_signatures,ConsensusPartition-method`
#
# == details
# The function makes density distributio nf of signatures in all columns.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
# setMethod(f = "signature_density",
# 	signature = "ConsensusPartition",
# 	definition = function(object, k, ...) {

# 	cl = get_class(object, k = k)$class
# 	data = object@.env$data[, object@column_index, drop = FALSE]

# 	all_den_list = lapply(seq_len(ncol(data)), function(i) {
# 		x = data[, i]
# 		density(x)
# 	})
# 	x_range = range(unlist(lapply(all_den_list, function(x) x$x)))
# 	y_range = range(unlist(lapply(all_den_list, function(x) x$y)))

# 	x = get_signatures(object, k = k, plot = FALSE, verbose = FALSE, ...)
# 	gp_tb = table(x$df$group)
# 	n_gp = sum(gp_tb > 5)
# 	gp_tb = gp_tb[gp_tb > 5]

# 	op = par(no.readonly = TRUE)
# 	par(mfrow = c(n_gp + 1, 1), mar = c(2, 4, 1, 3))
# 	plot(NULL, type = "n", xlim = x_range, ylim = y_range, ylab = "density", xlab = NULL)
# 	for(i in 1:ncol(data)) {
# 		lines(all_den_list[[i]], col = cola_opt$color_set_2[cl[i]], lwd = 1)
# 	}
# 	mtext("all rows", side = 4, line = 1)

# 	gp = x$df$group
# 	for(j in as.numeric(names(gp_tb))) {
# 		gp2 = gp[gp == as.character(j)]
# 		all_den_list = lapply(seq_len(ncol(data)), function(i) density(data[names(gp2), i]))
# 		# x_range = range(unlist(lapply(all_den_list, function(x) x$x)))
# 		y_range = range(unlist(lapply(all_den_list, function(x) x$y)))

# 		plot(NULL, type = "n", xlim = x_range, ylim = y_range, ylab = "density", xlab = NULL)
# 		for(i in 1:ncol(data)) {
# 			lines(all_den_list[[i]], col = cola_opt$color_set_2[cl[i]], lwd = ifelse(cl[i] == j, 2, 0.5))
# 		}
# 		mtext(qq("subgroup @{j}/@{k}"), side = 4, line = 1)
# 	}
# 	par(op)
# })


# == title
# Compare Signatures from Different k
#
# == param
# -object A `ConsensusPartition-class` object. 
# -k Number of subgroups. Value should be a vector.
# -verbose Whether to print message.
# -... Other arguments passed to `get_signatures,ConsensusPartition-method`.
#
# == details
# It plots an Euler diagram showing the overlap of signatures from different k.
#
# == example
# \donttest{
# data(golub_cola)
# res = golub_cola["ATC", "skmeans"]
# compare_signatures(res)
# }
setMethod(f = "compare_signatures",
	signature = "ConsensusPartition",
	definition = function(object, k = object@k, verbose = interactive(), ...) {

	check_pkg("eulerr", bioc = FALSE)
	
	sig_list = sapply(k, function(x) {
		tb = get_signatures(object, k = x, verbose = verbose, ..., plot = FALSE)
		if(is.null(tb)) {
			return(integer(0))
		} else {
			return(tb$which_row)
		}
	})

	l = sapply(sig_list, length) > 0
	if(any(l) && verbose) {
		qqcat("Following k have no signature found: \"@{paste(k[l], collapse=', ')}\"\n")
	}
	sig_list = sig_list[l]

	names(sig_list) = paste(k[l], "-group", sep = "")

	plot(eulerr::euler(sig_list), legend = TRUE, quantities = TRUE, main = "Signatures from different k")

})


# == title
# Find a best k for the k-means clustering
#
# == param
# -mat A matrix where k-means clustering is executed by rows.
# -max_km Maximal k to try.
#
# == details
# The best k is determined by looking for the knee/elbow of the WSS curve (within-cluster sum of square).
#
# Note this function is only for a rough and quick estimation of the best k.
#
find_best_km = function(mat, max_km = 15) {
	wss = (nrow(mat)-1)*sum(apply(mat,2,var))
	max_km = min(c(nrow(mat) - 1, max_km))
	for (i in 2:max_km) wss[i] = sum(kmeans(mat, centers = i, iter.max = 50)$withinss)
	row_km = min(elbow_finder(1:max_km, wss)[1], knee_finder(1:max_km, wss)[1])
	return(row_km)
}
jokergoo/cola documentation built on Feb. 29, 2024, 1:41 a.m.