R/ggbrain_slices.R

#' R6 class for managing slice data for ggbrain plots
#' @importFrom dplyr bind_rows full_join group_by across mutate summarize filter reframe
#' @importFrom tidyr pivot_wider pivot_longer unnest
#' @importFrom tibble tibble
#' @importFrom tidyselect matches
#' @importFrom data.table melt rbindlist setDF :=
#' @author Michael Hallquist
#' @keywords internal
ggbrain_slices <- R6::R6Class(
  classname="ggbrain_slices",
  private = list(
    pvt_slice_index = NULL,
    pvt_coord_input = NULL,
    pvt_coord_label = NULL,
    pvt_plane = NULL,
    pvt_slice_number = NULL,
    pvt_slice_data = list(),
    pvt_slice_matrix = list(),
    pvt_slice_labels = list(),
    pvt_is_contrast = NULL, # denotes whether each layer in slice_data is a contrast or an image
    pvt_layer_names = NULL, # names of layers/images within each slice
    pvt_contrast_definitions = NULL,

    # helper function to combine slice image and contrast data
    # TODO: sort out how to avoid calculating this twice for labels and numeric values
    get_combined_data = function(slice_indices = NULL, only_labeled=FALSE) {
      if (is.null(slice_indices)) {
        slice_indices <- private$pvt_slice_index
      } else {
        checkmate::assert_integerish(slice_indices, lower = 1, upper = max(private$pvt_slice_index), unique = TRUE)
      }

      # calculate overall ranges across slices for unified scales
      img_slice <- private$pvt_slice_data[slice_indices]
      if (isTRUE(only_labeled)) {
        # which layers are labeled
        has_labels <- sapply(private$pvt_slice_data[[1]], function(x) !is.null(attr(x, "label_columns")))
        stopifnot(sum(has_labels) > 0L) # would be a problem

        # subset img_slice to only labeled layers
        img_slice <- lapply(img_slice, function(ll) ll[has_labels])
      }

      # ensure that we still have data to return
      if (any(sapply(img_slice, length) == 0L)) return(NULL)

      nslices <- length(img_slice)
      nlayers <- length(img_slice[[1L]]) # assumes, rightly, that all slices have the same layers
      layer_names <- names(img_slice[[1L]])

      img_data <- lapply(seq_len(nlayers), function(ll) {
        label_columns <- attr(img_slice[[1L]][[ll]], "label_columns")
        ll_df <- lapply(img_slice, "[[", ll) %>%
          data.table::rbindlist(idcol="slice_index")

        # if label columns are present, gather them into a single key-value pair
        if (!is.null(label_columns)) {
          ll_df <- data.table::melt(ll_df, measure.vars=label_columns, variable.name=".label_col", value.name=".label_val")
        }

        ll_df[, layer := layer_names[ll]]
        return(ll_df)
      }) %>%
      data.table::rbindlist(fill=TRUE) %>%
      data.table::setDF()

      return(img_data)
    }
  ),

  # these active bindings create read-only access to class properties
  active = list(
    #' @field slice_index read-only access to the slice_index containing the slice numbers
    slice_index = function(value) {
      if (missing(value)) private$pvt_slice_index
      else stop("Cannot assign slice_index")
    },
    #' @field coord_input the input string used to lookup the slices
    coord_input = function(value) {
      if (missing(value)) private$pvt_coord_input
      else stop("Cannot assign coord_input")
    },
    #' @field coord_label the calculated x, y, or z coordinate of the relevant slice
    coord_label = function(value) {
      if (missing(value)) private$pvt_coord_label
      else stop("Cannot assign coord_label")
    },
    #' @field slice_number the slice number along the relevant axis of the 3D image matrix
    slice_number = function(value) {
      if (missing(value)) private$pvt_slice_number
      else stop("Cannot assign slice_number")
    },
    #' @field slice_data a nested list of data.frames where each element contains all data relevant to
    #'   that slice and the list elements within are each a given image
    slice_data = function(value) {
      if (missing(value)) private$pvt_slice_data
      else stop("Cannot assign slice_data")
    },

    #' @field slice_matrix the slice data in matrix form
    slice_matrix = function(value) {
      if (missing(value)) private$pvt_slice_matrix
      else stop("Cannot assign slice_matrix")
    },

    #' @field slice_labels a data.frame for each slice containing the coordinates of labels available to be drawn
    slice_labels = function(value) {
      if (missing(value)) private$pvt_slice_labels
      else stop("Cannot assign slice_labels")
    },

    #' @field layer_names a character vector of layer names within each slice
    layer_names = function(value) {
      if (missing(value)) private$pvt_layer_names
      else stop("Cannot assign layer_names")
    }
  ),
  public = list(
    #' @description create a ggbrain_slices object based
    #' @param slice_df a data.frame generated by ggbrain_images$get_slices()
    #' @details If this becomes a user-facing/exported class, we may want a more friendly constructor
    initialize = function(slice_df = NULL) {
      checkmate::assert_data_frame(slice_df)

      # ensure that layers within each slice match
      nm <- lapply(slice_df$slice_data, function(el) names(el))
      all_match <- all(sapply(nm, function(x) identical(x, nm[[1]])))
      if (!all_match) stop("Names of layers in $slice_data are not identical")

      # all slice_data provided at $initialize are treated as primary image data, not contrasts
      private$pvt_layer_names <- nm[[1]]
      private$pvt_is_contrast <- rep(FALSE, length(nm[[1]])) %>% setNames(private$pvt_layer_names)

      # empty lists for populating unused fields -- match length of slice_df for consistency
      empty_list <- lapply(seq_len(nrow(slice_df)), function(i) list())

      df_names <- names(slice_df)
      private$pvt_coord_input <- if ("coord_input" %in% df_names) slice_df$coord_input else empty_list
      private$pvt_coord_label <- if ("coord_label" %in% df_names) slice_df$coord_label else empty_list
      private$pvt_plane <- if ("plane" %in% df_names) slice_df$plane else empty_list
      private$pvt_slice_index <- if ("slice_index" %in% df_names) slice_df$slice_index else seq_along(df_names)
      private$pvt_slice_number <- if ("slice_number" %in% df_names) slice_df$slice_number else empty_list
      private$pvt_slice_data <- if ("slice_data" %in% df_names) slice_df$slice_data else empty_list
      private$pvt_slice_labels <- if ("slice_labels" %in% df_names) slice_df$slice_labels else empty_list
      private$pvt_slice_matrix <- if ("slice_matrix" %in% df_names) slice_df$slice_matrix else empty_list
    },

    #' @description computes contrasts of the sliced image data
    #' @param contrast_list a named list or character vector containing contrasts to be computed.
    #'   The names of the list form the contrast names, while the values should be character strings
    #'   that use standard R syntax for logical tests, subsetting, and arithmetic
    compute_contrasts = function(contrast_list=NULL) {
      if (is.null(contrast_list)) return(self) # skip out if no contrasts to compute

      if (checkmate::test_class(contrast_list, "character")) {
        contrast_list <- as.list(contrast_list) # tolerate named character vector input
      }

      # force unique names of input contrasts
      checkmate::assert_list(contrast_list, names = "unique")
      if (length(private$pvt_slice_data) == 0L) {
        stop("Cannot use $compute_contrasts() if there are no slice_data in the object")
      }

      # quietly ensure that we are not recomputing the same contrast
      for (cc in seq_along(contrast_list)) {
        nm_match <- which(names(private$pvt_contrast_definitions) == names(contrast_list)[cc])
        if (length(nm_match) == 1L && trimws(private$pvt_contrast_definitions[[nm_match]]) == trimws(contrast_list[[cc]])) {
          # identical contrast with the same name -- no need to do anything further with this one
          contrast_list[[cc]] <- NULL
        }
      }

      if (length(contrast_list) == 0L) return(self) # skip out if no contrasts to compute

      # check overlap with non-contrast data
      img_overlap <- intersect(names(contrast_list), private$pvt_layer_names[!private$pvt_is_contrast])
      if (length(img_overlap) > 0L) {
        warning(
          "The following contrast(s) overlap with the primary $slice_data (from images): ", paste(img_overlap, collapse = ", "), ".",
          " This will overwrite the original data with the contrast."
        )
      }

      # check overlap with contrast data
      con_overlap <- intersect(names(contrast_list), private$pvt_layer_names[private$pvt_is_contrast])
      if (length(con_overlap) > 0L) {
        warning("Existing contrast data will be replaced for the following contrasts: ", paste(con_overlap, collapse = ", "))
      }

      # convert slice data to wide format to allow contrasts to be parsed
      wide <- lapply(private$pvt_slice_data, function(slc_xx) {
        # add image name as prefix to value and labeled columns to allow for cbind
        slc_xx <- lapply(slc_xx, function(img_ii) {
          img_name <- img_ii$image[1L]
          nmr <- match(c("value", attr(img_ii, "label_columns")), names(img_ii))
          new_names <- paste(img_name, names(img_ii)[nmr], sep=".")
          names(img_ii)[nmr] <- new_names
          img_ii <- img_ii[, c("dim1", "dim2", new_names)] # subset to only key columns
          attr(img_ii, "image") <- img_name
          return(img_ii)
        })

        # recursively column bind result, omitting redundant dim columnns from second data.frame
        cbind_attr <- function(x1, x2) {
          # x1 %>% dplyr::bind_cols(subset(x2, select=c(-dim1, -dim2))) # pretty, but slower
          # x1 %>% dplyr::bind_cols(x2[, c(-1,-2)]) # faster, but riskier
          y <- x1 %>% dplyr::bind_cols(x2[, -match(c("dim1", "dim2"), names(x2)), drop=FALSE])
          attr(y, "image") <- c(attr(x1, "image"), attr(x2, "image")) # pass through image names for disambiguating columns
          return(y)
        }

        ss <- Reduce(cbind_attr, slc_xx)

        # safer to inner join instead of cbind, but this is unnecessary compute given that they all pass through mat2df
        # about 8x slower than bind_cols
        # c_df <- Reduce(function(x, y) inner_join(x, y, by = c("dim1", "dim2")), ff)

        return(ss)
      })

      # loop over slices in the wide structure
      for (ww in seq_along(wide)) {
        c_data <- lapply(seq_along(contrast_list), function(cc) {
          df <- contrast_parser(contrast_list[[cc]], data = wide[[ww]]) %>%
            mutate(image = names(contrast_list)[cc]) # tag contrasts with a label column

          # if user passes a simple subset operation, keep all other columns from original image in contrast
          # this helps preserve labels when we use a subsetting operation
          # merge on value as well so that we only add labels for subset values retained in the contrast
          if (!is.null(attr(df, "img_source"))) {
            src_df <- private$pvt_slice_data[[ww]][[attr(df, "img_source")]] %>% dplyr::select(-image)
            df <- df %>%
              dplyr::left_join(src_df, by = c("dim1", "dim2", "value"))
            attr(df, "label_columns") <- attr(src_df, "label_columns") # copy through label columns for subset contrast
          }

          return(df)
        }) %>% setNames(names(contrast_list))

        private$pvt_slice_data[[ww]][names(c_data)] <- c_data # update/set relevant elements of slice data
      }

      private$pvt_layer_names <- names(private$pvt_slice_data[[1]]) # update object with new names
      private$pvt_is_contrast[names(contrast_list)] <- TRUE
      private$pvt_contrast_definitions[names(contrast_list)] <- contrast_list # keep track of definitions
      return(self)
    },

    #' @description convert the slices object into a data.frame with list-columns for slice data elements
    as_tibble = function() {
      tb <- tibble::tibble(
        coord_input=private$pvt_coord_input,
        coord_label=private$pvt_coord_label,
        plane=private$pvt_plane,
        slice_index=private$pvt_slice_index,
        slice_number=private$pvt_plane,
        slice_data=private$pvt_slice_data,
        slice_labels=private$pvt_slice_labels,
        slice_matrix=private$pvt_slice_matrix
      )

      attr(tb, "layer_names") <- private$pvt_layer_names
      attr(tb, "is_contrast") <- private$pvt_is_contrast
      return(tb)
    },

    #' @description calculates the numeric ranges of each image/contrast in this object, across all
    #'   constituent slices. This is useful for setting scale limits that are shared across panels
    #' @param slice_indices an optional integer vector of slice indices to be used as a subset in the calculation
    #' @return a tibble keyed by 'layer' with overall low and high values, as well as split by pos/neg
    get_ranges = function(slice_indices = NULL) {
      img_data <- private$get_combined_data(slice_indices)

      img_ranges <- img_data %>%
        dplyr::group_by(layer) %>%
        dplyr::summarize(low = min(value, na.rm = TRUE), high = max(value, na.rm = TRUE), .groups = "drop")

      # for bisided layers, we need pos and neg ranges -- N.B. this does not support arbitrary cutpoints!
      img_ranges_posneg <- img_data %>%
        dplyr::filter(value > 2 * .Machine$double.eps | value < -2 * .Machine$double.eps) %>% # filter exact zeros so that we get true > and <
        dplyr::mutate(above_zero = factor(value > 0, levels = c(TRUE, FALSE), labels = c("pos", "neg"))) %>%
        dplyr::group_by(layer, above_zero, .drop=FALSE) %>%
        dplyr::summarize(
          low = suppressWarnings(min(value, na.rm = TRUE)),
          high = suppressWarnings(max(value, na.rm = TRUE)), .groups = "drop") %>%
        tidyr::pivot_wider(id_cols="layer", names_from="above_zero", values_from=c(low, high))

      # join the overall ranges with the pos/neg split
      img_ranges <- img_ranges %>%
        dplyr::full_join(img_ranges_posneg, by = "layer") %>%
        dplyr::mutate(across(matches("low|high"), ~ if_else(is.infinite(.x), NA_real_, .x))) # set Inf to NA

      return(img_ranges)

    },

    #' @description returns a data.frame with the unique values for each label layer, across all
    #'   constituent slices
    #' @param slice_indices an optional integer vector of slice indices to be used as a subset in the calculation
    #' @param add_labels an optional named list indicating the label columns to add to a given layer. The names
    #'   of the list specify the layer to which labels are added and the values should be character vectors
    #'   specifying the names of columns that serve as labels for the layer. These are always *added* to any
    #'   existing label columns for the layer.
    #' @details
    #'   \code{add_labels} is provided here for any categorical columns that were specified inline in the
    #'   layer definition using factor or as.factor, such as aes(fill=factor(value)). Otherwise, it's best
    #'   to input data with labels in the first place so that labels are described in the data structure itself.
    get_uvals = function(slice_indices = NULL, add_labels = NULL) {
      # handle late-breaking labels (coming through geom_ layer aes specification)
      if (!is.null(add_labels)) {
        checkmate::assert_list(add_labels, names = "unique")
        checkmate::assert_subset(names(add_labels), private$pvt_layer_names)
        for (ii in seq_along(private$pvt_slice_data)) {
          for (jj in seq_along(add_labels)) {
            lname <- names(add_labels)[jj]
            # verify that label columns exist
            stopifnot(all(add_labels[[jj]] %in% names(private$pvt_slice_data[[ii]][[lname]])))

            attr(private$pvt_slice_data[[ii]][[lname]], "label_columns") <-
              union(attr(private$pvt_slice_data[[ii]][[lname]], "label_columns"), add_labels[[jj]])
          }
        }
      }

      # examine first slice to see if any layers have labels (reasonably assumes all slices have same layers)
      has_labels <- sapply(private$pvt_slice_data[[1]], function(x) !is.null(attr(x, "label_columns")))
      if (!any(has_labels)) {
        return(data.frame()) # return empty data.frame
      } else {
        img_data <- private$get_combined_data(slice_indices, only_labeled = TRUE)
        img_uvals <- img_data %>%
          dplyr::group_by(layer, .label_col) %>%
          dplyr::reframe(uvals = sort(unique(.label_val))) %>%
          na.omit()
      }

      return(img_uvals)
    }
  )
)
michaelhallquist/ggbrain documentation built on June 11, 2025, 12:47 a.m.