R/Heatmap-class.R

Defines functions make_cluster Heatmap

Documented in Heatmap

###############################
# class for single heatmap
#

# == title
# Class for a Single Heatmap
#
# == details
# The `Heatmap-class` is not responsible for heatmap legend and annotation legends. The `draw,Heatmap-method` method
# constructs a `HeatmapList-class` object which only contains one single heatmap
# and call `draw,HeatmapList-method` to make the complete heatmap.
#
# == methods
# The `Heatmap-class` provides following methods:
#
# - `Heatmap`: constructor method.
# - `draw,Heatmap-method`: draw a single heatmap.
# - `add_heatmap,Heatmap-method` append heatmaps and annotations to a list of heatmaps.
# - `row_order,HeatmapList-method`: get order of rows
# - `column_order,HeatmapList-method`: get order of columns
# - `row_dend,HeatmapList-method`: get row dendrograms
# - `column_dend,HeatmapList-method`: get column dendrograms
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
Heatmap = setClass("Heatmap",
    slots = list(
        name = "character",

        matrix = "matrix",  # one or more matrix which are spliced by rows
        matrix_param = "list",
        matrix_color_mapping = "ANY",
        matrix_legend_param = "ANY",

        row_title = "ANY",
        row_title_param = "list",
        column_title = "ANY",
        column_title_param = "list",

        row_dend_list = "list", # one or more row clusters
        row_dend_slice = "ANY",
        row_dend_param = "list", # parameters for row cluster
        row_order_list = "list",
        row_order = "numeric",

        column_dend_list = "list",
        column_dend_slice = "ANY",
        column_dend_param = "list", # parameters for column cluster
        column_order_list = "list",
        column_order = "numeric",

        row_names_param = "list",
        column_names_param = "list",

        top_annotation = "ANY", # NULL or a `HeatmapAnnotation` object
        top_annotation_param = "list",
        bottom_annotation = "ANY",
        bottom_annotation_param = "list",
        left_annotation = "ANY", # NULL or a `HeatmapAnnotation` object
        left_annotation_param = "list",
        right_annotation = "ANY",
        right_annotation_param = "list",

        heatmap_param = "list",

        layout = "list"
    ),
    contains = "AdditiveUnit"
)



# == title
# Constructor method for Heatmap class
#
# == param
# -matrix A matrix. Either numeric or character. If it is a simple vector, it will be
#         converted to a one-column matrix.
# -col A vector of colors if the color mapping is discrete or a color mapping 
#      function if the matrix is continuous numbers (should be generated by `circlize::colorRamp2`). If the matrix is continuous,
#      the value can also be a vector of colors so that colors can be interpolated. Pass to `ColorMapping`. For more details
#      and examples, please refer to https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#colors .
# -name Name of the heatmap. By default the heatmap name is used as the title of the heatmap legend.
# -na_col Color for ``NA`` values.
# -rect_gp Graphic parameters for drawing rectangles (for heatmap body). The value should be specified by `grid::gpar` and ``fill`` parameter is ignored.
# -color_space The color space in which colors are interpolated. Only used if ``matrix`` is numeric and 
#            ``col`` is a vector of colors. Pass to `circlize::colorRamp2`.
# -border Whether draw border. The value can be logical or a string of color.
# -border_gp Graphic parameters for the borders. If you want to set different parameters for different heatmap slices,
#           please consider to use `decorate_heatmap_body`.
# -cell_fun Self-defined function to add graphics on each cell. Seven parameters will be passed into 
#           this function: ``j``, ``i``, ``x``, ``y``, ``width``, ``height``, ``fill`` which are column index,
#           row index in ``matrix``, coordinate of the cell,
#           the width and height of the cell and the filled color. ``x``, ``y``, ``width`` and ``height`` are all `grid::unit` objects.
# -layer_fun Similar as ``cell_fun``, but is vectorized. Check https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#customize-the-heatmap-body .
# -jitter Random shifts added to the matrix. The value can be logical or a single numeric value. It it is ``TRUE``, random 
#      values from uniform distribution between 0 and 1e-10 are generated. If it is a numeric value,
#      the range for the uniform distribution is (0, ``jitter``). It is mainly to solve the problem of "Error: node stack overflow"
#      when there are too many identical rows/columns for plotting the dendrograms. ADD: From version 2.5.6, the error of node stack overflow
#      has been fixed, now this argument is ignored.
# -row_title Title on the row.
# -row_title_side Will the title be put on the left or right of the heatmap?
# -row_title_gp Graphic parameters for row title.
# -row_title_rot Rotation of row title.
# -column_title Title on the column.
# -column_title_side Will the title be put on the top or bottom of the heatmap?
# -column_title_gp Graphic parameters for column title.
# -column_title_rot Rotation of column titles.
# -cluster_rows If the value is a logical, it controls whether to make cluster on rows. The value can also
#               be a `stats::hclust` or a `stats::dendrogram` which already contains clustering.
#               Check https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#clustering .
# -cluster_row_slices If rows are split into slices, whether perform clustering on the slice means?
# -clustering_distance_rows It can be a pre-defined character which is in 
#                ("euclidean", "maximum", "manhattan", "canberra", "binary", 
#                "minkowski", "pearson", "spearman", "kendall"). It can also be a function.
#                If the function has one argument, the input argument should be a matrix and 
#                the returned value should be a `stats::dist` object. If the function has two arguments,
#                the input arguments are two vectors and the function calculates distance between these
#                two vectors.
# -clustering_method_rows Method to perform hierarchical clustering, pass to `stats::hclust`.
# -row_dend_side Should the row dendrogram be put on the left or right of the heatmap?
# -row_dend_width Width of the row dendrogram, should be a `grid::unit` object.
# -show_row_dend Whether show row dendrogram?
# -row_dend_gp Graphic parameters for the dendrogram segments. If users already provide a `stats::dendrogram`
#                object with edges rendered, this argument will be ignored.
# -row_dend_reorder Apply reordering on row dendrograms. The value can be a logical value or a vector which contains weight 
#               which is used to reorder rows. The reordering is applied by `stats::reorder.dendrogram`.
# -cluster_columns Whether make cluster on columns? Same settings as ``cluster_rows``.
# -cluster_column_slices If columns are split into slices, whether perform clustering on the slice means?
# -clustering_distance_columns Same setting as ``clustering_distance_rows``.
# -clustering_method_columns Method to perform hierarchical clustering, pass to `stats::hclust`.
# -column_dend_side Should the column dendrogram be put on the top or bottom of the heatmap?
# -column_dend_height height of the column cluster, should be a `grid::unit` object.
# -show_column_dend Whether show column dendrogram?
# -column_dend_gp Graphic parameters for dendrogram segments. Same settings as ``row_dend_gp``.
# -column_dend_reorder Apply reordering on column dendrograms. Same settings as ``row_dend_reorder``.
# -row_order Order of rows. Manually setting row order turns off clustering.
# -column_order Order of column.
# -row_labels Optional row labels which are put as row names in the heatmap.
# -row_names_side Should the row names be put on the left or right of the heatmap?
# -show_row_names Whether show row names.
# -row_names_max_width Maximum width of row names viewport.
# -row_names_gp Graphic parameters for row names.
# -row_names_rot Rotation of row names.
# -row_names_centered Should row names put centered?
# -column_labels Optional column labels which are put as column names in the heatmap.
# -column_names_side Should the column names be put on the top or bottom of the heatmap?
# -column_names_max_height Maximum height of column names viewport.
# -show_column_names Whether show column names.
# -column_names_gp Graphic parameters for drawing text.
# -column_names_rot Rotation of column names.
# -column_names_centered Should column names put centered?
# -top_annotation A `HeatmapAnnotation` object.
# -bottom_annotation A `HeatmapAnnotation` object.
# -left_annotation It should be specified by `rowAnnotation`.
# -right_annotation it should be specified by `rowAnnotation`.
# -km Apply k-means clustering on rows. If the value is larger than 1, the heatmap will be split by rows according to the k-means clustering.
#     For each row slice, hierarchical clustering is still applied with parameters above.
# -split A vector or a data frame by which the rows are split. But if ``cluster_rows`` is a clustering object, ``split`` can be a single number
#        indicating to split the dendrogram by `stats::cutree`.
# -row_km Same as ``km``.
# -row_km_repeats Number of k-means runs to get a consensus k-means clustering. Note if ``row_km_repeats`` is set to more than one, the final number
#                of groups might be smaller than ``row_km``, but this might means the original ``row_km`` is not a good choice.
# -row_split Same as ``split``.
# -column_km K-means clustering on columns.
# -column_km_repeats Number of k-means runs to get a consensus k-means clustering. Similar as ``row_km_repeats``.
# -column_split Split on columns. For heatmap splitting, please refer to https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#heatmap-split .
# -gap Gap between row slices if the heatmap is split by rows. The value should be a `grid::unit` object.
# -row_gap Same as ``gap``.
# -column_gap Gap between column slices.
# -show_parent_dend_line When heatmap is split, whether to add a dashed line to mark parent dendrogram and children dendrograms?
# -width Width of the heatmap body.
# -height Height of the heatmap body.
# -heatmap_width Width of the whole heatmap (including heatmap components)
# -heatmap_height Height of the whole heatmap (including heatmap components). Check https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#size-of-the-heatmap .
# -show_heatmap_legend Whether show heatmap legend?
# -heatmap_legend_param A list contains parameters for the heatmap legends. See `color_mapping_legend,ColorMapping-method` for all available parameters.
# -use_raster Whether render the heatmap body as a raster image. It helps to reduce file size when the matrix is huge. If number of rows or columns is more than 2000, it is by default turned on. Note if ``cell_fun``
#       is set, ``use_raster`` is enforced to be ``FALSE``.
# -raster_device Graphic device which is used to generate the raster image.
# -raster_quality A value larger than 1.
# -raster_device_param A list of further parameters for the selected graphic device. For raster image support, please check https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#heatmap-as-raster-image .
# -raster_resize_mat Whether resize the matrix to let the dimension of the matrix the same as the dimension of the raster image?
#          The value can be logical. If it is ``TRUE``, `base::mean` is used to summarize the sub matrix which corresponds to a single pixel.
#          The value can also be a summary function, e.g. `base::max`.
# -raster_by_magick Whether to use `magick::image_resize` to scale the image.
# -raster_magick_filter Pass to ``filter`` argument of `magick::image_resize`. A character scalar and all possible values
#          are in `magick::filter_types`. The default is ``"Lanczos"``.
# -post_fun A function which will be executed after the heatmap list is drawn.
#
# == details
# The initialization function only applies parameter checking and fill values to the slots with some validation.
# 
# Following methods can be applied to the `Heatmap-class` object:
#
# - `show,Heatmap-method`: draw a single heatmap with default parameters
# - `draw,Heatmap-method`: draw a single heatmap.
# - ``+`` or `\%v\%` append heatmaps and annotations to a list of heatmaps.
#
# The constructor function pretends to be a high-level graphic function because the ``show`` method
# of the `Heatmap-class` object actually plots the graphics.
#
# == seealso
# https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
Heatmap = function(matrix, col, name, 
    na_col = "grey", 
    color_space = "LAB",
    rect_gp = gpar(col = NA), 
    border = NA,
    border_gp = gpar(col = "black"),
    cell_fun = NULL,
    layer_fun = NULL,
    jitter = FALSE,

    row_title = character(0), 
    row_title_side = c("left", "right"), 
    row_title_gp = gpar(fontsize = 13.2), 
    row_title_rot = switch(row_title_side[1], "left" = 90, "right" = 270),
    column_title = character(0), 
    column_title_side = c("top", "bottom"), 
    column_title_gp = gpar(fontsize = 13.2), 
    column_title_rot = 0,

    cluster_rows = TRUE, 
    cluster_row_slices = TRUE,
    clustering_distance_rows = "euclidean",
    clustering_method_rows = "complete", 
    row_dend_side = c("left", "right"),
    row_dend_width = unit(10, "mm"), 
    show_row_dend = TRUE, 
    row_dend_reorder = is.logical(cluster_rows) || is.function(cluster_rows),
    row_dend_gp = gpar(), 
    cluster_columns = TRUE, 
    cluster_column_slices = TRUE,
    clustering_distance_columns = "euclidean", 
    clustering_method_columns = "complete",
    column_dend_side = c("top", "bottom"), 
    column_dend_height = unit(10, "mm"), 
    show_column_dend = TRUE, 
    column_dend_gp = gpar(), 
    column_dend_reorder = is.logical(cluster_columns) || is.function(cluster_columns),

    row_order = NULL, 
    column_order = NULL,

    row_labels = rownames(matrix),
    row_names_side = c("right", "left"), 
    show_row_names = TRUE, 
    row_names_max_width = unit(6, "cm"), 
    row_names_gp = gpar(fontsize = 12), 
    row_names_rot = 0,
    row_names_centered = FALSE,
    column_labels = colnames(matrix),
    column_names_side = c("bottom", "top"), 
    show_column_names = TRUE, 
    column_names_max_height = unit(6, "cm"), 
    column_names_gp = gpar(fontsize = 12),
    column_names_rot = 90,
    column_names_centered = FALSE,

    top_annotation = NULL,
    bottom_annotation = NULL,
    left_annotation = NULL,
    right_annotation = NULL,

    km = 1, 
    split = NULL, 
    row_km = km,
    row_km_repeats = 1,
    row_split = split,
    column_km = 1,
    column_km_repeats = 1,
    column_split = NULL,
    gap = unit(1, "mm"),
    row_gap = unit(1, "mm"),
    column_gap = unit(1, "mm"),
    show_parent_dend_line = ht_opt$show_parent_dend_line,

    heatmap_width = unit(1, "npc"),
    width = NULL,
    heatmap_height = unit(1, "npc"), 
    height = NULL,

    show_heatmap_legend = TRUE,
    heatmap_legend_param = list(title = name),

    use_raster = NULL, 
    raster_device = c("png", "jpeg", "tiff", "CairoPNG", "CairoJPEG", "CairoTIFF", "agg_png"),
    raster_quality = 1,
    raster_device_param = list(),
    raster_resize_mat = FALSE,
    raster_by_magick = requireNamespace("magick", quietly = TRUE),
    raster_magick_filter = NULL,

    post_fun = NULL) {

    dev.null()
    on.exit(dev.off2())

    verbose = ht_opt("verbose")

    .Object = new("Heatmap")
    if(missing(name)) {
        name = paste0("matrix_", get_heatmap_index() + 1)
        increase_heatmap_index()
    } else if(is.null(name)) {
        name = paste0("matrix_", get_heatmap_index() + 1)
        increase_heatmap_index()
    }
    if(name == "") {
        stop_wrap("Heatmap name cannot be empty string.")
    }
    .Object@name = name

    # re-define some of the argument values according to global settings
    called_args = names(as.list(match.call())[-1])
    for(opt_name in c("row_names_gp", "column_names_gp", "row_title_gp", "column_title_gp")) {
        opt_name2 = paste0("heatmap_", opt_name)
        if(! opt_name %in% called_args) { # if this argument is not called
            if(!is.null(ht_opt(opt_name2))) {
                if(verbose) qqcat("re-assign @{opt_name} with `ht_opt('@{opt_name2}'')`\n")
                assign(opt_name, ht_opt(opt_name2))
            }
        }
    }

    if("top_annotation_height" %in% called_args) {
        stop_wrap("`top_annotation_height` is removed. Set the height directly in `HeatmapAnnotation()`.")
    }
    if("bottom_annotation_height" %in% called_args) {
        stop_wrap("`bottom_annotation_height` is removed. Set the height directly in `HeatmapAnnotation()`.")
    }
    if("combined_name_fun" %in% called_args) {
        stop_wrap("`combined_name_fun` is removed. Please directly set `row_names_title`. See https://jokergoo.github.io/ComplexHeatmap-reference/book/a-single-heatmap.html#titles-for-splitting")
    }

    if("heatmap_legend_param" %in% called_args) {
        for(opt_name in setdiff(c("title_gp", "title_position", "labels_gp", "grid_width", "grid_height", "border"), names(heatmap_legend_param))) {
            opt_name2 = paste0("legend_", opt_name)
            if(!is.null(ht_opt(opt_name2)))
                if(verbose) qqcat("re-assign heatmap_legend_param$@{opt_name} with `ht_opt('@{opt_name2}'')`\n")
                heatmap_legend_param[[opt_name]] = ht_opt(opt_name2)
        }
    } else {
        for(opt_name in c("title_gp", "title_position", "labels_gp", "grid_width", "grid_height", "border")) {
            opt_name2 = paste0("legend_", opt_name)
            if(!is.null(ht_opt(opt_name2)))
                if(verbose) qqcat("re-assign heatmap_legend_param$@{opt_name} with `ht_opt('@{opt_name2}'')`\n")
                heatmap_legend_param[[opt_name]] = ht_opt(opt_name2)
        }
    }

    if(is.data.frame(matrix)) {
        if(verbose) qqcat("convert data frame to matrix\n")
        warning_wrap("The input is a data frame-like object, convert it to a matrix.")
        if(!all(sapply(matrix, is.numeric))) {
            warning_wrap("Note: not all columns in the data frame are numeric. The data frame will be converted into a character matrix.")
        }
        matrix = as.matrix(matrix)
    }
    fa_level = NULL
    if(!is.matrix(matrix)) {
        if(is.atomic(matrix)) {
            if(is.factor(matrix)) {
                fa_level = levels(matrix)
            }
            rn = names(matrix)
            matrix = matrix(matrix, ncol = 1)
            if(!is.null(rn)) rownames(matrix) = rn
            if(!missing(name)) colnames(matrix) = name
            if(verbose) qqcat("convert simple vector to one-column matrix\n")
        } else {
            stop_wrap("If input is not a matrix, it should be a simple vector.")
        }
    }

    # if(ncol(matrix) == 0 || nrow(matrix) == 0) {
    #     show_heatmap_legend = FALSE
    #     .Object@heatmap_param$show_heatmap_legend = FALSE
    # }
    
    # if(ncol(matrix) == 0 && (!is.null(left_annotation) || !is.null(right_annotation))) {
    #     message_wrap("If you have row annotations for a zeor-column matrix, please directly use in form of `rowAnnotation(...) + NULL`")
    #     return(invisible(NULL))
    # }
    # if(nrow(matrix) == 0 && (!is.null(top_annotation) || !is.null(bottom_annotation))) {
    #     message_wrap("If you have column annotations for a zero-row matrix, please directly use in form of `HeatmapAnnotation(...) %v% NULL`")
    #     return(invisible(NULL))
    # }

    ### normalize km/split and row_km/row_split
    if(missing(row_km)) row_km = km
    if(is.null(row_km)) row_km = 1
    if(missing(row_split)) row_split = split
    if(missing(row_gap)) row_gap = gap
    if(is.null(column_km)) column_km = 1

    ####### zero and one column matrix ########
    if(ncol(matrix) == 0 || nrow(matrix) == 0) {
        if(!inherits(cluster_columns, c("dendrogram", "hclust"))) {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
        if(!inherits(cluster_rows, c("dendrogram", "hclust"))) {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        row_km = 1
        column_km = 1
        if(verbose) qqcat("zero row/column matrix, set cluster_columns/rows to FALSE\n")
    }
    if(ncol(matrix) == 1) {
        if(!inherits(cluster_columns, c("dendrogram", "hclust"))) {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
        column_km = 1
        if(verbose) qqcat("one-column matrix, set cluster_columns to FALSE\n")
    }
    if(nrow(matrix) == 1) {
        if(!inherits(cluster_rows, c("dendrogram", "hclust"))) {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        row_km = 1
        if(verbose) qqcat("one-row matrix, set cluster_rows to FALSE\n")
    }

    if(is.character(matrix)) {
        called_args = names(match.call()[-1])
        if("clustering_distance_rows" %in% called_args) {
        } else if(inherits(cluster_rows, c("dendrogram", "hclust"))) {
        } else {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        row_dend_reorder = FALSE
        cluster_row_slices = FALSE

        if(inherits(cluster_rows, c("dendrogram", "hclust")) && length(row_split) == 1) {
            if(!"cluster_row_slices" %in% called_args) {
                cluster_row_slices = TRUE
            }
        }

        if("clustering_distance_columns" %in% called_args) {
        } else if(inherits(cluster_columns, c("dendrogram", "hclust"))) {
        } else {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
        column_dend_reorder = FALSE
        cluster_column_slices = FALSE

        if(inherits(cluster_columns, c("dendrogram", "hclust")) && length(column_split) == 1) {
            if(!"cluster_column_slices" %in% called_args) {
                cluster_column_slices = TRUE
            }
        }

        row_km = 1
        column_km = 1
        if(verbose) qqcat("matrix is character. Do not cluster unless distance method is provided.\n")
    }
    class(matrix) = "matrix"
    .Object@matrix = matrix

    .Object@matrix_param$row_km = row_km
    .Object@matrix_param$row_km_repeats = row_km_repeats
    .Object@matrix_param$row_gap = row_gap
    .Object@matrix_param$column_km = column_km
    .Object@matrix_param$column_km_repeats = column_km_repeats
    .Object@matrix_param$column_gap = column_gap
    .Object@matrix_param$jitter = jitter

    ### check row_split and column_split ###
    if(!is.null(row_split)) {
        if(inherits(cluster_rows, c("dendrogram", "hclust"))) {
            if(is.numeric(row_split) && length(row_split) == 1) {
                .Object@matrix_param$row_split = row_split
            } else {
                stop_wrap("When `cluster_rows` is a dendrogram, `row_split` can only be a single number.")
            }
        } else {
            if(identical(cluster_rows, TRUE) && is.numeric(row_split) && length(row_split) == 1) {

            } else {
                if(!is.data.frame(row_split)) row_split = data.frame(row_split)
                if(nrow(row_split) != nrow(matrix)) {
                    stop_wrap("Length or nrow of `row_split` should be same as nrow of `matrix`.")
                }
            }
        }
    }
    .Object@matrix_param$row_split = row_split

    if(!is.null(column_split)) {
        if(inherits(cluster_columns, c("dendrogram", "hclust"))) {
            if(is.numeric(column_split) && length(column_split) == 1) {
                .Object@matrix_param$column_split = column_split
            } else {
               stop_wrap("When `cluster_columns` is a dendrogram, `column_split` can only be a single number.")
            }
        } else {
            if(identical(cluster_columns, TRUE) && is.numeric(column_split) && length(column_split) == 1) {

            } else {
                if(!is.data.frame(column_split)) column_split = data.frame(column_split)
                if(nrow(column_split) != ncol(matrix)) {
                    stop_wrap("Length or ncol of `column_split` should be same as ncol of `matrix`.")
                }
            }
        }
    }
    .Object@matrix_param$column_split = column_split


    ### parameters for heatmap body ###
    .Object@matrix_param$gp = check_gp(rect_gp)
    if(missing(border)) {
        if(!is.null(ht_opt$heatmap_border)) border = ht_opt$heatmap_border
    }
    if(!missing(border_gp) && missing(border)) border = TRUE
    .Object@matrix_param$border = border
    .Object@matrix_param$border_gp = border_gp

    if(!is.null(cell_fun)) {
        global_vars = codetools::findGlobals(cell_fun, merge = FALSE)$variables

        ee = new.env(parent = environment(cell_fun))
        for(v in global_vars) {
            assign(v, value = get(v, envir = environment(cell_fun)), envir = ee)
        }
        environment(cell_fun) = ee
    }
    if(!is.null(layer_fun)) {
        global_vars = codetools::findGlobals(layer_fun, merge = FALSE)$variables

        ee = new.env(parent = environment(layer_fun))
        for(v in global_vars) {
            assign(v, value = get(v, envir = environment(layer_fun)), envir = ee)
        }
        environment(layer_fun) = ee
    }

    .Object@matrix_param$cell_fun = cell_fun
    .Object@matrix_param$layer_fun = layer_fun

    if(nrow(matrix) > 100 || ncol(matrix) > 100) {
        if(!is.null(cell_fun)) {
            warning_wrap("You defined `cell_fun` for a heatmap with more than 100 rows or columns, which might be very slow to draw. Consider to use the vectorized version `layer_fun`.")
        }
    }

    ### color for main matrix #########
    if(ncol(matrix) > 0 && nrow(matrix) > 0) {
        if(missing(col)) {
            col = default_col(matrix, main_matrix = TRUE)
            if(!is.null(fa_level)) {
                col = col[fa_level]
            }
            if(verbose) qqcat("color is not specified, use randomly generated colors\n")
        }
        if(is.null(col)) {
            col = default_col(matrix, main_matrix = TRUE)
            if(!is.null(fa_level)) {
                col = col[fa_level]
            }
            if(verbose) qqcat("color is not specified, use randomly generated colors\n")
        }
        if(is.function(col)) {
            if(is.null(attr(col, "breaks"))) {
                breaks = seq(min(matrix, na.rm = TRUE), max(matrix, na.rm = TRUE), length.out = 5)
                rg = range(breaks)
                diff = rg[2] - rg[1]
                rg[1] = rg[1] + diff*0.05
                rg[2] = rg[2] - diff*0.05

                le = pretty(rg, n = 3)
                .Object@matrix_color_mapping = ColorMapping(col_fun = col, name = name, breaks = le, na_col = na_col)
            } else {
                .Object@matrix_color_mapping = ColorMapping(col_fun = col, name = name, na_col = na_col)
            }
            if(verbose) qqcat("input color is a color mapping function\n")
        } else if(inherits(col, "ColorMapping")){
            .Object@matrix_color_mapping = col
            if(verbose) qqcat("input color is a ColorMapping object\n")
        } else {

            if(is.null(names(col))) {
                if(length(col) == length(unique(as.vector(matrix)))) {
                    if(length(col) >= 6) {
                        message_wrap(qq("There are @{length(col)} unique colors in the vector `col` and @{length(col)} unique values in `matrix`. `Heatmap()` will treat it as an exact discrete one-to-one mapping. If this is not what you want, slightly change the number of colors, e.g. by adding one more color or removing a color."))
                    }
                    if(is.null(fa_level)) {
                        if(is.numeric(matrix)) {
                            names(col) = sort(unique(as.vector(matrix)))
                            col = rev(col)
                        } else {
                            names(col) = sort(unique(as.vector(matrix)))
                        }
                    } else {
                        names(col) = fa_level
                    }
                    .Object@matrix_color_mapping = ColorMapping(colors = col, name = name, na_col = na_col)
                    if(verbose) qqcat("input color is a vector with no names, treat it as discrete color mapping\n")
                } else if(is.numeric(matrix)) {
                    col = colorRamp2(seq(min(matrix, na.rm = TRUE), max(matrix, na.rm = TRUE), length.out = length(col)),
                                     col, space = color_space)
                    .Object@matrix_color_mapping = ColorMapping(col_fun = col, name = name, na_col = na_col)
                    if(verbose) qqcat("input color is a vector with no names, treat it as continuous color mapping\n")
                } else {
                    stop_wrap("`col` should have names to map to values in `mat`.")
                }
            } else {
                full_col = col
                # note here col can be reduced
                if(is.null(fa_level)) {
                    col = col[intersect(c(names(col), "_NA_"), as.character(matrix))]
                } else {
                    col = col[intersect(c(fa_level, "_NA_"), names(col))]
                }
                if(!is.null(heatmap_legend_param) && !identical(.Object@matrix_param$gp$type, "none")) {
                    if(!is.null(heatmap_legend_param$at) && !is.null(heatmap_legend_param[["labels"]])) {
                        l = heatmap_legend_param$at %in% names(col)
                        heatmap_legend_param$at = heatmap_legend_param$at[l]
                        heatmap_legend_param[["labels"]] = heatmap_legend_param[["labels"]][l]
                    } else if(is.null(heatmap_legend_param$at) && !is.null(heatmap_legend_param$labels)) {
                        l = heatmap_legend_param[["labels"]] %in% names(col)
                        heatmap_legend_param[["labels"]] = heatmap_legend_param[["labels"]][l]
                    } else if(!is.null(heatmap_legend_param$at) && is.null(heatmap_legend_param[["labels"]])) {
                        l = heatmap_legend_param$at %in% names(col)
                        heatmap_legend_param$at = heatmap_legend_param$at[l]
                    }
                }
                .Object@matrix_color_mapping = ColorMapping(colors = col, name = name, na_col = na_col, full_col = full_col)
                if(verbose) qqcat("input color is a named vector\n")
            }
        }
        .Object@matrix_legend_param = heatmap_legend_param
    }

    ##### titles, should also consider titles after row splitting #####
    if(identical(row_title, NA) || identical(row_title, "")) {
        row_title = character(0)
    }
    .Object@row_title = row_title
    .Object@row_title_param$rot = row_title_rot %% 360
    .Object@row_title_param$side = match.arg(row_title_side)[1]
    .Object@row_title_param$gp = check_gp(row_title_gp)  # if the number of settings is same as number of row-splits, gp will be adjusted by `make_row_dend`
    .Object@row_title_param$just = get_text_just(rot = row_title_rot, side = .Object@row_title_param$side)

    if(identical(column_title, NA) || identical(column_title, "")) {
        column_title = character(0)
    }
    .Object@column_title = column_title
    .Object@column_title_param$rot = column_title_rot %% 360
    .Object@column_title_param$side = match.arg(column_title_side)[1]
    .Object@column_title_param$gp = check_gp(column_title_gp)
    .Object@column_title_param$just = get_text_just(rot = column_title_rot, side = .Object@column_title_param$side)

    ### row labels/column labels ###
    if(is.null(rownames(matrix))) {
        if(is.null(row_labels)) {
            show_row_names = FALSE
        }
    }
    .Object@row_names_param$labels = row_labels
    .Object@row_names_param$side = match.arg(row_names_side)[1]
    .Object@row_names_param$show = show_row_names
    .Object@row_names_param$gp = check_gp(row_names_gp)
    .Object@row_names_param$rot = row_names_rot
    .Object@row_names_param$centered = row_names_centered
    .Object@row_names_param$max_width = row_names_max_width + unit(2, "mm")
    # we use anno_text to draw row/column names because it already takes care of text rotation
    if(show_row_names) {
        if(length(row_labels) != nrow(matrix)) {
            stop_wrap("Length of `row_labels` should be the same as the nrow of matrix.")
        }
        if(row_names_centered) {
            row_names_anno = anno_text(row_labels, which = "row", gp = row_names_gp, rot = row_names_rot,
                location = 0.5, 
                just = "center")
        } else {
            row_names_anno = anno_text(row_labels, which = "row", gp = row_names_gp, rot = row_names_rot,
                location = ifelse(.Object@row_names_param$side == "left", 1, 0), 
                just = ifelse(.Object@row_names_param$side == "left", "right", "left"))
        }
        .Object@row_names_param$anno = row_names_anno
    }

    if(is.null(colnames(matrix))) {
        if(is.null(column_labels)) {
            show_column_names = FALSE
        }
    }
    .Object@column_names_param$labels = column_labels
    .Object@column_names_param$side = match.arg(column_names_side)[1]
    .Object@column_names_param$show = show_column_names
    .Object@column_names_param$gp = check_gp(column_names_gp)
    .Object@column_names_param$rot = column_names_rot
    .Object@column_names_param$centered = column_names_centered
    .Object@column_names_param$max_height = column_names_max_height + unit(2, "mm")
    if(show_column_names) {
        if(length(column_labels) != ncol(matrix)) {
            stop_wrap("Length of `column_labels` should be the same as the ncol of matrix.")
        }
        if(column_names_centered) {
            column_names_anno = anno_text(column_labels, which = "column", gp = column_names_gp, rot = column_names_rot,
            location = 0.5, 
            just = "center")
        } else {
            column_names_anno = anno_text(column_labels, which = "column", gp = column_names_gp, rot = column_names_rot,
                location = ifelse(.Object@column_names_param$side == "top", 0, 1), 
                just = ifelse(.Object@column_names_param$side == "top", 
                         ifelse(.Object@column_names_param$rot >= 0, "left", "right"),
                         ifelse(.Object@column_names_param$rot >= 0, "right", "left")
                        ))
        }
        .Object@column_names_param$anno = column_names_anno
    }

    #### dendrograms ########
    if(missing(cluster_rows) && !missing(row_order)) {
        cluster_rows = FALSE
    }
    if(is.logical(cluster_rows)) {
        if(!cluster_rows) {
            row_dend_width = unit(0, "mm")
            show_row_dend = FALSE
        }
        .Object@row_dend_param$cluster = cluster_rows
    } else if(inherits(cluster_rows, "dendrogram") || inherits(cluster_rows, "hclust")) {
        .Object@row_dend_param$obj = cluster_rows
        .Object@row_dend_param$cluster = TRUE
    } else if(inherits(cluster_rows, "function")) {
        .Object@row_dend_param$fun = cluster_rows
        .Object@row_dend_param$cluster = TRUE
    } else {
        oe = try(cluster_rows <- as.dendrogram(cluster_rows), silent = TRUE)
        if(!inherits(oe, "try-error")) {
            .Object@row_dend_param$obj = cluster_rows
            .Object@row_dend_param$cluster = TRUE
        } else {
            stop_wrap("`cluster_rows` should be a logical value, a clustering function or a clustering object.")
        }
    }
    if(!show_row_dend) {
        row_dend_width = unit(0, "mm")
    }
    .Object@row_dend_list = list()
    .Object@row_dend_param$distance = clustering_distance_rows
    .Object@row_dend_param$method = clustering_method_rows
    .Object@row_dend_param$side = match.arg(row_dend_side)[1]
    .Object@row_dend_param$width = row_dend_width + ht_opt$DENDROGRAM_PADDING  # append the gap
    .Object@row_dend_param$show = show_row_dend
    .Object@row_dend_param$gp = check_gp(row_dend_gp)
    .Object@row_dend_param$reorder = row_dend_reorder
    .Object@row_order_list = list() # default order
    if(is.null(row_order)) {
        .Object@row_order = seq_len(nrow(matrix))
    }  else {
        if(is.character(row_order)) {
            row_order = structure(seq_len(nrow(matrix)), names = rownames(matrix))[row_order]
        }
        if(any(is.na(row_order))) {
            stop_wrap("`row_order` should not contain NA values.")
        }
        if(length(row_order) != nrow(matrix)) {
            stop_wrap("length of `row_order` should be same as the number of marix rows.")
        }
        .Object@row_order = row_order
    }
    .Object@row_dend_param$cluster_slices = cluster_row_slices

    if(missing(cluster_columns) && !missing(column_order)) {
        cluster_columns = FALSE
    }
    if(is.logical(cluster_columns)) {
        if(!cluster_columns) {
            column_dend_height = unit(0, "mm")
            show_column_dend = FALSE
        }
        .Object@column_dend_param$cluster = cluster_columns
    } else if(inherits(cluster_columns, "dendrogram") || inherits(cluster_columns, "hclust")) {
        .Object@column_dend_param$obj = cluster_columns
        .Object@column_dend_param$cluster = TRUE
    } else if(inherits(cluster_columns, "function")) {
        .Object@column_dend_param$fun = cluster_columns
        .Object@column_dend_param$cluster = TRUE
    } else {
        oe = try(cluster_columns <- as.dendrogram(cluster_columns), silent = TRUE)
        if(!inherits(oe, "try-error")) {
            .Object@column_dend_param$obj = cluster_columns
            .Object@column_dend_param$cluster = TRUE
        } else {
            stop_wrap("`cluster_columns` should be a logical value, a clustering function or a clustering object.")
        }
    }
    if(!show_column_dend) {
        column_dend_height = unit(0, "mm")
    }
    .Object@column_dend_list = list()
    .Object@column_dend_param$distance = clustering_distance_columns
    .Object@column_dend_param$method = clustering_method_columns
    .Object@column_dend_param$side = match.arg(column_dend_side)[1]
    .Object@column_dend_param$height = column_dend_height + ht_opt$DENDROGRAM_PADDING  # append the gap
    .Object@column_dend_param$show = show_column_dend
    .Object@column_dend_param$gp = check_gp(column_dend_gp)
    .Object@column_dend_param$reorder = column_dend_reorder
    if(is.null(column_order)) {
        .Object@column_order = seq_len(ncol(matrix))
    } else {
        if(is.character(column_order)) {
            column_order = structure(seq_len(ncol(matrix)), names = colnames(matrix))[column_order]
        }
        if(any(is.na(column_order))) {
            stop_wrap("`column_order` should not contain NA values.")
        }
        if(length(column_order) != ncol(matrix)) {
            stop_wrap("length of `column_order` should be same as the number of marix columns")
        }
        .Object@column_order = column_order
    }
    .Object@column_dend_param$cluster_slices = cluster_column_slices

    ######### annotations #############
    .Object@top_annotation = top_annotation # a `HeatmapAnnotation` object
    if(is.null(top_annotation)) {
        .Object@top_annotation_param$height = unit(0, "mm")    
    } else {
        if(inherits(top_annotation, "AnnotationFunction")) {
            stop_wrap("The annotation function `anno_*()` should be put inside `HeatmapAnnotation()`.")
        }
        .Object@top_annotation_param$height = height(top_annotation) + ht_opt$COLUMN_ANNO_PADDING  # append the gap
    }
    if(!is.null(top_annotation)) {
        if(length(top_annotation) > 0) {
            if(!.Object@top_annotation@which == "column") {
                stop_wrap("`which` in `top_annotation` should only be `column`.")
            }
        }
        nb = nobs(top_annotation)
        if(!is.na(nb)) {
            if(nb != ncol(.Object@matrix)) {
                stop_wrap("number of observations in top annotation should be as same as ncol of the matrix.")
            }
        }
    }
    if(!is.null(top_annotation)) {
        validate_anno_names_with_matrix(matrix, top_annotation, "column")
    }
    
    .Object@bottom_annotation = bottom_annotation # a `HeatmapAnnotation` object
    if(is.null(bottom_annotation)) {
        .Object@bottom_annotation_param$height = unit(0, "mm")
    } else {
        if(inherits(bottom_annotation, "AnnotationFunction")) {
            stop_wrap("The annotation function `anno_*()` should be put inside `HeatmapAnnotation()`.")
        }
        .Object@bottom_annotation_param$height = height(bottom_annotation) + ht_opt$COLUMN_ANNO_PADDING  # append the gap
    }
    if(!is.null(bottom_annotation)) {
        if(length(bottom_annotation) > 0) {
            if(!.Object@bottom_annotation@which == "column") {
                stop_wrap("`which` in `bottom_annotation` should only be `column`.")
            }
        }
        nb = nobs(bottom_annotation)
        if(!is.na(nb)) {
            if(nb != ncol(.Object@matrix)) {
                stop_wrap("number of observations in bottom annotation should be as same as ncol of the matrix.")
            }
        }
    }
    if(!is.null(bottom_annotation)) {
        validate_anno_names_with_matrix(matrix, bottom_annotation, "column")
    }

    .Object@left_annotation = left_annotation # a `rowAnnotation` object
    if(is.null(left_annotation)) {
        .Object@left_annotation_param$width = unit(0, "mm")
    } else {
        if(inherits(left_annotation, "AnnotationFunction")) {
            stop_wrap("The annotation function `anno_*()` should be put inside `rowAnnotation()`.")
        }
        .Object@left_annotation_param$width = width(left_annotation) + ht_opt$ROW_ANNO_PADDING  # append the gap
    }
    if(!is.null(left_annotation)) {
        if(length(left_annotation) > 0) {
            if(!.Object@left_annotation@which == "row") {
                stop_wrap("`which` in `left_annotation` should only be `row`, or consider using `rowAnnotation()`.")
            }
        }
        nb = nobs(left_annotation)
        if(!is.na(nb)) {
            if(nb != nrow(.Object@matrix)) {
                stop_wrap("number of observations in left annotation should be same as nrow of the matrix.")
            }
        }
    }
    if(!is.null(left_annotation)) {
        validate_anno_names_with_matrix(matrix, left_annotation, "row")
    }

    .Object@right_annotation = right_annotation # a `rowAnnotation` object
    if(is.null(right_annotation)) {
        .Object@right_annotation_param$width = unit(0, "mm")
    } else {
        if(inherits(right_annotation, "AnnotationFunction")) {
            stop_wrap("The annotation function `anno_*()` should be put inside `rowAnnotation()`.")
        }
        .Object@right_annotation_param$width = width(right_annotation) + ht_opt$ROW_ANNO_PADDING  # append the gap
    }
    if(!is.null(right_annotation)) {
        if(length(right_annotation) > 0) {
            if(!.Object@right_annotation@which == "row") {
                stop_wrap("`which` in `right_annotation` should only be `row`, or consider using `rowAnnotation()`.")
            }
        }
        nb = nobs(right_annotation)
        if(!is.na(nb)) {
            if(nb != nrow(.Object@matrix)) {
                stop_wrap("number of observations in right annotation should be same as nrow of the matrix.")
            }
        }
    }
    if(!is.null(right_annotation)) {
        validate_anno_names_with_matrix(matrix, right_annotation, "row")
    }

    .Object@layout = list(
        layout_size = list(
            column_title_top_height = unit(0, "mm"),
            column_dend_top_height = unit(0, "mm"),
            column_anno_top_height = unit(0, "mm"),
            column_names_top_height = unit(0, "mm"),
            column_title_bottom_height = unit(0, "mm"),
            column_dend_bottom_height = unit(0, "mm"),
            column_anno_bottom_height = unit(0, "mm"),
            column_names_bottom_height = unit(0, "mm"),

            row_title_left_width = unit(0, "mm"),
            row_dend_left_width = unit(0, "mm"),
            row_names_left_width = unit(0, "mm"),
            row_dend_right_width = unit(0, "mm"),
            row_names_right_width = unit(0, "mm"),
            row_title_right_width = unit(0, "mm"),
            row_anno_left_width = unit(0, "mm"),
            row_anno_right_width = unit(0, "mm")
        ),

        layout_index = matrix(nrow = 0, ncol = 2),
        graphic_fun_list = list(),
        initialized = FALSE
    )

    if(is.null(width)) {
        width = unit(ncol(matrix), "null")
    } else if(is.numeric(width) && !inherits(width, "unit")) {
        width = unit(width, "null")
    } else if(!inherits(width, "unit")) {
        stop_wrap("`width` should be a `unit` object or a single number.")
    }

    if(is.null(height)) {
        height = unit(nrow(matrix), "null")
    } else if(is.numeric(height) && !inherits(height, "unit")) {
        height = unit(height, "null")
    } else if(!inherits(height, "unit")) {
        stop_wrap("`height` should be a `unit` object or a single number.")
    }

    if(!is.null(width) && !is.null(heatmap_width)) {
        if(is_abs_unit(width) && is_abs_unit(heatmap_width)) {
            stop_wrap("`heatmap_width` and `width` should not all be the absolute units.")
        }
    }
    if(!is.null(height) && !is.null(heatmap_height)) {
        if(is_abs_unit(height) && is_abs_unit(heatmap_height)) {
            stop_wrap("`heatmap_height` and `height` should not all be the absolute units.")
        }
    }

    if(is.null(use_raster)) {
        if(nrow(matrix) > 2000 && ncol(matrix) > 10) {
            use_raster = TRUE
            if(ht_opt$message) {
                message_wrap("`use_raster` is automatically set to TRUE for a matrix with more than 2000 rows. You can control `use_raster` argument by explicitly setting TRUE/FALSE to it.\n\nSet `ht_opt$message = FALSE` to turn off this message.")
            }
        } else if(ncol(matrix) > 2000 && nrow(matrix) > 10) {
            use_raster = TRUE
            if(ht_opt$message) {
                message_wrap("`use_raster` is automatically set to TRUE for a matrix with more than 2000 columns You can control `use_raster` argument by explicitly setting TRUE/FALSE to it.\n\nSet `ht_opt$message = FALSE` to turn off this message.")
            }
        } else {
            use_raster = FALSE
        }
    }

    if(use_raster) {
        if(missing(raster_by_magick)) {
            if(!raster_by_magick) {
                if(ht_opt$message) {
                    message_wrap("'magick' package is suggested to install to give better rasterization.\n\nSet `ht_opt$message = FALSE` to turn off this message.")
                }
            }
        }
    }
    
    .Object@matrix_param$width = width
    .Object@matrix_param$height = height

    .Object@heatmap_param$width = heatmap_width
    .Object@heatmap_param$height = heatmap_height
    .Object@heatmap_param$show_heatmap_legend = show_heatmap_legend
    .Object@heatmap_param$use_raster = use_raster

    if(missing(raster_device)) {
        if(requireNamespace("Cairo", quietly = TRUE)) {
            raster_device = "CairoPNG"
        } else {
            raster_device = "png"
        }
    } else {
        raster_device = match.arg(raster_device)[1]
    }
    .Object@heatmap_param$raster_device = raster_device
    .Object@heatmap_param$raster_quality = raster_quality
    .Object@heatmap_param$raster_device_param = raster_device_param
    .Object@heatmap_param$raster_resize_mat = raster_resize_mat
    .Object@heatmap_param$raster_by_magick = raster_by_magick
    .Object@heatmap_param$raster_magick_filter = raster_magick_filter
    .Object@heatmap_param$verbose = verbose
    .Object@heatmap_param$post_fun = post_fun
    .Object@heatmap_param$calling_env = parent.frame()
    .Object@heatmap_param$show_parent_dend_line = show_parent_dend_line

    if(nrow(matrix) == 0) {
        .Object@matrix_param$height = unit(0, "mm")
    }
    if(ncol(matrix) == 0) {
        .Object@matrix_param$width = unit(0, "mm")
    }

    return(.Object)

}


# == title
# Make Cluster on Rows
#
# == param
# -object A `Heatmap-class` object.
#
# == details
# The function will fill or adjust ``row_dend_list``, ``row_order_list``, ``row_title`` and ``matrix_param`` slots.
#
# If ``order`` is defined, no clustering will be applied.
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "make_row_cluster",
    signature = "Heatmap",
    definition = function(object) {

    object = make_cluster(object, "row")
    if(length(object@row_title) > 1) {
        if(length(object@row_title) != length(object@row_order_list)) {
            stop_wrap("If `row_title` is set with length > 1, the length should be as same as the number of row slices.")
        }
    }
    return(object)  
})

# == title
# Make Cluster on Columns
#
# == param
# -object A `Heatmap-class` object.
#
# == details
# The function will fill or adjust ``column_dend_list``,
# ``column_order_list``, ``column_title`` and ``matrix_param`` slots.
#
# If ``order`` is defined, no clustering will be applied.
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "make_column_cluster",
    signature = "Heatmap",
    definition = function(object) {

    object = make_cluster(object, "column")
    if(length(object@column_title) > 1) {
        if(length(object@column_title) != length(object@column_order_list)) {
            stop_wrap("If `column_title` is set with length > 1, the length should be as same as the number of column slices.")
        }
    }
    return(object)
})

make_cluster = function(object, which = c("row", "column")) {

    which = match.arg(which)[1]

    verbose = object@heatmap_param$verbose

    if(ht_opt("fast_hclust")) {
        hclust = fastcluster::hclust
        if(verbose) qqcat("apply hclust by fastcluster::hclust\n")
    } else {
        hclust = stats::hclust
    }

    mat = object@matrix
    jitter = object@matrix_param$jitter
    if(is.numeric(mat)) {
        if(is.logical(jitter)) {
            if(jitter) {
                mat = mat + runif(length(mat), min = 0, max = 1e-10)
            }
        } else {
            mat = mat + runif(length(mat), min = 0, max = jitter + 0)
        }
    }

    distance = slot(object, paste0(which, "_dend_param"))$distance
    method = slot(object, paste0(which, "_dend_param"))$method
    order = slot(object, paste0(which, "_order"))  # pre-defined row order
    km = getElement(object@matrix_param, paste0(which, "_km"))
    km_repeats = getElement(object@matrix_param, paste0(which, "_km_repeats"))
    split = getElement(object@matrix_param, paste0(which, "_split"))
    reorder = slot(object, paste0(which, "_dend_param"))$reorder
    cluster = slot(object, paste0(which, "_dend_param"))$cluster
    cluster_slices = slot(object, paste0(which, "_dend_param"))$cluster_slices
    gap = getElement(object@matrix_param, paste0(which, "_gap"))

    dend_param = slot(object, paste0(which, "_dend_param"))
    dend_list = slot(object, paste0(which, "_dend_list"))
    dend_slice = slot(object, paste0(which, "_dend_slice"))
    order_list = slot(object, paste0(which, "_order_list"))
    order = slot(object, paste0(which, "_order"))

    names_param = slot(object, paste0(which, "_names_param"))

    dend_param$split_by_cutree = FALSE

    if(!is.null(dend_param$obj)) {
        if(inherits(dend_param$obj, "hclust")) {
            ncl = length(dend_param$obj$order)
        } else {
            ncl = nobs(dend_param$obj)
        }

        if(which == "row") {
            if(ncl != nrow(mat)) {
                stop_wrap("The length of the row clustering object is not the same as the number of matrix rows.")
            }
        } else {
            if(ncl != ncol(mat)) {
                stop_wrap("The length of the column clustering object is not the same as the number of matrix columns")
            }
        }
    }

    if(cluster) {

        if(is.numeric(split) && length(split) == 1) {
            if(is.null(dend_param$obj)) {
                if(verbose) qqcat("split @{which}s by cutree, apply hclust on the entire @{which}s\n")
                if(which == "row") {
                    dend_param$obj = hclust(get_dist(mat, distance), method = method)
                } else {
                    dend_param$obj = hclust(get_dist(t(mat), distance), method = method)
                }
            }
        }

        if(!is.null(dend_param$obj)) {
            if(km > 1) {
                stop_wrap("You can not perform k-means clustering since you have already specified a clustering object.")
            }

            if(inherits(dend_param$obj, "hclust")) {
                dend_param$obj = as.dendrogram(dend_param$obj)
                if(verbose) qqcat("convert hclust object to dendrogram object\n")
            }

            if(is.null(split)) {
                dend_list = list(dend_param$obj)
                order_list = list(get_dend_order(dend_param$obj))
                if(verbose) qqcat("since you provided a clustering object and @{which}_split is null, the entrie clustering object is taken as an one-element list.\n")
            } else {
                if(length(split) > 1 || !is.numeric(split)) {
                    stop_wrap(qq("Since you specified a clustering object, you can only split @{which}s by providing a number (number of @{which} slices)."))
                }
                if(split < 2) {
                    stop_wrap(qq("`@{which}_split` should be >= 2."))
                }
                dend_param$split_by_cutree = TRUE
                
                ct = cut_dendrogram(dend_param$obj, split)
                dend_list = ct$lower
                dend_slice = ct$upper
                sth = tapply(order.dendrogram(dend_param$obj), 
                    rep(seq_along(dend_list), times = sapply(dend_list, nobs)), 
                    function(x) x, simplify = FALSE)
                attributes(sth) = NULL
                order_list = sth
                if(verbose) qqcat("cut @{which} dendrogram into @{split} slices.\n")
            }

            ### do reordering if specified
            if(identical(reorder, NULL)) {
                if(is.numeric(mat)) {
                    reorder = TRUE
                } else {
                    reorder = FALSE
                }
            }

            do_reorder = TRUE
            if(identical(reorder, NA) || identical(reorder, FALSE)) {
                do_reorder = FALSE
            }
            if(identical(reorder, TRUE)) {
                do_reorder = TRUE
                if(which == "row") {
                    reorder = -rowMeans(mat, na.rm = TRUE)
                } else {
                    reorder = -colMeans(mat, na.rm = TRUE)
                }
            }

            if(do_reorder) {

                if(which == "row") {
                    if(length(reorder) != nrow(mat)) {
                        stop_wrap("weight of reordering should have same length as number of rows.\n")
                    }
                } else {
                    if(length(reorder) != ncol(mat)) {
                        stop_wrap("weight of reordering should have same length as number of columns\n")
                    }
                }
                
                for(i in seq_along(dend_list)) {
                    if(length(order_list[[i]]) > 1) {
                        sub_ind = sort(order_list[[i]])
                        dend_list[[i]] = reorder(dend_list[[i]], reorder[sub_ind], mean)
                        # the order of object@row_dend_list[[i]] is the order corresponding to the big dendrogram
                        order_list[[i]] = order.dendrogram(dend_list[[i]])
                    }
                }
            }

            dend_list = lapply(dend_list, adjust_dend_by_x)

            slot(object, paste0(which, "_order")) = unlist(order_list)
            slot(object, paste0(which, "_order_list")) = order_list
            slot(object, paste0(which, "_dend_list")) = dend_list
            slot(object, paste0(which, "_dend_param")) = dend_param
            slot(object, paste0(which, "_dend_slice")) = dend_slice

            if(!is.null(split)) {
                if(is.null(attr(dend_list[[1]], ".class_label"))) {
                    split = data.frame(rep(seq_along(order_list), times = sapply(order_list, length)))
                } else {
                    split = data.frame(rep(sapply(dend_list, function(x) attr(x, ".class_label")), times = sapply(order_list, length)))
                }
                object@matrix_param[[ paste0(which, "_split") ]] = split

                # adjust row_names_param$gp if the length of some elements is the same as row slices
                for(i in seq_along(names_param$gp)) {
                    if(length(names_param$gp[[i]]) == length(order_list)) {
                        gp_temp = NULL
                        for(j in seq_along(order_list)) {
                            gp_temp[ order_list[[j]] ] = names_param$gp[[i]][j]
                        }
                        names_param$gp[[i]] = gp_temp
                    }
                }
                if(!is.null(names_param$anno)) {
                    names_param$anno@var_env$gp = names_param$gp
                }
                slot(object, paste0(which, "_names_param")) = names_param

                n_slice = length(order_list)
                if(length(gap) == 1) {
                    gap = rep(gap, n_slice)
                } else if(length(gap) == n_slice - 1) {
                    gap = unit.c(gap, unit(0, "mm"))
                } else if(length(gap) != n_slice) {
                    stop_wrap(qq("Length of `gap` should be 1 or number of @{which} slices."))
                }
                object@matrix_param[[ paste0(which, "_gap") ]] = gap # adjust title

                title = slot(object, paste0(which, "_title"))
                if(!is.null(split)) {
                    if(length(title) == 0 && !is.null(title)) { ## default title
                        title = apply(unique(split), 1, paste, collapse = ",")
                    } else if(length(title) == 1) {
                        if(grepl("%s", title)) {
                            title = apply(unique(split), 1, function(x) {
                                lt = lapply(x, function(x) x)
                                lt$fmt = title
                                do.call(sprintf, lt)
                            })
                        } else if(grepl("@\\{.+\\}", title)) {
                            title = apply(unique(split), 1, function(x) {
                                x = x
                                envir = environment()
                                title = get("title")
                                op = parent.env(envir)
                                calling_env = object@heatmap_param$calling_env
                                parent.env(envir) = calling_env
                                title = GetoptLong::qq(title, envir = envir)
                                parent.env(envir) = op
                                return(title)
                            })
                        } else if(grepl("\\{.+\\}", title)) {
                            if(!requireNamespace("glue")) {
                                stop_wrap("You need to install glue package.")
                            }
                            title = apply(unique(split), 1, function(x) {
                                x = x
                                envir = environment()
                                title = get("title")
                                op = parent.env(envir)
                                calling_env = object@heatmap_param$calling_env
                                parent.env(envir) = calling_env
                                title = glue::glue(title, envir = calling_env)
                                parent.env(envir) = op
                                return(title)
                            })
                        }
                    }
                }
                slot(object, paste0(which, "_title")) = title
            }
            return(object)
        }

    } else {
        if(verbose) qqcat("no clustering is applied/exists on @{which}s\n")
    }

    if(verbose) qq("clustering object is not pre-defined, clustering is applied to each @{which} slice\n")
    # make k-means clustering to add a split column
    consensus_kmeans = function(mat, centers, km_repeats) {
        partition_list = lapply(seq_len(km_repeats), function(i) {
            as.cl_hard_partition(kmeans(mat, centers, iter.max = 50))
        })
        partition_list = cl_ensemble(list = partition_list)
        partition_consensus = cl_consensus(partition_list)
        as.vector(cl_class_ids(partition_consensus)) 
    }
    if(km > 1 && is.numeric(mat)) {
        if(which == "row") {
            # km.fit = kmeans(mat, centers = km)
            # cl = km.fit$cluster
            cl = consensus_kmeans(mat, km, km_repeats)
            meanmat = lapply(sort(unique(cl)), function(i) {
                colMeans(mat[cl == i, , drop = FALSE], na.rm = TRUE)
            })
        } else {
            # km.fit = kmeans(t(mat), centers = km)
            # cl = km.fit$cluster
            cl = consensus_kmeans(t(mat), km, km_repeats)
            meanmat = lapply(sort(unique(cl)), function(i) {
                rowMeans(mat[, cl == i, drop = FALSE], na.rm = TRUE)
            })
        }

        meanmat = do.call("cbind", meanmat)
        # if `reorder` is a vector, the slice dendrogram is reordered by the mean of reorder in each slice
        # or else, weighted by the mean of `meanmat`.
        if(length(reorder) > 1) {
            weight = tapply(reorder, cl, mean)
        } else {
            weight = colMeans(meanmat)
        }
        if(cluster_slices) {
            hc = hclust(dist(t(meanmat)))
            hc = as.hclust(reorder(as.dendrogram(hc), weight, mean))
        } else {
            hc = list(order = order(weight))
        }

        cl2 = numeric(length(cl))
        for(i in seq_along(hc$order)) {
            cl2[cl == hc$order[i]] = i
        }
        cl2 = factor(cl2, levels = seq_along(hc$order))

        if(is.null(split)) {
            split = data.frame(cl2)
        } else if(is.matrix(split)) {
            split = as.data.frame(split)
            split = cbind(cl2, split)
        } else if(is.null(ncol(split))) {
            split = data.frame(cl2, split)
        } else {
            split = cbind(cl2, split)
        }
        if(verbose) qqcat("apply k-means (@{km} groups) on @{which}s, append to the `split` data frame\n")
            
    }

    # split the original order into a list according to split
    order_list = list()
    if(is.null(split)) {
        order_list[[1]] = order
    } else {

        if(verbose) cat("process `split` data frame\n")
        if(is.null(ncol(split))) split = data.frame(split)
        if(is.matrix(split)) split = as.data.frame(split)

        for(i in seq_len(ncol(split))) {
            if(is.numeric(split[[i]])) {
                split[[i]] = factor(as.character(split[[i]]), levels = as.character(sort(unique(split[[i]]))))
            } else if(!is.factor(split[[i]])) {
                split[[i]] = factor(split[[i]])
            } else {
                # re-factor
                split[[i]] = factor(split[[i]], levels = intersect(levels(split[[i]]), unique(split[[i]])))
            }
        }

        split_name = apply(as.matrix(split), 1, paste, collapse = ",")

        order2 = do.call("order", split)
        level = unique(split_name[order2])
        for(k in seq_along(level)) {
            l = split_name == level[k]
            order_list[[k]] = intersect(order, which(l))
        }
        names(order_list) = level
    }

    slice_od = seq_along(order_list)
    # make dend in each slice
    if(cluster) {
        if(verbose) qqcat("apply clustering on each slice (@{length(order_list)} slices)\n")
        dend_list = rep(list(NULL), length(order_list))
        for(i in seq_along(order_list)) {
            if(which == "row") {
                submat = mat[ order_list[[i]], , drop = FALSE]
            } else {
                submat = mat[, order_list[[i]], drop = FALSE]
            }
            nd = 0
            if(which == "row") nd = nrow(submat) else nd = ncol(submat)
            if(nd > 1) {
                if(!is.null(dend_param$fun)) {
                    if(which == "row") {
                        obj = dend_param$fun(submat)
                    } else {
                        obj = dend_param$fun(t(submat))
                    }
                    if(inherits(obj, "dendrogram") || inherits(obj, "hclust")) {
                        dend_list[[i]] = obj
                    } else {
                        oe = try(obj <- as.dendrogram(obj), silent = TRUE)
                        if(inherits(oe, "try-error")) {
                            stop_wrap("the clustering function must return a `dendrogram` object or a object that can be coerced to `dendrogram` class.")
                        }
                        dend_list[[i]] = obj
                    }
                    order_list[[i]] = order_list[[i]][ get_dend_order(dend_list[[i]]) ]
                } else {

                        if(which == "row") {
                            dend_list[[i]] = hclust(get_dist(submat, distance), method = method)
                        } else {
                            dend_list[[i]] = hclust(get_dist(t(submat), distance), method = method)
                        }
                        order_list[[i]] = order_list[[i]][ get_dend_order(dend_list[[i]]) ]
                    #}
                }
            } else {
                # a dendrogram with one leaf
                dend_list[[i]] = structure(1, members = 1, height = 0, leaf = TRUE, class = "dendrogram")
                order_list[[i]] = order_list[[i]][1]
            }
        }
        names(dend_list) = names(order_list)

        for(i in seq_along(dend_list)) {
            if(inherits(dend_list[[i]], "hclust")) {
                dend_list[[i]] = as.dendrogram(dend_list[[i]])
            }
        }

        if(identical(reorder, NULL)) {
            if(is.numeric(mat)) {
                reorder = TRUE
            } else {
                reorder = FALSE
            }
        }

        do_reorder = TRUE
        if(identical(reorder, NA) || identical(reorder, FALSE)) {
            do_reorder = FALSE
        }
        if(identical(reorder, TRUE)) {
            do_reorder = TRUE
            if(which == "row") {
                reorder = -rowMeans(mat, na.rm = TRUE)
            } else {
                reorder = -colMeans(mat, na.rm = TRUE)
            }
        }

        if(do_reorder) {

            if(which == "row") {
                if(length(reorder) != nrow(mat)) {
                    stop_wrap("weight of reordering should have same length as number of rows\n")
                }
            } else {
                if(length(reorder) != ncol(mat)) {
                    stop_wrap("weight of reordering should have same length as number of columns\n")
                }
            }
            for(i in seq_along(dend_list)) {
                if(length(order_list[[i]]) > 1) {
                    sub_ind = sort(order_list[[i]])
                    dend_list[[i]] = reorder(dend_list[[i]], reorder[sub_ind], mean)
                    order_list[[i]] = sub_ind[ order.dendrogram(dend_list[[i]]) ]
                }
            }
            if(verbose) qqcat("reorder dendrograms in each @{which} slice\n")
        }

        if(length(order_list) > 1 && cluster_slices) {
            if(which == "row") {
                slice_mean = sapply(order_list, function(ind) colMeans(mat[ind, , drop = FALSE], na.rm = TRUE))
            } else {
                slice_mean = sapply(order_list, function(ind) rowMeans(mat[, ind, drop = FALSE], na.rm = TRUE))
            }
            if(!is.matrix(slice_mean)) {
                slice_mean = matrix(slice_mean, nrow = 1)
            }
            dend_slice = as.dendrogram(hclust(dist(t(slice_mean))))
            dend_slice = reorder(dend_slice, slice_mean, mean)
            if(verbose) qqcat("perform clustering on mean of @{which} slices\n")

            slice_od = order.dendrogram(dend_slice)
            order_list = order_list[slice_od]
            dend_list = dend_list[slice_od]
        }
    }

    dend_list = lapply(dend_list, adjust_dend_by_x)

    slot(object, paste0(which, "_order")) = unlist(order_list)
    slot(object, paste0(which, "_order_list")) = order_list
    slot(object, paste0(which, "_dend_list")) = dend_list
    slot(object, paste0(which, "_dend_param")) = dend_param
    slot(object, paste0(which, "_dend_slice")) = dend_slice
    object@matrix_param[[ paste0(which, "_split") ]] = split

    if(which == "row") {
        if(nrow(mat) != length(order)) {
            stop_wrap(qq("Number of rows in the matrix are not the same as the length of the cluster or the @{which} orders."))
        }
    } else {
        if(ncol(mat) != length(order)) {
            stop_wrap(qq("Number of columns in the matrix are not the same as the length of the cluster or the @{which} orders."))
        }
    }

    # adjust names_param$gp if the length of some elements is the same as slices
    for(i in seq_along(names_param$gp)) {
        if(length(names_param$gp[[i]]) == length(order_list)) {
            gp_temp = NULL
            for(j in seq_along(order_list)) {
                gp_temp[ order_list[[j]] ] = names_param$gp[[i]][j]
            }
            names_param$gp[[i]] = gp_temp   
        }
    }
    if(!is.null(names_param$anno)) {
        names_param$anno@var_env$gp = names_param$gp
    }
    slot(object, paste0(which, "_names_param")) = names_param

    n_slice = length(order_list)
    if(length(gap) == 1) {
        gap = rep(gap, n_slice)
    } else if(length(gap) == n_slice - 1) {
        gap = unit.c(gap, unit(0, "mm"))
    } else if(length(gap) != n_slice) {
        stop_wrap(qq("Length of `gap` should be 1 or number of @{which} slices."))
    }
    object@matrix_param[[ paste0(which, "_gap") ]] = gap

    # adjust title
    title = slot(object, paste0(which, "_title"))
    if(!is.null(split)) {
        if(length(title) == 0 && !is.null(title)) { ## default title
            title = names(order_list)
        } else if(length(title) == 1) {
            if(grepl("%s", title)) {
                title = apply(unique(split[order2, , drop = FALSE]), 1, function(x) {
                    lt = lapply(x, function(x) x)
                    lt$fmt = title
                    do.call(sprintf, lt)
                })[slice_od]
            } else if(grepl("@\\{.+\\}", title)) {
                title = apply(unique(split[order2, , drop = FALSE]), 1, function(x) {
                    x = x
                    envir = environment()
                    title = get("title")
                    op = parent.env(envir)
                    calling_env = object@heatmap_param$calling_env
                    parent.env(envir) = calling_env
                    title = GetoptLong::qq(title, envir = envir)
                    parent.env(envir) = op
                    return(title)
                })[slice_od]
            } else if(grepl("\\{.+\\}", title)) {
                if(!requireNamespace("glue")) {
                    stop_wrap("You need to install glue package.")
                }
                title = apply(unique(split[order2, , drop = FALSE]), 1, function(x) {
                    x = x
                    envir = environment()
                    title = get("title")
                    op = parent.env(envir)
                    calling_env = object@heatmap_param$calling_env
                    parent.env(envir) = calling_env
                    title = glue::glue(title, envir = calling_env)
                    parent.env(envir) = op
                    return(title)
                })[slice_od]
            }
        }
    }
    slot(object, paste0(which, "_title")) = title
    # check whether height of the dendrogram is zero
    # if(all(sapply(dend_list, dend_heights) == 0)) {
    #     slot(object, paste0(which, "_dend_param"))$show = FALSE
    # }
    return(object)

}

# == title
# Draw a Single Heatmap
#
# == param
# -object A `Heatmap-class` object.
# -internal If ``TRUE``, it is only used inside the calling of `draw,HeatmapList-method`. 
#           It only draws the heatmap without legends where the legend will be drawn by `draw,HeatmapList-method`. 
# -test Only for testing. If it is ``TRUE``, the heatmap body is directly drawn.
# -... Pass to `draw,HeatmapList-method`.
#
# == detail
# The function creates a `HeatmapList-class` object which only contains a single heatmap
# and call `draw,HeatmapList-method` to make the final heatmap.
#
# There are some arguments which control the some settings of the heatmap such as legends.
# Please go to `draw,HeatmapList-method` for these arguments.
#
# == value
# A `HeatmapList-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw",
    signature = "Heatmap",
    definition = function(object, internal = FALSE, test = FALSE, ...) {

    if(test) {
        object = prepare(object)
        grid.newpage()
        if(is_abs_unit(object@heatmap_param$width)) {
            width = object@heatmap_param$width
        } else {
            width = 0.8
        }
        if(is_abs_unit(object@heatmap_param$height)) {
            height = object@heatmap_param$height
        } else {
            height = 0.8
        }
        pushViewport(viewport(width = width, height = height))
        draw(object, internal = TRUE)
        upViewport()
    } else {
        if(internal) {  # a heatmap without legend
            # if(ncol(object@matrix) == 0 || nrow(object@matrix) == 0) return(invisible(NULL))
            if(nrow(object@layout$layout_index) == 0) return(invisible(NULL))
            layout = grid.layout(nrow = length(HEATMAP_LAYOUT_COLUMN_COMPONENT), 
                ncol = length(HEATMAP_LAYOUT_ROW_COMPONENT), widths = component_width(object), 
                heights = component_height(object))
            pushViewport(viewport(layout = layout))
            ht_layout_index = object@layout$layout_index
            ht_graphic_fun_list = object@layout$graphic_fun_list
            for(j in seq_len(nrow(ht_layout_index))) {
                if(HEATMAP_LAYOUT_COLUMN_COMPONENT["heatmap_body"] %in% ht_layout_index[j, 1] && 
                   HEATMAP_LAYOUT_ROW_COMPONENT["heatmap_body"] %in% ht_layout_index[j, 2]) {
                    pushViewport(viewport(layout.pos.row = ht_layout_index[j, 1], layout.pos.col = ht_layout_index[j, 2], name = paste(object@name, "heatmap_body_wrap", sep = "_")))
                } else {
                    pushViewport(viewport(layout.pos.row = ht_layout_index[j, 1], layout.pos.col = ht_layout_index[j, 2]))
                }
                ht_graphic_fun_list[[j]](object)
                upViewport()
            }
            upViewport()
        } else {
            # if(ncol(object@matrix) == 0) {
            #     stop_wrap("Single heatmap should contains a matrix with at least one column. Zero-column matrix can only be appended to the heatmap list.")
            # }
            ht_list = new("HeatmapList")
            ht_list = add_heatmap(ht_list, object)
            draw(ht_list, ...)
        }
    }
})

# == title
# Prepare the Heatmap
#
# == param
# -object A `Heatmap-class` object.
# -process_rows Whether to process rows of the heatmap.
# -process_columns Whether to process columns of the heatmap.
#
# == detail
# The preparation of the heatmap includes following steps:
#
# - making clustering on rows (by calling `make_row_cluster,Heatmap-method`)
# - making clustering on columns (by calling `make_column_cluster,Heatmap-method`)
# - making the layout of the heatmap (by calling `make_layout,Heatmap-method`)
#
# This function is only for internal use.
#
# == value
# The `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "prepare",
    signature = "Heatmap",
    definition = function(object, process_rows = TRUE, process_columns = TRUE) {

    if(object@layout$initialized) {
        return(object)
    }
    
    if(process_rows) {
        object = make_row_cluster(object)
    }
    if(process_columns) {
        object = make_column_cluster(object)
    }

    object = make_layout(object)
    return(object)

})
jokergoo/ComplexHeatmap documentation built on Nov. 17, 2023, 11:27 a.m.